package ch.swissTPH.amalid.application;

import cern.colt.matrix.impl.AbstractFormatter;
import ch.swissTPH.amalid.data.CompressedData;
import ch.swissTPH.amalid.data.Data;
import ch.swissTPH.amalid.f77Optimization.Uncmin_f77;
import ch.swissTPH.amalid.population.Population;
import ch.swissTPH.amalid.population.PopulationFactory;
import ch.swissTPH.amalid.util.Center;
import java.util.Arrays;
import java.util.Date;
import org.apache.commons.math.ConvergenceException;
import org.apache.commons.math.optimization.ConvergenceChecker;
import org.apache.commons.math.optimization.CostException;
import org.apache.commons.math.optimization.DirectSearchOptimizer;
import org.apache.commons.math.optimization.MultiDirectional;
import org.apache.commons.math.optimization.NelderMead;
import org.apache.commons.math.optimization.PointCostPair;
import org.apache.commons.math.stat.StatUtils;

/* loaded from: input_file:main/main.jar:ch/swissTPH/amalid/application/Optimizer.class */
public class Optimizer {
    private OutFile outFile;
    private String outFilePrefix;
    private PointCostPair optimum = new PointCostPair(defaultInits(), 1.0E98d);
    private String accurateComment = "\nFitting to dataset\n";

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:main/main.jar:ch/swissTPH/amalid/application/Optimizer$Checker.class */
    public class Checker implements ConvergenceChecker {
        public Checker() {
        }

        @Override // org.apache.commons.math.optimization.ConvergenceChecker
        public boolean converged(PointCostPair[] pointCostPairArr) {
            double[] dArr = new double[pointCostPairArr.length];
            for (int i = 0; i < pointCostPairArr.length; i++) {
                dArr[i] = pointCostPairArr[i].getCost();
            }
            return StatUtils.variance(dArr, dArr[0]) < 0.06d && dArr[1] - dArr[0] < 0.001d;
        }
    }

    public Optimizer(String str) {
        this.outFilePrefix = "";
        this.outFilePrefix = str;
    }

    public void minimize(int[] iArr, Data data, double[] dArr) {
        if (dArr == null) {
            dArr = defaultInits();
        }
        this.optimum = new PointCostPair(dArr, PopulationFactory.getPopulation(iArr, data).getAic(dArr, false));
        this.optimum = repeatedFit(iArr, data, this.optimum.getPoint(), 0.1d, false, this.accurateComment, Center.getMaxFittingIter());
    }

    public void twoStepMinimize(int[] iArr, Data data, double[] dArr) {
        if (dArr == null) {
            dArr = defaultInits();
        }
        this.optimum = fit(iArr, new CompressedData(data), dArr, 0.01d, false, "\nFitting to compressed dataset\n");
        this.optimum = repeatedFit(iArr, data, this.optimum.getPoint(), 0.1d, false, this.accurateComment, Center.getMaxFittingIter());
    }

    public PointCostPair repeatedFit(int[] iArr, Data data, double[] dArr, double d, boolean z, String str, int i) {
        PointCostPair pointCostPair = new PointCostPair(dArr, 1.0E98d);
        int i2 = 0;
        while (true) {
            if (i2 >= i) {
                break;
            }
            PointCostPair fit = pointCostPair == null ? fit(iArr, data, dArr, d, false, str) : fit(iArr, data, pointCostPair.getPoint(), d, false, str);
            if (Math.abs(fit.getCost() - pointCostPair.getCost()) < 0.001d) {
                this.outFile.append("\nEncountered Known Maximum. Success!\n");
                pointCostPair = new PointCostPair(fit.getPoint(), fit.getCost());
                break;
            }
            if (fit.getCost() < pointCostPair.getCost()) {
                pointCostPair = new PointCostPair(fit.getPoint(), fit.getCost());
                this.outFile.append("\nMaximum different from last one. Restart recommended.");
            }
            i2++;
        }
        return pointCostPair;
    }

    private PointCostPair fit(int[] iArr, Data data, double[] dArr, double d, boolean z, String str) {
        return Center.isUseEM() ? emFit(iArr, data, dArr, str) : Center.useUncmin() ? uncminFit(iArr, data, dArr, str) : directSearchFit(iArr, data, dArr, d, z, str);
    }

