package org.bitbucket.kienerj.higgsbosonclassifier;

import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import org.slf4j.ext.XLogger;
import org.slf4j.ext.XLoggerFactory;
import org.slf4j.profiler.Profiler;
import weka.Run;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.classifiers.meta.FilteredClassifier;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.core.WekaPackageManager;
import weka.core.converters.CSVLoader;
import weka.filters.unsupervised.attribute.Remove;

/* loaded from: input_file:org/bitbucket/kienerj/higgsbosonclassifier/App.class */
public class App {
    private static final XLogger logger = XLoggerFactory.getXLogger("HiggsBosonClassifier");
    private static long seed = 1;
    private static int numFolds = 3;

    private static Instances getTrainingData() throws Exception {
        logger.info("Loading training data");
        CSVLoader cSVLoader = new CSVLoader();
        cSVLoader.setSource(new File("training.csv"));
        Instances dataSet = cSVLoader.getDataSet();
        dataSet.setClassIndex(dataSet.numAttributes() - 1);
        return dataSet;
    }

    private static double findThreshold(Classifier classifier, Instances instances) throws Exception {
        logger.info("Creating training and test split");
        Instances instances2 = new Instances(instances);
        Random random = new Random(seed);
        instances2.randomize(random);
        instances2.stratify(numFolds);
        double[][] dArr = new double[3][instances2.numInstances()];
        int i = 0;
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(3);
        LinkedBlockingQueue linkedBlockingQueue = new LinkedBlockingQueue();
        for (int i2 = 0; i2 < numFolds; i2++) {
            linkedBlockingQueue.put(newFixedThreadPool.submit(new ClassifierBuilder(numFolds, i2, random, instances2, classifier)));
        }
        newFixedThreadPool.shutdown();
        for (int i3 = 0; i3 < numFolds; i3++) {
            logger.info("Getting probabilities for test split");
            ValidationRun validationRun = (ValidationRun) ((Future) linkedBlockingQueue.poll()).get();
            for (int i4 = 0; i4 < validationRun.getTestData().numInstances(); i4++) {
                Instance instance = validationRun.getTestData().instance(i4);
                dArr[0][i] = validationRun.getClassifier().distributionForInstance(instance)[0];
                dArr[1][i] = instance.classValue();
                dArr[2][i] = instance.value(validationRun.getTestData().numAttributes() - 2);
                i++;
            }
        }
        logger.info("Sorting " + dArr[0].length + " probabilities");
        int[] sort = Utils.sort(dArr[0]);
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 1.0d;
        double approximateMedianSignificance = approximateMedianSignificance(KStarConstants.FLOOR, KStarConstants.FLOOR);
        logger.info("Initial AMS " + approximateMedianSignificance);
        for (int length = sort.length - 1; length >= 0; length--) {
            if (dArr[1][sort[length]] == KStarConstants.FLOOR) {
                d += dArr[2][sort[length]];
            } else {
                d2 += dArr[2][sort[length]];
            }
            double approximateMedianSignificance2 = approximateMedianSignificance(d, d2);
            if (approximateMedianSignificance2 > approximateMedianSignificance) {
                approximateMedianSignificance = approximateMedianSignificance2;
                d3 = dArr[0][sort[length]];
            }
        }
        logger.info("Maximum AMS " + approximateMedianSignificance + " found for threshold " + d3);
        return d3;
    }

    private static double approximateMedianSignificance(double d, double d2) {
        return Math.sqrt(2.0d * ((((d + d2) + 10.0d) * Math.log(1.0d + (d / (d2 + 10.0d)))) - d));
    }

    private static Instances getTestData() throws Exception {
        logger.info("Loading test data");
        CSVLoader cSVLoader = new CSVLoader();
        cSVLoader.setSource(new File("test.csv"));
        Instances dataSet = cSVLoader.getDataSet();
        dataSet.insertAttributeAt(new Attribute("Weight"), dataSet.numAttributes());
        ArrayList arrayList = new ArrayList(2);
        arrayList.add("s");
        arrayList.add("b");
        dataSet.insertAttributeAt(new Attribute("Label", arrayList), dataSet.numAttributes());
        dataSet.setClassIndex(dataSet.numAttributes() - 1);
        return dataSet;
    }

    private static Classifier getClassifier(String[] strArr, Instances instances) throws Exception {
        String str = strArr[0];
        String[] strArr2 = new String[strArr.length - 1];
        if (strArr2.length > 0) {
            System.arraycopy(strArr, 1, strArr2, 0, strArr2.length);
        }
        List<String> findSchemeMatch = Run.findSchemeMatch(str, true);
        if (findSchemeMatch.size() > 1) {
            Iterator<String> it = findSchemeMatch.iterator();
            while (it.hasNext()) {
                logger.info(it.next());
            }
            logger.info("More than one scheme name matches -- exiting");
            System.exit(1);
        }
        Classifier forName = AbstractClassifier.forName(findSchemeMatch.get(0), strArr2);
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        filteredClassifier.setClassifier(forName);
        Remove remove = new Remove();
        remove.setAttributeIndices("1, " + (instances.numAttributes() - 1));
        filteredClassifier.setFilter(remove);
        return filteredClassifier;
    }

    private static int[][] getPredictions(Classifier classifier, Instances instances, double d) throws Exception {
        logger.info("Getting predictions");
        double[] dArr = new double[instances.numInstances()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = classifier.distributionForInstance(instances.instance(i))[0];
        }
        int[] sort = Utils.sort(dArr);
        int[] iArr = new int[sort.length];
        for (int i2 = 0; i2 < sort.length; i2++) {
            iArr[sort[i2]] = i2 + 1;
        }
        int[][] iArr2 = new int[instances.numInstances()][3];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            iArr2[i3][0] = (int) instances.instance(i3).value(0);
            iArr2[i3][1] = iArr[i3];
            iArr2[i3][2] = dArr[i3] >= d ? 0 : 1;
        }
        return iArr2;
    }

    private static void outputPredictions(int[][] iArr, String str) throws Exception {
        PrintWriter printWriter = new PrintWriter(new File("WEKA_WEIGHTED_3CV_OPT_THRESH" + str + ".sub"));
        logger.info("Saving predictions");
        printWriter.println("EventId,RankOrder,Class");
        for (int i = 0; i < iArr.length; i++) {
            printWriter.println(iArr[i][0] + "," + iArr[i][1] + "," + (iArr[i][2] == 0 ? "s" : "b"));
        }
        printWriter.close();
    }

    public static void main(String[] strArr) throws Exception {
        Profiler profiler = new Profiler("main");
        profiler.setLogger(logger);
        String str = new String(strArr[0]);
        for (int i = 1; i < strArr.length; i++) {
            str = str + "_" + strArr[i];
        }
        profiler.start("Load Packages");
        WekaPackageManager.loadPackages(false, true, false);
        profiler.start("Get Training Data");
        Instances trainingData = getTrainingData();
        profiler.start("Get Classifier");
        Classifier classifier = getClassifier(strArr, trainingData);
        profiler.start("Find Threshold");
        double findThreshold = findThreshold(AbstractClassifier.makeCopy(classifier), trainingData);
        profiler.start("Classification");
        outputPredictions(getPredictions(ClassifierBuilder.buildClassifier(AbstractClassifier.makeCopy(classifier), trainingData), getTestData(), findThreshold), str);
        profiler.stop().log();
    }
}
