package org.bitbucket.efsmtool.inference;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.log4j.Logger;
import org.bitbucket.efsmtool.app.Configuration;
import org.bitbucket.efsmtool.tracedata.SimpleTraceElement;
import org.bitbucket.efsmtool.tracedata.TraceElement;
import org.bitbucket.efsmtool.tracedata.TraceSet;
import org.bitbucket.efsmtool.tracedata.types.BooleanVariableAssignment;
import org.bitbucket.efsmtool.tracedata.types.DoubleVariableAssignment;
import org.bitbucket.efsmtool.tracedata.types.IntegerVariableAssignment;
import org.bitbucket.efsmtool.tracedata.types.StringVariableAssignment;
import org.bitbucket.efsmtool.tracedata.types.VariableAssignment;
import weka.classifiers.Classifier;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.meta.AdaBoostM1;
import weka.classifiers.meta.AdditiveRegression;
import weka.classifiers.rules.JRip;
import weka.classifiers.rules.M5Rules;
import weka.classifiers.trees.J48;
import weka.classifiers.trees.M5P;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.RemoveUseless;
import weka.filters.unsupervised.attribute.StringToNominal;
import weka.filters.unsupervised.attribute.StringToWordVector;

/* loaded from: input_file:org/bitbucket/efsmtool/inference/BaseClassifierInference.class */
public class BaseClassifierInference {
    private static final Logger LOGGER;
    private static final String final_name = "CLASS_NEXT";
    protected TraceSet traces;
    protected TraceSet evalSet;
    protected Map<String, Classifier> classifiers;
    protected HashMap<TraceElement, Instance> elementsToInstances;
    protected Map<Instance, TraceElement> instancesToElements;
    protected Map<String, Instances> testingData;
    protected Configuration.Data inferenceAlgorithm;
    static final /* synthetic */ boolean $assertionsDisabled;

    public HashMap<TraceElement, Instance> getElementsToInstances() {
        return this.elementsToInstances;
    }

    public BaseClassifierInference(TraceSet traceSet, Configuration.Data data) {
        LOGGER.setLevel(Configuration.LOGGING);
        this.traces = traceSet;
        this.inferenceAlgorithm = data;
        HashMap hashMap = new HashMap();
        this.elementsToInstances = new HashMap<>();
        this.evalSet = new TraceSet();
        this.testingData = new HashMap();
        this.instancesToElements = new HashMap();
        HashMap hashMap2 = new HashMap();
        Iterator<List<TraceElement>> it = traceSet.getPos().iterator();
        while (it.hasNext()) {
            mapFunctionsToTraceElements(it.next(), hashMap2);
        }
        addTrainingData(hashMap2, hashMap);
        buildClassifiers(hashMap);
    }

    public BaseClassifierInference(TraceSet traceSet, TraceSet traceSet2, Configuration.Data data) {
        this.traces = traceSet;
        this.evalSet = traceSet2;
        this.inferenceAlgorithm = data;
        HashMap hashMap = new HashMap();
        this.elementsToInstances = new HashMap<>();
        this.instancesToElements = new HashMap();
        this.testingData = new HashMap();
        HashMap hashMap2 = new HashMap();
        Iterator<List<TraceElement>> it = traceSet.getPos().iterator();
        while (it.hasNext()) {
            mapFunctionsToTraceElements(it.next(), hashMap2);
        }
        Iterator<List<TraceElement>> it2 = traceSet2.getPos().iterator();
        while (it2.hasNext()) {
            mapFunctionsToTraceElements(it2.next(), hashMap2);
        }
        addTrainingData(hashMap2, hashMap);
        buildClassifiers(hashMap);
    }

    public Map<String, Classifier> getClassifiers() {
        return this.classifiers;
    }

    private Map<String, Instances> addTrainingData(Map<String, Set<TraceElement>> map, Map<String, Instances> map2) {
        for (String str : map.keySet()) {
            Instances buildInstances = buildInstances(map.get(str));
            transferToTestingData(buildInstances, str);
            if (map2.containsKey(str)) {
                buildInstances.addAll(map2.get(str));
            }
            map2.put(str, buildInstances);
        }
        return map2;
    }

