/*
 * Decompiled with CFR 0.152.
 */
package unbbayes.datamining.evaluation.batchEvaluation;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Writer;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.ArrayList;
import java.util.Locale;
import java.util.ResourceBundle;
import unbbayes.datamining.datamanipulation.ArffLoader;
import unbbayes.datamining.datamanipulation.InstanceSet;
import unbbayes.datamining.datamanipulation.Loader;
import unbbayes.datamining.datamanipulation.TxtLoader;
import unbbayes.datamining.datamanipulation.Utils;
import unbbayes.datamining.evaluation.Folds;
import unbbayes.datamining.evaluation.ROCAnalysis;
import unbbayes.datamining.evaluation.batchEvaluation.Classifiers;
import unbbayes.datamining.evaluation.batchEvaluation.FoldEvaluation;
import unbbayes.datamining.evaluation.batchEvaluation.Indexes;
import unbbayes.datamining.evaluation.batchEvaluation.PreprocessorParameters;
import unbbayes.datamining.evaluation.batchEvaluation.model.Datasets;
import unbbayes.datamining.evaluation.batchEvaluation.model.Evaluations;
import unbbayes.datamining.evaluation.batchEvaluation.model.Preprocessors;
import unbbayes.datamining.gui.evaluation.batchEvaluation.controllers.LogsTabController;

public class RunScript {
    private int numRounds = 1;
    private int numFolds = 10;
    private Datasets datasets;
    private Preprocessors preprocessors;
    private Evaluations evaluations;
    private boolean computeAUC;
    private boolean buildROC;
    private String inputFilePath;
    private String outputFilePath;
    private String aucFileName;
    private String rocFileNameExtension;
    private String hullFileName;
    private int numberFractionDigits = 2;
    private LogsTabController logsWindowController;
    private ResourceBundle resource;

    public RunScript(Datasets dataset, Preprocessors preprocessors, Evaluations evaluations, LogsTabController logsWindowController, ResourceBundle resource) {
        this.datasets = dataset;
        this.preprocessors = preprocessors;
        this.evaluations = evaluations;
        this.logsWindowController = logsWindowController;
        this.resource = resource;
        this.computeAUC = evaluations.isComputeAUC();
        this.buildROC = evaluations.isBuildROC();
        this.outputFilePath = "results\\";
        this.aucFileName = "auc.txt";
        this.rocFileNameExtension = " - roc.txt";
        this.hullFileName = "hull.txt";
    }

    public void run() throws Exception {
        String originalOutputFilePath = this.outputFilePath;
        int numActiveDatasets = this.datasets.getNumActiveData();
        int i = 0;
        while (i < numActiveDatasets) {
            InstanceSet instanceSet = this.getInstanceSet(i);
            this.outputFilePath = String.valueOf(this.inputFilePath) + originalOutputFilePath;
            this.createPath(this.outputFilePath);
            FoldEvaluation testFold = this.getEvaluatedProbabilities(instanceSet);
            this.saveResults(instanceSet, testFold);
            String finished = this.resource.getString("runScriptFinished");
            String log = "Dataset: " + this.datasets.getDatasetName(i);
            log = String.valueOf(log) + " " + finished;
            this.logsWindowController.insertData(log);
            ++i;
        }
    }

    private FoldEvaluation getEvaluatedProbabilities(InstanceSet instanceSet) throws Exception {
        if (instanceSet.isCompacted()) {
            throw new IllegalArgumentException("cross validation works only on non compacted instanceSet!");
        }
        int positiveClass = instanceSet.getPositiveClass();
        PreprocessorParameters[] preprocessors = this.preprocessors.getActivePreprocessors();
        FoldEvaluation testFold = new FoldEvaluation(this.numFolds, this.numRounds, preprocessors, this.evaluations);
        int round = 0;
        while (round < this.numRounds) {
            Folds folds = new Folds(instanceSet, this.numFolds);
            int fold = 0;
            while (fold < this.numFolds) {
                InstanceSet train = folds.getTrain(fold);
                InstanceSet test = folds.getTest(fold);
                testFold.run(train, test, positiveClass);
                ++fold;
            }
            ++round;
        }
        return testFold;
    }

    private InstanceSet getInstanceSet(int i) throws Exception {
        String fileName = this.datasets.getDatasetFullName(i);
        int counterIndex = this.datasets.getCounterIndex(i);
        int classIndex = this.datasets.getClassIndex(i);
        return this.openFile(fileName, counterIndex, classIndex);
    }

    private InstanceSet openFile(String fileName, int counterIndex, int classIndex) throws Exception {
        InstanceSet instanceSet = this.openFile(fileName, counterIndex);
        instanceSet.setClassIndex(classIndex);
        return instanceSet;
    }

