package org.bitbucket.efsmtool.evaluation.kfolds;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.Callable;
import org.apache.log4j.Logger;
import org.bitbucket.efsmtool.app.Configuration;
import org.bitbucket.efsmtool.inference.BaseClassifierInference;
import org.bitbucket.efsmtool.inference.efsm.EDSMRedBlueMerger;
import org.bitbucket.efsmtool.inference.efsm.scoring.BasicScorer;
import org.bitbucket.efsmtool.inference.efsm.scoring.KTailsScorer;
import org.bitbucket.efsmtool.inference.efsm.scoring.Scorer;
import org.bitbucket.efsmtool.model.WekaGuardMachineDecorator;
import org.bitbucket.efsmtool.model.prefixtree.TracePrefixTreeGenerator;
import org.bitbucket.efsmtool.model.walk.MachineAnalysis;
import org.bitbucket.efsmtool.tracedata.TraceElement;

/* loaded from: input_file:org/bitbucket/efsmtool/evaluation/kfolds/Experiment.class */
public class Experiment implements Callable<List<Result>> {
    private static final Logger LOGGER = Logger.getLogger(Experiment.class.getName());
    protected final Random rand;
    protected final Set<List<TraceElement>> trace = new HashSet();
    protected final Set<List<TraceElement>> negTrace;
    protected final Set<List<TraceElement>> eval;
    protected final int folds;
    protected final int seed;
    protected final int tail;
    protected final Configuration.Data algo;
    protected final String name;
    protected final boolean data;
    protected final boolean ktail;
    protected final List<Result> results;

    public Experiment(String str, Random random, Set<List<TraceElement>> set, Set<List<TraceElement>> set2, int i, Configuration.Data data, int i2, int i3, boolean z, boolean z2) {
        this.ktail = z2;
        this.rand = random;
        this.trace.addAll(set);
        this.folds = i;
        this.algo = data;
        this.name = str;
        this.seed = i2;
        this.tail = i3;
        this.data = z;
        this.results = new ArrayList();
        this.negTrace = new HashSet();
        this.negTrace.addAll(set2);
        this.eval = this.trace;
    }

    public List<Result> getCurrentResults() {
        return this.results;
    }

    public Random getRand() {
        return this.rand;
    }

    public Set<List<TraceElement>> getTrace() {
        return this.trace;
    }

    public int getFolds() {
        return this.folds;
    }

    public int getSeed() {
        return this.seed;
    }

    public int getTail() {
        return this.tail;
    }

    public Configuration.Data getAlgo() {
        return this.algo;
    }

    public String getName() {
        return this.name;
    }

    public boolean isData() {
        return this.data;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.concurrent.Callable
    public List<Result> call() {
        LOGGER.info("Running experiment for:" + this.name + "," + this.algo.toString() + "," + this.seed + "," + this.data);
        List<Set<List<TraceElement>>> computeFolds = computeFolds(this.folds);
        Collections.shuffle(computeFolds, this.rand);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.folds; i++) {
            Set<List<TraceElement>> set = computeFolds.get(i);
            HashSet hashSet = new HashSet();
            for (int i2 = 0; i2 < this.folds; i2++) {
                if (i2 != i) {
                    hashSet.addAll(computeFolds.get(i2));
                }
            }
            long currentTimeMillis = System.currentTimeMillis();
            try {
                WekaGuardMachineDecorator learnModel = learnModel(hashSet, set, this.algo);
                long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
                Score score = score(learnModel, set, this.negTrace);
                score.setDuration(currentTimeMillis2);
                arrayList.add(score);
            } catch (Exception e) {
                LOGGER.error(e.toString());
            }
        }
        Score calculateMeans = calculateMeans(arrayList);
        Result result = new Result(this.name, this.algo.toString(), calculateMeans.getSensitivity(), calculateMeans.getSpecificity(), calculateMeans.getBCR(), calculateMeans.getDuration(), this.seed, this.tail, this.data);
        this.results.add(result);
        LOGGER.info("Results for:" + this.name + "," + this.algo.toString() + "," + this.seed + "," + this.data + "\n" + result);
        return this.results;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Score calculateMeans(List<Score> list) {
        Score score = new Score();
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (Score score2 : list) {
            d += score2.getSensitivity();
            d2 += score2.getSpecificity();
            d3 += score2.getBCR();
            d4 += score2.getDuration();
        }
        score.setSensitivity(d / list.size());
        score.setSpecificity(d2 / list.size());
        score.setBCR(d3 / list.size());
        score.setDuration(((long) d4) / list.size());
        return score;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Score score(WekaGuardMachineDecorator wekaGuardMachineDecorator, Set<List<TraceElement>> set, Set<List<TraceElement>> set2) {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        MachineAnalysis analysis = wekaGuardMachineDecorator.getAnalysis();
        Iterator<List<TraceElement>> it = set.iterator();
        while (it.hasNext()) {
            if (analysis.walk(it.next(), true)) {
                d += 1.0d;
            } else {
                d2 += 1.0d;
            }
        }
        Iterator<List<TraceElement>> it2 = set2.iterator();
        while (it2.hasNext()) {
            if (analysis.walk(it2.next(), true)) {
                d4 += 1.0d;
            } else {
                d3 += 1.0d;
            }
        }
        Score score = new Score();
        score.setSensitivity(d / (d + d2));
        score.setSpecificity(d3 / (d3 + d4));
        if (this.data) {
            LOGGER.debug("tp:" + d + ", tn:" + d3 + ", fp:" + d4 + ", fn:" + d2 + ", RESULT: " + score);
        }
        return score;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public WekaGuardMachineDecorator learnModel(Set<List<TraceElement>> set, Set<List<TraceElement>> set2, Configuration.Data data) throws InterruptedException {
        HashSet hashSet = new HashSet();
        hashSet.addAll(set2);
        EDSMRedBlueMerger eDSMRedBlueMerger = new EDSMRedBlueMerger(new TracePrefixTreeGenerator(), set, this.tail, new BaseClassifierInference(set, hashSet, data), this.data, getScorer());
        hashSet.addAll(this.negTrace);
        try {
            if (this.negTrace.size() < 10) {
                LOGGER.warn("NEGATIVE SAMPLE VERY SMALL");
            }
            return (WekaGuardMachineDecorator) eDSMRedBlueMerger.infer();
        } catch (InterruptedException e) {
            LOGGER.error("INFERENCE ERROR");
            return null;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<Set<List<TraceElement>>> computeFolds(int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(i2, new HashSet());
        }
        int i3 = 0;
        Iterator<List<TraceElement>> it = this.eval.iterator();
        while (it.hasNext()) {
            if (i3 == i) {
                i3 = 0;
            }
            int i4 = i3;
            i3++;
            ((Set) arrayList.get(i4)).add(it.next());
        }
        return arrayList;
    }

    protected Scorer getScorer() {
        return this.ktail ? new KTailsScorer(this.tail) : new BasicScorer(this.tail);
    }
}
