/*
 * Decompiled with CFR 0.152.
 */
package unbbayes.datamining.preprocessor.imbalanceddataset;

import java.util.Arrays;
import unbbayes.datamining.datamanipulation.Instance;
import unbbayes.datamining.datamanipulation.InstanceSet;
import unbbayes.datamining.evaluation.batchEvaluation.PreprocessorParameters;
import unbbayes.datamining.preprocessor.imbalanceddataset.Batch;
import unbbayes.datamining.preprocessor.imbalanceddataset.ClusterBasedUtils;
import unbbayes.datamining.preprocessor.imbalanceddataset.RandomOversampling;
import unbbayes.datamining.preprocessor.imbalanceddataset.RandomUndersampling;
import unbbayes.datamining.preprocessor.imbalanceddataset.Simplesampling;
import unbbayes.datamining.preprocessor.imbalanceddataset.Smote;
import unbbayes.datamining.preprocessor.imbalanceddataset.Utils;

public class ClusterBasedOversampling
extends Batch {
    protected Instance[] instances;
    protected int numInstances;
    protected int numClasses;
    protected int counterIndex;
    protected int classIndex;
    protected int[][][] clusters;
    protected int[] numClusters;
    protected double[][] clustersSize;
    protected int[] clustersClass;
    protected boolean[] deleteIndex;
    protected boolean[] classClusterized;
    private ClusterBasedUtils clusterBasedUtils;
    private boolean overMajority;
    private RandomOversampling oversampling;
    private RandomUndersampling undersampling;
    private Simplesampling simplesampling;
    private Smote smote;
    private static boolean useRatio = true;
    private static boolean useK = true;
    private static boolean useOverThresh = false;
    private static boolean usePosThresh = false;
    private static boolean useNegThresh = false;
    private static boolean useCleaning = false;

    public ClusterBasedOversampling(InstanceSet instanceSet, PreprocessorParameters parameters) {
        super(useRatio, useK, useOverThresh, usePosThresh, useNegThresh, useCleaning, instanceSet, parameters);
        this.preprocessorName = "Cluster Based Oversampling";
    }

    public ClusterBasedOversampling(ClusterBasedUtils clusterBasedUtils, InstanceSet instanceSet, PreprocessorParameters parameters) throws Exception {
        this(instanceSet, parameters);
        this.setClusterInfo(clusterBasedUtils);
    }

    public ClusterBasedOversampling(ClusterBasedUtils clusterBasedUtils, InstanceSet instanceSet) throws Exception {
        this(instanceSet, null);
        this.setClusterInfo(clusterBasedUtils);
    }

    public ClusterBasedOversampling(InstanceSet instanceSet) throws Exception {
        this(instanceSet, null);
        this.setClusterInfo(null);
    }

    public void setClusterInfo(ClusterBasedUtils clusterBasedUtils) throws Exception {
        this.clusterBasedUtils = clusterBasedUtils;
    }

    private void initializeClusters() throws Exception {
        this.clusters = this.clusterBasedUtils.getClustersByClass(this.k);
        this.numClusters = this.clusterBasedUtils.getNumClustersByClass(this.k);
        this.clustersSize = this.clusterBasedUtils.getClustersSizeByClass(this.k);
    }

    public void setInstanceSet(InstanceSet instanceSet) {
        this.instanceSet = instanceSet;
        this.initialize();
    }

    private void initialize() {
        this.smote = new Smote(this.instanceSet);
        this.smote.setInstanceSet(this.instanceSet);
        this.numInstances = this.instanceSet.numInstances;
        this.instances = this.instanceSet.instances;
        this.deleteIndex = new boolean[this.instanceSet.numInstances];
        Arrays.fill(this.deleteIndex, false);
        if (this.useSimplesampling) {
            this.simplesampling = new Simplesampling(this.instanceSet);
            this.simplesampling.setInstanceSet(this.instanceSet);
        } else {
            this.oversampling = new RandomOversampling(this.instanceSet);
            this.oversampling.setInstanceSet(this.instanceSet);
            this.undersampling = new RandomUndersampling(this.instanceSet);
            this.undersampling.setInstanceSet(this.instanceSet);
        }
    }

    protected void run() throws Exception {
        double finalSize;
        this.initialize();
        this.initializeClusters();
        int numClustersAux = this.numClusters[this.negativeClass];
        if (this.overMajority) {
            int biggestClusterIndex = 0;
            double biggestClusterSize = 0.0;
            int clusterID = 0;
            while (clusterID < numClustersAux) {
                if (this.clustersSize[this.negativeClass][clusterID] > biggestClusterSize) {
                    biggestClusterSize = this.clustersSize[this.negativeClass][clusterID];
                    biggestClusterIndex = clusterID;
                }
                ++clusterID;
            }
            double[] count = this.clustersSize[this.negativeClass];
            numClustersAux = this.numClusters[this.negativeClass];
            int clusterID2 = 0;
            while (clusterID2 < numClustersAux) {
                if (clusterID2 != biggestClusterIndex) {
                    finalSize = biggestClusterSize / count[clusterID2];
                    this.oversampleCluster(clusterID2, finalSize, this.negativeClass);
                }
                ++clusterID2;
            }
        }
        this.initialize();
        double majorityClassSize = 0.0;
        int inst = 0;
        while (inst < this.instanceSet.numInstances) {
            if (this.instances[inst].data[this.classIndex] == (float)this.negativeClass) {
                majorityClassSize += (double)this.instances[inst].data[this.counterIndex];
            }
            ++inst;
        }
        int classValue = 0;
        while (classValue < this.numClasses) {
            if (classValue != this.negativeClass) {
                numClustersAux = this.numClusters[classValue];
                double newSizePerCluster = majorityClassSize * (double)this.ratio / (double)(10 - this.ratio);
                newSizePerCluster /= (double)numClustersAux;
                int clusterID = 0;
                while (clusterID < numClustersAux) {
                    finalSize = newSizePerCluster;
                    if ((finalSize /= this.clustersSize[classValue][clusterID]) > 1.0) {
                        this.oversampleCluster(clusterID, finalSize, classValue);
                    }
                    ++clusterID;
                }
            }
            ++classValue;
        }
    }

    public void runUndersampling(double proportion) throws Exception {
        this.initialize();
        this.initializeClusters();
        int numClustersAux = this.numClusters[this.negativeClass];
        int clusterID = 0;
        while (clusterID < numClustersAux) {
            this.undersampleCluster(clusterID, proportion);
            ++clusterID;
        }
        int inst = 0;
        while (inst < this.numInstances) {
            float weight = this.instanceSet.instances[inst].data[this.counterIndex];
            if (weight < 1.0f) {
                this.deleteIndex[inst] = true;
            }
            ++inst;
        }
        Utils.removeMarkedInstances(this.instanceSet, this.deleteIndex);
        this.instanceSet.removeInstances(this.deleteIndex);
        this.instances = this.instanceSet.instances;
        this.numInstances = this.instanceSet.numInstances;
    }

    public void runOversampling(double proportion) throws Exception {
        this.initialize();
        this.initializeClusters();
        int numClustersAux = this.numClusters[this.positiveClass];
        int clusterID = 0;
        while (clusterID < numClustersAux) {
            this.oversampleCluster(clusterID, proportion, this.positiveClass);
            ++clusterID;
        }
    }

    private void oversampleCluster(int clusterIndex, double proportion, int classValue) {
        int[] cluster = this.clusters[classValue][clusterIndex];
        if (this.useSimplesampling) {
            this.simplesampling.setProportion(proportion);
            this.simplesampling.run(cluster);
        } else {
            this.oversampling.setProportion(proportion);
            this.oversampling.run(cluster);
        }
    }

    private int[] chooseInstancesIDs(int clusterIndex, int classValue) {
        int counter = 0;
        int[] instancesIDsTmp = new int[this.numInstances];
        int[] cluster = this.clusters[classValue][clusterIndex];
        int clusterSize = cluster.length;
        int i = 0;
        while (i < clusterSize) {
            int inst = cluster[i];
            if (!this.deleteIndex[inst]) {
                Instance instance = this.instanceSet.instances[inst];
                int instanceClass = (int)instance.data[this.classIndex];
                if (instanceClass == classValue || classValue == -1) {
                    instancesIDsTmp[counter] = inst;
                    ++counter;
                }
            }
            ++i;
        }
        int[] instancesIDs = new int[counter];
        int i2 = 0;
        while (i2 < counter) {
            instancesIDs[i2] = instancesIDsTmp[i2];
            ++i2;
        }
        return instancesIDs;
    }

    private void undersampleCluster(int clusterIndex, double proportion) {
        int[] instancesIDs = this.chooseInstancesIDs(clusterIndex, this.negativeClass);
        if (this.useSimplesampling) {
            this.simplesampling.setProportion(proportion);
            this.simplesampling.setRemove(false);
            this.simplesampling.run(instancesIDs);
        } else {
            this.undersampling.setProportion(proportion);
            this.undersampling.setRemove(false);
            this.undersampling.run(instancesIDs);
        }
    }

    public void setOverMajority(boolean overMajority) {
        this.overMajority = overMajority;
    }

    protected void initializeBatch(InstanceSet instanceSet) throws Exception {
        boolean overMajority = true;
        this.setOverMajority(overMajority);
        this.setInstanceSet(instanceSet);
    }
}

