/*
 * Decompiled with CFR 0.152.
 */
package unbbayes.evaluation;

import java.io.File;
import java.util.ArrayList;
import java.util.Formatter;
import java.util.List;
import java.util.Locale;
import unbbayes.io.BaseIO;
import unbbayes.io.NetIO;
import unbbayes.io.XMLBIFIO;
import unbbayes.prs.Node;
import unbbayes.prs.bn.ProbabilisticNetwork;
import unbbayes.prs.bn.TreeVariable;
import unbbayes.simulation.montecarlo.sampling.MonteCarloSampling;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class Evaluation {
    private ProbabilisticNetwork net;
    private TreeVariable[] targetNodeList;
    private TreeVariable[] evidenceNodeList;
    private int statesProduct;
    private int targetStatesProduct;
    private int evidenceStatesProduct;
    byte[][] sampleMatrix;
    int[] positionTargetNodeList;
    int[] positionEvidenceNodeList;
    private Formatter formatter;
    private TreeVariable targetNode;

    public String evaluate(String netFileName, List<String> targetNodeNameList, List<String> evidenceNodeNameList, int sampleSize) throws Exception {
        this.loadNetwork(netFileName);
        return this.evaluate(targetNodeNameList, evidenceNodeNameList, sampleSize);
    }

    public String evaluate(ProbabilisticNetwork net, List<String> targetNodeNameList, List<String> evidenceNodeNameList, int sampleSize) throws Exception {
        this.net = net;
        return this.evaluate(targetNodeNameList, evidenceNodeNameList, sampleSize);
    }

    private String evaluate(List<String> targetNodeNameList, List<String> evidenceNodeNameList, int sampleSize) throws Exception {
        int j;
        StringBuilder sb = new StringBuilder();
        this.formatter = new Formatter(sb, Locale.US);
        this.init(targetNodeNameList, evidenceNodeNameList);
        MonteCarloSampling mc = new MonteCarloSampling(this.net, sampleSize);
        mc.start();
        this.sampleMatrix = mc.getSampledStatesMatrix();
        this.targetNode = this.targetNodeList[0];
        if (this.targetNodeList.length != 1) {
            throw new Exception("For now, just one target node is accepted!");
        }
        this.positionTargetNodeList = new int[this.targetNodeList.length];
        this.positionEvidenceNodeList = new int[this.evidenceNodeList.length];
        List<Node> positionNodeList = mc.getSamplingNodeOrderQueue();
        int i = 0;
        while (i < this.positionTargetNodeList.length) {
            this.positionTargetNodeList[i] = positionNodeList.indexOf(this.net.getNode(this.targetNodeList[i].getName()));
            ++i;
        }
        i = 0;
        while (i < this.positionEvidenceNodeList.length) {
            this.positionEvidenceNodeList[i] = positionNodeList.indexOf(this.net.getNode(this.evidenceNodeList[i].getName()));
            ++i;
        }
        int[] frequencyEvidenceGivenTargetList = new int[this.statesProduct];
        int[] frequencyEvidenceList = new int[this.targetStatesProduct];
        int i2 = 0;
        while (i2 < this.sampleMatrix.length) {
            byte state;
            int row = 0;
            int currentStatesProduct = this.evidenceStatesProduct;
            int j2 = this.positionTargetNodeList.length - 1;
            while (j2 >= 0) {
                state = this.sampleMatrix[i2][this.positionTargetNodeList[j2]];
                row += state * currentStatesProduct;
                currentStatesProduct *= this.net.getNodeAt(this.positionTargetNodeList[j2]).getStatesSize();
                --j2;
            }
            int n = row / this.evidenceStatesProduct;
            frequencyEvidenceList[n] = frequencyEvidenceList[n] + 1;
            currentStatesProduct = this.evidenceStatesProduct;
            j2 = 0;
            while (j2 < this.positionEvidenceNodeList.length) {
                state = this.sampleMatrix[i2][this.positionEvidenceNodeList[j2]];
                row += state * (currentStatesProduct /= this.net.getNodeAt(this.positionEvidenceNodeList[j2]).getStatesSize());
                ++j2;
            }
            int n2 = row;
            frequencyEvidenceGivenTargetList[n2] = frequencyEvidenceGivenTargetList[n2] + 1;
            ++i2;
        }
        float[] postProbEvidenceGivenTarget = new float[this.statesProduct];
        int i3 = 0;
        while (i3 < postProbEvidenceGivenTarget.length) {
            float n = frequencyEvidenceList[i3 / this.evidenceStatesProduct];
            if (n != 0.0f) {
                postProbEvidenceGivenTarget[i3] = (float)frequencyEvidenceGivenTargetList[i3] / n;
            }
            ++i3;
        }
        float[] postProbTargetGivenEvidence = new float[this.statesProduct];
        int row = 0;
        float prob = 0.0f;
        float[] normalizationList = new float[this.evidenceStatesProduct];
        this.net.compile();
        int i4 = 0;
        while (i4 < this.targetNode.getStatesSize()) {
            int j3 = 0;
            while (j3 < this.evidenceStatesProduct) {
                row = j3 + i4 * this.evidenceStatesProduct;
                postProbTargetGivenEvidence[row] = prob = postProbEvidenceGivenTarget[row] * this.targetNode.getMarginalAt(i4);
                int n = j3++;
                normalizationList[n] = normalizationList[n] + prob;
            }
            ++i4;
        }
        float norm = 0.0f;
        int i5 = 0;
        while (i5 < postProbTargetGivenEvidence.length) {
            norm = normalizationList[i5 % this.evidenceStatesProduct];
            if (norm != 0.0f) {
                int n = i5;
                postProbTargetGivenEvidence[n] = postProbTargetGivenEvidence[n] / norm;
            }
            ++i5;
        }
        float[] postProbTargetGivenTarget = new float[(int)Math.pow(this.targetNode.getStatesSize(), 2.0)];
        int statesSize = this.targetNode.getStatesSize();
        row = 0;
        int index = 0;
        int i6 = 0;
        while (i6 < this.statesProduct) {
            int j4 = 0;
            while (j4 < statesSize) {
                row = i6 / this.evidenceStatesProduct * statesSize + j4;
                index = i6 % this.evidenceStatesProduct + j4 * this.evidenceStatesProduct;
                int n = row;
                postProbTargetGivenTarget[n] = postProbTargetGivenTarget[n] + postProbTargetGivenEvidence[i6] * postProbEvidenceGivenTarget[index];
                ++j4;
            }
            ++i6;
        }
        float averageClassification = 0.0f;
        int i7 = 0;
        while (i7 < statesSize) {
            averageClassification += postProbTargetGivenTarget[i7 * statesSize + i7];
            ++i7;
        }
        averageClassification /= (float)statesSize;
        this.formatter.format("P(T|E) = N[ P(E|T)P(T) ]\n", new Object[0]);
        i7 = 0;
        while (i7 < this.targetStatesProduct) {
            j = 0;
            while (j < this.evidenceStatesProduct) {
                this.formatter.format("%2.2f\t", Float.valueOf(postProbTargetGivenEvidence[i7 * this.evidenceStatesProduct + j] * 100.0f));
                ++j;
            }
            this.formatter.format("\n", new Object[0]);
            ++i7;
        }
        this.formatter.format("\n", new Object[0]);
        this.formatter.format("P(E|T)\n", new Object[0]);
        i7 = 0;
        while (i7 < this.evidenceStatesProduct) {
            j = 0;
            while (j < this.targetStatesProduct) {
                this.formatter.format("%2.2f\t", Float.valueOf(postProbEvidenceGivenTarget[j * this.evidenceStatesProduct + i7] * 100.0f));
                ++j;
            }
            this.formatter.format("\n", new Object[0]);
            ++i7;
        }
        this.formatter.format("\n", new Object[0]);
        this.formatter.format("P(T|T) = P(T|E)P(E|T)\n", new Object[0]);
        i7 = 0;
        while (i7 < statesSize) {
            j = 0;
            while (j < statesSize) {
                this.formatter.format("%2.2f\t", Float.valueOf(postProbTargetGivenTarget[i7 * statesSize + j] * 100.0f));
                ++j;
            }
            this.formatter.format("\n", new Object[0]);
            ++i7;
        }
        this.formatter.format("\n", new Object[0]);
        this.formatter.format("Average Correct Classification Probability\n", new Object[0]);
        this.formatter.format("%2.2f\t", Float.valueOf(averageClassification * 100.0f));
        return sb.toString();
    }

    private float[] computePostProbTargetGivenEvidenceUsingMC() {
        int[] frequencyTargetGivenEvidenceList = new int[this.statesProduct];
        int[] frequencyTargetList = new int[this.evidenceStatesProduct];
        int i = 0;
        while (i < this.sampleMatrix.length) {
            byte state;
            int row = 0;
            int currentStatesProduct = 1;
            int j = this.positionEvidenceNodeList.length - 1;
            while (j >= 0) {
                state = this.sampleMatrix[i][this.positionEvidenceNodeList[j]];
                row += state * currentStatesProduct;
                currentStatesProduct *= this.net.getNodeAt(this.positionEvidenceNodeList[j]).getStatesSize();
                --j;
            }
            int n = row;
            frequencyTargetList[n] = frequencyTargetList[n] + 1;
            j = 0;
            while (j < this.positionTargetNodeList.length) {
                state = this.sampleMatrix[i][this.positionTargetNodeList[j]];
                row += state * currentStatesProduct;
                currentStatesProduct *= this.net.getNodeAt(this.positionTargetNodeList[j]).getStatesSize();
                ++j;
            }
            int n2 = row;
            frequencyTargetGivenEvidenceList[n2] = frequencyTargetGivenEvidenceList[n2] + 1;
            ++i;
        }
        float[] postProbTargetGivenEvidence = new float[this.statesProduct];
        int i2 = 0;
        while (i2 < postProbTargetGivenEvidence.length) {
            float n = frequencyTargetList[i2 % this.evidenceStatesProduct];
            if (n != 0.0f) {
                postProbTargetGivenEvidence[i2] = (float)frequencyTargetGivenEvidenceList[i2] / n;
            }
            this.formatter.format("%2.2f\n", Float.valueOf(postProbTargetGivenEvidence[i2] * 100.0f));
            ++i2;
        }
        this.formatter.format("\n\n", new Object[0]);
        return postProbTargetGivenEvidence;
    }

    private void getExatProbTargetGivenEvidence() throws Exception {
        TreeVariable targetNode = this.targetNodeList[0];
        this.net.compile();
        float[] postProbList = new float[this.statesProduct];
        int sProd = targetNode.getStatesSize();
        byte[][] stateCombinationMatrix = new byte[this.statesProduct][this.net.getNodes().size()];
        int state = 0;
        int row = 0;
        while (row < this.statesProduct) {
            stateCombinationMatrix[row][0] = (byte)(row / (this.statesProduct / sProd));
            int j = 0;
            while (j < this.evidenceNodeList.length) {
                state = row / (this.statesProduct / (sProd *= this.evidenceNodeList[j].getStatesSize())) % this.evidenceNodeList[j].getStatesSize();
                this.evidenceNodeList[j].addFinding(state);
                stateCombinationMatrix[row][j + 1] = (byte)state;
                ++j;
            }
            sProd = targetNode.getStatesSize();
            try {
                this.net.updateEvidences();
                postProbList[row] = targetNode.getMarginalAt(stateCombinationMatrix[row][0]);
            }
            catch (Exception e) {
                postProbList[row] = 0.0f;
            }
            this.net.compile();
            ++row;
        }
        this.printProbMatrix(stateCombinationMatrix, postProbList);
    }

    private void getExatProbEvidenceGivenTarget() throws Exception {
    }

    private void printProbMatrix(byte[][] stateCombinationMatrix, float[] postProbList) {
        int i = 0;
        while (i < stateCombinationMatrix.length) {
            int j = 0;
            while (j < stateCombinationMatrix[0].length) {
                System.out.print(String.valueOf(stateCombinationMatrix[i][j]) + "    ");
                ++j;
            }
            System.out.println(postProbList[i]);
            ++i;
        }
    }

    private void init(List<String> targetNodeNameList, List<String> evidenceNodeNameList) {
        this.targetNodeList = new TreeVariable[targetNodeNameList.size()];
        this.evidenceNodeList = new TreeVariable[evidenceNodeNameList.size()];
        this.statesProduct = 1;
        this.targetStatesProduct = 1;
        this.evidenceStatesProduct = 1;
        int count = 0;
        for (String targetNodeName : targetNodeNameList) {
            Node targetNode = this.net.getNode(targetNodeName);
            this.targetNodeList[count] = (TreeVariable)targetNode;
            this.targetStatesProduct *= targetNode.getStatesSize();
            ++count;
        }
        count = 0;
        for (String evidenceNodeName : evidenceNodeNameList) {
            Node evidenceNode = this.net.getNode(evidenceNodeName);
            this.evidenceNodeList[count] = (TreeVariable)evidenceNode;
            this.evidenceStatesProduct *= evidenceNode.getStatesSize();
            ++count;
        }
        this.statesProduct = this.targetStatesProduct * this.evidenceStatesProduct;
    }

    private void loadNetwork(String netFileName) {
        File netFile = new File(netFileName);
        String fileExt = netFileName.substring(netFileName.length() - 3);
        try {
            BaseIO io = null;
            if (fileExt.equalsIgnoreCase("xml")) {
                io = new XMLBIFIO();
            } else if (fileExt.equalsIgnoreCase("net")) {
                io = new NetIO();
            } else {
                throw new Exception("The network must be in XMLBIF 0.4 or NET format!");
            }
            this.net = io.load(netFile);
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
    }

    public static void main(String[] args) throws Exception {
        ArrayList<String> targetNodeNameList = new ArrayList<String>();
        targetNodeNameList.add("Rain");
        ArrayList<String> evidenceNodeNameList = new ArrayList<String>();
        evidenceNodeNameList.add("Springler");
        evidenceNodeNameList.add("Cloudy");
        evidenceNodeNameList.add("Wet");
        String netFileName = "../UnBBayes/examples/xml-bif/WetGrass.xml";
        int sampleSize = 100000;
        Evaluation evaluation = new Evaluation();
        evaluation.evaluate(netFileName, targetNodeNameList, evidenceNodeNameList, sampleSize);
    }
}

