package uk.ac.cam.ch.wwmm.oscarMEMM.memm;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import nu.xom.Attribute;
import nu.xom.Builder;
import nu.xom.Document;
import nu.xom.Element;
import nu.xom.Nodes;
import nu.xom.ParsingException;
import nu.xom.ValidityException;
import opennlp.maxent.GIS;
import opennlp.model.Event;
import opennlp.model.EventCollectorAsStream;
import opennlp.model.MaxentModel;
import opennlp.model.TwoPassDataIndexer;
import org.apache.commons.io.IOUtils;
import org.apache.xpath.XPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.ac.cam.ch.wwmm.oscar.chemnamedict.core.ChemNameDictRegistry;
import uk.ac.cam.ch.wwmm.oscar.document.NamedEntity;
import uk.ac.cam.ch.wwmm.oscar.document.Token;
import uk.ac.cam.ch.wwmm.oscar.document.TokenSequence;
import uk.ac.cam.ch.wwmm.oscar.document.XOMBasedProcessingDocumentFactory;
import uk.ac.cam.ch.wwmm.oscar.exceptions.DataFormatException;
import uk.ac.cam.ch.wwmm.oscar.tools.StringTools;
import uk.ac.cam.ch.wwmm.oscar.types.BioTag;
import uk.ac.cam.ch.wwmm.oscar.types.BioType;
import uk.ac.cam.ch.wwmm.oscar.types.NamedEntityType;
import uk.ac.cam.ch.wwmm.oscar.xmltools.XOMTools;
import uk.ac.cam.ch.wwmm.oscarMEMM.memm.data.MutableMEMMModel;
import uk.ac.cam.ch.wwmm.oscarMEMM.memm.gis.SimpleEventCollector;
import uk.ac.cam.ch.wwmm.oscarMEMM.memm.rescorer.MEMMOutputRescorer;
import uk.ac.cam.ch.wwmm.oscarMEMM.memm.rescorer.MEMMOutputRescorerTrainer;
import uk.ac.cam.ch.wwmm.oscarMEMM.models.TrainingDataExtractor;
import uk.ac.cam.ch.wwmm.oscarrecogniser.extractedtrainingdata.ExtractedTrainingData;
import uk.ac.cam.ch.wwmm.oscarrecogniser.tokenanalysis.NGramBuilder;
import uk.ac.cam.ch.wwmm.oscartokeniser.HyphenTokeniser;
import uk.ac.cam.ch.wwmm.oscartokeniser.Tokeniser;

/* loaded from: input_file:uk/ac/cam/ch/wwmm/oscarMEMM/memm/MEMMTrainer.class */
public final class MEMMTrainer {
    private MutableMEMMModel model;
    private Map<BioType, Map<String, Double>> featureCVScores;
    private final Logger logger = LoggerFactory.getLogger(MEMMTrainer.class);
    private boolean retrain = true;
    private boolean splitTrain = true;
    private boolean featureSel = true;
    private boolean simpleRescore = true;
    private boolean nameTypes = false;
    Map<BioType, List<Event>> evsByPrev = new HashMap();
    private Map<BioType, Set<String>> perniciousFeatures = null;
    private int trainingCycles = 100;
    private int featureCutOff = 1;

    public MEMMTrainer(ChemNameDictRegistry chemNameDictRegistry) {
        this.model = new MutableMEMMModel(Collections.unmodifiableSet(chemNameDictRegistry.getAllNames()));
    }

    private void train(FeatureList featureList, BioType bioType, BioType bioType2) {
        if (this.perniciousFeatures != null && this.perniciousFeatures.containsKey(bioType2)) {
            featureList.removeFeatures(this.perniciousFeatures.get(bioType2));
        }
        if (featureList.getFeatureCount() == 0) {
            featureList.addFeature("EMPTY");
        }
        this.model.getTagSet().add(bioType);
        Event event = new Event(bioType.toString(), featureList.toArray());
        List<Event> list = this.evsByPrev.get(bioType2);
        if (list == null) {
            list = new ArrayList();
            this.evsByPrev.put(bioType2, list);
        }
        list.add(event);
    }

    private void trainOnSentence(TokenSequence tokenSequence) {
        List<FeatureList> extractFeatures = FeatureExtractor.extractFeatures(tokenSequence, this.model.getNGram(), this.model.getChemNameDictNames());
        List<Token> tokens = tokenSequence.getTokens();
        BioType bioType = new BioType(BioTag.O);
        for (int i = 0; i < tokens.size(); i++) {
            train(extractFeatures.get(i), tokens.get(i).getBioType(), bioType);
            bioType = tokens.get(i).getBioType();
        }
    }

