package de.unima.ki.anyburl.algorithm;

import de.unima.ki.anyburl.Settings;
import de.unima.ki.anyburl.data.Triple;
import de.unima.ki.anyburl.data.TripleSet;
import de.unima.ki.anyburl.structure.Rule;
import de.unima.ki.anyburl.structure.ScoreTree;
import de.unima.ki.anyburl.structure.compare.RuleConfidenceComparator;
import de.unima.ki.anyburl.threads.Predictor;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:de/unima/ki/anyburl/algorithm/RuleEngine.class */
public class RuleEngine {
    private static final double EPSILON = 1.0E-4d;
    private static LinkedList<Triple> predictionTasks = new LinkedList<>();
    private static int predictionsMade = 0;
    private static PrintWriter predictionsWriter = null;
    private static int DEBUG_TESTSET_SUBSET = 0;

    public static void materializeRules(LinkedList<Rule> linkedList, TripleSet tripleSet, TripleSet tripleSet2) {
        TripleSet materialize;
        int i = 0;
        Iterator<Rule> it = linkedList.iterator();
        while (it.hasNext()) {
            Rule next = it.next();
            i++;
            if (i % (linkedList.size() / 100) == 0) {
                System.out.println("* " + (100.0d * (i / linkedList.size())) + "% of all rules materialized");
            }
            if (next.bodysize() <= 2 && (materialize = next.materialize(tripleSet)) != null) {
                tripleSet2.addTripleSet(materialize);
            }
        }
    }

