/*
 * Decompiled with CFR 0.152.
 */
package unbbayes.datamining.evaluation.batchEvaluation;

import unbbayes.datamining.classifiers.Classifier;
import unbbayes.datamining.datamanipulation.Instance;
import unbbayes.datamining.datamanipulation.InstanceSet;
import unbbayes.datamining.datamanipulation.Utils;
import unbbayes.datamining.evaluation.ROCAnalysis;
import unbbayes.datamining.evaluation.batchEvaluation.Classifiers;
import unbbayes.datamining.evaluation.batchEvaluation.Indexes;
import unbbayes.datamining.evaluation.batchEvaluation.InitializePreprocessors;
import unbbayes.datamining.evaluation.batchEvaluation.PreprocessorParameters;
import unbbayes.datamining.evaluation.batchEvaluation.model.Evaluations;

public class FoldEvaluation {
    private float[][][][] rocPointsAvg;
    private float[][][][] rocPointsStdDev;
    private float[][][][][] rocPointsProbsTemp;
    private double[][][] auc;
    private double[][][] aucTemp;
    private String[] samplingName;
    private String[] samplingParameters;
    private int numBatchIterations;
    private int numClassifiers;
    private int positiveClass;
    private InitializePreprocessors preprocessor;
    private Indexes indexes;
    private int pos;
    private int numRoundsTotal;
    private int classIndex;
    private int counterIndex;
    private int samplingID;
    private boolean multiClass;
    private float[][] classDistribution;
    private boolean buildROC;
    private boolean computeAUC;
    private PreprocessorParameters[] preprocessors;

    public FoldEvaluation(int numFolds, int numRounds, PreprocessorParameters[] preprocessors, Evaluations evaluations) {
        this.preprocessors = preprocessors;
        this.buildROC = evaluations.isBuildROC();
        this.computeAUC = evaluations.isComputeAUC();
        this.numClassifiers = Classifiers.getNumClassifiers();
        this.numRoundsTotal = numFolds * numRounds;
        this.pos = 0;
    }

    public void run(InstanceSet originalTrain, InstanceSet test, int positiveClass) throws Exception {
        this.samplingID = 0;
        this.multiClass = originalTrain.isMultiClass();
        this.preprocessor = new InitializePreprocessors(originalTrain, this.preprocessors);
        this.numBatchIterations = this.preprocessor.getTotalNumBatchIterations();
        this.initialize(originalTrain);
        this.buildROC = true;
        this.samplingID = 0;
        while (this.samplingID < this.numBatchIterations) {
            InstanceSet train = new InstanceSet(originalTrain);
            int[] batchID = this.preprocessor.getNextBatchID();
            this.runAux(train, test, batchID);
            ++this.samplingID;
        }
        ++this.pos;
        if (!this.multiClass && this.pos == this.numRoundsTotal) {
            this.average();
        }
    }

    public int runAux(InstanceSet train, InstanceSet test, int[] batchID) throws Exception {
        int preprocessorID = 0;
        try {
            this.preprocessor.applyPreprocessor(train, batchID);
        }
        catch (Exception e) {
            e.printStackTrace();
            boolean bl = true;
        }
        this.samplingName[this.samplingID] = this.preprocessor.getPreprocessorName();
        this.samplingParameters[this.samplingID] = this.preprocessor.getPreprocessorParameters();
        this.classifyEvaluate(train, test);
        this.classDistribution[this.samplingID] = train.getClassDistribution();
        return preprocessorID;
    }

    private void classifyEvaluate(InstanceSet train, InstanceSet test) throws Exception {
        int classfID = 0;
        while (classfID < this.numClassifiers) {
            float[] probs = this.classifyEvaluate(train, test, classfID);
            if (!this.multiClass) {
                if (this.buildROC) {
                    this.rocPointsProbsTemp[this.samplingID][classfID][this.pos] = ROCAnalysis.computeROCPoints(probs, test, this.positiveClass);
                }
                if (this.computeAUC) {
                    this.aucTemp[this.samplingID][classfID][this.pos] = ROCAnalysis.computeAUC(probs, test, this.positiveClass) * 100.0;
                }
            }
            ++classfID;
        }
    }