    private void trainOnFile(File file) throws DataFormatException, IOException {
        this.logger.debug("Train on: " + file + "... ");
        FileInputStream fileInputStream = new FileInputStream(file);
        try {
            trainOnStream(fileInputStream);
            IOUtils.closeQuietly((InputStream) fileInputStream);
        } catch (Throwable th) {
            IOUtils.closeQuietly((InputStream) fileInputStream);
            throw th;
        }
    }

    private void trainOnStream(InputStream inputStream) throws DataFormatException, IOException {
        long currentTimeMillis = System.currentTimeMillis();
        try {
            trainOnDoc(new Builder().build(inputStream));
            this.logger.debug("Time: {}", Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
        } catch (ParsingException e) {
            throw new DataFormatException("incorrect formatting of training resource");
        }
    }

    private void trainOnDoc(Document document) {
        Nodes query = document.query("//cmlPile");
        for (int i = 0; i < query.size(); i++) {
            query.get(i).detach();
        }
        Nodes query2 = document.query("//ne[@type='CPR']");
        for (int i2 = 0; i2 < query2.size(); i2++) {
            XOMTools.removeElementPreservingText((Element) query2.get(i2));
        }
        if (this.nameTypes) {
            Nodes query3 = document.query("//ne");
            for (int i3 = 0; i3 < query3.size(); i3++) {
                Element element = (Element) query3.get(i3);
                if (element.getAttributeValue("type").equals(NamedEntityType.REACTION.getName()) && element.getValue().matches("[A-Z]\\p{Ll}\\p{Ll}.*\\s.*")) {
                    element.addAttribute(new Attribute("type", "NRN"));
                    this.logger.debug("NRN: " + element.getValue());
                } else if (element.getAttributeValue("type").equals(NamedEntityType.COMPOUND.getName()) && element.getValue().matches("[A-Z]\\p{Ll}\\p{Ll}.*\\s.*")) {
                    element.addAttribute(new Attribute("type", "NCM"));
                    this.logger.debug("NCM: " + element.getValue());
                } else if (element.getAttributeValue("type").equals(NamedEntityType.COMPOUND.getName())) {
                    element.addAttribute(new Attribute("type", "CM"));
                    this.logger.debug("CM: " + element.getValue());
                }
            }
        }
        Iterator<TokenSequence> it = XOMBasedProcessingDocumentFactory.getInstance().makeTokenisedDocument(Tokeniser.getDefaultInstance(), document, true, false).getTokenSequences().iterator();
        while (it.hasNext()) {
            trainOnSentence(it.next());
        }
    }

    private void trainOnSbFilesNosplit(List<File> list) throws DataFormatException, IOException {
        if (this.retrain) {
            HyphenTokeniser.reinitialise();
            this.model.setExtractedTrainingData(new ExtractedTrainingData(new TrainingDataExtractor(filesToDocs(list)).toXML()));
            HyphenTokeniser.reinitialise();
        }
        Iterator<File> it = list.iterator();
        while (it.hasNext()) {
            trainOnFile(it.next());
        }
        finishTraining();
    }

    @Deprecated
    private Collection<Document> filesToDocs(Collection<File> collection) {
        ArrayList arrayList = new ArrayList();
        Iterator<File> it = collection.iterator();
        while (it.hasNext()) {
            try {
                arrayList.add(new Builder().build(it.next()));
            } catch (IOException e) {
                e.printStackTrace();
            } catch (ValidityException e2) {
                e2.printStackTrace();
            } catch (ParsingException e3) {
                e3.printStackTrace();
            }
        }
        return arrayList;
    }

    private void trainOnSbFiles(List<File> list) throws DataFormatException, IOException {
        if (!this.splitTrain) {
            trainOnSbFilesNosplit(list);
            return;
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < 2; i++) {
            arrayList.add(new HashSet());
            arrayList2.add(new HashSet());
        }
        for (int i2 = 0; i2 < list.size(); i2++) {
            for (int i3 = 0; i3 < 2; i3++) {
                if (i3 == i2 % 2) {
                    ((Set) arrayList.get(i3)).add(list.get(i2));
                } else {
                    ((Set) arrayList2.get(i3)).add(list.get(i2));
                }
            }
        }
        for (int i4 = 0; i4 < 2; i4++) {
            if (this.retrain) {
                HyphenTokeniser.reinitialise();
                new TrainingDataExtractor(filesToDocs((Collection) arrayList2.get(i4)));
                HyphenTokeniser.reinitialise();
            }
            int i5 = 0;
            Iterator it = ((Set) arrayList.get(i4)).iterator();
            while (it.hasNext()) {
                i5++;
                trainOnFile((File) it.next());
            }
        }
        finishTraining();
        if (this.retrain) {
            HyphenTokeniser.reinitialise();
            new TrainingDataExtractor(filesToDocs(list));
            HyphenTokeniser.reinitialise();
        }
    }

    private void trainOnSbFilesWithCVFS(List<File> list) throws DataFormatException, IOException {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < 3; i++) {
            arrayList.add(new ArrayList());
            arrayList2.add(new ArrayList());
        }
        for (int i2 = 0; i2 < list.size(); i2++) {
            for (int i3 = 0; i3 < 3; i3++) {
                if (i3 == i2 % 3) {
                    ((List) arrayList.get(i3)).add(list.get(i2));
                } else {
                    ((List) arrayList2.get(i3)).add(list.get(i2));
                }
            }
        }
        for (int i4 = 0; i4 < 3; i4++) {
            trainOnSbFiles((List) arrayList2.get(i4));
            this.evsByPrev.clear();
            Iterator it = ((List) arrayList.get(i4)).iterator();
            while (it.hasNext()) {
                cvFeatures((File) it.next());
            }
        }
        findPerniciousFeatures();
        trainOnSbFiles(list);
    }

