package edu.northwestern.at.morphadorner.corpuslinguistics.postagger.transitionmatrix;

import edu.northwestern.at.utils.CompoundKey;
import edu.northwestern.at.utils.Formatters;
import edu.northwestern.at.utils.IsCloseableObject;
import edu.northwestern.at.utils.Map2D;
import edu.northwestern.at.utils.Map2DFactory;
import edu.northwestern.at.utils.Map3D;
import edu.northwestern.at.utils.Map3DFactory;
import edu.northwestern.at.utils.MapFactory;
import edu.northwestern.at.utils.StringUtils;
import edu.northwestern.at.utils.UnicodeReader;
import edu.northwestern.at.utils.logger.DummyLogger;
import edu.northwestern.at.utils.logger.Logger;
import edu.northwestern.at.utils.logger.UsesLogger;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStreamWriter;
import java.io.Reader;
import java.io.Writer;
import java.net.URL;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.zip.GZIPInputStream;

/* loaded from: input_file:edu/northwestern/at/morphadorner/corpuslinguistics/postagger/transitionmatrix/TransitionMatrix.class */
public class TransitionMatrix extends IsCloseableObject implements UsesLogger {
    protected static boolean debug = true;
    protected static final int UNIGRAM = 0;
    protected static final int BIGRAM = 1;
    protected static final int TRIGRAM = 2;
    protected Map<String, Integer> unigramCountMap = MapFactory.createNewMap();
    protected Map2D<String, String, Integer> bigramCountMap = Map2DFactory.createNewMap2D();
    protected Map3D<String, String, String, Integer> trigramCountMap = Map3DFactory.createNewMap3D();
    protected Map<String, Double> unigramProbMap = MapFactory.createNewMap();
    protected Map2D<String, String, Double> bigramProbMap = Map2DFactory.createNewMap2D();
    protected Map3D<String, String, String, Double> trigramProbMap = Map3DFactory.createNewMap3D();
    protected int[] totalNGrams = {0, 0, 0};
    protected int[] uniqueNGrams = {0, 0, 0};
    protected int totalWords = 0;
    protected boolean haveProbabilities = false;
    protected double[] bigramWeights = null;
    protected double[] trigramWeights = null;
    protected Logger logger = new DummyLogger();

    @Override // edu.northwestern.at.utils.logger.UsesLogger
    public Logger getLogger() {
        return this.logger;
    }

    @Override // edu.northwestern.at.utils.logger.UsesLogger
    public void setLogger(Logger logger) {
        this.logger = logger;
    }

    public void incrementCount(String str, int i) {
        Integer num = this.unigramCountMap.get(str);
        int i2 = i;
        if (num != null) {
            i2 = num.intValue() + i;
        } else {
            int[] iArr = this.uniqueNGrams;
            iArr[0] = iArr[0] + 1;
        }
        int[] iArr2 = this.totalNGrams;
        iArr2[0] = iArr2[0] + i;
        this.unigramCountMap.put(str, new Integer(i2));
        this.totalWords += i;
        this.haveProbabilities = false;
    }

    public void incrementCount(String str, String str2, int i) {
        Integer num = this.bigramCountMap.get(str, str2);
        int i2 = i;
        if (num != null) {
            i2 = num.intValue() + i;
        } else {
            int[] iArr = this.uniqueNGrams;
            iArr[1] = iArr[1] + 1;
        }
        int[] iArr2 = this.totalNGrams;
        iArr2[1] = iArr2[1] + i;
        this.bigramCountMap.put(str, str2, new Integer(i2));
        this.haveProbabilities = false;
    }

    public void incrementCount(String str, String str2, String str3, int i) {
        Integer num = this.trigramCountMap.get(str, str2, str3);
        int i2 = i;
        if (num != null) {
            i2 = num.intValue() + i;
        } else {
            int[] iArr = this.uniqueNGrams;
            iArr[2] = iArr[2] + 1;
        }
        int[] iArr2 = this.totalNGrams;
        iArr2[2] = iArr2[2] + i;
        this.trigramCountMap.put(str, str2, str3, new Integer(i2));
        this.haveProbabilities = false;
    }

    public double safelyDivideCount(int i, int i2) {
        double d = 0.0d;
        if (i2 > 0.0d) {
            d = i / i2;
        }
        return d;
    }

    public double safelyDivideSmoothedCount(int i, int i2) {
        double d = 0.0d;
        if (i2 > 1) {
            d = (i - 1) / (i2 - 1);
        }
        return d;
    }

    public void calculateProbabilities() {
        computeBigramWeights();
        computeTrigramWeights();
        this.haveProbabilities = true;
    }