    private PointCostPair emFit(int[] iArr, Data data, double[] dArr, String str) {
        PointCostPair pointCostPair = null;
        Population population = PopulationFactory.getPopulation(iArr, data);
        population.getAic(dArr, false);
        this.outFile = new OutFile(this.outFilePrefix + "Fitting" + Arrays.toString(iArr));
        this.outFile.append(AbstractFormatter.DEFAULT_ROW_SEPARATOR + new Date().toString() + AbstractFormatter.DEFAULT_ROW_SEPARATOR + Center.getConfigString());
        this.outFile.append(str + "using EM-Algorithm\n");
        double[] dArr2 = new double[Center.getNrParameters()];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = dArr[i];
        }
        try {
            pointCostPair = new EMAlgorithm(population, dArr2).minimize();
        } catch (CostException e) {
            e.printStackTrace();
            System.exit(1);
        }
        population.getAic(pointCostPair.getPoint(), false);
        this.outFile.append("\nConfiguration: " + Arrays.toString(iArr));
        this.outFile.append("\nTransformed parameters: " + Arrays.toString(Center.getTransformedParams()));
        this.outFile.append("\nReal parameters: " + Arrays.toString(Center.getRealParams()));
        this.outFile.append("\nValue: " + pointCostPair.getCost());
        this.outFile.append(AbstractFormatter.DEFAULT_ROW_SEPARATOR + new Date().toString());
        this.outFile.close();
        return pointCostPair;
    }

    private PointCostPair directSearchFit(int[] iArr, Data data, double[] dArr, double d, boolean z, String str) {
        DirectSearchOptimizer nelderMead = z ? new NelderMead() : new MultiDirectional(30.0d, 0.5d);
        Checker checker = new Checker();
        PointCostPair pointCostPair = null;
        Population population = PopulationFactory.getPopulation(iArr, data);
        population.getAic(dArr, false);
        this.outFile = new OutFile(this.outFilePrefix + "Fitting" + Arrays.toString(iArr));
        this.outFile.append(AbstractFormatter.DEFAULT_ROW_SEPARATOR + new Date().toString() + AbstractFormatter.DEFAULT_ROW_SEPARATOR + Center.getConfigString());
        this.outFile.append(str);
        double[] dArr2 = new double[Center.getNrParameters()];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = dArr[i];
        }
        double[] dArr3 = new double[Center.getNrParameters()];
        for (int i2 = 0; i2 < dArr3.length; i2++) {
            dArr3[i2] = dArr[i2] + (((-1) + (2 * Math.round(Math.random()))) * d);
        }
        try {
            pointCostPair = nelderMead.minimize(population, 50000, checker, dArr2, dArr3, 1, new Date().getTime());
        } catch (ConvergenceException e) {
            e.printStackTrace();
            System.err.println("The above error probably happened because the maximum nr. of iterations was exceed. Try restarting from the best point found so far.");
            System.exit(1);
        } catch (CostException e2) {
            e2.printStackTrace();
        }
        population.getAic(pointCostPair.getPoint(), false);
        this.outFile.append("\nConfiguration: " + Arrays.toString(iArr));
        this.outFile.append("\nTransformed parameters: " + Arrays.toString(Center.getTransformedParams()));
        this.outFile.append("\nReal parameters: " + Arrays.toString(Center.getRealParams()));
        this.outFile.append("\nValue: " + pointCostPair.getCost());
        this.outFile.append(AbstractFormatter.DEFAULT_ROW_SEPARATOR + new Date().toString());
        this.outFile.close();
        return pointCostPair;
    }

    private PointCostPair uncminFit(int[] iArr, Data data, double[] dArr, String str) {
        Population population = PopulationFactory.getPopulation(iArr, data);
        population.getAic(dArr, false);
        this.outFile = new OutFile(this.outFilePrefix + "Fitting" + Arrays.toString(iArr));
        this.outFile.append(AbstractFormatter.DEFAULT_ROW_SEPARATOR + new Date().toString() + AbstractFormatter.DEFAULT_ROW_SEPARATOR + Center.getConfigString());
        this.outFile.append(str + "using UncMin-Algorithm\n");
        double[] dArr2 = new double[Center.getNrParameters() + 1];
        for (int i = 1; i <= Center.getNrParameters(); i++) {
            dArr2[i] = Math.log(dArr[i - 1]);
        }
        new Uncmin_f77();
        double[] dArr3 = new double[dArr2.length];
        double[] dArr4 = new double[dArr2.length];
        Uncmin_f77.optif0_f77(Center.getNrParameters(), dArr2, population, dArr3, dArr4, new double[dArr2.length], new int[dArr2.length], new double[dArr2.length][dArr2.length], new double[dArr2.length]);
        double[] dArr5 = new double[Center.getNrParameters()];
        for (int i2 = 1; i2 <= Center.getNrParameters(); i2++) {
            dArr5[i2 - 1] = Math.exp(dArr3[i2]);
        }
        PointCostPair pointCostPair = new PointCostPair(dArr5, dArr4[1]);
        population.getAic(pointCostPair.getPoint(), false);
        this.outFile.append("\nConfiguration: " + Arrays.toString(iArr));
        this.outFile.append("\nTransformed parameters: " + Arrays.toString(Center.getTransformedParams()));
        this.outFile.append("\nReal parameters: " + Arrays.toString(Center.getRealParams()));
        this.outFile.append("\nValue: " + pointCostPair.getCost());
        this.outFile.append(AbstractFormatter.DEFAULT_ROW_SEPARATOR + new Date().toString());
        this.outFile.close();
        return pointCostPair;
    }

    private double[] defaultInits() {
        double[] dArr = new double[100];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = 1.0d;
        }
        return dArr;
    }

    public PointCostPair getOptimum() {
        return this.optimum;
    }
}