    public static void applyRulesARX(LinkedList<Rule> linkedList, TripleSet tripleSet, TripleSet tripleSet2, TripleSet tripleSet3, int i, PrintWriter printWriter) {
        if (DEBUG_TESTSET_SUBSET > 0) {
            System.out.println("* debugging mode, choosing small fraction of testset");
            TripleSet tripleSet4 = new TripleSet();
            for (int i2 = 0; i2 < DEBUG_TESTSET_SUBSET; i2++) {
                tripleSet4.addTriple(tripleSet.getTriples().get(i2));
            }
            for (int i3 = DEBUG_TESTSET_SUBSET; i3 < tripleSet.getTriples().size(); i3++) {
                tripleSet3.addTriple(tripleSet.getTriples().get(i3));
            }
            tripleSet = tripleSet4;
        }
        System.out.println("* applying rules");
        HashMap<String, ArrayList<Rule>> createOrderedRuleIndex = createOrderedRuleIndex(linkedList);
        System.out.println("* set up index structure covering rules for prediction for " + createOrderedRuleIndex.size() + " relations");
        ScoreTree.LOWER_BOUND = i;
        ScoreTree.UPPER_BOUND = ScoreTree.LOWER_BOUND;
        ScoreTree.EPSILON = EPSILON;
        predictionTasks.addAll(tripleSet.getTriples());
        predictionsWriter = printWriter;
        Thread[] threadArr = new Thread[Settings.WORKER_THREADS];
        System.out.print("* creating worker threads ");
        for (int i4 = 0; i4 < Settings.WORKER_THREADS; i4++) {
            System.out.print("#" + i4 + " ");
            threadArr[i4] = new Predictor(tripleSet, tripleSet2, tripleSet3, i, createOrderedRuleIndex);
            threadArr[i4].start();
        }
        System.out.println();
        while (alive(threadArr)) {
            try {
                Thread.sleep(500L);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        predictionsWriter.flush();
        predictionsWriter.close();
        System.out.println("* done with rule application");
        predictionsMade = 0;
    }

    private static boolean alive(Thread[] threadArr) {
        for (Thread thread : threadArr) {
            if (thread.isAlive()) {
                return true;
            }
        }
        return false;
    }

    public static synchronized Triple getNextPredictionTask() {
        predictionsMade++;
        Triple poll = predictionTasks.poll();
        if (predictionsMade % 100 == 0) {
            if (poll != null) {
                System.out.println("* (#" + predictionsMade + ") trying to guess the tail and head of " + poll.toString());
            }
            predictionsWriter.flush();
        }
        return poll;
    }

    public static void predictMax(TripleSet tripleSet, TripleSet tripleSet2, TripleSet tripleSet3, int i, HashMap<String, ArrayList<Rule>> hashMap, Triple triple) {
        ScoreTree scoreTree = new ScoreTree();
        LinkedHashMap<String, Double> predictMax = predictMax(tripleSet, tripleSet2, tripleSet3, i, hashMap, triple, false, scoreTree);
        ScoreTree scoreTree2 = new ScoreTree();
        LinkedHashMap<String, Double> predictMax2 = predictMax(tripleSet, tripleSet2, tripleSet3, i, hashMap, triple, true, scoreTree2);
        if (Settings.PATH_EXPLANATION != null) {
            writeTopKExplanation(triple, tripleSet, predictMax2, scoreTree2, predictMax, scoreTree, i);
        }
        writeTopKCandidates(triple, tripleSet, predictMax2, predictMax, predictionsWriter, i);
    }

    public static LinkedHashMap<String, Double> predictMaxOLD(TripleSet tripleSet, TripleSet tripleSet2, TripleSet tripleSet3, int i, HashMap<String, ArrayList<Rule>> hashMap, Triple triple, boolean z, ScoreTree scoreTree) {
        String relation = triple.getRelation();
        String head = triple.getHead();
        String tail = triple.getTail();
        if (hashMap.containsKey(relation)) {
            ArrayList<Rule> arrayList = hashMap.get(relation);
            Rule rule = null;
            new HashSet();
            HashSet hashSet = new HashSet();
            Iterator<Rule> it = arrayList.iterator();
            while (it.hasNext()) {
                Rule next = it.next();
                if (rule != null) {
                    hashSet.addAll(getFilteredEntities(tripleSet2, tripleSet3, tripleSet, triple, z ? rule.computeHeadResults(tail, tripleSet2) : rule.computeTailResults(head, tripleSet2), !z));
                    if (rule.getAppliedConfidence() > next.getAppliedConfidence()) {
                        if (scoreTree.fine()) {
                            break;
                        }
                        if (hashSet.size() > 0) {
                            if (Settings.PATH_EXPLANATION != null) {
                                scoreTree.addValues(rule.getAppliedConfidence(), hashSet, rule);
                            } else {
                                scoreTree.addValues(rule.getAppliedConfidence(), hashSet, null);
                            }
                            hashSet.clear();
                        }
                    } else {
                        continue;
                    }
                }
                rule = next;
            }
            if (!scoreTree.fine() && rule != null) {
                hashSet.addAll(getFilteredEntities(tripleSet2, tripleSet3, tripleSet, triple, z ? rule.computeHeadResults(tail, tripleSet2) : rule.computeTailResults(head, tripleSet2), !z));
                if (Settings.PATH_EXPLANATION != null) {
                    scoreTree.addValues(rule.getAppliedConfidence(), hashSet, rule);
                } else {
                    scoreTree.addValues(rule.getAppliedConfidence(), hashSet, null);
                }
                hashSet.clear();
            }
        }
        LinkedHashMap<String, Double> linkedHashMap = new LinkedHashMap<>();
        scoreTree.getAsLinkedList(linkedHashMap, z ? tail : head);
        return linkedHashMap;
    }

    public static LinkedHashMap<String, Double> predictMax(TripleSet tripleSet, TripleSet tripleSet2, TripleSet tripleSet3, int i, HashMap<String, ArrayList<Rule>> hashMap, Triple triple, boolean z, ScoreTree scoreTree) {
        String relation = triple.getRelation();
        String head = triple.getHead();
        String tail = triple.getTail();
        if (hashMap.containsKey(relation)) {
            ArrayList<Rule> arrayList = hashMap.get(relation);
            new HashSet();
            HashSet hashSet = new HashSet();
            Iterator<Rule> it = arrayList.iterator();
            while (it.hasNext()) {
                Rule next = it.next();
                hashSet.addAll(getFilteredEntities(tripleSet2, tripleSet3, tripleSet, triple, z ? next.computeHeadResults(tail, tripleSet2) : next.computeTailResults(head, tripleSet2), !z));
                if (scoreTree.fine()) {
                    break;
                }
                if (hashSet.size() > 0) {
                    if (Settings.PATH_EXPLANATION != null) {
                        scoreTree.addValues(next.getAppliedConfidence(), hashSet, next);
                    } else {
                        scoreTree.addValues(next.getAppliedConfidence(), hashSet, null);
                    }
                    hashSet.clear();
                }
            }
        }
        LinkedHashMap<String, Double> linkedHashMap = new LinkedHashMap<>();
        scoreTree.getAsLinkedList(linkedHashMap, z ? tail : head);
        return linkedHashMap;
    }

    public static void predictNoisyOr(TripleSet tripleSet, TripleSet tripleSet2, TripleSet tripleSet3, int i, HashMap<String, ArrayList<Rule>> hashMap, Triple triple) {
        String relation = triple.getRelation();
        String head = triple.getHead();
        String tail = triple.getTail();
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        if (hashMap.containsKey(relation)) {
            Iterator<Rule> it = hashMap.get(relation).iterator();
            while (it.hasNext()) {
                Rule next = it.next();
                Iterator<String> it2 = getFilteredEntities(tripleSet2, tripleSet3, tripleSet, triple, next.computeTailResults(head, tripleSet2), true).iterator();
                while (it2.hasNext()) {
                    String next2 = it2.next();
                    if (!hashMap2.containsKey(next2)) {
                        hashMap2.put(next2, new ArrayList());
                    }
                    ((ArrayList) hashMap2.get(next2)).add(next);
                }
                Iterator<String> it3 = getFilteredEntities(tripleSet2, tripleSet3, tripleSet, triple, next.computeHeadResults(tail, tripleSet2), false).iterator();
                while (it3.hasNext()) {
                    String next3 = it3.next();
                    if (!hashMap3.containsKey(next3)) {
                        hashMap3.put(next3, new ArrayList());
                    }
                    ((ArrayList) hashMap3.get(next3)).add(next);
                }
            }
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        LinkedHashMap linkedHashMap2 = new LinkedHashMap();
        replaceMyselfByEntity(linkedHashMap, head);
        replaceMyselfByEntity(linkedHashMap2, tail);
        sortByValue(linkedHashMap);
        sortByValue(linkedHashMap2);
        writeTopKCandidates(triple, tripleSet, linkedHashMap2, linkedHashMap, predictionsWriter, i);
    }

    private static void computeNoisyOr(HashMap<String, ArrayList<Rule>> hashMap, LinkedHashMap<String, Double> linkedHashMap) {
        for (String str : hashMap.keySet()) {
            double d = 0.0d;
            int size = (Settings.AGGREGATION_MAX_NUM_RULES_PER_CANDIDATE < 0 || hashMap.get(str).size() < Settings.AGGREGATION_MAX_NUM_RULES_PER_CANDIDATE) ? hashMap.get(str).size() : Settings.AGGREGATION_MAX_NUM_RULES_PER_CANDIDATE;
            int i = 1;
            Iterator<Rule> it = hashMap.get(str).iterator();
            while (it.hasNext()) {
                d += Math.log(1.0d - it.next().getAppliedConfidence());
                if (i == size) {
                    break;
                } else {
                    i++;
                }
            }
            linkedHashMap.put(str, Double.valueOf((-1.0d) * d));
        }
    }

    public static void replaceMyselfByEntity(LinkedHashMap<String, Double> linkedHashMap, String str) {
        if (linkedHashMap.containsKey(Settings.REWRITE_REFLEXIV_TOKEN)) {
            double doubleValue = linkedHashMap.get(Settings.REWRITE_REFLEXIV_TOKEN).doubleValue();
            linkedHashMap.remove(Settings.REWRITE_REFLEXIV_TOKEN);
            linkedHashMap.put(str, Double.valueOf(doubleValue));
        }
    }

    public static HashMap<String, ArrayList<Rule>> createOrderedRuleIndex(LinkedList<Rule> linkedList) {
        HashMap<String, ArrayList<Rule>> hashMap = new HashMap<>();
        long j = 0;
        Iterator<Rule> it = linkedList.iterator();
        while (it.hasNext()) {
            Rule next = it.next();
            if (Settings.THRESHOLD_CORRECT_PREDICTIONS <= next.getCorrectlyPredicted() && Settings.THRESHOLD_CONFIDENCE <= next.getConfidence()) {
                String targetRelation = next.getTargetRelation();
                if (!hashMap.containsKey(targetRelation)) {
                    hashMap.put(targetRelation, new ArrayList<>());
                }
                hashMap.get(targetRelation).add(next);
                if (j % 100000 == 0 && j > 1) {
                    System.out.println("* indexed " + j + " rules for prediction");
                }
                j++;
            }
        }
        for (String str : hashMap.keySet()) {
            hashMap.get(str).trimToSize();
            Collections.sort(hashMap.get(str), new RuleConfidenceComparator());
        }
        System.out.println("* indexed and sorted " + j + " rules for using them to make predictions");
        return hashMap;
    }

    private static HashSet<String> getFilteredEntities(TripleSet tripleSet, TripleSet tripleSet2, TripleSet tripleSet3, Triple triple, Set<String> set, boolean z) {
        HashSet<String> hashSet = new HashSet<>();
        for (String str : set) {
            if (!z) {
                if (!tripleSet2.isTrue(str, triple.getRelation(), triple.getTail()) && !tripleSet.isTrue(str, triple.getRelation(), triple.getTail()) && !tripleSet3.isTrue(str, triple.getRelation(), triple.getTail())) {
                    hashSet.add(str);
                }
                if (tripleSet3.isTrue(str, triple.getRelation(), triple.getTail()) && str.equals(triple.getHead())) {
                    hashSet.add(str);
                }
            }
            if (z) {
                if (!tripleSet2.isTrue(triple.getHead(), triple.getRelation(), str) && !tripleSet.isTrue(triple.getHead(), triple.getRelation(), str) && !tripleSet3.isTrue(triple.getHead(), triple.getRelation(), str)) {
                    hashSet.add(str);
                }
                if (tripleSet3.isTrue(triple.getHead(), triple.getRelation(), str) && str.equals(triple.getTail())) {
                    hashSet.add(str);
                }
            }
        }
        return hashSet;
    }

    private static synchronized void writeTopKCandidates(Triple triple, TripleSet tripleSet, LinkedHashMap<String, Double> linkedHashMap, LinkedHashMap<String, Double> linkedHashMap2, PrintWriter printWriter, int i) {
        printWriter.println(triple);
        int i2 = 0;
        printWriter.print("Heads: ");
        for (Map.Entry<String, Double> entry : linkedHashMap.entrySet()) {
            if (triple.getHead().equals(entry.getKey()) || !tripleSet.isTrue(entry.getKey(), triple.getRelation(), triple.getTail())) {
                printWriter.print(String.valueOf(entry.getKey()) + "\t" + entry.getValue() + "\t");
                i2++;
            }
            if (i2 == i) {
                break;
            }
        }
        printWriter.println();
        int i3 = 0;
        printWriter.print("Tails: ");
        for (Map.Entry<String, Double> entry2 : linkedHashMap2.entrySet()) {
            if (triple.getTail().equals(entry2.getKey()) || !tripleSet.isTrue(triple.getHead(), triple.getRelation(), entry2.getKey())) {
                printWriter.print(String.valueOf(entry2.getKey()) + "\t" + entry2.getValue() + "\t");
                i3++;
            }
            if (i3 == i) {
                break;
            }
        }
        printWriter.println();
        printWriter.flush();
    }

    private static synchronized void writeTopKExplanation(Triple triple, TripleSet tripleSet, LinkedHashMap<String, Double> linkedHashMap, ScoreTree scoreTree, LinkedHashMap<String, Double> linkedHashMap2, ScoreTree scoreTree2, int i) {
        Settings.EXPLANATION_WRITER.println(triple);
        Settings.EXPLANATION_WRITER.println("Heads:");
        Settings.EXPLANATION_WRITER.println(scoreTree);
        Settings.EXPLANATION_WRITER.println("Tails:");
        Settings.EXPLANATION_WRITER.println(scoreTree2);
        Settings.EXPLANATION_WRITER.flush();
    }

    public static void sortByValue(LinkedHashMap<String, Double> linkedHashMap) {
        ArrayList<Map.Entry> arrayList = new ArrayList(linkedHashMap.entrySet());
        Collections.sort(arrayList, new Comparator<Map.Entry<String, Double>>() { // from class: de.unima.ki.anyburl.algorithm.RuleEngine.1
            @Override // java.util.Comparator
            public int compare(Map.Entry<String, Double> entry, Map.Entry<String, Double> entry2) {
                if (entry.getValue().doubleValue() < entry2.getValue().doubleValue()) {
                    return 1;
                }
                return entry.getValue().doubleValue() > entry2.getValue().doubleValue() ? -1 : 0;
            }
        });
        linkedHashMap.clear();
        for (Map.Entry entry : arrayList) {
            linkedHashMap.put((String) entry.getKey(), (Double) entry.getValue());
        }
    }
}
