package uk.ac.cam.ch.wwmm.oscarrecogniser.tokenanalysis;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.regex.Pattern;
import org.apache.xpath.XPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.ac.cam.ch.wwmm.oscar.chemnamedict.core.ChemNameDictRegistry;
import uk.ac.cam.ch.wwmm.oscar.terms.TermSets;
import uk.ac.cam.ch.wwmm.oscarrecogniser.extractedtrainingdata.ExtractedTrainingData;

/* loaded from: input_file:uk/ac/cam/ch/wwmm/oscarrecogniser/tokenanalysis/NGramBuilder.class */
public class NGramBuilder {
    static final double SCALE = 500.0d;
    static final String ALPHABET = "$^S0%<>&'()*+,-./:;=?@[]abcdefghijklmnopqrstuvwxyz|~";
    private int[] C1C;
    private int[][] C2C;
    private int[][][] C3C;
    private int[][][][] C4C;
    private int[] E1C;
    private int[][] E2C;
    private int[][][] E3C;
    private int[][][][] E4C;
    private double[][][][] LP4C;
    private double[][][][] LP4E;
    private Collection<String> extraChemical;
    private Collection<String> extraEnglish;
    private boolean extraOnly;
    private List<String> chemWords;
    private List<String> englishWords;
    private Set<String> chemSet;
    private Set<String> engSet;
    private ExtractedTrainingData etd;
    private Set<String> registryNames;
    private static final Logger LOG = LoggerFactory.getLogger(NGramBuilder.class);
    private static final Pattern matchWhiteSpace = Pattern.compile("\\s+");
    private static final Pattern matchTwoOrMoreAdjacentLetters = Pattern.compile(".*[a-z][a-z].*");

    NGramBuilder(ExtractedTrainingData extractedTrainingData, Set<String> set) {
        this.etd = extractedTrainingData;
        this.registryNames = set;
        this.chemWords = new ArrayList();
        this.englishWords = new ArrayList();
        readTrainingData();
    }

    NGramBuilder() {
        this(null, Collections.unmodifiableSet(ChemNameDictRegistry.getDefaultInstance().getAllNames()));
    }

    private void train() {
        int length = ALPHABET.length();
        this.C1C = new int[length];
        this.C2C = new int[length][length];
        this.C3C = new int[length][length][length];
        this.C4C = new int[length][length][length][length];
        this.E1C = new int[length];
        this.E2C = new int[length][length];
        this.E3C = new int[length][length][length];
        this.E4C = new int[length][length][length][length];
        this.engSet = new HashSet();
        for (String str : this.englishWords) {
            addEngNGrams(str);
            this.engSet.add(NGram.parseWord(str));
        }
        this.chemSet = new HashSet();
        for (String str2 : this.chemWords) {
            addChemNGrams(str2);
            this.chemSet.add(NGram.parseWord(str2));
        }
        this.LP4C = calcLP4(this.C1C, this.C2C, this.C3C, this.C4C);
        this.LP4E = calcLP4(this.E1C, this.E2C, this.E3C, this.E4C);
        this.C1C = null;
        this.C2C = (int[][]) null;
        this.C3C = (int[][][]) null;
        this.C4C = (int[][][][]) null;
        this.E1C = null;
        this.E2C = (int[][]) null;
        this.E3C = (int[][][]) null;
        this.E4C = (int[][][][]) null;
        LOG.debug("nGrams initialised");
    }

    double[][][][] getLP4C() {
        return this.LP4C;
    }

    double[][][][] getLP4E() {
        return this.LP4E;
    }

    public List<String> getEnglishWords() {
        return this.englishWords;
    }

    public List<String> getChemicalWords() {
        return this.chemWords;
    }

    private void readTrainingData() {
        if (!this.extraOnly) {
            readStopWordsTrainingData();
            if (this.etd != null) {
                readExtractedTrainingData();
            }
            readChemNameDictTrainingData();
            readElementsTrainingData();
            readUdwTrainingData();
            readAseTrainingData();
        }
        readExtraTrainingData();
    }

    private void readCollection(Collection<String> collection, boolean z) {
        if (collection == null) {
            return;
        }
        Iterator<String> it = collection.iterator();
        while (it.hasNext()) {
            String[] split = matchWhiteSpace.split(it.next());
            for (int i = 0; i < split.length; i++) {
                if (matchTwoOrMoreAdjacentLetters.matcher(split[i]).matches()) {
                    if (z) {
                        addChemical(split[i]);
                    } else {
                        addEnglish(split[i]);
                    }
                }
            }
        }
    }

    private void readStopWordsTrainingData() {
        readCollection(TermSets.getDefaultInstance().getStopWords(), false);
    }