    private InstanceSet openFile(String fileName, int counterIndex) throws Exception {
        File file = new File(fileName);
        Loader loader = null;
        if (fileName.regionMatches(true, fileName.length() - 5, ".arff", 0, 5)) {
            loader = new ArffLoader(file, -1);
        } else if (fileName.regionMatches(true, fileName.length() - 4, ".txt", 0, 4)) {
            loader = new TxtLoader(file, -1);
        }
        loader.setCounterAttribute(counterIndex);
        while (loader.getInstance()) {
        }
        if (loader != null) {
            this.inputFilePath = String.valueOf(file.getParent()) + "\\";
            return loader.getInstanceSet();
        }
        String exceptionMsg = "Couldn't open training instanceSet " + fileName;
        throw new Exception(exceptionMsg);
    }

    private void saveResults(InstanceSet instanceSet, FoldEvaluation testFold) throws IOException {
        if (instanceSet.isMultiClass()) {
            this.saveMultiClassResults(instanceSet, testFold);
        } else {
            if (this.computeAUC) {
                this.saveAUCResults(instanceSet, testFold);
            }
            if (this.buildROC) {
                this.saveROCResults(instanceSet, testFold);
                this.saveHullResults(instanceSet, testFold);
            }
        }
    }

    private void saveMultiClassResults(InstanceSet instanceSet, FoldEvaluation testFold) throws IOException {
        int numSamplings = testFold.getNumBatchIterations();
        int numClassifiers = testFold.getNumClassifiers();
        Indexes indexes = testFold.getIndexes();
        int samplingID = 0;
        while (samplingID < numSamplings) {
            int classfID = 0;
            while (classfID < numClassifiers) {
                this.saveMultiClassResults(instanceSet, testFold, indexes, samplingID, classfID);
                ++classfID;
            }
            ++samplingID;
        }
    }

    private void saveMultiClassResults(InstanceSet instanceSet, FoldEvaluation testFold, Indexes indexes, int samplingID, int classfierID) throws IOException {
        int numClasses = instanceSet.numClasses();
        int classIndex = instanceSet.classIndex;
        String classifierName = Classifiers.getClassifierName(classfierID);
        String samplingName = testFold.getSamplingName(samplingID).replace('\t', ' ');
        String outputFilePath = String.valueOf(new String(this.outputFilePath)) + "/";
        outputFilePath = String.valueOf(outputFilePath) + classifierName + " - " + samplingName + " - ";
        String fileName = String.valueOf(outputFilePath) + "confusionMatrix.txt";
        File output = new File(fileName);
        PrintWriter writer = new PrintWriter((Writer)new FileWriter(output), true);
        int i = 0;
        while (i < numClasses) {
            writer.print("\t");
            writer.print(instanceSet.attributes[classIndex].value(i));
            ++i;
        }
        writer.println();
        int[][] confusionMatrix = indexes.getConfusionMatrix(samplingID, classfierID);
        int i2 = 0;
        while (i2 < numClasses) {
            writer.print(instanceSet.attributes[classIndex].value(i2));
            int j = 0;
            while (j < numClasses) {
                writer.print("\t");
                writer.print(confusionMatrix[i2][j]);
                ++j;
            }
            writer.println();
            ++i2;
        }
        writer.println();
        writer.println();
        writer.println("Class frequency");
        writer.println("\tTraining set\tEvaluation set");
        i2 = 0;
        while (i2 < numClasses) {
            writer.print(instanceSet.attributes[classIndex].value(i2));
            writer.print("\t");
            writer.print(testFold.getClassDistribution(samplingID, i2));
            writer.print("\t");
            ++i2;
        }
        writer.flush();
        writer.close();
    }