    private void trainOnSbFilesWithRescore(List<File> list, MEMMModel mEMMModel, double d) throws Exception {
        MEMMOutputRescorerTrainer mEMMOutputRescorerTrainer = new MEMMOutputRescorerTrainer(mEMMModel, d);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < 3; i++) {
            arrayList.add(new ArrayList());
            arrayList2.add(new ArrayList());
        }
        for (int i2 = 0; i2 < list.size(); i2++) {
            for (int i3 = 0; i3 < 3; i3++) {
                if (i3 == i2 % 3) {
                    ((List) arrayList.get(i3)).add(list.get(i2));
                } else {
                    ((List) arrayList2.get(i3)).add(list.get(i2));
                }
            }
        }
        for (int i4 = 0; i4 < 3; i4++) {
            if (this.simpleRescore) {
                trainOnSbFiles((List) arrayList2.get(i4));
            } else {
                trainOnSbFilesWithCVFS((List) arrayList2.get(i4));
            }
            Iterator it = ((List) arrayList.get(i4)).iterator();
            while (it.hasNext()) {
                mEMMOutputRescorerTrainer.trainOnFile((File) it.next(), mEMMModel);
            }
            this.evsByPrev.clear();
            if (!this.simpleRescore) {
                this.featureCVScores.clear();
                this.perniciousFeatures.clear();
            }
        }
        mEMMOutputRescorerTrainer.finishTraining();
        MEMMOutputRescorer mEMMOutputRescorer = new MEMMOutputRescorer();
        mEMMOutputRescorer.readElement(mEMMOutputRescorerTrainer.writeElement());
        this.model.setRescorer(mEMMOutputRescorer);
        if (this.simpleRescore) {
            trainOnSbFiles(list);
        } else {
            trainOnSbFilesWithCVFS(list);
        }
    }

    private void finishTraining() throws IOException {
        this.model.makeEntityTypesAndZeroProbs();
        for (BioType bioType : this.evsByPrev.keySet()) {
            this.logger.debug("tag: {}", bioType);
            List<Event> list = this.evsByPrev.get(bioType);
            if (this.featureSel) {
                list = new FeatureSelector().selectFeatures(list);
            }
            if (list.size() == 1) {
                list.add(list.get(0));
            }
            try {
                this.model.putGISModel(bioType, GIS.trainModel(this.trainingCycles, new TwoPassDataIndexer(new EventCollectorAsStream(new SimpleEventCollector(list)), this.featureCutOff)));
            } catch (Exception e) {
                this.model.putGISModel(bioType, GIS.trainModel(this.trainingCycles, new TwoPassDataIndexer(new EventCollectorAsStream(new SimpleEventCollector(list)), 1)));
            }
        }
    }

    private Map<BioType, Double> runGIS(MaxentModel maxentModel, String[] strArr) {
        HashMap hashMap = new HashMap();
        hashMap.putAll(this.model.getZeroProbs());
        double[] eval = maxentModel.eval(strArr);
        for (int i = 0; i < eval.length; i++) {
            hashMap.put(BioType.fromString(maxentModel.getOutcome(i)), Double.valueOf(eval[i]));
        }
        return hashMap;
    }

    private Map<BioType, Map<BioType, Double>> calcResults(FeatureList featureList) {
        HashMap hashMap = new HashMap();
        String[] array = featureList.toArray();
        for (BioType bioType : this.model.getTagSet()) {
            MaxentModel maxentModelByPrev = this.model.getMaxentModelByPrev(bioType);
            if (maxentModelByPrev != null) {
                hashMap.put(bioType, runGIS(maxentModelByPrev, array));
            }
        }
        return hashMap;
    }

    private void cvFeatures(File file) throws IOException, DataFormatException {
        long currentTimeMillis = System.currentTimeMillis();
        this.logger.debug("Cross-Validate features on: " + file + "... ");
        try {
            Document build = new Builder().build(file);
            Nodes query = build.query("//cmlPile");
            for (int i = 0; i < query.size(); i++) {
                query.get(i).detach();
            }
            Nodes query2 = build.query("//ne[@type='CPR']");
            for (int i2 = 0; i2 < query2.size(); i2++) {
                XOMTools.removeElementPreservingText((Element) query2.get(i2));
            }
            Iterator<TokenSequence> it = XOMBasedProcessingDocumentFactory.getInstance().makeTokenisedDocument(Tokeniser.getDefaultInstance(), build, true, false).getTokenSequences().iterator();
            while (it.hasNext()) {
                cvFeatures(it.next());
            }
            this.logger.debug("time: {}", Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
        } catch (ParsingException e) {
            throw new DataFormatException("malformed scrapbook file: " + file.getName(), e);
        }
    }

    private double infoLoss(double[] dArr, int i) {
        return (-Math.log(dArr[i])) / Math.log(2.0d);
    }

    private void cvFeatures(TokenSequence tokenSequence) {
        if (this.featureCVScores == null) {
            this.featureCVScores = new HashMap();
        }
        List<FeatureList> extractFeatures = FeatureExtractor.extractFeatures(tokenSequence, this.model.getNGram(), this.model.getChemNameDictNames());
        List<Token> tokens = tokenSequence.getTokens();
        BioType bioType = new BioType(BioTag.O);
        for (int i = 0; i < tokens.size(); i++) {
            BioType bioType2 = tokens.get(i).getBioType();
            MaxentModel maxentModelByPrev = this.model.getMaxentModelByPrev(bioType);
            if (maxentModelByPrev != null) {
                Map<String, Double> map = this.featureCVScores.get(bioType);
                if (map == null) {
                    map = new HashMap();
                    this.featureCVScores.put(bioType, map);
                }
                bioType = bioType2;
                int index = maxentModelByPrev.getIndex(bioType2.toString());
                if (index != -1) {
                    FeatureList featureList = extractFeatures.get(i);
                    if (featureList.getFeatureCount() != 0) {
                        String[] array = featureList.toArray();
                        String[] array2 = featureList.toArray();
                        double[] eval = maxentModelByPrev.eval(array);
                        for (int i2 = 0; i2 < featureList.getFeatureCount(); i2++) {
                            array2[i2] = "IGNORETHIS";
                            double infoLoss = infoLoss(maxentModelByPrev.eval(array2), index) - infoLoss(eval, index);
                            if (Double.isNaN(infoLoss)) {
                                infoLoss = 0.0d;
                            }
                            String feature = featureList.getFeature(i2);
                            double d = 0.0d;
                            if (map.containsKey(feature)) {
                                d = map.get(feature).doubleValue();
                            }
                            map.put(feature, Double.valueOf(infoLoss + d));
                            array2[i2] = array[i2];
                        }
                    }
                }
            }
        }
    }

    private void findPerniciousFeatures() {
        this.perniciousFeatures = new HashMap();
        for (BioType bioType : this.featureCVScores.keySet()) {
            HashSet hashSet = new HashSet();
            this.perniciousFeatures.put(bioType, hashSet);
            for (String str : StringTools.getSortedKeyList(this.featureCVScores.get(bioType))) {
                double doubleValue = this.featureCVScores.get(bioType).get(str).doubleValue();
                if (doubleValue < XPath.MATCH_SCORE_QNAME) {
                    this.logger.debug("Removing:\t" + bioType + "\t" + str + "\t" + doubleValue);
                    hashSet.add(str);
                }
            }
        }
    }

    public void rescore(List<NamedEntity> list) {
        this.model.getRescorer().rescore(list, this.model.getChemNameDictNames());
    }

    public MEMMModel getModel() {
        return this.model;
    }

    public void trainOnDocs(List<Document> list) throws IOException {
        this.model.setExtractedTrainingData(new ExtractedTrainingData(new TrainingDataExtractor(list).toXML()));
        this.model.nGram = NGramBuilder.buildOrDeserialiseModel(this.model.etd, this.model.chemNameDictNames);
        Iterator<Document> it = list.iterator();
        while (it.hasNext()) {
            trainOnDoc((Document) it.next().copy());
        }
        finishTraining();
    }
}
