package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.AlphabetFactory;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Randoms;
import gnu.trove.TIntDoubleHashMap;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import java.util.zip.GZIPOutputStream;

/* loaded from: input_file:cc/mallet/topics/WeightedTopicModel.class */
public class WeightedTopicModel implements Serializable {
    private static Logger logger;
    static CommandOption.String inputFile;
    static CommandOption.String weightsFile;
    static CommandOption.String evaluatorFilename;
    static CommandOption.String stateFile;
    static CommandOption.Integer numTopicsOption;
    static CommandOption.Integer numEpochsOption;
    static CommandOption.Integer numIterationsOption;
    static CommandOption.Integer randomSeedOption;
    static CommandOption.Double alphaOption;
    static CommandOption.Double betaOption;
    public static Pattern sourceWordPattern;
    public static Pattern targetWordPattern;
    protected Alphabet alphabet;
    protected LabelAlphabet topicAlphabet;
    protected int numTopics;
    protected int numTypes;
    protected double alpha;
    protected double alphaSum;
    protected double beta;
    protected double betaSum;
    protected int[] oneDocTopicCounts;
    protected int[][] typeTopicCounts;
    protected int[] tokensPerTopic;
    protected TIntDoubleHashMap[] typeTypeWeights;
    protected double[][] logTypeTopicWeights;
    protected double[][] typeTopicWeights;
    protected double[] totalTopicWeights;
    protected Randoms random;
    protected double[] logCountRatioCache;
    static final /* synthetic */ boolean $assertionsDisabled;
    public int showTopicsInterval = 50;
    public int wordsPerTopic = 10;
    protected boolean printLogLikelihood = false;
    protected ArrayList<TopicAssignment> data = new ArrayList<>();
    protected NumberFormat formatter = NumberFormat.getInstance();

    public WeightedTopicModel(int i, double d, double d2, Randoms randoms) {
        this.topicAlphabet = AlphabetFactory.labelAlphabetOfSize(i);
        this.numTopics = this.topicAlphabet.size();
        this.alphaSum = d;
        this.alpha = d / this.numTopics;
        this.beta = d2;
        this.random = randoms;
        this.oneDocTopicCounts = new int[this.numTopics];
        this.tokensPerTopic = new int[this.numTopics];
        this.formatter.setMaximumFractionDigits(5);
        logger.info("Weighted LDA: " + this.numTopics + " topics");
    }

    public Alphabet getAlphabet() {
        return this.alphabet;
    }

    public LabelAlphabet getTopicAlphabet() {
        return this.topicAlphabet;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public ArrayList<TopicAssignment> getData() {
        return this.data;
    }

    public void setTopicDisplay(int i, int i2) {
        this.showTopicsInterval = i;
        this.wordsPerTopic = i2;
    }

    public void setRandomSeed(int i) {
        this.random = new Randoms(i);
    }

    public int[][] getTypeTopicCounts() {
        return this.typeTopicCounts;
    }

    public int[] getTopicTotals() {
        return this.tokensPerTopic;
    }

    public void addInstances(InstanceList instanceList) {
        this.alphabet = instanceList.getDataAlphabet();
        this.numTypes = this.alphabet.size();
        this.betaSum = this.beta * this.numTypes;
        this.typeTopicCounts = new int[this.numTypes][this.numTopics];
        this.typeTopicWeights = new double[this.numTypes][this.numTopics];
        this.totalTopicWeights = new double[this.numTopics];
        for (int i = 0; i < this.numTypes; i++) {
            Arrays.fill(this.typeTopicWeights[i], this.beta);
        }
        Arrays.fill(this.totalTopicWeights, this.betaSum);
        int i2 = 0;
        Iterator<Instance> it = instanceList.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            i2++;
            this.data.add(new TopicAssignment(next, new LabelSequence(this.topicAlphabet, new int[((FeatureSequence) next.getData()).size()])));
        }
    }

