package com.jujutsu.utils;

import com.github.fommil.netlib.BLAS;
import com.github.fommil.netlib.LAPACK;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import org.ejml.data.DenseMatrix64F;
import org.jblas.ComplexDoubleMatrix;
import org.jblas.DoubleMatrix;
import org.jblas.Eigen;
import org.netlib.util.intW;

/* loaded from: input_file:com/jujutsu/utils/BlasOps.class */
public class BlasOps {

    /* loaded from: input_file:com/jujutsu/utils/BlasOps$PCABean.class */
    static class PCABean implements Comparable<PCABean> {
        double eigenValue;
        ComplexDoubleMatrix vector;

        public PCABean(double d, ComplexDoubleMatrix complexDoubleMatrix) {
            this.eigenValue = d;
            this.vector = complexDoubleMatrix;
        }

        @Override // java.lang.Comparable
        public int compareTo(PCABean pCABean) {
            return Double.compare(pCABean.eigenValue, this.eigenValue);
        }

        public String toString() {
            return "PCABean [eigenValue=" + this.eigenValue + ", vector=" + this.vector + "]";
        }
    }

    void benchmark(int i) {
        int sqrt = (int) Math.sqrt(i);
        double[] dArr = new double[sqrt * sqrt];
        dArr[0] = 2.0d;
        dArr[4] = 2.0d;
        dArr[8] = 2.0d;
        System.out.println("Start is: " + new DoubleMatrix(dArr).reshape(sqrt, sqrt));
        double[] copyOf = Arrays.copyOf(dArr, dArr.length);
        double[] dArr2 = new double[1];
        int[] iArr = new int[sqrt];
        intW intw = new intW(0);
        LAPACK.getInstance().dgetri(sqrt, dArr, sqrt, iArr, dArr2, -1, intw);
        double[] dArr3 = new double[(int) dArr2[0]];
        LAPACK.getInstance().dgetrf(sqrt, sqrt, dArr, sqrt, iArr, intw);
        if (intw.val != 0) {
            throw new IllegalArgumentException();
        }
        LAPACK.getInstance().dgetri(sqrt, dArr, sqrt, iArr, dArr3, dArr3.length, intw);
        if (intw.val != 0) {
            throw new IllegalArgumentException();
        }
        BLAS.getInstance().dgemm("N", "N", sqrt, sqrt, sqrt, 1.0d, copyOf, sqrt, dArr, sqrt, 0.0d, new double[sqrt * sqrt], sqrt);
        System.out.println("Result is: " + new DoubleMatrix(dArr).reshape(sqrt, sqrt));
        System.out.println("BlasInvert Result is: " + blasInvert(new DoubleMatrix(dArr)));
    }

    public static DoubleMatrix blasInvert(DoubleMatrix doubleMatrix) {
        double[] dArr = new double[doubleMatrix.columns];
        int[] iArr = new int[doubleMatrix.columns];
        intW intw = new intW(0);
        int i = doubleMatrix.columns;
        double[] array = doubleMatrix.toArray();
        LAPACK.getInstance().dgetrf(i, i, array, i, iArr, intw);
        if (intw.val != 0) {
            throw new IllegalArgumentException();
        }
        LAPACK.getInstance().dgetri(i, array, i, iArr, dArr, dArr.length, intw);
        if (intw.val != 0) {
            throw new IllegalArgumentException();
        }
        return new DoubleMatrix(array).reshape(doubleMatrix.rows, doubleMatrix.columns);
    }

    public static DenseMatrix64F blasInvertDense(DenseMatrix64F denseMatrix64F) {
        return new DenseMatrix64F(blasInvert(new DoubleMatrix((double[]) denseMatrix64F.getData().clone()).reshape(denseMatrix64F.numRows, denseMatrix64F.numCols)).toArray2());
    }

    public static DoubleMatrix square(DoubleMatrix doubleMatrix) {
        DoubleMatrix dup = doubleMatrix.dup();
        for (int i = 0; i < dup.getLength(); i++) {
            double d = dup.get(i);
            dup.put(i, d * d);
        }
        return dup;
    }

    public static DoubleMatrix scalarInverse(DoubleMatrix doubleMatrix) {
        DoubleMatrix dup = doubleMatrix.dup();
        for (int i = 0; i < dup.getLength(); i++) {
            dup.put(i, 1.0d / dup.get(i));
        }
        return dup;
    }

    public static void assignAtIndex(DoubleMatrix doubleMatrix, int[] iArr, int[] iArr2, double d) {
        for (int i = 0; i < iArr.length; i++) {
            doubleMatrix.put(iArr[i], iArr2[i], d);
        }
    }

    public static boolean[][] equal(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        boolean[][] zArr = new boolean[doubleMatrix.rows][doubleMatrix.columns];
        if (doubleMatrix.length != doubleMatrix2.length) {
            throw new IllegalArgumentException("Dimensions does not match");
        }
        if (doubleMatrix.columns != doubleMatrix2.columns) {
            throw new IllegalArgumentException("Dimensions does not match");
        }
        for (int i = 0; i < doubleMatrix.rows; i++) {
            for (int i2 = 0; i2 < doubleMatrix.columns; i2++) {
                zArr[i][i2] = Double.compare(doubleMatrix.get(i, i2), doubleMatrix2.get(i, i2)) == 0;
            }
        }
        return zArr;
    }

