/*
 * Decompiled with CFR 0.152.
 */
package unbbayes.simulation.likelihoodweighting.inference;

import java.io.File;
import java.text.NumberFormat;
import java.util.Locale;
import unbbayes.io.BaseIO;
import unbbayes.io.NetIO;
import unbbayes.io.XMLBIFIO;
import unbbayes.prs.Node;
import unbbayes.prs.bn.ProbabilisticNetwork;
import unbbayes.prs.bn.TreeVariable;
import unbbayes.simulation.likelihoodweighting.sampling.LikelihoodWeightingSampling;

public class LikelihoodWeightingInference {
    protected LikelihoodWeightingSampling lwSampling;
    protected ProbabilisticNetwork pn;
    protected int nTrials;

    public LikelihoodWeightingInference(ProbabilisticNetwork pn, int nTrials) {
        this.pn = pn;
        this.nTrials = nTrials;
        this.lwSampling = new LikelihoodWeightingSampling(pn, nTrials);
    }

    public void run() {
        this.lwSampling.start();
        int i = 0;
        while (i < this.pn.getNodeCount()) {
            Node node = this.pn.getNodeAt(i);
            if (!((TreeVariable)node).hasEvidence()) {
                this.updateMarginal(node);
            }
            ++i;
        }
    }

    protected void updateMarginal(Node node) {
        float[] marginal = new float[node.getStatesSize()];
        byte[][] sampledMatrix = this.lwSampling.getSampledStatesMatrix();
        float[] probEvdGivenPar = this.lwSampling.getProbabilityEvidenceGivenParentList();
        int nodeIndex = this.lwSampling.getSamplingNodeOrderQueue().indexOf(node);
        int i = 0;
        while (i < sampledMatrix.length) {
            byte state;
            byte by = state = sampledMatrix[i][nodeIndex];
            marginal[by] = marginal[by] + probEvdGivenPar[i];
            ++i;
        }
        this.normalize(marginal);
        ((TreeVariable)node).initMarginalList();
        ((TreeVariable)node).addLikeliHood(marginal);
    }

    protected void normalize(float[] floatList) {
        double total = 0.0;
        int i = 0;
        while (i < floatList.length) {
            total += (double)floatList[i];
            ++i;
        }
        i = 0;
        while (i < floatList.length) {
            int n = i++;
            floatList[n] = (float)((double)floatList[n] / total);
        }
    }

    private static ProbabilisticNetwork loadNetwork(String netFileName) {
        File netFile = new File(netFileName);
        String fileExt = netFileName.substring(netFileName.length() - 3);
        ProbabilisticNetwork pn = null;
        try {
            BaseIO io = null;
            if (fileExt.equalsIgnoreCase("xml")) {
                io = new XMLBIFIO();
            } else if (fileExt.equalsIgnoreCase("net")) {
                io = new NetIO();
            } else {
                throw new Exception("The network must be in XMLBIF 0.4 or NET format!");
            }
            pn = io.load(netFile);
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return pn;
    }

    public static void main(String[] args) throws Exception {
        int i;
        NumberFormat nf = NumberFormat.getInstance(Locale.US);
        nf.setMaximumFractionDigits(2);
        String netFileName = "../UnBBayes/examples/asia.net";
        int sampleSize = 100000;
        ProbabilisticNetwork pn = LikelihoodWeightingInference.loadNetwork(netFileName);
        LikelihoodWeightingInference lw = new LikelihoodWeightingInference(pn, sampleSize);
        lw.run();
        for (Node node : pn.getNodes()) {
            System.out.println(node.getDescription());
            i = 0;
            while (i < node.getStatesSize()) {
                System.out.println("\t" + node.getStateAt(i) + ": " + nf.format(((TreeVariable)node).getMarginalAt(i) * 100.0f));
                ++i;
            }
            System.out.println();
        }
        ((TreeVariable)pn.getNodeAt(0)).addFinding(0);
        lw.run();
        for (Node node : pn.getNodes()) {
            System.out.println(node.getDescription());
            i = 0;
            while (i < node.getStatesSize()) {
                System.out.println("\t" + node.getStateAt(i) + ": " + nf.format(((TreeVariable)node).getMarginalAt(i) * 100.0f));
                ++i;
            }
            System.out.println();
        }
    }
}