    protected void computeTrigramWeights() {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        int i = 0;
        int i2 = this.totalNGrams[0];
        Iterator<CompoundKey> it = this.trigramCountMap.iterator();
        while (it.hasNext()) {
            i++;
            Comparable[] keyValues = it.next().getKeyValues();
            String[] strArr = new String[keyValues.length];
            for (int i3 = 0; i3 < strArr.length; i3++) {
                strArr[i3] = keyValues[i3].toString();
            }
            int count = getCount(strArr[0], strArr[1], strArr[2]);
            if (count > 0) {
                double safelyDivideCount = safelyDivideCount(getCount(strArr[2]), i2);
                double safelyDivideCount2 = safelyDivideCount(getCount(strArr[1], strArr[2]), getCount(strArr[1]));
                double safelyDivideCount3 = safelyDivideCount(count, getCount(strArr[0], strArr[1]));
                this.unigramProbMap.put(strArr[2], new Double(safelyDivideCount));
                this.bigramProbMap.put(strArr[1], strArr[2], new Double(safelyDivideCount2));
                this.trigramProbMap.put(strArr[0], strArr[1], strArr[2], new Double(safelyDivideCount3));
                double max = Math.max(Math.max(safelyDivideCount, safelyDivideCount2), safelyDivideCount3);
                if (max == safelyDivideCount) {
                    d += count;
                } else if (max == safelyDivideCount2) {
                    d2 += count;
                } else {
                    d3 += count;
                }
            }
        }
        double d4 = d + d2 + d3;
        if (d4 > 0.0d) {
            d /= d4;
            d2 /= d4;
            d3 /= d4;
        }
        this.trigramWeights = new double[]{d, d2, d3};
    }

    protected void computeBigramWeights() {
        double d = 0.0d;
        double d2 = 0.0d;
        int i = 0;
        int i2 = 0;
        int i3 = this.totalNGrams[0];
        Iterator<CompoundKey> it = this.bigramCountMap.keySet().iterator();
        while (it.hasNext()) {
            Comparable[] keyValues = it.next().getKeyValues();
            String[] strArr = new String[keyValues.length];
            for (int i4 = 0; i4 < strArr.length; i4++) {
                strArr[i4] = keyValues[i4].toString();
            }
            int count = getCount(strArr[0], strArr[1]);
            if (count > 0) {
                if (safelyDivideSmoothedCount(getCount(strArr[1]), i3) > safelyDivideSmoothedCount(count, getCount(strArr[0]))) {
                    d += count;
                    i++;
                } else {
                    d2 += count;
                    i2++;
                }
            }
        }
        double d3 = d + d2;
        if (d3 > 0.0d) {
            double d4 = d / d3;
            double d5 = d2 / d3;
        }
        this.bigramWeights = new double[]{0.03d, 0.97d};
    }

    public int getCount(String str) {
        int i = 0;
        Integer num = this.unigramCountMap.get(str);
        if (num != null) {
            i = num.intValue();
        }
        return i;
    }

    public int getCount(String str, String str2) {
        int i = 0;
        Integer num = this.bigramCountMap.get(str, str2);
        if (num != null) {
            i = num.intValue();
        }
        return i;
    }

    public int getCount(String str, String str2, String str3) {
        int i = 0;
        Integer num = this.trigramCountMap.get(str, str2, str3);
        if (num != null) {
            i = num.intValue();
        }
        return i;
    }

    public double getProbability(String str) {
        if (!this.haveProbabilities) {
            calculateProbabilities();
        }
        Double d = this.unigramProbMap.get(str);
        double d2 = 0.0d;
        if (d != null) {
            d2 = d.doubleValue();
        }
        return d2;
    }

    public double getProbability(String str, String str2) {
        if (!this.haveProbabilities) {
            calculateProbabilities();
        }
        Double d = this.bigramProbMap.get(str, str2);
        double d2 = 0.0d;
        if (d != null) {
            d2 = d.doubleValue();
        }
        return d2;
    }

    public double getProbability(String str, String str2, String str3) {
        if (!this.haveProbabilities) {
            calculateProbabilities();
        }
        Double d = this.trigramProbMap.get(str, str2, str3);
        double d2 = 0.0d;
        if (d != null) {
            d2 = d.doubleValue();
        }
        return d2;
    }

    public Set<String> rowKeySet() {
        return this.trigramCountMap.rowKeySet();
    }

    public Set<String> columnKeySet() {
        return this.trigramCountMap.columnKeySet();
    }

    public Set<String> sliceKeySet() {
        return this.trigramCountMap.sliceKeySet();
    }

    public int getTotalWordCount() {
        return this.totalWords;
    }