    public static void assignAllLessThan(DoubleMatrix doubleMatrix, double d, double d2) {
        for (int i = 0; i < doubleMatrix.length; i++) {
            if (doubleMatrix.get(i) < d) {
                doubleMatrix.put(i, d2);
            }
        }
    }

    public static DoubleMatrix log(DoubleMatrix doubleMatrix) {
        DoubleMatrix dup = doubleMatrix.dup();
        for (int i = 0; i < dup.length; i++) {
            dup.put(i, Math.log(doubleMatrix.get(i)));
        }
        return dup;
    }

    public static DoubleMatrix log(DoubleMatrix doubleMatrix, boolean z) {
        DoubleMatrix dup = doubleMatrix.dup();
        for (int i = 0; i < dup.length; i++) {
            dup.put(i, Math.log(doubleMatrix.get(i)));
            if (z && Double.isInfinite(dup.get(i))) {
                dup.put(i, 0.0d);
            }
        }
        return dup;
    }

    public static DoubleMatrix replaceNaN(DoubleMatrix doubleMatrix, double d) {
        DoubleMatrix dup = doubleMatrix.dup();
        for (int i = 0; i < doubleMatrix.length; i++) {
            if (Double.isNaN(doubleMatrix.get(i))) {
                dup.put(i, d);
            } else {
                dup.put(i, doubleMatrix.get(i));
            }
        }
        return dup;
    }

    public static DoubleMatrix tile(DoubleMatrix doubleMatrix, int i, int i2) {
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(doubleMatrix.rows * i, doubleMatrix.columns * i2);
        int i3 = 0;
        for (int i4 = 0; i4 < i; i4++) {
            for (int i5 = 0; i5 < doubleMatrix.rows; i5++) {
                int i6 = 0;
                for (int i7 = 0; i7 < i2; i7++) {
                    for (int i8 = 0; i8 < doubleMatrix.columns; i8++) {
                        int i9 = i6;
                        i6++;
                        doubleMatrix2.put(i3, i9, doubleMatrix.get(i5, i8));
                    }
                }
                i3++;
            }
        }
        return doubleMatrix2;
    }

    public static boolean[][] biggerThan(DoubleMatrix doubleMatrix, double d) {
        boolean[][] zArr = new boolean[doubleMatrix.rows][doubleMatrix.columns];
        for (int i = 0; i < doubleMatrix.rows; i++) {
            for (int i2 = 0; i2 < doubleMatrix.columns; i2++) {
                zArr[i][i2] = Double.compare(doubleMatrix.get(i, i2), d) == 1;
            }
        }
        return zArr;
    }

    public static DoubleMatrix abs(boolean[][] zArr) {
        DoubleMatrix doubleMatrix = new DoubleMatrix(zArr.length, zArr[0].length);
        for (int i = 0; i < zArr.length; i++) {
            for (int i2 = 0; i2 < zArr[0].length; i2++) {
                doubleMatrix.put(i, i2, zArr[i][i2] ? 1.0d : 0.0d);
            }
        }
        return doubleMatrix;
    }

    public static DoubleMatrix exp(DoubleMatrix doubleMatrix) {
        DoubleMatrix dup = doubleMatrix.dup();
        for (int i = 0; i < dup.length; i++) {
            dup.put(i, Math.exp(doubleMatrix.get(i)));
        }
        return dup;
    }

    public static boolean containsNaNs(DoubleMatrix doubleMatrix) {
        for (int i = 0; i < doubleMatrix.length; i++) {
            if (Double.isNaN(doubleMatrix.get(i))) {
                return true;
            }
        }
        return false;
    }

    public static DoubleMatrix sign(DoubleMatrix doubleMatrix) {
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(doubleMatrix.rows, doubleMatrix.columns);
        for (int i = 0; i < doubleMatrix.length; i++) {
            doubleMatrix2.put(i, doubleMatrix.get(i) >= 0.0d ? 1.0d : -1.0d);
        }
        return doubleMatrix2;
    }

    public static DoubleMatrix sqrt(DoubleMatrix doubleMatrix) {
        DoubleMatrix dup = doubleMatrix.dup();
        for (int i = 0; i < dup.length; i++) {
            dup.put(i, Math.sqrt(doubleMatrix.get(i)));
        }
        return dup;
    }

    public static DoubleMatrix pca(DoubleMatrix doubleMatrix, int i) {
        DoubleMatrix div = doubleMatrix.mmul(doubleMatrix.transpose()).div(doubleMatrix.columns);
        ComplexDoubleMatrix eigenvalues = Eigen.eigenvalues(div);
        ComplexDoubleMatrix complexDoubleMatrix = Eigen.eigenvectors(div)[0];
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < complexDoubleMatrix.columns; i2++) {
            arrayList.add(new PCABean(eigenvalues.get(i2).real(), complexDoubleMatrix.getColumn(i2)));
        }
        Collections.sort(arrayList);
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(i, ((PCABean) arrayList.get(0)).vector.rows);
        for (int i3 = 0; i3 < i; i3++) {
            doubleMatrix2.putRow(i3, ((PCABean) arrayList.get(i3)).vector.getReal());
        }
        return doubleMatrix2.mmul(doubleMatrix);
    }
}