    private void readChemNameDictTrainingData() {
        readCollection(this.registryNames, true);
    }

    private void readElementsTrainingData() {
        readCollection(TermSets.getDefaultInstance().getElements(), true);
    }

    private void readUdwTrainingData() {
        HashSet hashSet = new HashSet();
        for (String str : TermSets.getDefaultInstance().getUsrDictWords()) {
            if (!this.registryNames.contains(str) && (this.etd == null || !this.etd.getChemicalWords().contains(str))) {
                hashSet.add(str);
            }
        }
        readCollection(hashSet, false);
    }

    private void readExtractedTrainingData() {
        readCollection(this.etd.getChemicalWords(), true);
        readCollection(this.etd.getNonChemicalWords(), false);
    }

    private void readAseTrainingData() {
        readCollection(TermSets.getDefaultInstance().getChemAses(), true);
        readCollection(TermSets.getDefaultInstance().getNonChemAses(), false);
    }

    private void readExtraTrainingData() {
        readCollection(this.extraChemical, true);
        readCollection(this.extraEnglish, false);
    }

    private double[][][][] calcLP4(int[] iArr, int[][] iArr2, int[][][] iArr3, int[][][][] iArr4) {
        int length = ALPHABET.length();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < length; i4++) {
            i += iArr[i4];
            if (iArr[i4] > 0) {
                i2++;
            } else {
                i3++;
            }
        }
        double[] dArr = new double[length];
        for (int i5 = 0; i5 < length; i5++) {
            if (iArr[i5] > 0) {
                dArr[i5] = (1.0d * iArr[i5]) / (1.0d * (i + i2));
            } else {
                dArr[i5] = (1.0d * i2) / ((1.0d * i3) * (i + i2));
            }
        }
        int i6 = 0;
        int i7 = 0;
        int i8 = 0;
        int i9 = 0;
        int[] iArr5 = new int[length];
        int[] iArr6 = new int[length];
        int[] iArr7 = new int[length];
        for (int i10 = 0; i10 < length; i10++) {
            iArr5[i10] = 0;
            iArr6[i10] = 0;
            iArr7[i10] = 0;
            for (int i11 = 0; i11 < length; i11++) {
                if (iArr2[i10][i11] == 1) {
                    int i12 = i10;
                    iArr5[i12] = iArr5[i12] + 1;
                    i6++;
                } else if (iArr2[i10][i11] == 2) {
                    int i13 = i10;
                    iArr6[i13] = iArr6[i13] + 1;
                    i7++;
                } else if (iArr2[i10][i11] == 3) {
                    int i14 = i10;
                    iArr7[i14] = iArr7[i14] + 1;
                    i8++;
                } else if (iArr2[i10][i11] == 4) {
                    int i15 = i10;
                    iArr7[i15] = iArr7[i15] + 1;
                    i9++;
                } else if (iArr2[i10][i11] > 4) {
                    int i16 = i10;
                    iArr7[i16] = iArr7[i16] + 1;
                }
            }
        }
        double d = (1.0d * i6) / (1.0d * (i6 + (2 * i7)));
        double d2 = 1.0d - (((2.0d * d) * i7) / (1.0d * i6));
        double d3 = 1.0d - (((2.0d * d) * i8) / (1.0d * i7));
        double d4 = 1.0d - (((2.0d * d) * i9) / (1.0d * i8));
        double[] dArr2 = new double[length];
        for (int i17 = 0; i17 < length; i17++) {
            double d5 = 0.0d;
            for (int i18 = 0; i18 < length; i18++) {
                d5 += iArr2[i17][i18];
            }
            dArr2[i17] = (((d2 * iArr5[i17]) + (d3 * iArr6[i17])) + (d4 * iArr7[i17])) / d5;
        }
        double[][] dArr3 = new double[length][length];
        for (int i19 = 0; i19 < length; i19++) {
            double d6 = 0.0d;
            double d7 = 0.0d;
            for (int i20 = 0; i20 < length; i20++) {
                d6 += iArr2[i19][i20];
            }
            if (d6 > XPath.MATCH_SCORE_QNAME) {
                for (int i21 = 0; i21 < length; i21++) {
                    dArr3[i19][i21] = ((iArr2[i19][i21] - (iArr2[i19][i21] == 0 ? XPath.MATCH_SCORE_QNAME : iArr2[i19][i21] == 1 ? d2 : iArr2[i19][i21] == 2 ? d3 : d4)) / d6) + (dArr2[i19] * dArr[i21]);
                    d7 += dArr3[i19][i21];
                }
            } else {
                for (int i22 = 0; i22 < length; i22++) {
                    dArr3[i19][i22] = dArr[i22];
                    d7 += dArr3[i19][i22];
                }
            }
        }
        int i23 = 0;
        int i24 = 0;
        int i25 = 0;
        int i26 = 0;
        int[][] iArr8 = new int[length][length];
        int[][] iArr9 = new int[length][length];
        int[][] iArr10 = new int[length][length];
        for (int i27 = 0; i27 < length; i27++) {
            for (int i28 = 0; i28 < length; i28++) {
                iArr8[i27][i28] = 0;
                iArr9[i27][i28] = 0;
                iArr10[i27][i28] = 0;
                for (int i29 = 0; i29 < length; i29++) {
                    if (iArr3[i27][i28][i29] == 1) {
                        int[] iArr11 = iArr8[i27];
                        int i30 = i28;
                        iArr11[i30] = iArr11[i30] + 1;
                        i23++;
                    } else if (iArr3[i27][i28][i29] == 2) {
                        int[] iArr12 = iArr9[i27];
                        int i31 = i28;
                        iArr12[i31] = iArr12[i31] + 1;
                        i24++;
                    } else if (iArr3[i27][i28][i29] == 3) {
                        int[] iArr13 = iArr10[i27];
                        int i32 = i28;
                        iArr13[i32] = iArr13[i32] + 1;
                        i25++;
                    } else if (iArr3[i27][i28][i29] == 4) {
                        int[] iArr14 = iArr10[i27];
                        int i33 = i28;
                        iArr14[i33] = iArr14[i33] + 1;
                        i26++;
                    } else if (iArr3[i27][i28][i29] > 4) {
                        int[] iArr15 = iArr10[i27];
                        int i34 = i28;
                        iArr15[i34] = iArr15[i34] + 1;
                    }
                }
            }
        }
        double d8 = (1.0d * i23) / (1.0d * (i23 + (2 * i24)));
        double d9 = 1.0d - (((2.0d * d8) * i24) / (1.0d * i23));
        double d10 = 1.0d - (((2.0d * d8) * i25) / (1.0d * i24));
        double d11 = 1.0d - (((2.0d * d8) * i26) / (1.0d * i25));
        double[][] dArr4 = new double[length][length];
        for (int i35 = 0; i35 < length; i35++) {
            for (int i36 = 0; i36 < length; i36++) {
                double d12 = 0.0d;
                for (int i37 = 0; i37 < length; i37++) {
                    d12 += iArr3[i35][i36][i37];
                }
                dArr4[i35][i36] = (((d9 * iArr8[i35][i36]) + (d10 * iArr9[i35][i36])) + (d11 * iArr10[i35][i36])) / d12;
            }
        }
        double[][][] dArr5 = new double[length][length][length];
        for (int i38 = 0; i38 < length; i38++) {
            for (int i39 = 0; i39 < length; i39++) {
                double d13 = 0.0d;
                double d14 = 0.0d;
                for (int i40 = 0; i40 < length; i40++) {
                    d13 += iArr3[i38][i39][i40];
                }
                if (d13 > XPath.MATCH_SCORE_QNAME) {
                    for (int i41 = 0; i41 < length; i41++) {
                        dArr5[i38][i39][i41] = ((iArr3[i38][i39][i41] - (iArr3[i38][i39][i41] == 0 ? XPath.MATCH_SCORE_QNAME : iArr3[i38][i39][i41] == 1 ? d9 : iArr3[i38][i39][i41] == 2 ? d10 : d11)) / d13) + (dArr4[i38][i39] * dArr3[i39][i41]);
                        d14 += dArr5[i38][i39][i41];
                    }
                } else {
                    for (int i42 = 0; i42 < length; i42++) {
                        dArr5[i38][i39][i42] = dArr3[i39][i42];
                        d14 += dArr5[i38][i39][i42];
                    }
                }
            }
        }
        int i43 = 0;
        int i44 = 0;
        int i45 = 0;
        int i46 = 0;
        int[][][] iArr16 = new int[length][length][length];
        int[][][] iArr17 = new int[length][length][length];
        int[][][] iArr18 = new int[length][length][length];
        for (int i47 = 0; i47 < length; i47++) {
            for (int i48 = 0; i48 < length; i48++) {
                for (int i49 = 0; i49 < length; i49++) {
                    iArr16[i47][i48][i49] = 0;
                    iArr17[i47][i48][i49] = 0;
                    iArr18[i47][i48][i49] = 0;
                    for (int i50 = 0; i50 < length; i50++) {
                        if (iArr4[i47][i48][i49][i50] == 1) {
                            int[] iArr19 = iArr16[i47][i48];
                            int i51 = i49;
                            iArr19[i51] = iArr19[i51] + 1;
                            i43++;
                        } else if (iArr4[i47][i48][i49][i50] == 2) {
                            int[] iArr20 = iArr17[i47][i48];
                            int i52 = i49;
                            iArr20[i52] = iArr20[i52] + 1;
                            i44++;
                        } else if (iArr4[i47][i48][i49][i50] == 3) {
                            int[] iArr21 = iArr18[i47][i48];
                            int i53 = i49;
                            iArr21[i53] = iArr21[i53] + 1;
                            i45++;
                        } else if (iArr4[i47][i48][i49][i50] == 4) {
                            int[] iArr22 = iArr18[i47][i48];
                            int i54 = i49;
                            iArr22[i54] = iArr22[i54] + 1;
                            i46++;
                        } else if (iArr4[i47][i48][i49][i50] > 4) {
                            int[] iArr23 = iArr18[i47][i48];
                            int i55 = i49;
                            iArr23[i55] = iArr23[i55] + 1;
                        }
                    }
                }
            }
        }
        double d15 = (1.0d * i43) / (1.0d * (i43 + (2 * i44)));
        double d16 = 1.0d - (((2.0d * d15) * i44) / (1.0d * i43));
        double d17 = 1.0d - (((2.0d * d15) * i45) / (1.0d * i44));
        double d18 = 1.0d - (((2.0d * d15) * i46) / (1.0d * i45));
        double[][][] dArr6 = new double[length][length][length];
        for (int i56 = 0; i56 < length; i56++) {
            for (int i57 = 0; i57 < length; i57++) {
                for (int i58 = 0; i58 < length; i58++) {
                    double d19 = 0.0d;
                    for (int i59 = 0; i59 < length; i59++) {
                        d19 += iArr4[i56][i57][i58][i59];
                    }
                    dArr6[i56][i57][i58] = (((d16 * iArr16[i56][i57][i58]) + (d17 * iArr17[i56][i57][i58])) + (d18 * iArr18[i56][i57][i58])) / d19;
                }
            }
        }
        double[][][][] dArr7 = new double[length][length][length][length];
        for (int i60 = 0; i60 < length; i60++) {
            for (int i61 = 0; i61 < length; i61++) {
                for (int i62 = 0; i62 < length; i62++) {
                    double d20 = 0.0d;
                    double d21 = 0.0d;
                    for (int i63 = 0; i63 < length; i63++) {
                        d20 += iArr4[i60][i61][i62][i63];
                    }
                    if (d20 > XPath.MATCH_SCORE_QNAME) {
                        for (int i64 = 0; i64 < length; i64++) {
                            dArr7[i60][i61][i62][i64] = ((iArr4[i60][i61][i62][i64] - (iArr4[i60][i61][i62][i64] == 0 ? XPath.MATCH_SCORE_QNAME : iArr4[i60][i61][i62][i64] == 1 ? d16 : iArr4[i60][i61][i62][i64] == 2 ? d17 : d18)) / d20) + (dArr6[i60][i61][i62] * dArr5[i61][i62][i64]);
                            d21 += dArr7[i60][i61][i62][i64];
                        }
                    } else {
                        for (int i65 = 0; i65 < length; i65++) {
                            dArr7[i60][i61][i62][i65] = dArr5[i61][i62][i65];
                            d21 += dArr7[i60][i61][i62][i65];
                        }
                    }
                }
            }
        }
        for (int i66 = 0; i66 < length; i66++) {
            for (int i67 = 0; i67 < length; i67++) {
                for (int i68 = 0; i68 < length; i68++) {
                    for (int i69 = 0; i69 < length; i69++) {
                        dArr7[i66][i67][i68][i69] = Math.log(dArr7[i66][i67][i68][i69]);
                    }
                }
            }
        }
        return dArr7;
    }

    private void addChemical(String str) {
        this.chemWords.add(str);
    }

    private void addEnglish(String str) {
        this.englishWords.add(str);
    }

    private void addEngNGrams(String str) {
        addWordNGrams(str, this.E1C, this.E2C, this.E3C, this.E4C);
    }

    private void addChemNGrams(String str) {
        addWordNGrams(str, this.C1C, this.C2C, this.C3C, this.C4C);
    }

    private void addWordNGrams(String str, int[] iArr, int[][] iArr2, int[][][] iArr3, int[][][][] iArr4) {
        String parseWord = NGram.parseWord(str);
        if (parseWord.length() <= 1) {
            return;
        }
        String addStartAndEnd = NGram.addStartAndEnd(parseWord);
        int length = addStartAndEnd.length();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        for (int i5 = 0; i5 < length; i5++) {
            if (i5 > 2) {
                i = i2;
            }
            if (i5 > 1) {
                i2 = i3;
            }
            if (i5 > 0) {
                i3 = i4;
            }
            i4 = ALPHABET.indexOf(addStartAndEnd.charAt(i5));
            iArr[i4] = iArr[i4] + 1;
            if (i5 > 0) {
                int[] iArr5 = iArr2[i3];
                iArr5[i4] = iArr5[i4] + 1;
                if (i5 > 1) {
                    int[] iArr6 = iArr3[i2][i3];
                    iArr6[i4] = iArr6[i4] + 1;
                    if (i5 > 2) {
                        int[] iArr7 = iArr4[i][i2][i3];
                        iArr7[i4] = iArr7[i4] + 1;
                    }
                }
            }
        }
    }

    public static NGram buildModel() {
        return buildModel(null, Collections.unmodifiableSet(ChemNameDictRegistry.getDefaultInstance().getAllNames()));
    }

    public static NGram buildModel(ExtractedTrainingData extractedTrainingData, Set<String> set) {
        NGramBuilder nGramBuilder = new NGramBuilder(extractedTrainingData, set);
        nGramBuilder.train();
        return nGramBuilder.toNGram();
    }

    private NGram toNGram() {
        int length = this.LP4C.length;
        int i = length * length * length;
        int i2 = length * length;
        short[] sArr = new short[length * length * length * length];
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i3 = 0; i3 < length; i3++) {
            for (int i4 = 0; i4 < length; i4++) {
                for (int i5 = 0; i5 < length; i5++) {
                    for (int i6 = 0; i6 < length; i6++) {
                        double d3 = this.LP4C[i3][i4][i5][i6] - this.LP4E[i3][i4][i5][i6];
                        if (d3 > d) {
                            d = d3;
                        }
                        if (d3 < d2) {
                            d2 = d3;
                        }
                        double d4 = SCALE * d3;
                        if (d4 > 32767.0d) {
                            System.err.println("Warning: upper bound exceeded - " + d4);
                            d4 = 32767.0d;
                        } else if (d4 < -32768.0d) {
                            System.err.println("Warning: lower bound exceeded - " + d4);
                            d4 = -32768.0d;
                        }
                        sArr[(i3 * i) + (i4 * i2) + (i5 * length) + i6] = (short) Math.round(d4);
                    }
                }
            }
        }
        return new NGram(sArr);
    }

    public String calculateSourceDataFingerprint() {
        return calculateOrderInvariantHashCode(this.englishWords) + "_" + calculateOrderInvariantHashCode(this.chemWords);
    }

    private int calculateOrderInvariantHashCode(List<String> list) {
        int i = 0;
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            i ^= it.next().hashCode();
        }
        return i;
    }

    private static String getModelFileLocation(String str) {
        return "uk/ac/cam/ch/wwmm/oscarrecogniser/tokenanalysis/ngram-model" + str + ".dat.gz";
    }

    static NGram deserialiseModel(String str) throws IOException {
        return NGram.loadModel(getModelFileLocation(str));
    }

    public static NGram buildOrDeserialiseModel() {
        return buildOrDeserialiseModel(null, Collections.unmodifiableSet(ChemNameDictRegistry.getDefaultInstance().getAllNames()));
    }

    public static NGram buildOrDeserialiseModel(ExtractedTrainingData extractedTrainingData, Set<String> set) {
        NGramBuilder nGramBuilder = new NGramBuilder(extractedTrainingData, set);
        try {
            return deserialiseModel(nGramBuilder.calculateSourceDataFingerprint());
        } catch (IOException e) {
            nGramBuilder.train();
            return nGramBuilder.toNGram();
        }
    }

    public static void main(String[] strArr) throws IOException {
        ChemNameDictRegistry defaultInstance = ChemNameDictRegistry.getDefaultInstance();
        ExtractedTrainingData loadExtractedTrainingData = ExtractedTrainingData.loadExtractedTrainingData("chempapers");
        System.out.println("building ngrams...");
        NGramBuilder nGramBuilder = new NGramBuilder(loadExtractedTrainingData, Collections.unmodifiableSet(defaultInstance.getAllNames()));
        nGramBuilder.train();
        NGram nGram = nGramBuilder.toNGram();
        System.out.println("serialising data...");
        nGram.saveData(new FileOutputStream(new File("src/main/resources/" + getModelFileLocation(nGramBuilder.calculateSourceDataFingerprint()))));
        System.out.println("...done!");
    }
}
