package com.jujutsu.tsne;

import com.jujutsu.tsne.TSne;
import com.jujutsu.utils.BlasOps;
import com.jujutsu.utils.MatrixOps;
import org.jblas.DoubleMatrix;

/* loaded from: input_file:com/jujutsu/tsne/BlasTSne.class */
public class BlasTSne implements TSne {
    MatrixOps mo = new MatrixOps();
    protected volatile boolean abort = false;

    @Override // com.jujutsu.tsne.TSne
    public double[][] tsne(TSneConfiguration tSneConfiguration) {
        double[][] xin = tSneConfiguration.getXin();
        int outputDims = tSneConfiguration.getOutputDims();
        int initialDims = tSneConfiguration.getInitialDims();
        double perplexity = tSneConfiguration.getPerplexity();
        int maxIter = tSneConfiguration.getMaxIter();
        boolean usePca = tSneConfiguration.usePca();
        String simpleName = getClass().getSimpleName();
        System.out.println("X:Shape is = " + xin.length + " x " + xin[0].length);
        System.out.println("Running " + simpleName + ".");
        if (usePca && xin[0].length > initialDims && initialDims > 0) {
            xin = new PrincipalComponentAnalysis().pca(xin, initialDims);
            System.out.println("X:Shape after PCA is = " + xin.length + " x " + xin[0].length);
        }
        int length = xin.length;
        DoubleMatrix randn = DoubleMatrix.randn(length, outputDims);
        DoubleMatrix.zeros(length, outputDims);
        DoubleMatrix zeros = DoubleMatrix.zeros(length, outputDims);
        DoubleMatrix ones = DoubleMatrix.ones(length, outputDims);
        DoubleMatrix doubleMatrix = new DoubleMatrix(x2p(xin, 1.0E-5d, perplexity).P);
        DoubleMatrix add = doubleMatrix.add(doubleMatrix.transpose());
        DoubleMatrix max = add.div(add.sum()).mul(4.0d).max(1.0E-12d);
        System.out.println("Y:Shape is = " + randn.rows + " x " + randn.columns);
        int i = 0;
        while (i < maxIter && !this.abort) {
            DoubleMatrix transpose = BlasOps.square(randn).rowSums().transpose();
            DoubleMatrix scalarInverse = BlasOps.scalarInverse(randn.mmul(randn.transpose()).mul(-2.0d).addRowVector(transpose).transpose().addRowVector(transpose).add(1.0d));
            BlasOps.assignAtIndex(scalarInverse, MatrixOps.range(length), MatrixOps.range(length), 0.0d);
            DoubleMatrix max2 = scalarInverse.div(scalarInverse.sum()).max(1.0E-12d);
            DoubleMatrix mul = max.sub(max2).mul(scalarInverse);
            DoubleMatrix mul2 = DoubleMatrix.diag(mul.rowSums()).sub(mul).mmul(randn).mul(4.0d);
            double d = i < 20 ? 0.5d : 0.8d;
            DoubleMatrix doubleMatrix2 = new DoubleMatrix();
            doubleMatrix2.copy(ones);
            DoubleMatrix doubleMatrix3 = new DoubleMatrix();
            doubleMatrix3.copy(ones);
            ones = doubleMatrix2.add(0.2d).mul(BlasOps.abs(MatrixOps.negate(MatrixOps.equal(BlasOps.biggerThan(mul2, 0.0d), BlasOps.biggerThan(zeros, 0.0d))))).add(doubleMatrix3.mul(0.8d).mul(BlasOps.abs(MatrixOps.equal(BlasOps.biggerThan(mul2, 0.0d), BlasOps.biggerThan(zeros, 0.0d)))));
            BlasOps.assignAllLessThan(ones, 0.01d, 0.01d);
            zeros = zeros.mul(d).sub(ones.mul(mul2).mul(500));
            DoubleMatrix add2 = randn.add(zeros);
            randn = add2.sub(BlasOps.tile(add2.columnMeans(), length, 1));
            if (i % 100 == 0) {
                System.out.println("Iteration " + i + ": error is " + max.mul(BlasOps.replaceNaN(BlasOps.log(max.div(max2)), 0.0d)).sum());
            } else if (i % 10 == 0) {
                System.out.println("Iteration " + i);
            }
            if (i == 100) {
                max = max.div(4.0d);
            }
            i++;
        }
        return randn.toArray2();
    }