    private void transferToTestingData(Instances instances, String str) {
        Iterator it = instances.iterator();
        HashSet hashSet = new HashSet();
        while (it.hasNext()) {
            Instance instance = (Instance) it.next();
            if (contains(this.evalSet.getPos(), this.instancesToElements.get(instance))) {
                if (this.testingData.containsKey(str)) {
                    this.testingData.get(str).add(instance);
                    hashSet.add(instance);
                } else {
                    this.testingData.put(str, new Instances(instances, instances.indexOf(instance), 1));
                    hashSet.add(instance);
                }
            }
        }
        instances.removeAll(hashSet);
    }

    private boolean contains(Set<List<TraceElement>> set, TraceElement traceElement) {
        Iterator<List<TraceElement>> it = set.iterator();
        while (it.hasNext()) {
            if (it.next().contains(traceElement)) {
                return true;
            }
        }
        return false;
    }

    private void buildClassifiers(Map<String, Instances> map) {
        this.classifiers = new HashMap();
        for (String str : map.keySet()) {
            Instances instances = map.get(str);
            try {
                Classifier makeClassifier = makeClassifier(this.inferenceAlgorithm);
                makeClassifier.buildClassifier(instances);
                this.classifiers.put(str, makeClassifier);
            } catch (Exception e) {
                LOGGER.info(str + " only has one possible following function, no need for classifier.");
            }
        }
    }

    private Classifier makeClassifier(Configuration.Data data) {
        switch (data) {
            case J48:
                J48 createJ48 = createJ48();
                if (Configuration.WEKA_OPTIONS.length > 0) {
                    try {
                        createJ48.setOptions(Configuration.WEKA_OPTIONS);
                    } catch (Exception e) {
                        System.err.println("Invalid WEKA options - running with default settings.");
                    }
                }
                return createJ48;
            case AdaBoostDiscrete:
                AdaBoostM1 adaBoostM1 = new AdaBoostM1();
                adaBoostM1.setClassifier(createJ48());
                adaBoostM1.setSeed(Configuration.SEED);
                if (Configuration.WEKA_OPTIONS.length > 0) {
                    try {
                        adaBoostM1.setOptions(Configuration.WEKA_OPTIONS);
                    } catch (Exception e2) {
                        System.err.println("Invalid WEKA options - running with default settings.");
                    }
                }
                return adaBoostM1;
            case NaiveBayes:
                NaiveBayes naiveBayes = new NaiveBayes();
                if (Configuration.WEKA_OPTIONS.length > 0) {
                    try {
                        naiveBayes.setOptions(Configuration.WEKA_OPTIONS);
                    } catch (Exception e3) {
                        System.err.println("Invalid WEKA options - running with default settings.");
                    }
                }
                return naiveBayes;
            case M5:
                M5P m5p = new M5P();
                if (Configuration.WEKA_OPTIONS.length > 0) {
                    try {
                        m5p.setOptions(Configuration.WEKA_OPTIONS);
                    } catch (Exception e4) {
                        System.err.println("Invalid WEKA options - running with default settings.");
                    }
                }
                return m5p;
            case M5Rules:
                M5Rules m5Rules = new M5Rules();
                if (Configuration.WEKA_OPTIONS.length > 0) {
                    try {
                        m5Rules.setOptions(Configuration.WEKA_OPTIONS);
                    } catch (Exception e5) {
                        System.err.println("Invalid WEKA options - running with default settings.");
                    }
                }
                return m5Rules;
            case AdditiveRegression:
                AdditiveRegression additiveRegression = new AdditiveRegression();
                if (Configuration.WEKA_OPTIONS.length > 0) {
                    try {
                        additiveRegression.setOptions(Configuration.WEKA_OPTIONS);
                    } catch (Exception e6) {
                        System.err.println("Invalid WEKA options - running with default settings.");
                    }
                }
                return additiveRegression;
            case JRIP:
                JRip jRip = new JRip();
                jRip.setSeed(Configuration.SEED);
                if (Configuration.WEKA_OPTIONS.length > 0) {
                    try {
                        jRip.setOptions(Configuration.WEKA_OPTIONS);
                    } catch (Exception e7) {
                        System.err.println("Invalid WEKA options - running with default settings.");
                    }
                }
                return jRip;
            default:
                return null;
        }
    }

