public static List<ConditionalProbabilityTable> sumOut( List<ConditionalProbabilityTable> factorization, List<Variable> varsToSumOut) { List<ConditionalProbabilityTable> newFactorization = new ArrayList<>(factorization); for (Variable varToSumOut : varsToSumOut) { // Multiply all together List<ConditionalProbabilityTable> relevantCpts = Inference.getAllCptsContaining(newFactorization, varToSumOut); ConditionalProbabilityTable product = Inference.multiplyAll(relevantCpts); // Marginalize out the variable ConditionalProbabilityTable marginal = product.marginalize(varToSumOut); newFactorization.removeAll(relevantCpts); newFactorization.add(marginal); } return newFactorization; }
@Test public void sumOut() { XmlBif xmlBif; try { xmlBif = new XmlBif( new FileInputStream( new File("/Users/jhonatanoliveira/Insync/uregina/research/dataset/asia.xml"))); BayesianNetwork bn = xmlBif.getBn(); List<Variable> varsToSumOut = bn.getAllVariables(); varsToSumOut.remove(bn.getVariable("dysp")); List<ConditionalProbabilityTable> newFactorization = Inference.sumOut(bn.getAllCpts(), varsToSumOut); ConditionalProbabilityTable finalCpt = Inference.multiplyAll(newFactorization); System.out.println(finalCpt); System.out.println(finalCpt.getValues()); } catch (ParserConfigurationException | SAXException | IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } }