% Simple Leave One Out Crossvalidation
classdef janis_validatorLeaveOneOut < janis_validator

    methods
        % Constructor
        function obj = janis_validatorLeaveOneOut(classCollection, classifier,patternParser)
           obj = obj@janis_validator(classCollection, classifier,patternParser);
           if ~isa(classCollection,'janis_classCollection')
               throw(MException('janis:WrongClass','This is not a classCollection-Class!'));
           elseif  ~isa(classifier,'janis_classifier')
               throw(MException('janis:WrongClass','This is not a classifier-Class!'));
           else
               % simple: the number of rounds is the amount of subjects
               obj.rounds = classCollection.patternCount;
           end
        end

        % leave one out cross validation
        function results = runValidation(jv)
            Indices = 1:jv.rounds;
            jv.classPerformance = classperf(jv.patternParser.getLabels);
            for r=1:jv.rounds
                testIdx = (Indices == r);
                jv = jv.computeForSubjects(testIdx);
                jv.actualRound = r;
                notify(jv,'cycle');
                jv.roundPerformances = [jv.roundPerformances; jv.classPerformance.LastCorrectRate];
            end
            results = jv.classPerformance;
        end

        function clStruct = getClassifierForSubject(jv,subject)
            Indices = 1:jv.rounds;
            testIdx = (Indices == subject);
            [patternMat, trainLabels] = jv.patternParser.get(testIdx == 0);
            clStruct = jv.classifier.getClassifier(patternMat, trainLabels);
        end

        function jv = computeForSubjects(jv, testIdx)
               [patternMat, trainLabels, testMat, testLabels] = jv.patternParser.get(testIdx == 0);
               jv.classifier = jv.classifier.train(patternMat, trainLabels);
               [prediction decFactors] = jv.classifier.test(testMat);
               jv.classPerformance = classperf(jv.classPerformance,prediction,testIdx);
               jv.predicted = [jv.predicted; prediction'];
               decFactors.label = testLabels;
               jv.decFactors = [jv.decFactors decFactors];
        end


    end


end