    private void saveROCResults(InstanceSet instanceSet, FoldEvaluation testFold) throws IOException {
        int numClassifiers = Classifiers.getNumClassifiers();
        int numSamplings = testFold.getNumBatchIterations();
        int i = 0;
        while (i < numClassifiers) {
            String classifierName = Classifiers.getClassifierName(i);
            String filePath = String.valueOf(this.outputFilePath) + classifierName;
            this.createPath(filePath);
            int sampleID = 0;
            while (sampleID < numSamplings) {
                String sampleName = testFold.getSamplingName(sampleID);
                String fileName = String.valueOf(filePath) + "\\" + sampleID + sampleName + this.rocFileNameExtension;
                File output = new File(fileName);
                PrintWriter writer = new PrintWriter((Writer)new FileWriter(output), true);
                float[][] rocPointsAvg = testFold.getRocPointsAvg(sampleID, i);
                float[][] rocPointsStdDev = testFold.getRocPointsStdDev(sampleID, i);
                int numROCPoints = rocPointsAvg.length;
                float[][] rocPointsAux = new float[numROCPoints][4];
                int n = 0;
                while (n < numROCPoints) {
                    rocPointsAux[n][0] = rocPointsAvg[n][0];
                    rocPointsAux[n][1] = rocPointsAvg[n][1];
                    rocPointsAux[n][2] = rocPointsStdDev[n][0];
                    rocPointsAux[n][3] = rocPointsStdDev[n][1];
                    ++n;
                }
                Utils.sort(rocPointsAux);
                numROCPoints = rocPointsAvg.length;
                writer.println();
                writer.print(String.valueOf(sampleName) + "\t");
                writer.println(testFold.getSamplingParameters(sampleID));
                int k = 0;
                while (k < numROCPoints) {
                    String fpAvg = this.toComma(rocPointsAux[k][0], this.numberFractionDigits);
                    String tpAvg = this.toComma(rocPointsAux[k][1], this.numberFractionDigits);
                    String fpStdDev = this.toComma(rocPointsAux[k][2], this.numberFractionDigits);
                    String tpStdDev = this.toComma(rocPointsAux[k][3], this.numberFractionDigits);
                    writer.print(String.valueOf(fpAvg) + "\t" + tpAvg + "\t");
                    writer.println(String.valueOf(fpStdDev) + "\t" + tpStdDev);
                    ++k;
                }
                writer.flush();
                writer.close();
                ++sampleID;
            }
            ++i;
        }
    }

    private void saveHullResults(InstanceSet instanceSet, FoldEvaluation testFold) throws IOException {
        int numClassifiers = Classifiers.getNumClassifiers();
        int numSamplings = testFold.getNumBatchIterations();
        ArrayList<float[]> hullPoints = new ArrayList<float[]>();
        int classID = 0;
        while (classID < numClassifiers) {
            int sampleID = 0;
            while (sampleID < numSamplings) {
                float[][] rocPoints = testFold.getRocPointsAvg(sampleID, classID);
                int numRocPoints = rocPoints.length;
                int i = 0;
                while (i < numRocPoints) {
                    hullPoints.add(rocPoints[i]);
                    ++i;
                }
                ++sampleID;
            }
            ++classID;
        }
        hullPoints = ROCAnalysis.computeConvexHull(hullPoints);
        ROCAnalysis.sort(hullPoints);
        File output = new File(String.valueOf(this.outputFilePath) + this.hullFileName);
        PrintWriter writer = new PrintWriter((Writer)new FileWriter(output), true);
        int numHullResults = hullPoints.size();
        writer.println();
        writer.println("Hull");
        this.numberFractionDigits = 10;
        int n = 0;
        while (n < numHullResults) {
            String fp = this.toComma(hullPoints.get(n)[0], this.numberFractionDigits);
            String tp = this.toComma(hullPoints.get(n)[1], this.numberFractionDigits);
            writer.println(String.valueOf(fp) + "\t" + tp);
            ++n;
        }
        writer.flush();
        writer.close();
    }

    private String toComma(float f, int numberFractionDigits) {
        DecimalFormatSymbols dfs = new DecimalFormatSymbols();
        dfs.setDecimalSeparator(',');
        DecimalFormat format = new DecimalFormat();
        format.setDecimalFormatSymbols(dfs);
        format.setMinimumFractionDigits(numberFractionDigits);
        return format.format(f);
    }

    private void createPath(String path) {
        File output = new File(path);
        if (!output.exists()) {
            output.mkdir();
        }
    }

    private void saveAUCResults(InstanceSet instanceSet, FoldEvaluation testFold) throws IOException {
        int numClassifiers = testFold.getNumClassifiers();
        int numSamplings = testFold.getNumBatchIterations();
        String fileName = String.valueOf(this.outputFilePath) + this.aucFileName;
        File output = new File(fileName);
        PrintWriter writer = new PrintWriter((Writer)new FileWriter(output), true);
        int i = 0;
        while (i < numClassifiers) {
            String classifierName = Classifiers.getClassifierName(i);
            int sampleID = 0;
            while (sampleID < numSamplings) {
                writer.print(classifierName);
                writer.print("\t" + testFold.getSamplingName(sampleID));
                double value = testFold.getAuc(sampleID, i)[0];
                double stdDev = testFold.getAuc(sampleID, i)[1];
                String stringValue = String.format(Locale.FRANCE, "%.2f", value);
                String stringStdDev = String.format(Locale.FRANCE, "%.2f", stdDev);
                writer.print("\t" + stringValue + " (" + stringStdDev + ")");
                writer.print("\t" + testFold.getSamplingParameters(sampleID));
                writer.println();
                ++sampleID;
            }
            ++i;
        }
        writer.flush();
        writer.close();
    }
}