    public void readTypeTypeWeights(File file) throws Exception {
        this.typeTypeWeights = new TIntDoubleHashMap[this.numTypes];
        logger.info("num types: " + this.numTypes);
        for (int i = 0; i < this.numTypes; i++) {
            this.typeTypeWeights[i] = new TIntDoubleHashMap();
            this.typeTypeWeights[i].put(i, 1.0d);
        }
        BufferedReader bufferedReader = new BufferedReader(new FileReader(file));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                return;
            }
            String[] split = readLine.split("\t");
            double d = 0.0d;
            for (int i2 = 1; i2 < split.length; i2 += 2) {
                d += Double.parseDouble(split[i2]);
            }
            int lookupIndex = this.alphabet.lookupIndex(split[0]);
            this.typeTypeWeights[lookupIndex].put(lookupIndex, Double.parseDouble(split[1]) / d);
            for (int i3 = 2; i3 < split.length; i3 += 2) {
                this.typeTypeWeights[lookupIndex].put(this.alphabet.lookupIndex(split[i3]), Double.parseDouble(split[i3 + 1]) / d);
            }
        }
    }

    public void sample(int i, boolean z, int i2) throws IOException {
        int i3 = 1;
        while (i3 <= i) {
            long currentTimeMillis = System.currentTimeMillis();
            for (int i4 = 0; i4 < this.data.size(); i4++) {
                FeatureSequence featureSequence = (FeatureSequence) this.data.get(i4).instance.getData();
                LabelSequence labelSequence = this.data.get(i4).topicSequence;
                sampleTopicsForOneDoc(featureSequence, labelSequence, z && i3 == 1, false);
                for (int i5 = 1; i5 < i2; i5++) {
                    sampleTopicsForOneDoc(featureSequence, labelSequence, false, false);
                }
            }
            logger.info(i3 + "\t" + (System.currentTimeMillis() - currentTimeMillis) + "ms\t");
            if (this.showTopicsInterval != 0 && i3 % this.showTopicsInterval == 0) {
                logger.info("<" + i3 + ">\n" + topWords(this.wordsPerTopic));
            }
            i3++;
        }
    }

    protected void sampleTopicsForOneDoc(FeatureSequence featureSequence, FeatureSequence featureSequence2, boolean z, boolean z2) {
        int[] features = featureSequence2.getFeatures();
        int length = featureSequence.getLength();
        int[] iArr = new int[this.numTopics];
        if (!z) {
            for (int i = 0; i < length; i++) {
                int i2 = features[i];
                iArr[i2] = iArr[i2] + 1;
            }
        }
        double[] dArr = new double[this.numTopics];
        for (int i3 = 0; i3 < length; i3++) {
            int indexAtPosition = featureSequence.getIndexAtPosition(i3);
            int i4 = features[i3];
            TIntDoubleHashMap tIntDoubleHashMap = this.typeTypeWeights[indexAtPosition];
            int[] keys = tIntDoubleHashMap.keys();
            int[] iArr2 = this.typeTopicCounts[indexAtPosition];
            double[] dArr2 = this.typeTopicWeights[indexAtPosition];
            if (!z) {
                iArr[i4] = iArr[i4] - 1;
                int[] iArr3 = this.tokensPerTopic;
                iArr3[i4] = iArr3[i4] - 1;
                if (!$assertionsDisabled && this.tokensPerTopic[i4] < 0) {
                    throw new AssertionError();
                }
                iArr2[i4] = iArr2[i4] - 1;
                int i5 = iArr2[i4];
                for (int i6 : keys) {
                    double d = tIntDoubleHashMap.get(i6);
                    double[] dArr3 = this.typeTopicWeights[i6];
                    dArr3[i4] = dArr3[i4] - d;
                    double[] dArr4 = this.totalTopicWeights;
                    dArr4[i4] = dArr4[i4] - d;
                }
            }
            double d2 = 0.0d;
            for (int i7 = 0; i7 < this.numTopics; i7++) {
                double d3 = (this.alpha + iArr[i7]) * (dArr2[i7] / this.totalTopicWeights[i7]);
                d2 += d3;
                dArr[i7] = d3;
                if (z2 && indexAtPosition == 68) {
                    System.out.println(indexAtPosition + "\t" + i7 + "\t" + iArr[i7] + "\t" + iArr2[i7] + "\t" + dArr2[i7] + "\t" + this.tokensPerTopic[i7] + "\t" + d2);
                }
            }
            double nextUniform = this.random.nextUniform() * d2;
            if (z2) {
                System.out.println("sample " + nextUniform + " / " + d2);
            }
            int i8 = -1;
            while (nextUniform > 0.0d) {
                i8++;
                nextUniform -= dArr[i8];
            }
            if (z2 || i8 == -1) {
            }
            features[i3] = i8;
            int i9 = i8;
            iArr[i9] = iArr[i9] + 1;
            int[] iArr4 = this.tokensPerTopic;
            int i10 = i8;
            iArr4[i10] = iArr4[i10] + 1;
            int i11 = i8;
            iArr2[i11] = iArr2[i11] + 1;
            int i12 = iArr2[i8];
            for (int i13 : keys) {
                double d4 = tIntDoubleHashMap.get(i13);
                double[] dArr5 = this.typeTopicWeights[i13];
                int i14 = i8;
                dArr5[i14] = dArr5[i14] + d4;
                double[] dArr6 = this.totalTopicWeights;
                int i15 = i8;
                dArr6[i15] = dArr6[i15] + d4;
            }
        }
    }

    public String topWords(int i) {
        StringBuilder sb = new StringBuilder();
        IDSorter[] iDSorterArr = new IDSorter[this.numTypes];
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            for (int i3 = 0; i3 < this.numTypes; i3++) {
                iDSorterArr[i3] = new IDSorter(i3, this.typeTopicCounts[i3][i2]);
            }
            Arrays.sort(iDSorterArr);
            sb.append(i2 + "\t" + this.tokensPerTopic[i2] + "\t" + this.formatter.format(this.totalTopicWeights[i2]));
            for (int i4 = 0; i4 < i; i4++) {
                sb.append(this.alphabet.lookupObject(iDSorterArr[i4].getID()) + " ");
            }
            sb.append("\n");
        }
        return sb.toString();
    }

    /* JADX WARN: Type inference failed for: r0v12, types: [int[], int[][]] */
    public MarginalProbEstimator getEstimator() {
        int bitCount = Integer.bitCount(this.numTopics) == 1 ? Integer.bitCount(this.numTopics - 1) : Integer.bitCount((Integer.highestOneBit(this.numTopics) * 2) - 1);
        ?? r0 = new int[this.numTypes];
        for (int i = 0; i < this.numTypes; i++) {
            int[] iArr = this.typeTopicCounts[i];
            int i2 = 0;
            for (int i3 = 0; i3 < this.numTopics; i3++) {
                if (iArr[i3] > 0) {
                    i2++;
                }
            }
            int[] iArr2 = new int[i2];
            for (int i4 = 0; i4 < this.numTopics; i4++) {
                if (iArr[i4] > 0) {
                    int i5 = (iArr[i4] << bitCount) + i4;
                    int i6 = 0;
                    while (iArr2[i6] > i5) {
                        i6++;
                    }
                    while (i6 < iArr2.length && i5 > iArr2[i6]) {
                        int i7 = iArr2[i6];
                        iArr2[i6] = i5;
                        i5 = i7;
                        i6++;
                    }
                }
            }
            r0[i] = iArr2;
        }
        double[] dArr = new double[this.numTopics];
        Arrays.fill(dArr, this.alpha);
        return new MarginalProbEstimator(this.numTopics, dArr, this.alphaSum, this.beta, r0, this.tokensPerTopic);
    }

    public void printState(File file) throws IOException {
        PrintStream printStream = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(file))));
        printState(printStream);
        printStream.close();
    }

    public void printState(PrintStream printStream) {
        printStream.println("#doc source pos typeindex type topic");
        for (int i = 0; i < this.data.size(); i++) {
            FeatureSequence featureSequence = (FeatureSequence) this.data.get(i).instance.getData();
            LabelSequence labelSequence = this.data.get(i).topicSequence;
            StringBuilder sb = new StringBuilder();
            for (int i2 = 0; i2 < labelSequence.getLength(); i2++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i2);
                int indexAtPosition2 = labelSequence.getIndexAtPosition(i2);
                sb.append(i);
                sb.append(' ');
                sb.append("NA");
                sb.append(' ');
                sb.append(i2);
                sb.append(' ');
                sb.append(indexAtPosition);
                sb.append(' ');
                sb.append(this.alphabet.lookupObject(indexAtPosition));
                sb.append(' ');
                sb.append(indexAtPosition2);
                sb.append("\n");
            }
            printStream.print(sb.toString());
        }
    }

    public static void main(String[] strArr) throws Exception {
        CommandOption.setSummary(WeightedTopicModel.class, "Train topics with weights between word types encoded in the prior");
        CommandOption.process(WeightedTopicModel.class, strArr);
        InstanceList load = InstanceList.load(new File(inputFile.value));
        WeightedTopicModel weightedTopicModel = new WeightedTopicModel(numTopicsOption.value, alphaOption.value, betaOption.value, randomSeedOption.value != 0 ? new Randoms(randomSeedOption.value) : new Randoms());
        weightedTopicModel.addInstances(load);
        weightedTopicModel.readTypeTypeWeights(new File(weightsFile.value));
        int i = 1;
        while (i <= numEpochsOption.value) {
            weightedTopicModel.sample(numIterationsOption.value, i == 1, 1);
            if (stateFile.wasInvoked()) {
                weightedTopicModel.printState(new File(stateFile.value + "." + i));
            }
            if (evaluatorFilename.wasInvoked()) {
                try {
                    ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(evaluatorFilename.value + "." + i));
                    objectOutputStream.writeObject(weightedTopicModel.getEstimator());
                    objectOutputStream.close();
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
            i++;
        }
    }

    static {
        $assertionsDisabled = !WeightedTopicModel.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(WeightedTopicModel.class.getName());
        inputFile = new CommandOption.String(WeightedTopicModel.class, "input", "FILENAME", true, null, "The filename from which to read the list of training instances.  Use - for stdin.  The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null);
        weightsFile = new CommandOption.String(WeightedTopicModel.class, "weights-filename", "FILENAME", true, null, "The filename for the word-word weights file.", null);
        evaluatorFilename = new CommandOption.String(WeightedTopicModel.class, "evaluator-filename", "FILENAME", true, null, "A held-out likelihood evaluator for new documents.  By default this is null, indicating that no file will be written.", null);
        stateFile = new CommandOption.String(WeightedTopicModel.class, "state-filename", "FILENAME", true, null, "The filename in which to write the Gibbs sampling state after at the end of the iterations.  By default this is null, indicating that no file will be written.", null);
        numTopicsOption = new CommandOption.Integer(WeightedTopicModel.class, "num-topics", "INTEGER", true, 10, "The number of topics to fit.", null);
        numEpochsOption = new CommandOption.Integer(WeightedTopicModel.class, "num-epochs", "INTEGER", true, 1, "The number of cycles of training. Evaluators and state files will be saved after each epoch.", null);
        numIterationsOption = new CommandOption.Integer(WeightedTopicModel.class, "num-iterations", "INTEGER", true, 1000, "The number of iterations of Gibbs sampling PER EPOCH.", null);
        randomSeedOption = new CommandOption.Integer(WeightedTopicModel.class, "random-seed", "INTEGER", true, 0, "The random seed for the Gibbs sampler.  Default is 0, which will use the clock.", null);
        alphaOption = new CommandOption.Double(WeightedTopicModel.class, "alpha", "DECIMAL", true, 50.0d, "Alpha parameter: smoothing over topic distribution.", null);
        betaOption = new CommandOption.Double(WeightedTopicModel.class, "beta", "DECIMAL", true, 0.01d, "Beta parameter: smoothing over topic distribution.", null);
        sourceWordPattern = Pattern.compile("(.*) \\((\\d+)\\)");
        targetWordPattern = Pattern.compile("  (\\d+)\t(\\d+)\t([\\d\\.]+)\t(.*)");
    }
}
