/*
 * Decompiled with CFR 0.152.
 */
package de.unima.ki.anyburl.structure;

import de.unima.ki.anyburl.Settings;
import de.unima.ki.anyburl.data.Triple;
import de.unima.ki.anyburl.data.TripleSet;
import de.unima.ki.anyburl.structure.Atom;
import de.unima.ki.anyburl.structure.Rule;
import de.unima.ki.anyburl.structure.RuleUntyped;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;

public abstract class RuleAcyclic
extends Rule {
    public RuleAcyclic(RuleUntyped r) {
        super(r);
    }

    @Override
    public HashSet<String> computeTailResults(String head, TripleSet ts) {
        HashSet<String> resultSet = new HashSet<String>();
        if (this.isXRule()) {
            if (this.head.getRight().equals(head)) {
                return resultSet;
            }
            HashSet<String> previousValues = new HashSet<String>();
            previousValues.add(head);
            previousValues.add(this.head.getRight());
            if (this.isBodyTrueAcyclic("X", head, 0, previousValues, ts)) {
                resultSet.add(this.head.getRight());
                return resultSet;
            }
        } else if (this.head.getLeft().equals(head)) {
            this.computeValuesReversed("Y", resultSet, ts);
            return resultSet;
        }
        return resultSet;
    }

    @Override
    public HashSet<String> computeHeadResults(String tail, TripleSet ts) {
        HashSet<String> resultSet = new HashSet<String>();
        if (this.isYRule()) {
            if (this.head.getLeft().equals(tail)) {
                return resultSet;
            }
            HashSet<String> previousValues = new HashSet<String>();
            previousValues.add(tail);
            previousValues.add(this.head.getLeft());
            if (this.isBodyTrueAcyclic("Y", tail, 0, previousValues, ts)) {
                resultSet.add(this.head.getLeft());
                return resultSet;
            }
        } else if (this.isXRule() && this.head.getRight().equals(tail)) {
            this.computeValuesReversed("X", resultSet, ts);
            return resultSet;
        }
        return resultSet;
    }

    @Override
    public void computeScores(TripleSet triples) {
        if (this.isXRule()) {
            HashSet<String> xvalues = new HashSet<String>();
            this.computeValuesReversed("X", xvalues, triples);
            int predicted = 0;
            int correctlyPredicted = 0;
            for (String xvalue : xvalues) {
                ++predicted;
                if (!triples.isTrue(xvalue, this.head.getRelation(), this.head.getRight())) continue;
                ++correctlyPredicted;
            }
            this.predicted = predicted;
            this.correctlyPredicted = correctlyPredicted;
            this.confidence = (double)correctlyPredicted / (double)predicted;
        } else {
            HashSet<String> yvalues = new HashSet<String>();
            this.computeValuesReversed("Y", yvalues, triples);
            int predicted = 0;
            int correctlyPredicted = 0;
            for (String yvalue : yvalues) {
                ++predicted;
                if (!triples.isTrue(this.head.getLeft(), this.head.getRelation(), yvalue)) continue;
                ++correctlyPredicted;
            }
            this.predicted = predicted;
            this.correctlyPredicted = correctlyPredicted;
            this.confidence = (double)correctlyPredicted / (double)predicted;
        }
    }

    @Override
    public int[] computeScores(Rule that, TripleSet triples) {
        int[] scores = new int[2];
        int predictedBoth = 0;
        int correctlyPredictedBoth = 0;
        if (this.isXRule()) {
            HashSet<String> xvalues = new HashSet<String>();
            String yvalue = this.getHead().getRight();
            this.computeValuesReversed("X", xvalues, triples);
            for (String xvalue : xvalues) {
                HashSet<Triple> explanation = that.getTripleExplanation(xvalue, yvalue, new HashSet<Triple>(), triples);
                if (explanation == null || explanation.size() <= 0) continue;
                ++predictedBoth;
                if (!triples.isTrue(xvalue, this.head.getRelation(), yvalue)) continue;
                ++correctlyPredictedBoth;
            }
        } else {
            HashSet<String> yvalues = new HashSet<String>();
            String xvalue = this.getHead().getLeft();
            this.computeValuesReversed("Y", yvalues, triples);
            for (String yvalue : yvalues) {
                HashSet<Triple> explanation = that.getTripleExplanation(xvalue, yvalue, new HashSet<Triple>(), triples);
                if (explanation == null || explanation.size() <= 0) continue;
                ++predictedBoth;
                if (!triples.isTrue(xvalue, this.head.getRelation(), yvalue)) continue;
                ++correctlyPredictedBoth;
            }
        }
        scores[0] = predictedBoth;
        scores[1] = correctlyPredictedBoth;
        return scores;
    }

    @Override
    public boolean isPredictedX(String leftValue, String rightValue, Triple forbidden, TripleSet ts) {
        if (forbidden == null) {
            if (this.isXRule()) {
                HashSet<String> previousValues = new HashSet<String>();
                previousValues.add(leftValue);
                return this.isBodyTrueAcyclic("X", leftValue, 0, previousValues, ts);
            }
            HashSet<String> previousValues = new HashSet<String>();
            previousValues.add(rightValue);
            return this.isBodyTrueAcyclic("Y", rightValue, 0, previousValues, ts);
        }
        if (this.isXRule()) {
            HashSet<String> previousValues = new HashSet<String>();
            previousValues.add(leftValue);
            return this.isBodyTrueAcyclicX("X", leftValue, 0, forbidden, previousValues, ts);
        }
        HashSet<String> previousValues = new HashSet<String>();
        previousValues.add(rightValue);
        return this.isBodyTrueAcyclicX("Y", rightValue, 0, forbidden, previousValues, ts);
    }

    protected boolean isBodyTrueAcyclic(String variable, String value, int bodyIndex, HashSet<String> previousValues, TripleSet triples) {
        Atom atom = this.body.get(bodyIndex);
        boolean headNotTail = atom.getLeft().equals(variable);
        if (this.body.size() - 1 == bodyIndex) {
            boolean constant;
            boolean bl = constant = headNotTail ? atom.isRightC() : atom.isLeftC();
            if (constant) {
                String constantValue;
                String string = constantValue = headNotTail ? atom.getRight() : atom.getLeft();
                if (previousValues.contains(constantValue) && !constantValue.equals(this.head.getConstant())) {
                    return false;
                }
                if (headNotTail) {
                    return triples.isTrue(value, atom.getRelation(), constantValue);
                }
                return triples.isTrue(constantValue, atom.getRelation(), value);
            }
            Set<String> results = triples.getEntities(atom.getRelation(), value, headNotTail);
            for (String r : results) {
                if (previousValues.contains(r)) continue;
                return true;
            }
            return false;
        }
        Set<String> results = triples.getEntities(atom.getRelation(), value, headNotTail);
        String nextVariable = headNotTail ? atom.getRight() : atom.getLeft();
        for (String nextValue : results) {
            if (previousValues.contains(nextValue)) continue;
            previousValues.add(nextValue);
            if (this.isBodyTrueAcyclic(nextVariable, nextValue, bodyIndex + 1, previousValues, triples)) {
                return true;
            }
            previousValues.remove(nextValue);
        }
        return false;
    }

    private boolean isBodyTrueAcyclicX(String variable, String value, int bodyIndex, Triple forbidden, HashSet<String> previousValues, TripleSet triples) {
        Atom atom = this.body.get(bodyIndex);
        boolean headNotTail = atom.getLeft().equals(variable);
        if (this.body.size() - 1 == bodyIndex) {
            boolean constant;
            boolean bl = constant = headNotTail ? atom.isRightC() : atom.isLeftC();
            if (constant) {
                String constantValue;
                String string = constantValue = headNotTail ? atom.getRight() : atom.getLeft();
                if (previousValues.contains(constantValue) && !constantValue.equals(this.head.getConstant())) {
                    return false;
                }
                if (headNotTail) {
                    return triples.isTrue(value, atom.getRelation(), constantValue);
                }
                return triples.isTrue(constantValue, atom.getRelation(), value);
            }
            Set<String> results = triples.getEntities(atom.getRelation(), value, headNotTail);
            for (String r : results) {
                if (previousValues.contains(r)) continue;
                return true;
            }
            return false;
        }
        Set<String> results = triples.getEntities(atom.getRelation(), value, headNotTail);
        String nextVariable = headNotTail ? atom.getRight() : atom.getLeft();
        for (String nextValue : results) {
            if (forbidden.equals(headNotTail, value, atom.getRelation(), nextValue) || previousValues.contains(nextValue)) continue;
            previousValues.add(nextValue);
            if (this.isBodyTrueAcyclicX(nextVariable, nextValue, bodyIndex + 1, forbidden, previousValues, triples)) {
                return true;
            }
            previousValues.remove(nextValue);
        }
        return false;
    }

    public void computeValuesReversed(String targetVariable, HashSet<String> targetValues, TripleSet ts) {
        int atomIndex = this.body.size() - 1;
        Atom lastAtom = this.body.get(atomIndex);
        String unboundVariable = this.getUnboundVariable();
        if (unboundVariable == null) {
            boolean nextVarIsLeft = !lastAtom.isLeftC();
            String constant = lastAtom.getLR(!nextVarIsLeft);
            String nextVariable = lastAtom.getLR(nextVarIsLeft);
            Set<String> values = ts.getEntities(lastAtom.getRelation(), constant, !nextVarIsLeft);
            HashSet<String> previousValues = new HashSet<String>();
            previousValues.add(constant);
            previousValues.add(this.head.getConstant());
            int counter = 0;
            for (String value : values) {
                this.forwardReversed(nextVariable, value, atomIndex - 1, targetVariable, targetValues, ts, previousValues);
                if (!(Rule.APPLICATION_MODE || targetValues.size() < Settings.SAMPLE_SIZE && ++counter < Settings.BEAM_SAMPLING_MAX_BODY_GROUNDING_ATTEMPTS)) {
                    return;
                }
                if (!Rule.APPLICATION_MODE || targetValues.size() < Settings.DISCRIMINATION_BOUND) continue;
                targetValues.clear();
                return;
            }
        } else {
            boolean nextVarIsLeft = !lastAtom.getLeft().equals(unboundVariable);
            String nextVariable = lastAtom.getLR(nextVarIsLeft);
            ArrayList<Triple> triples = ts.getTriplesByRelation(lastAtom.getRelation());
            int counter = 0;
            for (Triple t : triples) {
                ++counter;
                String value = t.getValue(nextVarIsLeft);
                HashSet<String> previousValues = new HashSet<String>();
                String previousValue = t.getValue(!nextVarIsLeft);
                previousValues.add(previousValue);
                previousValues.add(this.head.getConstant());
                this.forwardReversed(nextVariable, value, atomIndex - 1, targetVariable, targetValues, ts, previousValues);
                if (!(Rule.APPLICATION_MODE || targetValues.size() < Settings.SAMPLE_SIZE && counter < Settings.BEAM_SAMPLING_MAX_BODY_GROUNDING_ATTEMPTS)) {
                    return;
                }
                if (!Rule.APPLICATION_MODE || targetValues.size() < Settings.DISCRIMINATION_BOUND) continue;
                targetValues.clear();
                return;
            }
        }
    }

    public void beamValuesReversed(String targetVariable, HashSet<String> targetValues, TripleSet ts) {
        int atomIndex = this.body.size() - 1;
        Atom lastAtom = this.body.get(atomIndex);
        if (this.getGroundingsLastAtom(ts) < Settings.AC_MIN_NUM_OF_LAST_ATOM_GROUNDINGS) {
            return;
        }
        String unboundVariable = this.getUnboundVariable();
        if (unboundVariable == null) {
            String value;
            boolean nextVarIsLeft = !lastAtom.isLeftC();
            String constant = lastAtom.getLR(!nextVarIsLeft);
            String nextVariable = lastAtom.getLR(nextVarIsLeft);
            int counter = 0;
            while ((value = ts.getRandomEntity(lastAtom.getRelation(), constant, !nextVarIsLeft)) != null) {
                ++counter;
                HashSet<String> previousValues = new HashSet<String>();
                previousValues.add(constant);
                previousValues.add(this.head.getConstant());
                String targetValue = this.beamForwardReversed(nextVariable, value, atomIndex - 1, targetVariable, ts, previousValues);
                if (targetValue != null) {
                    targetValues.add(targetValue);
                }
                if (counter <= Settings.SAMPLE_SIZE) continue;
                return;
            }
        } else {
            Triple t;
            boolean nextVarIsLeft = !lastAtom.getLeft().equals(unboundVariable);
            String nextVariable = lastAtom.getLR(nextVarIsLeft);
            int counter = 0;
            while ((t = ts.getRandomTripleByRelation(lastAtom.getRelation())) != null) {
                ++counter;
                String value = t.getValue(nextVarIsLeft);
                HashSet<String> previousValues = new HashSet<String>();
                String previousValue = t.getValue(!nextVarIsLeft);
                previousValues.add(previousValue);
                previousValues.add(this.head.getConstant());
                String targetValue = this.beamForwardReversed(nextVariable, value, atomIndex - 1, targetVariable, ts, previousValues);
                if (targetValue != null) {
                    targetValues.add(targetValue);
                }
                if (counter <= Settings.SAMPLE_SIZE) continue;
                return;
            }
        }
    }

    private void forwardReversed(String variable, String value, int bodyIndex, String targetVariable, HashSet<String> targetValues, TripleSet ts, HashSet<String> previousValues) {
        if (previousValues.contains(value)) {
            return;
        }
        if (bodyIndex < 0) {
            targetValues.add(value);
        } else {
            HashSet<String> currentValues = new HashSet<String>();
            currentValues.add(value);
            currentValues.addAll(previousValues);
            Atom atom = this.body.get(bodyIndex);
            boolean nextVarIsLeft = false;
            nextVarIsLeft = !atom.getLeft().equals(variable);
            String nextVariable = atom.getLR(nextVarIsLeft);
            HashSet<String> nextValues = new HashSet<String>();
            if (!Rule.APPLICATION_MODE && targetValues.size() >= Settings.SAMPLE_SIZE) {
                return;
            }
            nextValues.addAll(ts.getEntities(atom.getRelation(), value, !nextVarIsLeft));
            for (String nextValue : nextValues) {
                this.forwardReversed(nextVariable, nextValue, bodyIndex - 1, targetVariable, targetValues, ts, currentValues);
            }
        }
    }

    private String beamForwardReversed(String variable, String value, int bodyIndex, String targetVariable, TripleSet ts, HashSet<String> previousValues) {
        if (previousValues.contains(value)) {
            return null;
        }
        if (bodyIndex < 0) {
            return value;
        }
        previousValues.add(value);
        Atom atom = this.body.get(bodyIndex);
        boolean nextVarIsLeft = false;
        nextVarIsLeft = !atom.getLeft().equals(variable);
        String nextVariable = atom.getLR(nextVarIsLeft);
        String nextValue = ts.getRandomEntity(atom.getRelation(), value, !nextVarIsLeft);
        if (nextValue != null) {
            return this.beamForwardReversed(nextVariable, nextValue, bodyIndex - 1, targetVariable, ts, previousValues);
        }
        return null;
    }

    protected abstract String getUnboundVariable();

    @Override
    public boolean isRefinable() {
        return false;
    }

    @Override
    public Triple getRandomValidPrediction(TripleSet ts) {
        ArrayList<Triple> validPredictions = this.getPredictions(ts, 1);
        if (validPredictions == null || validPredictions.size() == 0) {
            return null;
        }
        int index = rand.nextInt(validPredictions.size());
        return validPredictions.get(index);
    }

    @Override
    public Triple getRandomInvalidPrediction(TripleSet ts) {
        ArrayList<Triple> validPredictions = this.getPredictions(ts, -1);
        if (validPredictions == null || validPredictions.size() == 0) {
            return null;
        }
        int index = rand.nextInt(validPredictions.size());
        return validPredictions.get(index);
    }

    @Override
    public ArrayList<Triple> getPredictions(TripleSet ts) {
        return this.getPredictions(ts, 0);
    }

    protected ArrayList<Triple> getPredictions(TripleSet ts, int valid) {
        ArrayList<Triple> materialized = new ArrayList<Triple>();
        HashSet<Object> resultSet = new HashSet();
        resultSet = this.isXRule() ? this.computeHeadResults(this.getHead().getRight(), ts) : this.computeTailResults(this.getHead().getLeft(), ts);
        for (String string : resultSet) {
            Triple t = this.isXRule() ? new Triple(string, this.getTargetRelation(), this.getHead().getRight()) : new Triple(this.getHead().getLeft(), this.getTargetRelation(), string);
            if (valid == 1) {
                if (!ts.isTrue(t)) continue;
                materialized.add(t);
                continue;
            }
            if (valid == -1) {
                if (ts.isTrue(t)) continue;
                materialized.add(t);
                continue;
            }
            materialized.add(t);
        }
        return materialized;
    }

    public abstract int getGroundingsLastAtom(TripleSet var1);

    public void detachAndPolish() {
        Atom h;
        this.head = h = this.head.createCopy();
        this.body.detach();
        if (this.head.getRight().equals("X")) {
            this.head.setRight("Y");
            int i = 0;
            while (i < this.bodysize()) {
                Atom a = this.getBodyAtom(i);
                a.replace("X", "Y");
                ++i;
            }
        }
    }
}

