/*
 * Decompiled with CFR 0.152.
 */
package unbbayes.datamining.classifiers;

import java.io.Serializable;
import java.util.Arrays;
import unbbayes.datamining.classifiers.DistributionClassifier;
import unbbayes.datamining.classifiers.neuralnetwork.ActivationFunction;
import unbbayes.datamining.classifiers.neuralnetwork.HiddenNeuron;
import unbbayes.datamining.classifiers.neuralnetwork.MeanSquaredError;
import unbbayes.datamining.classifiers.neuralnetwork.OutputNeuron;
import unbbayes.datamining.classifiers.neuralnetwork.Sigmoid;
import unbbayes.datamining.classifiers.neuralnetwork.Tanh;
import unbbayes.datamining.datamanipulation.Attribute;
import unbbayes.datamining.datamanipulation.Instance;
import unbbayes.datamining.datamanipulation.InstanceSet;
import unbbayes.datamining.datamanipulation.Stats;
import unbbayes.datamining.datamanipulation.Utils;

public class NeuralNetwork
extends DistributionClassifier
implements Serializable {
    private static final long serialVersionUID = 0L;
    public static final int AUTO_HIDDEN_LAYER_SIZE = -1;
    public static final int NO_ERROR_VARIATION_STOP_CRITERION = -2;
    public static final int SIGMOID = 0;
    public static final int TANH = 1;
    public static final int NO_NORMALIZATION = 0;
    public static final int LINEAR_NORMALIZATION = 1;
    public static final int MEAN_0_STANDARD_DEVIATION_1_NORMALIZATION = 2;
    private transient float learningRate;
    private float originalLearningRate;
    private float momentum;
    private transient int hiddenLayerSize;
    private transient ActivationFunction activationFunction = null;
    private int trainingTime;
    private transient float minimumErrorVariation;
    private boolean learningRateDecay = false;
    private transient MeanSquaredError meanSquaredError;
    private transient InstanceSet instanceSet;
    private float[] inputLayer;
    private HiddenNeuron[] hiddenLayer;
    private OutputNeuron[] outputLayer;
    private int activationFunctionType;
    private int numAttributes;
    private int numericalInputNormalization = 0;
    private boolean numericOutput;
    private float[] highestValue;
    private float[] lowestValue;
    private int[] inputLayerIndexes;
    private Attribute[] attributeVector;
    private int classIndex;
    private float[] attributeMean;
    private float[] attributeStandardDeviation;
    private INormalization normalizationFunction;
    private float activationFunctionSteep;
    private transient int numClasses;
    private int inputLayerSize;
    private transient float[][] expectedOutput;
    private float[] hiddenLayerOutput;

    public NeuralNetwork(float learningRate, boolean learningRateDecay, float momentum, int hiddenLayerSize, int activationFunction, int trainingTime, int numericalInputNormalization, float activationFunctionSteep, float minimumErrorVariation) {
        this.learningRate = learningRate;
        this.originalLearningRate = learningRate;
        this.learningRateDecay = learningRateDecay;
        this.momentum = momentum;
        this.hiddenLayerSize = hiddenLayerSize;
        this.trainingTime = trainingTime;
        this.numericalInputNormalization = numericalInputNormalization;
        this.minimumErrorVariation = minimumErrorVariation;
        this.activationFunctionType = activationFunction;
        this.activationFunctionSteep = activationFunctionSteep;
        if (activationFunction == 0) {
            this.activationFunction = new Sigmoid(activationFunctionSteep);
        } else if (activationFunction == 1) {
            this.activationFunction = new Tanh(activationFunctionSteep);
        }
    }

    public NeuralNetwork(float learningRate, float momentum, int hiddenLayerSize, int activationFunction, int trainingTime) {
        this(learningRate, false, momentum, hiddenLayerSize, activationFunction, trainingTime, 0, 1.0f, -2.0f);
    }

    public void buildClassifier(InstanceSet instanceSet) throws Exception {
        this.instanceSet = instanceSet;
        this.numAttributes = instanceSet.numAttributes();
        this.numericOutput = instanceSet.getClassAttribute().isNumeric();
        float quadraticError = 0.0f;
        long numWeightedInstances = instanceSet.numWeightedInstances();
        long numInstances = instanceSet.numInstances();
        this.attributeVector = instanceSet.getAttributes();
        this.classIndex = instanceSet.getClassIndex();
        this.numClasses = this.attributeVector[this.classIndex].numValues();
        this.highestValue = new float[this.numAttributes];
        this.lowestValue = new float[this.numAttributes];
        if (this.numericOutput) {
            Stats stats = instanceSet.getAttributeStats(this.classIndex).getNumericStats();
            this.highestValue[this.classIndex] = stats.getMax();
            this.lowestValue[this.classIndex] = stats.getMin();
        }
        this.expectedOutput = new float[this.numClasses][0];
        if (this.numericOutput) {
            float[] output = this.attributeVector[this.classIndex].getDistinticNumericValues();
            int i = 0;
            while (i < this.numClasses) {
                this.expectedOutput[i] = new float[1];
                this.expectedOutput[i][0] = this.activationFunction.normalizeToFunctionInterval(output[i], this.highestValue[this.classIndex], this.lowestValue[this.classIndex]);
                ++i;
            }
        } else {
            int i = 0;
            while (i < this.numClasses) {
                this.expectedOutput[i] = new float[this.numClasses];
                this.expectedOutput[i][i] = 1.0f;
                ++i;
            }
        }
        if (this.numericalInputNormalization == 0) {
            this.normalizationFunction = new NoNormalization();
        } else if (this.numericalInputNormalization == 1) {
            this.normalizationFunction = new LinearNormalization();
        } else if (this.numericalInputNormalization == 2) {
            this.normalizationFunction = new Mean0StdDeviation1Normalization();
        }
        this.inputLayerIndexes = new int[this.numAttributes - 1];
        int counter = 1;
        this.inputLayerIndexes[0] = 0;
        int i = 0;
        while (i < this.numAttributes - 1) {
            Attribute att = this.attributeVector[i];
            if (i != this.classIndex) {
                this.inputLayerIndexes[counter] = att.isNumeric() ? this.inputLayerIndexes[counter - 1] + 1 : this.inputLayerIndexes[counter - 1] + att.numValues();
                if (this.inputLayerIndexes.length == ++counter) break;
            }
            ++i;
        }
        i = 0;
        while (i < this.numAttributes) {
            if (i != this.classIndex) {
                this.inputLayerSize = instanceSet.getAttribute(i).isNumeric() ? ++this.inputLayerSize : (this.inputLayerSize += instanceSet.getAttribute(i).numValues());
            }
            ++i;
        }
        this.inputLayer = new float[this.inputLayerSize];
        if (this.hiddenLayerSize == -1) {
            this.hiddenLayerSize = (this.numAttributes + 1) / 2;
            if (this.hiddenLayerSize < 3) {
                this.hiddenLayerSize = 3;
            }
        }
        this.hiddenLayer = new HiddenNeuron[this.hiddenLayerSize];
        i = 0;
        while (i < this.hiddenLayer.length) {
            this.hiddenLayer[i] = new HiddenNeuron(this.activationFunction, this.inputLayerSize, this.momentum);
            ++i;
        }
        this.hiddenLayerOutput = new float[this.hiddenLayerSize];
        if (this.numericOutput) {
            this.outputLayer = new OutputNeuron[1];
            this.outputLayer[0] = new OutputNeuron(this.activationFunction, this.hiddenLayer.length, this.momentum);
        } else {
            this.outputLayer = new OutputNeuron[this.numClasses];
            i = 0;
            while (i < this.outputLayer.length) {
                this.outputLayer[i] = new OutputNeuron(this.activationFunction, this.hiddenLayer.length, this.momentum);
                ++i;
            }
        }
        int epoch = 0;
        while (epoch < this.trainingTime) {
            float oldQuadraticError = quadraticError;
            quadraticError = 0.0f;
            if (this.learningRateDecay) {
                this.learningRate = this.originalLearningRate / (float)epoch;
            }
            int inst = 0;
            while ((long)inst < numInstances) {
                Instance instance = instanceSet.getInstance(inst);
                float instanceWeight = instance.getWeight();
                int i2 = 0;
                while ((float)i2 < instanceWeight) {
                    quadraticError += this.learn(instance);
                    ++i2;
                }
                ++inst;
            }
            quadraticError /= (float)numWeightedInstances;
            if (this.meanSquaredError != null) {
                this.meanSquaredError.setMeanSquaredError(epoch, quadraticError);
            }
            if (this.minimumErrorVariation != -2.0f && Math.abs((oldQuadraticError - quadraticError) * 100.0f / oldQuadraticError) < this.minimumErrorVariation) {
                this.trainingTime = epoch;
                break;
            }
            ++epoch;
        }
    }

    private float learn(Instance instance) {
        float totalErrorEnergy = 0.0f;
        this.inputLayerSetUp(instance);
        int i = 0;
        while (i < this.hiddenLayer.length) {
            this.hiddenLayer[i].calculateOutputValue(this.inputLayer);
            this.hiddenLayerOutput[i] = this.hiddenLayer[i].outputValue;
            ++i;
        }
        i = 0;
        while (i < this.outputLayer.length) {
            this.outputLayer[i].calculateOutputValue(this.hiddenLayerOutput);
            float instantaneousError = this.outputLayer[i].calculateErrorTerm(this.expectedOutput[instance.getClassValue()][i]);
            totalErrorEnergy += instantaneousError * instantaneousError;
            ++i;
        }
        i = 0;
        while (i < this.outputLayer.length) {
            this.outputLayer[i].updateWeights(this.learningRate, this.hiddenLayerOutput);
            ++i;
        }
        i = 0;
        while (i < this.hiddenLayer.length) {
            this.hiddenLayer[i].calculateErrorTerm(this.outputLayer, i);
            ++i;
        }
        i = 0;
        while (i < this.hiddenLayer.length) {
            this.hiddenLayer[i].updateWeights(this.learningRate, this.inputLayer);
            ++i;
        }
        return totalErrorEnergy / 2.0f;
    }

    private void inputLayerSetUp(Instance instance) {
        int counter = 0;
        Arrays.fill(this.inputLayer, -1.0f);
        int i = 0;
        while (i < this.numAttributes) {
            if (i != this.classIndex) {
                if (!instance.isMissing(i)) {
                    int index = this.inputLayerIndexes[counter];
                    Attribute att = this.attributeVector[i];
                    if (att.isNumeric()) {
                        float data = instance.getValue(att);
                        this.inputLayer[index] = this.normalizationFunction.normalize(data, i);
                    } else {
                        this.inputLayer[index += (int)instance.getValue((int)i)] = 1.0f;
                    }
                }
                ++counter;
            }
            ++i;
        }
    }

    public float[] distributionForInstance(Instance instance) throws Exception {
        float[] distribution;
        block6: {
            distribution = new float[this.outputLayer.length];
            Arrays.fill(this.inputLayer, -1.0f);
            this.inputLayerSetUp(instance);
            int i = 0;
            while (i < this.hiddenLayer.length) {
                this.hiddenLayer[i].calculateOutputValue(this.inputLayer);
                this.hiddenLayerOutput[i] = this.hiddenLayer[i].outputValue;
                ++i;
            }
            i = 0;
            while (i < this.outputLayer.length) {
                distribution[i] = this.outputLayer[i].calculateOutputValue(this.hiddenLayerOutput);
                ++i;
            }
            if (!this.numericOutput) break block6;
            if (this.activationFunctionType == 0) {
                i = 0;
                while (i < distribution.length) {
                    distribution[i] = Utils.normalize(distribution[i], 1.0f, 0.0f, this.highestValue[this.classIndex], this.lowestValue[this.classIndex]);
                    ++i;
                }
            } else {
                i = 0;
                while (i < distribution.length) {
                    distribution[i] = Utils.normalize(distribution[i], 1.0f, -1.0f, this.highestValue[this.classIndex], this.lowestValue[this.classIndex]);
                    ++i;
                }
            }
        }
        return distribution;
    }

    public void setMeanSquaredErrorOutput(MeanSquaredError meanSquaredError) {
        this.meanSquaredError = meanSquaredError;
    }

    public Attribute[] getAttributeVector() {
        return this.attributeVector;
    }

    public int getClassIndex() {
        return this.classIndex;
    }

    public String toString() {
        String learningRateStr = "Learning Rate: " + this.originalLearningRate;
        String momentumStr = "Momentum: " + this.momentum;
        String hiddenSizeStr = "Hidden Layer Size: " + this.hiddenLayer.length;
        String actFunctionSteepStr = "Activation Function Steep: " + this.activationFunctionSteep;
        String learningRateDecayStr = "Learning Rate Decay: " + this.learningRateDecay;
        String trainingTimeStr = "Training Time: " + this.trainingTime;
        String activationFunctionStr = "Activation Function: ";
        if (this.activationFunctionType == 0) {
            activationFunctionStr = String.valueOf(activationFunctionStr) + "Sigmoid";
        } else if (this.activationFunctionType == 1) {
            activationFunctionStr = String.valueOf(activationFunctionStr) + "Tanh";
        }
        String inputNormalization = "Numerical Input Normalization: ";
        if (this.numericalInputNormalization == 0) {
            inputNormalization = String.valueOf(inputNormalization) + "No normalization";
        } else if (this.numericalInputNormalization == 1) {
            inputNormalization = String.valueOf(inputNormalization) + "Linear normalization";
        } else if (this.numericalInputNormalization == 2) {
            inputNormalization = String.valueOf(inputNormalization) + "Mean 0 and standard deviation 1 normalization";
        }
        String classAttribute = "Class Attribute: " + this.attributeVector[this.classIndex].getAttributeName();
        return String.valueOf(learningRateStr) + "\n" + momentumStr + "\n" + hiddenSizeStr + "\n" + trainingTimeStr + "\n" + activationFunctionStr + "\n" + learningRateDecayStr + "\n" + inputNormalization + "\n" + actFunctionSteepStr + "\n" + classAttribute;
    }

    public static interface INormalization {
        public float normalize(float var1, int var2);
    }

    public class LinearNormalization
    implements INormalization,
    Serializable {
        private static final long serialVersionUID = 0L;

        public LinearNormalization() {
            int att = 0;
            while (att < NeuralNetwork.this.numAttributes) {
                if (att != NeuralNetwork.this.classIndex && NeuralNetwork.this.instanceSet.getAttribute(att).isNumeric()) {
                    ((NeuralNetwork)NeuralNetwork.this).highestValue[att] = Float.MIN_VALUE;
                    ((NeuralNetwork)NeuralNetwork.this).lowestValue[att] = Float.MAX_VALUE;
                    int numInstances = ((NeuralNetwork)NeuralNetwork.this).instanceSet.numInstances;
                    int inst = 0;
                    while (inst < numInstances) {
                        float value = ((NeuralNetwork)NeuralNetwork.this).instanceSet.getInstance((int)inst).data[att];
                        ((NeuralNetwork)NeuralNetwork.this).highestValue[att] = Math.max(NeuralNetwork.this.highestValue[att], value);
                        ((NeuralNetwork)NeuralNetwork.this).lowestValue[att] = Math.min(NeuralNetwork.this.lowestValue[att], value);
                        ++inst;
                    }
                }
                ++att;
            }
        }

        public float normalize(float data, int attributeIndex) {
            return Utils.normalize(data, NeuralNetwork.this.highestValue[attributeIndex], NeuralNetwork.this.lowestValue[attributeIndex], 1.0f, -1.0f);
        }
    }

    public class Mean0StdDeviation1Normalization
    implements INormalization,
    Serializable {
        private static final long serialVersionUID = 0L;

        public Mean0StdDeviation1Normalization() throws Exception {
            NeuralNetwork.this.attributeMean = new float[NeuralNetwork.this.numAttributes];
            NeuralNetwork.this.attributeStandardDeviation = new float[NeuralNetwork.this.numAttributes];
            int i = 0;
            while (i < NeuralNetwork.this.numAttributes) {
                Attribute att = NeuralNetwork.this.instanceSet.getAttribute(i);
                if (att.isNumeric() && i != NeuralNetwork.this.classIndex) {
                    ((NeuralNetwork)NeuralNetwork.this).attributeMean[i] = (float)Utils.mean(NeuralNetwork.this.instanceSet, i);
                    ((NeuralNetwork)NeuralNetwork.this).attributeStandardDeviation[i] = (float)Utils.standardDeviation(NeuralNetwork.this.instanceSet, i, NeuralNetwork.this.attributeMean[i]);
                }
                ++i;
            }
        }

        public float normalize(float data, int attributeIndex) {
            return (data - NeuralNetwork.this.attributeMean[attributeIndex]) / NeuralNetwork.this.attributeStandardDeviation[attributeIndex];
        }
    }

    public class NoNormalization
    implements INormalization,
    Serializable {
        private static final long serialVersionUID = 0L;

        public float normalize(float data, int attributeIndex) {
            return data;
        }
    }
}

