/*
 * Decompiled with CFR 0.152.
 */
package unbbayes.datamining.evaluation;

import java.util.Date;
import java.util.Random;
import java.util.ResourceBundle;
import unbbayes.datamining.datamanipulation.InstanceSet;

public class Folds {
    private static ResourceBundle resource = ResourceBundle.getBundle("unbbayes.datamining.classifiers.resources.ClassifiersResource");
    int numFolds;
    int currentFold;
    InstanceSet instanceSet;
    InstanceSet train;
    InstanceSet temp;
    InstanceSet test;
    InstanceSet eval;

    public Folds(InstanceSet instanceSet, int numFolds) throws Exception {
        this.numFolds = numFolds;
        this.instanceSet = instanceSet;
        if (instanceSet.isCompacted()) {
            throw new IllegalArgumentException("cross validation only works with non compacted instanceSet!");
        }
        if (instanceSet.getClassAttribute().isNominal()) {
            instanceSet.stratify(numFolds);
        } else {
            instanceSet.randomize(new Random(new Date().getTime()));
        }
    }

    public InstanceSet getTrain(int foldID) {
        return Folds.trainCV(this.instanceSet, this.numFolds, foldID);
    }

    public InstanceSet getTest(int foldID) {
        return Folds.testCV(this.instanceSet, this.numFolds, foldID);
    }

    public static InstanceSet trainCV(InstanceSet instanceSet, int numFolds, int currentFold) {
        int offset;
        int numInstances = instanceSet.numInstances();
        if (numFolds < 2) {
            throw new IllegalArgumentException(resource.getString("folds2"));
        }
        if (numFolds > numInstances) {
            throw new IllegalArgumentException(resource.getString("moreFolds"));
        }
        int numInstPerFold = numInstances / numFolds;
        if (currentFold < numInstances % numFolds) {
            ++numInstPerFold;
            offset = currentFold;
        } else {
            offset = numInstances % numFolds;
        }
        int trainSize = numInstances - numInstPerFold;
        InstanceSet train = new InstanceSet(instanceSet, trainSize);
        int first = currentFold * (numInstances / numFolds) + offset;
        instanceSet.copyInstancesTo(train, 0, first);
        instanceSet.copyInstancesTo(train, first + numInstPerFold, numInstances - first - numInstPerFold);
        return train;
    }

    public static InstanceSet testCV(InstanceSet instanceSet, int numFolds, int currentFold) {
        int offset;
        int numInstances = instanceSet.numInstances();
        if (numFolds < 2) {
            throw new IllegalArgumentException(resource.getString("folds2"));
        }
        if (numFolds > numInstances) {
            throw new IllegalArgumentException(resource.getString("moreFolds"));
        }
        int numInstPerFold = numInstances / numFolds;
        if (currentFold < numInstances % numFolds) {
            ++numInstPerFold;
            offset = currentFold;
        } else {
            offset = numInstances % numFolds;
        }
        InstanceSet test = new InstanceSet(instanceSet, numInstPerFold);
        int first = currentFold * (numInstances / numFolds) + offset;
        instanceSet.copyInstancesTo(test, first, numInstPerFold);
        return test;
    }
}

