/*
 * Decompiled with CFR 0.152.
 */
package unbbayes.datamining.datamanipulation;

import java.util.ArrayList;
import unbbayes.datamining.classifiers.decisiontree.Leaf;
import unbbayes.datamining.classifiers.decisiontree.Node;
import unbbayes.datamining.classifiers.decisiontree.NominalNode;
import unbbayes.datamining.classifiers.decisiontree.NumericNode;
import unbbayes.datamining.datamanipulation.Attribute;
import unbbayes.datamining.datamanipulation.ClassifierUtils;
import unbbayes.datamining.datamanipulation.Options;
import unbbayes.datamining.datamanipulation.Utils;

public class PrunningUtils {
    private Attribute classAttribute;
    private final double[] P0 = new double[]{-59.96335010141079, 98.00107541859997, -56.67628574690703, 13.931260938727968, -1.2391658386738125};
    private final double[] Q0 = new double[]{1.9544885833814176, 4.676279128988815, 86.36024213908905, -225.46268785411937, 200.26021238006066, -82.03722561683334, 15.90562251262117, -1.1833162112133};
    private final double[] P1 = new double[]{4.0554489230596245, 31.525109459989388, 57.16281922464213, 44.08050738932008, 14.684956192885803, 2.1866330685079025, -0.1402560791713545, -0.03504246268278482, -8.574567851546854E-4};
    private final double[] Q1 = new double[]{15.779988325646675, 45.39076351288792, 41.3172038254672, 15.04253856929075, 2.504649462083094, -0.14218292285478779, -0.03808064076915783, -9.332594808954574E-4};
    private final double[] P2 = new double[]{3.2377489177694603, 6.915228890689842, 3.9388102529247444, 1.3330346081580755, 0.20148538954917908, 0.012371663481782003, 3.0158155350823543E-4, 2.6580697468673755E-6, 6.239745391849833E-9};
    private final double[] Q2 = new double[]{6.02427039364742, 3.6798356385616087, 1.3770209948908132, 0.21623699359449663, 0.013420400608854318, 3.2801446468212774E-4, 2.8924786474538068E-6, 6.790194080099813E-9};

    public Node pruneTree(Node root, Attribute classAttribute) {
        this.classAttribute = classAttribute;
        int numClasses = classAttribute.getDistinticNominalValues().length;
        Node rootClone = this.cloneTree(root);
        this.pruneTree(rootClone, new float[numClasses]);
        return rootClone;
    }

    private double pruneTree(Node child, float[] distribution) {
        int numClasses = distribution.length;
        float confidence = Options.getInstance().getConfidenceLevel();
        ArrayList grandChildren = child.getChildren();
        if (grandChildren.get(0) instanceof Leaf) {
            Leaf leaf = (Leaf)grandChildren.get(0);
            float[] leafDistribution = leaf.getDistribution();
            if (leafDistribution == null) {
                double e = 0.0;
                double N = 0.0;
                return this.addErrs(N, e, confidence);
            }
            System.arraycopy(leafDistribution, 0, distribution, 0, distribution.length);
            double e = ClassifierUtils.sumNonClassDistribution(leafDistribution, leaf.getClassValue());
            double N = e + (double)distribution[leaf.getClassValue()];
            return this.addErrs(N, e, confidence);
        }
        double currentError = 0.0;
        float[] distributionSum = new float[numClasses];
        int i = 0;
        while (i < grandChildren.size()) {
            Node node = (Node)grandChildren.get(i);
            float[] distributionTemp = new float[numClasses];
            double error = this.pruneTree(node, distributionTemp);
            currentError += error / (double)Utils.sum(distributionTemp);
            distributionSum = Utils.arraysSum(distributionSum, distributionTemp);
            ++i;
        }
        int classValue = Utils.maxIndex(distributionSum);
        double e = ClassifierUtils.sumNonClassDistribution(distributionSum, classValue);
        double N = e + (double)distributionSum[classValue];
        double prunningError = this.addErrs(N, e, confidence);
        System.arraycopy(distributionSum, 0, distribution, 0, distribution.length);
        if (prunningError < currentError) {
            child.removeChildren();
            child.add(new Leaf(this.classAttribute, distributionSum, -1.0f, 0));
            return prunningError;
        }
        return currentError;
    }

    private Node cloneTree(Node parent) {
        ArrayList children = parent.getChildren();
        ArrayList<Object> newChildren = new ArrayList<Object>();
        if (children.get(0) instanceof Leaf) {
            Leaf leaf = (Leaf)children.get(0);
            newChildren.add(new Leaf(this.classAttribute, leaf.getDistribution(), -1.0f, 0));
        } else {
            int i = 0;
            while (i < children.size()) {
                Node node = (Node)children.get(i);
                Node nodeClone = this.cloneTree(node);
                newChildren.add(nodeClone);
                ++i;
            }
        }
        if (parent instanceof NumericNode) {
            NumericNode numericNode = (NumericNode)parent;
            return new NumericNode(numericNode, newChildren);
        }
        if (parent instanceof NominalNode) {
            NominalNode nominalNode = (NominalNode)parent;
            return new NominalNode(nominalNode, newChildren);
        }
        return new Node(parent.getAttribute(), newChildren);
    }

    private double addErrs(double N, double e, float CF) {
        if (e < 1.0) {
            double base = N * (1.0 - Math.pow(CF, 1.0 / N));
            if (e == 0.0) {
                return base;
            }
            return base + e * (this.addErrs(N, 1.0, CF) - base);
        }
        if (e + 0.5 >= N) {
            return Math.max(N - e, 0.0);
        }
        double z = this.normalInverse(1.0f - CF);
        double f = (e + 0.5) / N;
        double r = (f + z * z / (2.0 * N) + z * Math.sqrt(f / N - f * f / N + z * z / (4.0 * N * N))) / (1.0 + z * z / N);
        return r * N - e;
    }

    private double normalInverse(double y0) {
        double s2pi = Math.sqrt(Math.PI * 2);
        if (y0 <= 0.0) {
            throw new IllegalArgumentException();
        }
        if (y0 >= 1.0) {
            throw new IllegalArgumentException();
        }
        boolean code = true;
        double y = y0;
        if (y > 0.8646647167633873) {
            y = 1.0 - y;
            code = false;
        }
        if (y > 0.1353352832366127) {
            double y2 = (y -= 0.5) * y;
            double x = y + y * (y2 * this.polevl(y2, this.P0, 4) / this.p1evl(y2, this.Q0, 8));
            return x *= s2pi;
        }
        double x = Math.sqrt(-2.0 * Math.log(y));
        double x0 = x - Math.log(x) / x;
        double z = 1.0 / x;
        double x1 = x < 8.0 ? z * this.polevl(z, this.P1, 8) / this.p1evl(z, this.Q1, 8) : z * this.polevl(z, this.P2, 8) / this.p1evl(z, this.Q2, 8);
        x = x0 - x1;
        if (code) {
            x = -x;
        }
        return x;
    }

    private double p1evl(double x, double[] coef, int N) {
        double ans = x + coef[0];
        int i = 1;
        while (i < N) {
            ans = ans * x + coef[i];
            ++i;
        }
        return ans;
    }

    private double polevl(double x, double[] coef, int N) {
        double ans = coef[0];
        int i = 1;
        while (i <= N) {
            ans = ans * x + coef[i];
            ++i;
        }
        return ans;
    }
}

