public static ConditionalProbabilityTable multiplyAll(List<ConditionalProbabilityTable> cpts) { ConditionalProbabilityTable product = null; if (cpts.size() > 0) { product = cpts.get(0); for (int i = 1; i < cpts.size(); i++) { product = product.multiply(cpts.get(i)); } } return product; }
public static List<ConditionalProbabilityTable> getAllCptsContaining( List<ConditionalProbabilityTable> factorization, Variable variable) { List<ConditionalProbabilityTable> relevantCpts = new ArrayList<>(); for (ConditionalProbabilityTable cpt : factorization) { if (cpt.contains(variable)) { relevantCpts.add(cpt); } } return relevantCpts; }
public static List<ConditionalProbabilityTable> sumOut( List<ConditionalProbabilityTable> factorization, List<Variable> varsToSumOut) { List<ConditionalProbabilityTable> newFactorization = new ArrayList<>(factorization); for (Variable varToSumOut : varsToSumOut) { // Multiply all together List<ConditionalProbabilityTable> relevantCpts = Inference.getAllCptsContaining(newFactorization, varToSumOut); ConditionalProbabilityTable product = Inference.multiplyAll(relevantCpts); // Marginalize out the variable ConditionalProbabilityTable marginal = product.marginalize(varToSumOut); newFactorization.removeAll(relevantCpts); newFactorization.add(marginal); } return newFactorization; }
@Test public void sumOut() { XmlBif xmlBif; try { xmlBif = new XmlBif( new FileInputStream( new File("/Users/jhonatanoliveira/Insync/uregina/research/dataset/asia.xml"))); BayesianNetwork bn = xmlBif.getBn(); List<Variable> varsToSumOut = bn.getAllVariables(); varsToSumOut.remove(bn.getVariable("dysp")); List<ConditionalProbabilityTable> newFactorization = Inference.sumOut(bn.getAllCpts(), varsToSumOut); ConditionalProbabilityTable finalCpt = Inference.multiplyAll(newFactorization); System.out.println(finalCpt); System.out.println(finalCpt.getValues()); } catch (ParserConfigurationException | SAXException | IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } }