package de.unima.ki.arch.main;

import de.unima.ki.arch.config.Config;
import de.unima.ki.arch.data.Triple;
import de.unima.ki.arch.data.TripleSet;
import de.unima.ki.arch.experiments.helper.HitsAtK;
import de.unima.ki.arch.io.AMIEReader;
import de.unima.ki.arch.logic.AmieReasoner;
import de.unima.ki.arch.logic.HornClause;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.TreeSet;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.StreamSupport;

/* loaded from: input_file:de/unima/ki/arch/main/Main.class */
public class Main {
    private static AmieReasoner reasoner;
    private static TripleSet testSet;
    private static TripleSet trainingSet;
    private static TripleSet validationSet;
    private static TripleSet knowledgeBase;
    private static Random random;

    public static void main(String[] strArr) {
        readDataSets();
        prepareReasoner();
        evaluate();
    }

    private static void prepareReasoner() {
        AMIEReader aMIEReader = new AMIEReader();
        Map<String, List<HornClause>> readRules = aMIEReader.readRules(new File(Config.RULE_FILEPATH));
        aMIEReader.printRuleStatistics();
        reasoner = new AmieReasoner(readRules, knowledgeBase);
    }

    private static void readDataSets() {
        testSet = new TripleSet(Config.TEST_FILEPATH);
        trainingSet = new TripleSet(Config.TRAIN_FILEPATH);
        validationSet = new TripleSet(Config.VALIDATION_FILEPATH);
        knowledgeBase = trainingSet;
    }

    private static void evaluate() {
        HitsAtK hitsAtK = new HitsAtK();
        hitsAtK.addFilterTripleSet(trainingSet);
        hitsAtK.addFilterTripleSet(validationSet);
        hitsAtK.addFilterTripleSet(testSet);
        AtomicInteger atomicInteger = new AtomicInteger(-1);
        Semaphore semaphore = new Semaphore(1);
        StreamSupport.stream(testSet.getTriples().spliterator(), true).forEach(triple -> {
            if (sampleGetsSkipped()) {
                return;
            }
            int incrementAndGet = atomicInteger.incrementAndGet();
            if (incrementAndGet % 100 == 0) {
                long round = Math.round(testSet.getTriples().size() * 1.0d);
                System.out.println(String.valueOf(incrementAndGet) + "/" + round + " (" + ((int) (((1.0d * incrementAndGet) / round) * 100.0d)) + "%) evaluated.");
            }
            HashMap<String, Double> headCandidates = reasoner.getHeadCandidates(triple);
            HashMap<String, Double> tailCandidates = reasoner.getTailCandidates(triple);
            semaphore.acquireUninterruptibly();
            try {
                hitsAtK.evaluateHead(headCandidates, triple);
                hitsAtK.evaluateTail(tailCandidates, triple);
            } finally {
                semaphore.release();
            }
        });
        try {
            hitsAtK.storeHits(Config.HITS_FILEPATH, Config.K);
        } catch (FileNotFoundException e) {
            System.err.println("Error. Could not write Hits result file:");
            e.printStackTrace();
        }
    }

    private static boolean sampleGetsSkipped() {
        return false;
    }

    private static void countMaxFilterSize() {
        int i = 0;
        ArrayList<TripleSet> arrayList = new ArrayList();
        arrayList.add(testSet);
        arrayList.add(trainingSet);
        arrayList.add(validationSet);
        int i2 = 0;
        Iterator<Triple> it = testSet.iterator();
        while (it.hasNext()) {
            Triple next = it.next();
            i2++;
            TreeSet treeSet = new TreeSet();
            TreeSet treeSet2 = new TreeSet();
            for (TripleSet tripleSet : arrayList) {
                treeSet.addAll(tripleSet.getHeadEntities(next.getRelation(), next.getTail()));
                treeSet2.addAll(tripleSet.getTailEntities(next.getRelation(), next.getHead()));
            }
            if (treeSet.size() > 3000) {
                System.out.println(String.valueOf(treeSet.size()) + " subj (#" + i2 + ")" + next);
            } else if (treeSet2.size() > 3000) {
                System.out.println(String.valueOf(treeSet2.size()) + " obj (#" + i2 + ")" + next);
            }
            if (treeSet.size() > i) {
                i = treeSet.size();
            }
            if (treeSet2.size() > i) {
                i = treeSet2.size();
            }
        }
        System.out.println("Max: " + i);
    }
}