    private J48 createJ48() {
        J48 j48 = new J48();
        j48.setUseLaplace(true);
        j48.setReducedErrorPruning(true);
        j48.setSeed(Configuration.SEED);
        return j48;
    }

    private String getStringAttributeIndices(Instances instances, boolean z) {
        String str = "";
        int numAttributes = instances.numAttributes();
        if (z) {
            numAttributes--;
        }
        for (int i = 0; i < numAttributes; i++) {
            if (instances.attribute(i).isString()) {
                if (!str.equals("")) {
                    str = str + ",";
                }
                str = str + (i + 1);
            }
        }
        return str;
    }

    public Instances applyFilter(Filter filter, Instances instances) throws Exception {
        return Filter.useFilter(instances, filter);
    }

    protected StringToNominal createStringToNominalFilter(String str) {
        StringToNominal stringToNominal = new StringToNominal();
        try {
            stringToNominal.setOptions(new String[]{"-R", str});
        } catch (Exception e) {
            e.printStackTrace();
        }
        return stringToNominal;
    }

    public Instances buildInstances(Set<TraceElement> set) {
        SimpleTraceElement simpleTraceElement = (SimpleTraceElement) set.toArray()[0];
        ArrayList<Attribute> buildAttributeList = buildAttributeList(simpleTraceElement);
        Instances instances = new Instances(simpleTraceElement.getName(), buildAttributeList, set.size());
        ArrayList arrayList = new ArrayList();
        for (TraceElement traceElement : set) {
            if (traceElement.getNext() != null) {
                Set<VariableAssignment<?>> data = traceElement.getData();
                DenseInstance denseInstance = new DenseInstance(data.size() + 1);
                if (traceElement.getNext() != null) {
                    convertToInstance(data, traceElement.getNext(), denseInstance, buildAttributeList);
                    instances.add(denseInstance);
                    arrayList.add(traceElement);
                }
            }
        }
        try {
            String stringAttributeIndices = getStringAttributeIndices(instances, false);
            if (stringAttributeIndices.equals("")) {
                instances.setClassIndex(instances.numAttributes() - 1);
                return instances;
            }
            Instances filterStringToNominal = filterStringToNominal(instances, stringAttributeIndices);
            for (int i = 0; i < filterStringToNominal.size(); i++) {
                TraceElement traceElement2 = (TraceElement) arrayList.get(i);
                if (!$assertionsDisabled && traceElement2 == null) {
                    throw new AssertionError();
                }
                this.elementsToInstances.put(traceElement2, filterStringToNominal.get(i));
                this.instancesToElements.put(filterStringToNominal.get(i), traceElement2);
            }
            return filterStringToNominal;
        } catch (Exception e) {
            e.printStackTrace();
            if ($assertionsDisabled || this.elementsToInstances.size() == set.size()) {
                return instances;
            }
            throw new AssertionError();
        }
    }

    protected Instances filterStringToNominal(Instances instances, String str) throws Exception {
        instances.setClassIndex(-1);
        StringToNominal createStringToNominalFilter = createStringToNominalFilter(str);
        createStringToNominalFilter.setInputFormat(instances);
        RemoveUseless removeUseless = new RemoveUseless();
        removeUseless.setInputFormat(instances);
        applyFilter(removeUseless, instances);
        Instances applyFilter = applyFilter(createStringToNominalFilter, instances);
        applyFilter.setClassIndex(instances.numAttributes() - 1);
        return applyFilter;
    }

    protected Instances filterStringToWordSequence(Instances instances) throws Exception {
        Filter createStringToWordFilter = createStringToWordFilter(getStringAttributeIndices(instances, true));
        createStringToWordFilter.setInputFormat(instances);
        RemoveUseless removeUseless = new RemoveUseless();
        removeUseless.setInputFormat(instances);
        applyFilter(removeUseless, instances);
        Instances applyFilter = applyFilter(createStringToWordFilter, instances);
        Filter createReorderFilter = createReorderFilter();
        createReorderFilter.setInputFormat(applyFilter);
        return applyFilter(createReorderFilter, applyFilter);
    }

