/*
 * Decompiled with CFR 0.152.
 */
package de.unima.ki.anyburl.algorithm;

import de.unima.ki.anyburl.Settings;
import de.unima.ki.anyburl.data.Triple;
import de.unima.ki.anyburl.data.TripleSet;
import de.unima.ki.anyburl.structure.Rule;
import de.unima.ki.anyburl.structure.ScoreTree;
import de.unima.ki.anyburl.structure.compare.RuleConfidenceComparator;
import de.unima.ki.anyburl.threads.Predictor;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class RuleEngine {
    private static final double EPSILON = 1.0E-4;
    private static LinkedList<Triple> predictionTasks = new LinkedList();
    private static int predictionsMade = 0;
    private static PrintWriter predictionsWriter = null;
    private static int DEBUG_TESTSET_SUBSET = 0;

    public static void materializeRules(LinkedList<Rule> rules, TripleSet trainingSet, TripleSet materializedSet) {
        int ruleCounter = 0;
        for (Rule rule : rules) {
            TripleSet materializedRule;
            if (++ruleCounter % (rules.size() / 100) == 0) {
                System.out.println("* " + 100.0 * ((double)ruleCounter / (double)rules.size()) + "% of all rules materialized");
            }
            if (rule.bodysize() > 2 || (materializedRule = rule.materialize(trainingSet)) == null) continue;
            materializedSet.addTripleSet(materializedRule);
        }
    }

    public static void applyRulesARX(LinkedList<Rule> rules, TripleSet testSet, TripleSet trainingSet, TripleSet validationSet, int k, PrintWriter resultsWriter) {
        if (DEBUG_TESTSET_SUBSET > 0) {
            Triple t;
            System.out.println("* debugging mode, choosing small fraction of testset");
            TripleSet testSetReduced = new TripleSet();
            int i = 0;
            while (i < DEBUG_TESTSET_SUBSET) {
                t = testSet.getTriples().get(i);
                testSetReduced.addTriple(t);
                ++i;
            }
            i = DEBUG_TESTSET_SUBSET;
            while (i < testSet.getTriples().size()) {
                t = testSet.getTriples().get(i);
                validationSet.addTriple(t);
                ++i;
            }
            testSet = testSetReduced;
        }
        System.out.println("* applying rules");
        HashMap<String, ArrayList<Rule>> relation2Rules4Prediction = RuleEngine.createOrderedRuleIndex(rules);
        System.out.println("* set up index structure covering rules for prediction for " + relation2Rules4Prediction.size() + " relations");
        ScoreTree.UPPER_BOUND = ScoreTree.LOWER_BOUND = k;
        ScoreTree.EPSILON = 1.0E-4;
        predictionTasks.addAll(testSet.getTriples());
        predictionsWriter = resultsWriter;
        Thread[] predictors = new Thread[Settings.WORKER_THREADS];
        System.out.print("* creating worker threads ");
        int threadCounter = 0;
        while (threadCounter < Settings.WORKER_THREADS) {
            System.out.print("#" + threadCounter + " ");
            predictors[threadCounter] = new Predictor(testSet, trainingSet, validationSet, k, relation2Rules4Prediction);
            predictors[threadCounter].start();
            ++threadCounter;
        }
        System.out.println();
        while (RuleEngine.alive(predictors)) {
            try {
                Thread.sleep(500L);
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        predictionsWriter.flush();
        predictionsWriter.close();
        System.out.println("* done with rule application");
        predictionsMade = 0;
    }

    private static boolean alive(Thread[] threads) {
        Thread[] threadArray = threads;
        int n = threads.length;
        int n2 = 0;
        while (n2 < n) {
            Thread t = threadArray[n2];
            if (t.isAlive()) {
                return true;
            }
            ++n2;
        }
        return false;
    }

    public static synchronized Triple getNextPredictionTask() {
        Triple triple = predictionTasks.poll();
        if (++predictionsMade % 100 == 0) {
            if (triple != null) {
                System.out.println("* (#" + predictionsMade + ") trying to guess the tail and head of " + triple.toString());
            }
            predictionsWriter.flush();
        }
        return triple;
    }

    public static void predictMax(TripleSet testSet, TripleSet trainingSet, TripleSet validationSet, int k, HashMap<String, ArrayList<Rule>> relation2Rules4Prediction, Triple triple) {
        ScoreTree kTree = new ScoreTree();
        LinkedHashMap<String, Double> kTailCandidates = RuleEngine.predictMax(testSet, trainingSet, validationSet, k, relation2Rules4Prediction, triple, false, kTree);
        ScoreTree kTailTree = kTree;
        kTree = new ScoreTree();
        LinkedHashMap<String, Double> kHeadCandidates = RuleEngine.predictMax(testSet, trainingSet, validationSet, k, relation2Rules4Prediction, triple, true, kTree);
        ScoreTree kHeadTree = kTree;
        if (Settings.PATH_EXPLANATION != null) {
            RuleEngine.writeTopKExplanation(triple, testSet, kHeadCandidates, kHeadTree, kTailCandidates, kTailTree, k);
        }
        RuleEngine.writeTopKCandidates(triple, testSet, kHeadCandidates, kTailCandidates, predictionsWriter, k);
    }

    public static LinkedHashMap<String, Double> predictMaxOLD(TripleSet testSet, TripleSet trainingSet, TripleSet filterSet, int k, HashMap<String, ArrayList<Rule>> relation2Rules, Triple triple, boolean predictHeadNotTail, ScoreTree kTree) {
        String relation = triple.getRelation();
        String head = triple.getHead();
        String tail = triple.getTail();
        if (relation2Rules.containsKey(relation)) {
            ArrayList<Rule> relevantRules = relation2Rules.get(relation);
            Rule previousRule = null;
            HashSet<String> candidates = new HashSet();
            HashSet<String> fCandidates = new HashSet<String>();
            for (Rule rule : relevantRules) {
                if (previousRule != null) {
                    candidates = predictHeadNotTail ? previousRule.computeHeadResults(tail, trainingSet) : previousRule.computeTailResults(head, trainingSet);
                    fCandidates.addAll(RuleEngine.getFilteredEntities(trainingSet, filterSet, testSet, triple, candidates, !predictHeadNotTail));
                    if (previousRule.getAppliedConfidence() > rule.getAppliedConfidence()) {
                        if (kTree.fine()) break;
                        if (fCandidates.size() > 0) {
                            if (Settings.PATH_EXPLANATION != null) {
                                kTree.addValues(previousRule.getAppliedConfidence(), fCandidates, previousRule);
                            } else {
                                kTree.addValues(previousRule.getAppliedConfidence(), fCandidates, null);
                            }
                            fCandidates.clear();
                        }
                    }
                }
                previousRule = rule;
            }
            if (!kTree.fine() && previousRule != null) {
                candidates = predictHeadNotTail ? previousRule.computeHeadResults(tail, trainingSet) : previousRule.computeTailResults(head, trainingSet);
                fCandidates.addAll(RuleEngine.getFilteredEntities(trainingSet, filterSet, testSet, triple, candidates, !predictHeadNotTail));
                if (Settings.PATH_EXPLANATION != null) {
                    kTree.addValues(previousRule.getAppliedConfidence(), fCandidates, previousRule);
                } else {
                    kTree.addValues(previousRule.getAppliedConfidence(), fCandidates, null);
                }
                fCandidates.clear();
            }
        }
        LinkedHashMap<String, Double> kCandidates = new LinkedHashMap<String, Double>();
        kTree.getAsLinkedList(kCandidates, predictHeadNotTail ? tail : head);
        return kCandidates;
    }

    public static LinkedHashMap<String, Double> predictMax(TripleSet testSet, TripleSet trainingSet, TripleSet validationSet, int k, HashMap<String, ArrayList<Rule>> relation2Rules, Triple triple, boolean predictHeadNotTail, ScoreTree kTree) {
        String relation = triple.getRelation();
        String head = triple.getHead();
        String tail = triple.getTail();
        if (relation2Rules.containsKey(relation)) {
            ArrayList<Rule> relevantRules = relation2Rules.get(relation);
            HashSet<String> candidates = new HashSet();
            HashSet<String> fCandidates = new HashSet<String>();
            for (Rule rule : relevantRules) {
                candidates = predictHeadNotTail ? rule.computeHeadResults(tail, trainingSet) : rule.computeTailResults(head, trainingSet);
                fCandidates.addAll(RuleEngine.getFilteredEntities(trainingSet, validationSet, testSet, triple, candidates, !predictHeadNotTail));
                if (kTree.fine()) break;
                if (fCandidates.size() <= 0) continue;
                if (Settings.PATH_EXPLANATION != null) {
                    kTree.addValues(rule.getAppliedConfidence(), fCandidates, rule);
                } else {
                    kTree.addValues(rule.getAppliedConfidence(), fCandidates, null);
                }
                fCandidates.clear();
            }
        }
        LinkedHashMap<String, Double> kCandidates = new LinkedHashMap<String, Double>();
        kTree.getAsLinkedList(kCandidates, predictHeadNotTail ? tail : head);
        return kCandidates;
    }

    public static void predictNoisyOr(TripleSet testSet, TripleSet trainingSet, TripleSet validationSet, int k, HashMap<String, ArrayList<Rule>> relation2Rules, Triple triple) {
        String relation = triple.getRelation();
        String head = triple.getHead();
        String tail = triple.getTail();
        HashMap<String, ArrayList<Rule>> explainedTailCandidates = new HashMap<String, ArrayList<Rule>>();
        HashMap<String, ArrayList<Rule>> explainedHeadCandidates = new HashMap<String, ArrayList<Rule>>();
        if (relation2Rules.containsKey(relation)) {
            ArrayList<Rule> relevantRules = relation2Rules.get(relation);
            for (Rule rule : relevantRules) {
                HashSet<String> tailCandidates = rule.computeTailResults(head, trainingSet);
                HashSet<String> fTailCandidates = RuleEngine.getFilteredEntities(trainingSet, validationSet, testSet, triple, tailCandidates, true);
                for (String fTailCandidate : fTailCandidates) {
                    if (!explainedTailCandidates.containsKey(fTailCandidate)) {
                        explainedTailCandidates.put(fTailCandidate, new ArrayList());
                    }
                    explainedTailCandidates.get(fTailCandidate).add(rule);
                }
                HashSet<String> headCandidates = rule.computeHeadResults(tail, trainingSet);
                HashSet<String> fHeadCandidates = RuleEngine.getFilteredEntities(trainingSet, validationSet, testSet, triple, headCandidates, false);
                for (String fHeadCandidate : fHeadCandidates) {
                    if (!explainedHeadCandidates.containsKey(fHeadCandidate)) {
                        explainedHeadCandidates.put(fHeadCandidate, new ArrayList());
                    }
                    explainedHeadCandidates.get(fHeadCandidate).add(rule);
                }
            }
        }
        LinkedHashMap<String, Double> kTailCandidates = new LinkedHashMap<String, Double>();
        LinkedHashMap<String, Double> kHeadCandidates = new LinkedHashMap<String, Double>();
        if (Settings.AGGREGATION_ID == 3) {
            RuleEngine.computeNoisyOr(explainedTailCandidates, kTailCandidates);
            RuleEngine.computeNoisyOr(explainedHeadCandidates, kHeadCandidates);
        }
        RuleEngine.replaceMyselfByEntity(kTailCandidates, head);
        RuleEngine.replaceMyselfByEntity(kHeadCandidates, tail);
        RuleEngine.sortByValue(kTailCandidates);
        RuleEngine.sortByValue(kHeadCandidates);
        RuleEngine.writeTopKCandidates(triple, testSet, kHeadCandidates, kTailCandidates, predictionsWriter, k);
    }

    private static void computeNoisyOr(HashMap<String, ArrayList<Rule>> allCandidates, LinkedHashMap<String, Double> kCandidates) {
        for (String cand : allCandidates.keySet()) {
            double log_prob_sum = 0.0;
            int num_rules = Settings.AGGREGATION_MAX_NUM_RULES_PER_CANDIDATE < 0 || allCandidates.get(cand).size() < Settings.AGGREGATION_MAX_NUM_RULES_PER_CANDIDATE ? allCandidates.get(cand).size() : Settings.AGGREGATION_MAX_NUM_RULES_PER_CANDIDATE;
            int ctr = 1;
            for (Rule r : allCandidates.get(cand)) {
                log_prob_sum += Math.log(1.0 - r.getAppliedConfidence());
                if (ctr == num_rules) break;
                ++ctr;
            }
            double score = -1.0 * log_prob_sum;
            kCandidates.put(cand, score);
        }
    }

    public static void replaceMyselfByEntity(LinkedHashMap<String, Double> candidates, String replacement) {
        if (candidates.containsKey("me_myself_i")) {
            double myselfConf = candidates.get("me_myself_i");
            candidates.remove("me_myself_i");
            candidates.put(replacement, myselfConf);
        }
    }

    public static HashMap<String, ArrayList<Rule>> createOrderedRuleIndex(LinkedList<Rule> rules) {
        HashMap<String, ArrayList<Rule>> relation2Rules = new HashMap<String, ArrayList<Rule>>();
        long l = 0L;
        for (Rule rule : rules) {
            if (Settings.THRESHOLD_CORRECT_PREDICTIONS > rule.getCorrectlyPredicted() || Settings.THRESHOLD_CONFIDENCE > rule.getConfidence()) continue;
            String relation = rule.getTargetRelation();
            if (!relation2Rules.containsKey(relation)) {
                relation2Rules.put(relation, new ArrayList());
            }
            relation2Rules.get(relation).add(rule);
            if (l % 100000L == 0L && l > 1L) {
                System.out.println("* indexed " + l + " rules for prediction");
            }
            ++l;
        }
        for (String relation : relation2Rules.keySet()) {
            ((ArrayList)relation2Rules.get(relation)).trimToSize();
            Collections.sort((List)relation2Rules.get(relation), new RuleConfidenceComparator());
        }
        System.out.println("* indexed and sorted " + l + " rules for using them to make predictions");
        return relation2Rules;
    }

    private static HashSet<String> getFilteredEntities(TripleSet trainingSet, TripleSet validationSet, TripleSet testSet, Triple t, Set<String> candidateEntities, boolean tailNotHead) {
        HashSet<String> filteredEntities = new HashSet<String>();
        for (String entity : candidateEntities) {
            if (!tailNotHead) {
                if (!(validationSet.isTrue(entity, t.getRelation(), t.getTail()) || trainingSet.isTrue(entity, t.getRelation(), t.getTail()) || testSet.isTrue(entity, t.getRelation(), t.getTail()))) {
                    filteredEntities.add(entity);
                }
                if (testSet.isTrue(entity, t.getRelation(), t.getTail()) && entity.equals(t.getHead())) {
                    filteredEntities.add(entity);
                }
            }
            if (!tailNotHead) continue;
            if (!(validationSet.isTrue(t.getHead(), t.getRelation(), entity) || trainingSet.isTrue(t.getHead(), t.getRelation(), entity) || testSet.isTrue(t.getHead(), t.getRelation(), entity))) {
                filteredEntities.add(entity);
            }
            if (!testSet.isTrue(t.getHead(), t.getRelation(), entity) || !entity.equals(t.getTail())) continue;
            filteredEntities.add(entity);
        }
        return filteredEntities;
    }

    private static synchronized void writeTopKCandidates(Triple t, TripleSet testSet, LinkedHashMap<String, Double> kHeadCandidates, LinkedHashMap<String, Double> kTailCandidates, PrintWriter writer, int k) {
        writer.println(t);
        int i = 0;
        writer.print("Heads: ");
        for (Map.Entry<String, Double> entry : kHeadCandidates.entrySet()) {
            if (t.getHead().equals(entry.getKey()) || !testSet.isTrue(entry.getKey(), t.getRelation(), t.getTail())) {
                writer.print(String.valueOf(entry.getKey()) + "\t" + entry.getValue() + "\t");
                ++i;
            }
            if (i == k) break;
        }
        writer.println();
        i = 0;
        writer.print("Tails: ");
        for (Map.Entry<String, Double> entry : kTailCandidates.entrySet()) {
            if (t.getTail().equals(entry.getKey()) || !testSet.isTrue(t.getHead(), t.getRelation(), entry.getKey())) {
                writer.print(String.valueOf(entry.getKey()) + "\t" + entry.getValue() + "\t");
                ++i;
            }
            if (i == k) break;
        }
        writer.println();
        writer.flush();
    }

    private static synchronized void writeTopKExplanation(Triple t, TripleSet testSet, LinkedHashMap<String, Double> kHeadCandidates, ScoreTree headTree, LinkedHashMap<String, Double> kTailCandidates, ScoreTree tailTree, int k) {
        Settings.EXPLANATION_WRITER.println(t);
        Settings.EXPLANATION_WRITER.println("Heads:");
        Settings.EXPLANATION_WRITER.println(headTree);
        Settings.EXPLANATION_WRITER.println("Tails:");
        Settings.EXPLANATION_WRITER.println(tailTree);
        Settings.EXPLANATION_WRITER.flush();
    }

    public static void sortByValue(LinkedHashMap<String, Double> m) {
        ArrayList<Map.Entry<String, Double>> entries = new ArrayList<Map.Entry<String, Double>>(m.entrySet());
        Collections.sort(entries, new Comparator<Map.Entry<String, Double>>(){

            @Override
            public int compare(Map.Entry<String, Double> lhs, Map.Entry<String, Double> rhs) {
                if (lhs.getValue() < rhs.getValue()) {
                    return 1;
                }
                if (lhs.getValue() > rhs.getValue()) {
                    return -1;
                }
                return 0;
            }
        });
        m.clear();
        for (Map.Entry entry : entries) {
            m.put((String)entry.getKey(), (Double)entry.getValue());
        }
    }
}

