/*
 * Decompiled with CFR 0.152.
 */
package unbbayes.datamining.classifiers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.ResourceBundle;
import unbbayes.datamining.classifiers.DistributionClassifier;
import unbbayes.datamining.datamanipulation.Attribute;
import unbbayes.datamining.datamanipulation.AttributeStats;
import unbbayes.datamining.datamanipulation.Instance;
import unbbayes.datamining.datamanipulation.InstanceSet;
import unbbayes.datamining.datamanipulation.Utils;
import unbbayes.prs.Edge;
import unbbayes.prs.bn.PotentialTable;
import unbbayes.prs.bn.ProbabilisticNetwork;
import unbbayes.prs.bn.ProbabilisticNode;
import unbbayes.prs.exception.InvalidParentException;

public class NaiveBayes
extends DistributionClassifier
implements Serializable {
    private static final long serialVersionUID = 1L;
    private static ResourceBundle resource = ResourceBundle.getBundle("unbbayes.datamining.classifiers.resources.ClassifiersResource");
    private float[][][] nominalCounts;
    private float[] priors;
    private double[][] stdDevPerClass;
    private double[][] meanPerClass;
    private InstanceSet instanceSet;
    private ProbabilisticNode classAtt;
    private int width = 50;
    private ProbabilisticNetwork net = new ProbabilisticNetwork("NaiveBayes");
    private int numAttributes;
    private int numClasses;
    private int numValues;
    private int classIndex;
    private Attribute[] attributes;
    private byte[] attributeType;
    private int attIndex;
    private byte nominalCounter;
    private byte numericCounter;
    private float MISSING_VALUE;
    private boolean normalize = true;
    private float[] precision;

    public void buildClassifier(InstanceSet instanceSet) throws Exception {
        float sum;
        int att;
        this.instanceSet = instanceSet;
        this.numClasses = instanceSet.numClasses();
        this.attributes = instanceSet.attributes;
        this.attributeType = instanceSet.attributeType;
        this.classIndex = instanceSet.classIndex;
        this.MISSING_VALUE = Float.NaN;
        byte[] attributeType = instanceSet.attributeType;
        this.numAttributes = instanceSet.numAttributes;
        int numNominalAttributes = instanceSet.numNominalAttributes;
        int numNumericAttributes = this.numAttributes - numNominalAttributes;
        --numNominalAttributes;
        this.nominalCounts = Utils.computeNominalDistributions(instanceSet);
        this.attIndex = 0;
        this.stdDevPerClass = new double[numNumericAttributes][];
        this.meanPerClass = new double[numNumericAttributes][];
        this.precision = instanceSet.computePrecision();
        if (instanceSet.numInstances > 1) {
            att = 0;
            while (att < this.numAttributes) {
                if (attributeType[att] != 1) {
                    ArrayList<double[]> stdDevMeanPerClass = Utils.stdDevMeanPerClass(instanceSet, att);
                    this.stdDevPerClass[this.attIndex] = stdDevMeanPerClass.get(0);
                    this.meanPerClass[this.attIndex] = stdDevMeanPerClass.get(1);
                    ++this.attIndex;
                }
                ++att;
            }
        } else {
            att = 0;
            while (att < this.numAttributes) {
                if (attributeType[att] != 1) {
                    this.stdDevPerClass[this.attIndex] = new double[2];
                    this.meanPerClass[this.attIndex] = new double[2];
                    ++this.attIndex;
                }
                ++att;
            }
        }
        AttributeStats[] attributeStats = instanceSet.getAttributeStats();
        this.priors = this.originalDistribution == null ? attributeStats[this.classIndex].getNominalCountsWeighted() : this.originalDistribution;
        int att2 = 0;
        while (att2 < numNominalAttributes) {
            int k = 0;
            while (k < this.numClasses) {
                sum = Utils.sum(this.nominalCounts[k][att2]);
                this.numValues = this.nominalCounts[k][att2].length;
                int i = 0;
                while (i < this.numValues) {
                    float[] fArray = this.nominalCounts[k][att2];
                    int n = i;
                    fArray[n] = fArray[n] + 1.0f;
                    float[] fArray2 = this.nominalCounts[k][att2];
                    int n2 = i++;
                    fArray2[n2] = fArray2[n2] / (sum += (float)this.numValues);
                }
                ++k;
            }
            ++att2;
        }
        sum = Utils.sum(this.priors);
        int k = 0;
        while (k < this.numClasses) {
            int n = k;
            this.priors[n] = this.priors[n] + 1.0f;
            double aux = this.priors[k];
            this.priors[k] = (float)(aux /= (double)(sum += (float)this.numClasses));
            ++k;
        }
        this.createProbabilisticNodeClass();
        this.nominalCounter = 0;
        this.numericCounter = 0;
        int att3 = 0;
        while (att3 < this.numAttributes) {
            if (att3 != this.classIndex) {
                this.createProbabilisticNode(att3, attributeType[att3]);
            }
            ++att3;
        }
        sum = 0.0f;
    }

    private void createProbabilisticNodeClass() {
        Attribute att = this.attributes[this.classIndex];
        ProbabilisticNode node = new ProbabilisticNode();
        node.setDescription(att.getAttributeName());
        node.setName(att.getAttributeName());
        this.numValues = att.numValues();
        int i = 0;
        while (i < this.numValues) {
            node.appendState(att.value(i));
            ++i;
        }
        if (this.numAttributes == 1) {
            node.setPosition(50.0, 30.0);
        } else {
            node.setPosition(50 + (this.numAttributes - 2) * 50, 30.0);
        }
        PotentialTable tab = node.getPotentialTable();
        tab.addVariable(node);
        int i2 = 0;
        while (i2 < this.numValues) {
            tab.setValue(i2, this.priors[i2]);
            ++i2;
        }
        this.net.addNode(node);
        this.classAtt = node;
    }

    private void createProbabilisticNode(int att, byte attributeType) throws InvalidParentException {
        Attribute attribute = this.attributes[att];
        ProbabilisticNode node = new ProbabilisticNode();
        node.setDescription(attribute.getAttributeName());
        node.setName(attribute.getAttributeName());
        this.numValues = attribute.numValues();
        int i = 0;
        while (i < this.numValues) {
            node.appendState(attribute.value(i));
            ++i;
        }
        PotentialTable tab = node.getPotentialTable();
        tab.addVariable(node);
        node.setPosition(this.width, 100.0);
        this.width += 100;
        this.net.addNode(node);
        Edge arco = new Edge(this.classAtt, node);
        this.net.addEdge(arco);
        if (attributeType == 1) {
            int[] coord = new int[this.numClasses];
            int k = 0;
            while (k < this.numClasses) {
                int i2 = 0;
                while (i2 < this.numValues) {
                    coord[0] = i2;
                    coord[1] = k;
                    tab.setValue(coord, this.nominalCounts[k][this.nominalCounter][i2]);
                    ++i2;
                }
                ++k;
            }
            this.nominalCounter = (byte)(this.nominalCounter + 1);
        } else {
            node.setMean(this.meanPerClass[this.numericCounter]);
            node.setStandardDeviation(this.stdDevPerClass[this.numericCounter]);
            this.numericCounter = (byte)(this.numericCounter + 1);
        }
    }

    public float[] distributionForInstance(Instance instance) throws Exception {
        double aux;
        float[] inst = instance.data;
        double[] probsAux = new double[this.numClasses];
        double maxProb = -1.0;
        int k = 0;
        while (k < this.numClasses) {
            probsAux[k] = this.priors[k];
            this.numericCounter = 0;
            this.nominalCounter = 0;
            int att = 0;
            while (att < this.numAttributes) {
                if (att != this.classIndex && inst[att] != this.MISSING_VALUE) {
                    if (this.attributeType[att] == 1) {
                        int n = k;
                        probsAux[n] = probsAux[n] * (double)this.nominalCounts[k][this.nominalCounter][(int)inst[att]];
                        this.nominalCounter = (byte)(this.nominalCounter + 1);
                    } else {
                        double stdDev = this.stdDevPerClass[this.numericCounter][k];
                        double mean = this.meanPerClass[this.numericCounter][k];
                        aux = Utils.getProbability(inst[att], mean, stdDev, this.precision[this.numericCounter]);
                        aux = Math.max(1.0E-75, aux);
                        int n = k;
                        probsAux[n] = probsAux[n] * aux;
                        this.numericCounter = (byte)(this.numericCounter + 1);
                    }
                }
                ++att;
            }
            if (probsAux[k] > maxProb) {
                maxProb = probsAux[k];
            }
            ++k;
        }
        if (this.normalize) {
            if (maxProb < 1.0E-75) {
                k = 0;
                while (k < this.numClasses) {
                    int n = k++;
                    probsAux[n] = probsAux[n] * 1.0E75;
                }
            }
        } else {
            k = 0;
            while (k < this.numClasses) {
                int n = k++;
                probsAux[n] = probsAux[n] * 1.0E61;
            }
        }
        float[] probs = new float[this.numClasses];
        int size = probsAux.length;
        if (this.normalize) {
            double sum = Utils.sum(probsAux);
            if (sum > 0.0) {
                int att = 0;
                while (att < size) {
                    aux = probsAux[att];
                    probs[att] = (float)(aux /= sum);
                    ++att;
                }
            }
        } else {
            int att = 0;
            while (att < size) {
                probs[att] = (float)probsAux[att];
                ++att;
            }
        }
        if (Float.isNaN(probs[0]) || Float.isNaN(probs[1])) {
            boolean bl = true;
        }
        return probs;
    }

    public double[] distributionForInstanceEma(Instance instance) throws Exception {
        float[] inst = instance.data;
        double[] probsAux = new double[this.numClasses];
        double maxProb = -1.0;
        int k = 0;
        while (k < this.numClasses) {
            probsAux[k] = this.priors[k];
            this.numericCounter = 0;
            this.nominalCounter = 0;
            int att = 0;
            while (att < this.numAttributes) {
                if (att != this.classIndex && inst[att] != this.MISSING_VALUE) {
                    if (this.attributeType[att] == 1) {
                        int n = k;
                        probsAux[n] = probsAux[n] * (double)this.nominalCounts[k][this.nominalCounter][(int)inst[att]];
                        this.nominalCounter = (byte)(this.nominalCounter + 1);
                    } else {
                        double stdDev = this.stdDevPerClass[this.numericCounter][k];
                        double mean = this.meanPerClass[this.numericCounter][k];
                        double aux = Utils.getProbability(inst[att], mean, stdDev, this.precision[this.numericCounter]);
                        aux = Math.max(1.0E-75, aux);
                        int n = k;
                        probsAux[n] = probsAux[n] * aux;
                        this.numericCounter = (byte)(this.numericCounter + 1);
                    }
                }
                ++att;
            }
            if (probsAux[k] > maxProb) {
                maxProb = probsAux[k];
            }
            ++k;
        }
        return probsAux;
    }

    public String toString() {
        if (this.instanceSet == null) {
            return this.nullInstancesString();
        }
        try {
            StringBuffer text = new StringBuffer("Naive Bayes");
            int i = 0;
            while (i < this.numClasses) {
                text.append("\n\n" + resource.getString("class") + " " + this.instanceSet.getClassAttribute().value(i) + ": P(C) = " + Utils.doubleToString(this.priors[i], 10, 8) + "\n\n");
                int k = 0;
                while (k < this.numAttributes) {
                    if (this.attributes[k].getIndex() != this.classIndex) {
                        text.append(String.valueOf(resource.getString("attribute")) + " " + this.attributes[k].getAttributeName() + "\n");
                        this.numValues = this.attributes[k].numValues();
                        int j = 0;
                        while (j < this.numValues) {
                            text.append(String.valueOf(this.attributes[k].value(j)) + "\t");
                            ++j;
                        }
                        text.append("\n");
                        j = 0;
                        while (j < this.numValues) {
                            text.append(String.valueOf(Utils.doubleToString(this.nominalCounts[i][k][j], 10, 8)) + "\t");
                            ++j;
                        }
                        text.append("\n\n");
                    }
                    ++k;
                }
                ++i;
            }
            return text.toString();
        }
        catch (Exception e) {
            return resource.getString("exception5");
        }
    }

    private String nullInstancesString() {
        try {
            StringBuffer text = new StringBuffer("Naive Bayes");
            int i = 0;
            while (i < this.priors.length) {
                text.append("\n\n" + resource.getString("class") + " " + i + ": P(C) = " + Utils.doubleToString(this.priors[i], 10, 8) + "\n\n");
                if (this.nominalCounts != null) {
                    int attIndex = 0;
                    while (attIndex < this.nominalCounts[i].length) {
                        text.append(String.valueOf(resource.getString("attribute")) + " " + attIndex + "\n");
                        int j = 0;
                        while (j < this.nominalCounts[i][attIndex].length) {
                            text.append(String.valueOf(j) + "\t");
                            ++j;
                        }
                        text.append("\n");
                        j = 0;
                        while (j < this.nominalCounts[i][attIndex].length) {
                            text.append(String.valueOf(Utils.doubleToString(this.nominalCounts[i][attIndex][j], 10, 8)) + "\t");
                            ++j;
                        }
                        text.append("\n\n");
                        ++attIndex;
                    }
                }
                ++i;
            }
            return text.toString();
        }
        catch (Exception e) {
            return resource.getString("exception5");
        }
    }

    public ProbabilisticNetwork getProbabilisticNetwork() {
        return this.net;
    }

    public void setNormalize(boolean normalize) {
        this.normalize = normalize;
    }
}