    public TSne.R Hbeta(double[][] dArr, double d) {
        DoubleMatrix doubleMatrix = new DoubleMatrix(dArr);
        DoubleMatrix exp = BlasOps.exp(doubleMatrix.mul(-d));
        double sum = exp.sum();
        double log = Math.log(sum) + ((d * doubleMatrix.mul(exp).sum()) / sum);
        DoubleMatrix div = exp.div(sum);
        TSne.R r = new TSne.R();
        r.H = log;
        r.P = div.toArray2();
        return r;
    }

    public TSne.R x2p(double[][] dArr, double d, double d2) {
        int length = dArr.length;
        double[][] sum = MatrixOps.sum(MatrixOps.square(dArr), 1);
        double[][] addRowVector = MatrixOps.addRowVector(MatrixOps.addColumnVector(this.mo.transpose(MatrixOps.scalarMult(MatrixOps.times(dArr, this.mo.transpose(dArr)), -2.0d)), sum), this.mo.transpose(sum));
        double[][] fillMatrix = MatrixOps.fillMatrix(length, length, 0.0d);
        double[] dArr2 = MatrixOps.fillMatrix(length, length, 1.0d)[0];
        double log = Math.log(d2);
        System.out.println("Starting x2p...");
        for (int i = 0; i < length; i++) {
            if (i % 500 == 0) {
                System.out.println("Computing P-values for point " + i + " of " + length + "...");
            }
            double d3 = Double.NEGATIVE_INFINITY;
            double d4 = Double.POSITIVE_INFINITY;
            double[][] valuesFromRow = MatrixOps.getValuesFromRow(addRowVector, i, MatrixOps.concatenate(MatrixOps.range(0, i), MatrixOps.range(i + 1, length)));
            TSne.R Hbeta = Hbeta(valuesFromRow, dArr2[i]);
            double d5 = Hbeta.H;
            double[][] dArr3 = Hbeta.P;
            double d6 = d5 - log;
            int i2 = 0;
            while (true) {
                int i3 = i2;
                if (Math.abs(d6) > d && i3 < 50) {
                    if (d6 > 0.0d) {
                        d3 = dArr2[i];
                        if (Double.isInfinite(d4)) {
                            dArr2[i] = dArr2[i] * 2.0d;
                        } else {
                            dArr2[i] = (dArr2[i] + d4) / 2.0d;
                        }
                    } else {
                        d4 = dArr2[i];
                        if (Double.isInfinite(d3)) {
                            dArr2[i] = dArr2[i] / 2.0d;
                        } else {
                            dArr2[i] = (dArr2[i] + d3) / 2.0d;
                        }
                    }
                    TSne.R Hbeta2 = Hbeta(valuesFromRow, dArr2[i]);
                    double d7 = Hbeta2.H;
                    dArr3 = Hbeta2.P;
                    d6 = d7 - log;
                    i2 = i3 + 1;
                }
            }
            MatrixOps.assignValuesToRow(fillMatrix, i, MatrixOps.concatenate(MatrixOps.range(0, i), MatrixOps.range(i + 1, length)), dArr3[0]);
        }
        TSne.R r = new TSne.R();
        r.P = fillMatrix;
        r.beta = dArr2;
        System.out.println("Mean value of sigma: " + MatrixOps.mean(MatrixOps.sqrt(MatrixOps.scalarInverse(dArr2))));
        return r;
    }

    @Override // com.jujutsu.tsne.TSne
    public void abort() {
        this.abort = true;
    }
}
