% Leave One Out PER GROUP Cross Validation, Matched:
%    Assumes, that n-th subject of a group is paired with the n-th subject
%    of another group. (Also assumes, the groups are about the same size).
%    In every round n-th subject of each group is used as test set, the
%    remaining subjects are the training set.
classdef janis_validatorMatched < janis_validator

    methods

        % Constructor
        function obj = janis_validatorMatched(classCollection, classifier,patternParser)
           obj = obj@janis_validator(classCollection, classifier,patternParser);
           if ~isa(classCollection,'janis_classCollection')
               throw(MException('janis:WrongClass','This is not a classCollection-Class!'));
           %elseif (classCollection.size ~= 2)
           %    throw(MException('janis:WrongClassification','Can only binary classify'));
           elseif (classCollection.get(1).size ~= classCollection.get(2).size)
               throw(MException('janis:WrongMatching','class should be matched! But cannot, because subject count differs'));
           elseif  ~isa(classifier,'janis_classifier')
               throw(MException('janis:WrongClass','This is not a classifier-Class!'));
           elseif mod(classCollection.patternCount,classCollection.size) ~= 0
               throw(MException('janis:WrongMatching','Uhm, there seems to be something wrong with the number of subjects...'));
           else
               % rounds should be half of the amount of patterns, i.e. the
               % amount of subjects in each group.
               obj.rounds = classCollection.patternCount/classCollection.size;
           end
        end

        % runs the validation
        function results = runValidation(jv)
            Indices = 1:jv.rounds;
            jv.classPerformance = classperf(jv.patternParser.getLabels);
            for r=1:jv.rounds
                %testIdx = [(Indices == r) (Indices == r)];
                testIdx = repmat((Indices == r),1,jv.patternParser.classCollection.size);
                jv = jv.computeForSubjects(testIdx);
                jv.actualRound = r;
                notify(jv,'cycle');
                jv.roundPerformances = [jv.roundPerformances; jv.classPerformance.LastCorrectRate];
            end
            results = jv.classPerformance;
        end

        function clStruct = getClassifierForSubject(jv,subject)
            Indices = 1:jv.rounds;
            testIdx = [(Indices == subject) (Indices == subject)];
            [patternMat, trainLabels] = jv.patternParser.get(testIdx == 0);
            clStruct = jv.classifier.getClassifier(patternMat, trainLabels);
        end

        function jv = computeForSubjects(jv, testIdx)
                   leaveOuts = {};
                   for k=1:sum(testIdx)
                       leaveOuts{size(leaveOuts,1)+1,1} = jv.patternParser.classCollection.classes{k}.members{testIdx(1:length(testIdx)/sum(testIdx))}.description;
                   end
                   for h=1:size(jv.patternParser.intraValidationProcessors)
                      if isa(jv.patternParser.intraValidationProcessors{h},'janis_pro_2ndLevel')
                          jv.patternParser.intraValidationProcessors{h}.setLeaveOut(leaveOuts);
                      end
                   end
                   [patternMat, trainLabels, testMat, testLabels] = jv.patternParser.get(~testIdx);
                   jv.classifier = jv.classifier.train(patternMat, trainLabels);
                   [prediction decFactors] = jv.classifier.test(testMat);
                   jv.classPerformance = classperf(jv.classPerformance,prediction,testIdx);
                   jv.predicted = [jv.predicted; prediction'];
                   decFactors.label = testLabels;
                   jv.decFactors = [jv.decFactors decFactors];
        end

        function jv = exportCSV(jv,destFolder)
                [patternMat, trainLabels, testMat, testLabels] = jv.patternParser.get([1:(jv.patternParser.size)]);
                for k=1:size(patternMat,2)
                    toSave = patternMat(:,k);
                    save([destFolder filesep 'subject' num2str(k) '.txt'],'toSave','-ASCII');
                end

        end

        function new = copy(this)
            new = copy@janis_validator(this);
        end

    end
end