    public void loadTransitionMatrix(URL url, boolean z, String str, char c) throws IOException {
        InputStream openStream = url.openStream();
        GZIPInputStream gZIPInputStream = null;
        if (z) {
            gZIPInputStream = new GZIPInputStream(openStream);
        }
        loadTransitionMatrix(new UnicodeReader(z ? gZIPInputStream : openStream, str), c);
    }

    public void loadTransitionMatrix(URL url, String str, char c) throws IOException {
        loadTransitionMatrix(url, false, str, c);
    }

    public void loadTransitionMatrix(Reader reader, char c) throws IOException {
        String str = c + "";
        this.totalNGrams[0] = 0;
        this.totalNGrams[1] = 0;
        this.totalNGrams[2] = 0;
        this.uniqueNGrams[0] = 0;
        this.uniqueNGrams[1] = 0;
        this.uniqueNGrams[2] = 0;
        this.totalWords = 0;
        BufferedReader bufferedReader = new BufferedReader(reader);
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                calculateProbabilities();
                return;
            }
            String[] split = readLine.split(str);
            switch (split.length) {
                case 2:
                    incrementCount(split[0], Integer.parseInt(split[1]));
                    break;
                case 3:
                    incrementCount(split[0], split[1], Integer.parseInt(split[2]));
                    break;
                case 4:
                    incrementCount(split[0], split[1], split[2], Integer.parseInt(split[3]));
                    break;
            }
        }
    }

    public void displayNGramCounts() {
        this.logger.logDebug("");
        this.logger.logDebug("Transition matrix total ngram counts");
        this.logger.logDebug("");
        this.logger.logDebug("   Unigram    Bigram   Trigram");
        this.logger.logDebug("");
        this.logger.logDebug((StringUtils.lpad(Formatters.formatIntegerWithCommas(this.totalNGrams[0]), 10) + StringUtils.lpad(Formatters.formatIntegerWithCommas(this.totalNGrams[1]), 10)) + StringUtils.lpad(Formatters.formatIntegerWithCommas(this.totalNGrams[2]), 10));
        this.logger.logDebug("");
        this.logger.logDebug("Transition matrix unique ngram counts");
        this.logger.logDebug("");
        this.logger.logDebug("   Unigram    Bigram   Trigram");
        this.logger.logDebug("");
        this.logger.logDebug((StringUtils.lpad(Formatters.formatIntegerWithCommas(this.uniqueNGrams[0]), 10) + StringUtils.lpad(Formatters.formatIntegerWithCommas(this.uniqueNGrams[1]), 10)) + StringUtils.lpad(Formatters.formatIntegerWithCommas(this.uniqueNGrams[2]), 10));
    }

    public void saveTransitionMatrix(String str, String str2, char c) throws IOException {
        displayNGramCounts();
        saveTransitionMatrix(new OutputStreamWriter(new FileOutputStream(str, false), str2), c);
    }

    public void saveTransitionMatrix(Writer writer, char c) throws IOException {
        int count;
        BufferedWriter bufferedWriter = new BufferedWriter(writer);
        String[] strArr = (String[]) this.unigramCountMap.keySet().toArray(new String[0]);
        Arrays.sort(strArr);
        String[] strArr2 = (String[]) this.bigramCountMap.columnKeySet().toArray(new String[0]);
        Arrays.sort(strArr2);
        String[] strArr3 = (String[]) this.trigramCountMap.sliceKeySet().toArray(new String[0]);
        Arrays.sort(strArr3);
        for (int i = 0; i < strArr.length; i++) {
            int count2 = getCount(strArr[i]);
            if (count2 > 0) {
                bufferedWriter.write(strArr[i] + c + count2);
                bufferedWriter.newLine();
            }
            for (int i2 = 0; i2 < strArr2.length; i2++) {
                int count3 = getCount(strArr[i], strArr2[i2]);
                if (count3 > 0) {
                    bufferedWriter.write(strArr[i] + c + strArr2[i2] + c + count3);
                    bufferedWriter.newLine();
                }
                for (int i3 = 0; i3 < strArr3.length; i3++) {
                    if (!strArr3.equals("slice") && (count = getCount(strArr[i], strArr2[i2], strArr3[i3])) > 0) {
                        bufferedWriter.write(strArr[i] + c + strArr2[i2] + c + strArr3[i3] + c + count);
                        bufferedWriter.newLine();
                    }
                }
            }
        }
        bufferedWriter.flush();
        bufferedWriter.close();
    }

    public double[] getBigramWeights() {
        if (this.bigramWeights == null) {
            calculateProbabilities();
        }
        return this.bigramWeights;
    }

    public double[] getTrigramWeights() {
        if (this.trigramWeights == null) {
            calculateProbabilities();
        }
        return this.trigramWeights;
    }
}