    private Filter createStringToWordFilter(String str) {
        StringToWordVector stringToWordVector = new StringToWordVector();
        try {
            stringToWordVector.setOptions(new String[]{"-R", str});
        } catch (Exception e) {
            e.printStackTrace();
        }
        return stringToWordVector;
    }

    private Filter createReorderFilter() {
        StringToWordVector stringToWordVector = new StringToWordVector();
        try {
            stringToWordVector.setOptions(new String[]{"-R", "last-first"});
        } catch (Exception e) {
            e.printStackTrace();
        }
        return stringToWordVector;
    }

    public static ArrayList<Attribute> buildAttributeList(SimpleTraceElement simpleTraceElement) {
        ArrayList<Attribute> arrayList = new ArrayList<>();
        for (VariableAssignment<?> variableAssignment : simpleTraceElement.getData()) {
            String name = variableAssignment.getName();
            arrayList.add(((variableAssignment instanceof DoubleVariableAssignment) || (variableAssignment instanceof IntegerVariableAssignment) || (variableAssignment instanceof BooleanVariableAssignment)) ? new Attribute(name) : new Attribute(name, (List) null));
        }
        arrayList.add(new Attribute(final_name, (List) null));
        return arrayList;
    }

    private static void convertToInstance(Set<VariableAssignment<?>> set, TraceElement traceElement, Instance instance, List<Attribute> list) {
        for (VariableAssignment<?> variableAssignment : set) {
            if (!variableAssignment.isNull()) {
                if (variableAssignment instanceof DoubleVariableAssignment) {
                    instance.setValue(findAttribute(variableAssignment.getName(), list), ((DoubleVariableAssignment) variableAssignment).getValue().doubleValue());
                } else if (variableAssignment instanceof StringVariableAssignment) {
                    instance.setValue(findAttribute(variableAssignment.getName(), list), ((StringVariableAssignment) variableAssignment).getValue());
                } else if (variableAssignment instanceof BooleanVariableAssignment) {
                    instance.setValue(findAttribute(variableAssignment.getName(), list), ((BooleanVariableAssignment) variableAssignment).getValue().booleanValue() ? 1.0d : 0.0d);
                }
            }
        }
        instance.setValue(findAttribute(final_name, list), traceElement.getName());
    }

    private static Attribute findAttribute(String str, List<Attribute> list) {
        for (Attribute attribute : list) {
            if (attribute.name().equals(str)) {
                return attribute;
            }
        }
        return null;
    }

    private void mapFunctionsToTraceElements(List<TraceElement> list, Map<String, Set<TraceElement>> map) {
        for (int i = 0; i < list.size(); i++) {
            TraceElement traceElement = list.get(i);
            String name = traceElement.getName();
            Set<TraceElement> set = map.get(name);
            if (set == null) {
                set = new HashSet();
                map.put(name, set);
            }
            set.add(traceElement);
        }
    }

    public static double computeMargin(Classifier classifier, Instance instance) {
        double d = 0.0d;
        try {
            double[] distributionForInstance = classifier.distributionForInstance(instance);
            if (distributionForInstance.length >= 2) {
                Arrays.sort(distributionForInstance);
                d = distributionForInstance[distributionForInstance.length - 1] - distributionForInstance[distributionForInstance.length - 2];
            } else {
                d = 0.0d;
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return d;
    }

    public static Instances makeInstances(Set<Instance> set, String str) {
        Instance instance = (Instance) set.toArray()[0];
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < instance.numAttributes(); i++) {
            arrayList.add(instance.attribute(i));
        }
        Instances instances = new Instances(str, arrayList, set.size());
        Iterator<Instance> it = set.iterator();
        while (it.hasNext()) {
            instances.add(it.next());
        }
        instances.setClassIndex(instances.numAttributes() - 1);
        return instances;
    }

    static {
        $assertionsDisabled = !BaseClassifierInference.class.desiredAssertionStatus();
        LOGGER = Logger.getLogger(BaseClassifierInference.class.getName());
    }
}