    private float[] classifyEvaluate(InstanceSet train, InstanceSet test, int classfID) throws Exception {
        float[] distribution = train.getClassDistribution(false);
        Classifier classifier = Classifiers.newClassifier(classfID);
        Classifiers.buildClassifier(train, classifier, distribution, this.positiveClass);
        return this.evaluateClassifier(classifier, test, this.samplingID, classfID);
    }

    private float[] evaluateClassifier(Classifier classifier, InstanceSet test, int samplingID, int classfID) throws Exception {
        this.classIndex = test.classIndex;
        this.counterIndex = test.counterIndex;
        int numInstances = test.numInstances();
        float[] probs = new float[numInstances];
        int inst = 0;
        while (inst < numInstances) {
            Instance instance = test.getInstance(inst);
            float[] dist = classifier.distributionForInstance(instance);
            probs[inst] = dist[this.positiveClass];
            int actualClass = (int)instance.data[this.classIndex];
            int predictedClass = Utils.maxIndex(dist);
            float weight = instance.data[this.counterIndex];
            this.indexes.insert(samplingID, classfID, this.pos, actualClass, predictedClass, weight);
            ++inst;
        }
        return probs;
    }

    private void average() {
        this.rocPointsAvg = new float[this.numBatchIterations][this.numClassifiers][][];
        this.rocPointsStdDev = new float[this.numBatchIterations][this.numClassifiers][][];
        this.auc = new double[this.numBatchIterations][this.numClassifiers][];
        int i = 0;
        while (i < this.numBatchIterations) {
            this.rocPointsAvg[i] = new float[this.numClassifiers][][];
            this.rocPointsStdDev[i] = new float[this.numClassifiers][][];
            this.auc[i] = new double[this.numClassifiers][];
            int j = 0;
            while (j < this.numClassifiers) {
                float[][][] rocPointsAux = ROCAnalysis.averageROCPoints(this.rocPointsProbsTemp[i][j]);
                this.rocPointsAvg[i][j] = rocPointsAux[0];
                this.rocPointsStdDev[i][j] = rocPointsAux[1];
                if (this.numRoundsTotal > 1) {
                    this.auc[i][j] = Utils.computeMeanStdDev(this.aucTemp[i][j]);
                } else {
                    this.auc[i][j] = new double[2];
                    this.auc[i][j][0] = this.aucTemp[i][j][0];
                    this.auc[i][j][1] = 0.0;
                }
                ++j;
            }
            ++i;
        }
    }

    private void initialize(InstanceSet instanceSet) throws Exception {
        if (this.rocPointsProbsTemp == null) {
            this.rocPointsProbsTemp = new float[this.numBatchIterations][this.numClassifiers][this.numRoundsTotal][][];
            this.aucTemp = new double[this.numBatchIterations][this.numClassifiers][this.numRoundsTotal];
            this.samplingName = new String[this.numBatchIterations];
            this.samplingParameters = new String[this.numBatchIterations];
            this.indexes = new Indexes(instanceSet, this.numBatchIterations, this.numClassifiers, this.numRoundsTotal);
            this.classDistribution = new float[this.numBatchIterations][];
        }
    }

    public float[][] getRocPointsAvg(int samplingID, int classifier) {
        return this.rocPointsAvg[samplingID][classifier];
    }

    public float[][] getRocPointsStdDev(int samplingID, int classifier) {
        return this.rocPointsStdDev[samplingID][classifier];
    }

    public double[] getAuc(int samplingID, int classifier) {
        return this.auc[samplingID][classifier];
    }

    public InitializePreprocessors getSamplings() {
        return this.preprocessor;
    }

    public String getSamplingName(int samplingID) {
        return this.samplingName[samplingID];
    }

    public String getSamplingParameters(int samplingID) {
        return this.samplingParameters[samplingID];
    }

    public int getNumBatchIterations() {
        return this.numBatchIterations;
    }

    public int getNumClassifiers() {
        return this.numClassifiers;
    }

    public Indexes getIndexes() {
        return this.indexes;
    }

    public int getClassDistribution(int samplingID, int classID) {
        return (int)this.classDistribution[samplingID][classID];
    }
}

