/*
 * Decompiled with CFR 0.152.
 */
package unbbayes.datamining.evaluation;

import unbbayes.datamining.classifiers.Classifier;
import unbbayes.datamining.classifiers.DistributionClassifier;
import unbbayes.datamining.datamanipulation.InstanceSet;
import unbbayes.datamining.evaluation.Evaluation;
import unbbayes.datamining.evaluation.Folds;
import unbbayes.datamining.evaluation.ITrainingMode;
import unbbayes.datamining.evaluation.batchEvaluation.FoldEvaluation;
import unbbayes.datamining.evaluation.batchEvaluation.PreprocessorParameters;
import unbbayes.datamining.evaluation.batchEvaluation.model.Evaluations;

public class CrossValidation
implements ITrainingMode {
    public static void crossValidateModel(InstanceSet instanceSet, Classifier classifier, int numFolds) throws Exception {
        Evaluation evaluation = new Evaluation(instanceSet);
        Folds folds = new Folds(instanceSet, numFolds);
        int fold = 0;
        while (fold < numFolds) {
            InstanceSet train = folds.getTrain(fold);
            classifier.buildClassifier(train);
            InstanceSet test = folds.getTest(fold);
            evaluation.evaluateModel(classifier, test);
            ++fold;
        }
    }

    public static float[] getEvaluatedProbabilities(Classifier classifier, InstanceSet instanceSet, int positiveClass, int numFolds) throws Exception {
        Folds folds = new Folds(instanceSet, numFolds);
        int numInstances = instanceSet.numInstances;
        float[] probs = new float[numInstances];
        int inst = 0;
        int fold = 0;
        while (fold < numFolds) {
            InstanceSet train = folds.getTrain(fold);
            classifier.buildClassifier(train);
            InstanceSet test = folds.getTest(fold);
            numInstances = test.numInstances();
            int i = 0;
            while (i < numInstances) {
                float[] dist = ((DistributionClassifier)classifier).distributionForInstance(test.getInstance(i));
                probs[inst] = dist[positiveClass];
                ++inst;
                ++i;
            }
            ++fold;
        }
        return probs;
    }

    public static FoldEvaluation getEvaluatedProbabilities(InstanceSet instanceSet, int numFolds, int numRounds, int positiveClass, PreprocessorParameters[] preprocessors, Evaluations evaluations) throws Exception {
        if (instanceSet.isCompacted()) {
            throw new IllegalArgumentException("cross validation works only on non compacted instanceSet!");
        }
        FoldEvaluation testFold = new FoldEvaluation(1, numRounds, preprocessors, evaluations);
        int round = 0;
        while (round < numRounds) {
            Folds folds = new Folds(instanceSet, numFolds);
            int fold = 0;
            while (fold < numFolds) {
                InstanceSet train = folds.getTrain(fold);
                InstanceSet test = folds.getTest(fold);
                testFold.run(train, test, positiveClass);
                ++fold;
            }
            ++round;
        }
        return testFold;
    }

    public static FoldEvaluation getEvaluatedProbabilities(InstanceSet train, InstanceSet test, int numRounds, int positiveClass, PreprocessorParameters[] preprocessors, Evaluations evaluations) throws Exception {
        FoldEvaluation testFold = new FoldEvaluation(1, numRounds, preprocessors, evaluations);
        int round = 0;
        while (round < numRounds) {
            testFold.run(train, test, positiveClass);
            ++round;
        }
        return testFold;
    }
}

