package com.jujutsu.tsne;

import com.jujutsu.tsne.TSne;
import com.jujutsu.utils.EjmlOps;
import com.jujutsu.utils.MatrixOps;
import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

/* loaded from: input_file:com/jujutsu/tsne/FastTSne.class */
public class FastTSne implements TSne {
    MatrixOps mo = new MatrixOps();
    protected volatile boolean abort = false;

    public static double[][] readBinaryDoubleMatrix(int i, int i2, String str) throws FileNotFoundException, IOException {
        double[][] dArr = new double[i][i2];
        DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(str).getAbsolutePath())));
        Throwable th = null;
        for (double[] dArr2 : dArr) {
            try {
                try {
                    for (int i3 = 0; i3 < dArr[0].length; i3++) {
                        dArr2[i3] = dataInputStream.readDouble();
                    }
                } finally {
                }
            } catch (Throwable th2) {
                if (dataInputStream != null) {
                    if (th != null) {
                        try {
                            dataInputStream.close();
                        } catch (Throwable th3) {
                            th.addSuppressed(th3);
                        }
                    } else {
                        dataInputStream.close();
                    }
                }
                throw th2;
            }
        }
        if (dataInputStream != null) {
            if (0 != 0) {
                try {
                    dataInputStream.close();
                } catch (Throwable th4) {
                    th.addSuppressed(th4);
                }
            } else {
                dataInputStream.close();
            }
        }
        return dArr;
    }

    @Override // com.jujutsu.tsne.TSne
    public double[][] tsne(TSneConfiguration tSneConfiguration) {
        double[][] xin = tSneConfiguration.getXin();
        int outputDims = tSneConfiguration.getOutputDims();
        int initialDims = tSneConfiguration.getInitialDims();
        double perplexity = tSneConfiguration.getPerplexity();
        int maxIter = tSneConfiguration.getMaxIter();
        boolean usePca = tSneConfiguration.usePca();
        String simpleName = getClass().getSimpleName();
        System.out.println("X:Shape is = " + xin.length + " x " + xin[0].length);
        System.out.println("Running " + simpleName + ".");
        System.currentTimeMillis();
        long currentTimeMillis = System.currentTimeMillis();
        if (usePca && xin[0].length > initialDims && initialDims > 0) {
            xin = new PrincipalComponentAnalysis().pca(xin, initialDims);
            System.out.println("X:Shape after PCA is = " + xin.length + " x " + xin[0].length);
        }
        int length = xin.length;
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(MatrixOps.rnorm(length, outputDims));
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(denseMatrix64F.numRows, denseMatrix64F.numRows);
        DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(MatrixOps.fillMatrix(length, outputDims, 0.0d));
        DenseMatrix64F denseMatrix64F4 = new DenseMatrix64F(MatrixOps.fillMatrix(length, outputDims, 0.0d));
        DenseMatrix64F denseMatrix64F5 = new DenseMatrix64F(MatrixOps.fillMatrix(length, outputDims, 1.0d));
        DenseMatrix64F denseMatrix64F6 = new DenseMatrix64F(length, outputDims);
        DenseMatrix64F denseMatrix64F7 = new DenseMatrix64F(length, outputDims);
        DenseMatrix64F denseMatrix64F8 = new DenseMatrix64F(x2p(xin, 1.0E-5d, perplexity).P);
        DenseMatrix64F denseMatrix64F9 = new DenseMatrix64F(denseMatrix64F8.numRows, denseMatrix64F8.numCols);
        DenseMatrix64F denseMatrix64F10 = new DenseMatrix64F(denseMatrix64F8);
        DenseMatrix64F denseMatrix64F11 = new DenseMatrix64F(denseMatrix64F8.numRows, denseMatrix64F8.numCols);
        DenseMatrix64F denseMatrix64F12 = new DenseMatrix64F(MatrixOps.fillMatrix(denseMatrix64F10.numRows, denseMatrix64F10.numCols, 0.0d));
        CommonOps.transpose(denseMatrix64F8, denseMatrix64F9);
        CommonOps.addEquals(denseMatrix64F8, denseMatrix64F9);
        CommonOps.divide(denseMatrix64F8, CommonOps.elementSum(denseMatrix64F8));
        EjmlOps.replaceNaN(denseMatrix64F8, Double.MIN_VALUE);
        CommonOps.scale(4.0d, denseMatrix64F8);
        EjmlOps.maximize(denseMatrix64F8, 1.0E-12d);
        System.out.println("Y:Shape is = " + denseMatrix64F.getNumRows() + " x " + denseMatrix64F.getNumCols());
        DenseMatrix64F denseMatrix64F13 = new DenseMatrix64F(denseMatrix64F.numRows, denseMatrix64F.numCols);
        DenseMatrix64F denseMatrix64F14 = new DenseMatrix64F(1, denseMatrix64F.numRows);
        DenseMatrix64F denseMatrix64F15 = new DenseMatrix64F(denseMatrix64F.numRows, denseMatrix64F.numRows);
        DenseMatrix64F denseMatrix64F16 = new DenseMatrix64F(denseMatrix64F8.numRows, denseMatrix64F8.numCols);
        int i = 0;
        while (i < maxIter && !this.abort) {
            CommonOps.elementPower(denseMatrix64F, 2.0d, denseMatrix64F13);
            CommonOps.sumRows(denseMatrix64F13, denseMatrix64F14);
            CommonOps.multAddTransB(-2.0d, denseMatrix64F, denseMatrix64F, denseMatrix64F2);
            EjmlOps.addRowVector(denseMatrix64F2, denseMatrix64F14);
            CommonOps.transpose(denseMatrix64F2);
            EjmlOps.addRowVector(denseMatrix64F2, denseMatrix64F14);
            CommonOps.add(denseMatrix64F2, 1.0d);
            CommonOps.divide(1.0d, denseMatrix64F2);
            denseMatrix64F15.set((D1Matrix64F) denseMatrix64F2);
            EjmlOps.assignAtIndex(denseMatrix64F15, MatrixOps.range(length), MatrixOps.range(length), 0.0d);
            CommonOps.divide(denseMatrix64F15, CommonOps.elementSum(denseMatrix64F15), denseMatrix64F16);
            EjmlOps.maximize(denseMatrix64F16, 1.0E-12d);
            CommonOps.subtract(denseMatrix64F8, denseMatrix64F16, denseMatrix64F10);
            CommonOps.elementMult(denseMatrix64F10, denseMatrix64F15);
            DenseMatrix64F sumRows = CommonOps.sumRows(denseMatrix64F10, null);
            double[] dArr = new double[sumRows.numRows];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr[i2] = sumRows.get(i2, 0);
            }
            EjmlOps.setDiag(denseMatrix64F12, dArr);
            CommonOps.subtract(denseMatrix64F12, denseMatrix64F10, denseMatrix64F10);
            CommonOps.mult(denseMatrix64F10, denseMatrix64F, denseMatrix64F3);
            CommonOps.scale(4.0d, denseMatrix64F3);
            double d = i < 20 ? 0.5d : 0.8d;
            boolean[][] equal = MatrixOps.equal(EjmlOps.biggerThan(denseMatrix64F3, 0.0d), EjmlOps.biggerThan(denseMatrix64F4, 0.0d));
            EjmlOps.setData(denseMatrix64F6, MatrixOps.abs(MatrixOps.negate(equal)));
            EjmlOps.setData(denseMatrix64F7, MatrixOps.abs(equal));
            DenseMatrix64F denseMatrix64F17 = new DenseMatrix64F(denseMatrix64F5);
            DenseMatrix64F denseMatrix64F18 = new DenseMatrix64F(denseMatrix64F5);
            CommonOps.add(denseMatrix64F17, 0.2d);
            CommonOps.scale(0.8d, denseMatrix64F18);
            CommonOps.elementMult(denseMatrix64F17, denseMatrix64F6);
            CommonOps.elementMult(denseMatrix64F18, denseMatrix64F7);
            CommonOps.add(denseMatrix64F17, denseMatrix64F18, denseMatrix64F5);
            EjmlOps.assignAllLessThan(denseMatrix64F5, 0.01d, 0.01d);
            CommonOps.scale(d, denseMatrix64F4);
            DenseMatrix64F denseMatrix64F19 = new DenseMatrix64F(denseMatrix64F5.numRows, denseMatrix64F3.numCols);
            CommonOps.elementMult(denseMatrix64F5, denseMatrix64F3, denseMatrix64F19);
            CommonOps.scale(500, denseMatrix64F19);
            CommonOps.subtractEquals(denseMatrix64F4, denseMatrix64F19);
            CommonOps.addEquals(denseMatrix64F, denseMatrix64F4);
            CommonOps.subtractEquals(denseMatrix64F, EjmlOps.tile(EjmlOps.colMean(denseMatrix64F, 0), length, 1));
            if (i % 100 == 0) {
                DenseMatrix64F denseMatrix64F20 = new DenseMatrix64F(denseMatrix64F8);
                CommonOps.elementDiv(denseMatrix64F20, denseMatrix64F16);
                CommonOps.elementLog(denseMatrix64F20, denseMatrix64F11);
                EjmlOps.replaceNaN(denseMatrix64F11, Double.MIN_VALUE);
                CommonOps.elementMult(denseMatrix64F11, denseMatrix64F8);
                EjmlOps.replaceNaN(denseMatrix64F11, Double.MIN_VALUE);
                double elementSum = CommonOps.elementSum(denseMatrix64F11);
                System.out.printf("Iteration %d: error is %f (50 iterations in %4.2f seconds)\n", Integer.valueOf(i), Double.valueOf(elementSum), Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
                if (elementSum < 0.0d) {
                    System.err.println("Warning: Error is negative, this is usually a very bad sign!");
                }
                currentTimeMillis = System.currentTimeMillis();
            } else if (i % 10 == 0) {
                System.out.printf("Iteration %d: (10 iterations in %4.2f seconds)\n", Integer.valueOf(i), Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
                currentTimeMillis = System.currentTimeMillis();
            }
            if (i == 100) {
                CommonOps.divide(denseMatrix64F8, 4.0d);
            }
            i++;
        }
        return EjmlOps.extractDoubleArray(denseMatrix64F);
    }

    public TSne.R Hbeta(double[][] dArr, double d) {
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(dArr);
        CommonOps.scale(-d, denseMatrix64F);
        CommonOps.elementExp(denseMatrix64F, denseMatrix64F);
        double elementSum = CommonOps.elementSum(denseMatrix64F);
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(dArr);
        CommonOps.elementMult(denseMatrix64F2, denseMatrix64F);
        double log = Math.log(elementSum) + ((d * CommonOps.elementSum(denseMatrix64F2)) / elementSum);
        CommonOps.scale(1.0d / elementSum, denseMatrix64F);
        TSne.R r = new TSne.R();
        r.H = log;
        r.P = EjmlOps.extractDoubleArray(denseMatrix64F);
        return r;
    }

    public TSne.R x2p(double[][] dArr, double d, double d2) {
        int length = dArr.length;
        double[][] sum = MatrixOps.sum(MatrixOps.square(dArr), 1);
        double[][] addRowVector = MatrixOps.addRowVector(MatrixOps.addColumnVector(this.mo.transpose(MatrixOps.scalarMult(MatrixOps.times(dArr, this.mo.transpose(dArr)), -2.0d)), sum), this.mo.transpose(sum));
        double[][] fillMatrix = MatrixOps.fillMatrix(length, length, 0.0d);
        double[] dArr2 = MatrixOps.fillMatrix(length, length, 1.0d)[0];
        double log = Math.log(d2);
        System.out.println("Starting x2p...");
        for (int i = 0; i < length; i++) {
            if (i % 500 == 0) {
                System.out.println("Computing P-values for point " + i + " of " + length + "...");
            }
            double d3 = Double.NEGATIVE_INFINITY;
            double d4 = Double.POSITIVE_INFINITY;
            double[][] valuesFromRow = MatrixOps.getValuesFromRow(addRowVector, i, MatrixOps.concatenate(MatrixOps.range(0, i), MatrixOps.range(i + 1, length)));
            TSne.R Hbeta = Hbeta(valuesFromRow, dArr2[i]);
            double d5 = Hbeta.H;
            double[][] dArr3 = Hbeta.P;
            double d6 = d5 - log;
            int i2 = 0;
            while (true) {
                int i3 = i2;
                if (Math.abs(d6) > d && i3 < 50) {
                    if (d6 > 0.0d) {
                        d3 = dArr2[i];
                        if (Double.isInfinite(d4)) {
                            dArr2[i] = dArr2[i] * 2.0d;
                        } else {
                            dArr2[i] = (dArr2[i] + d4) / 2.0d;
                        }
                    } else {
                        d4 = dArr2[i];
                        if (Double.isInfinite(d3)) {
                            dArr2[i] = dArr2[i] / 2.0d;
                        } else {
                            dArr2[i] = (dArr2[i] + d3) / 2.0d;
                        }
                    }
                    TSne.R Hbeta2 = Hbeta(valuesFromRow, dArr2[i]);
                    double d7 = Hbeta2.H;
                    dArr3 = Hbeta2.P;
                    d6 = d7 - log;
                    i2 = i3 + 1;
                }
            }
            MatrixOps.assignValuesToRow(fillMatrix, i, MatrixOps.concatenate(MatrixOps.range(0, i), MatrixOps.range(i + 1, length)), dArr3[0]);
        }
        TSne.R r = new TSne.R();
        r.P = fillMatrix;
        r.beta = dArr2;
        System.out.println("Mean value of sigma: " + MatrixOps.mean(MatrixOps.sqrt(MatrixOps.scalarInverse(dArr2))));
        return r;
    }

    @Override // com.jujutsu.tsne.TSne
    public void abort() {
        this.abort = true;
    }
}
