/** * Constructs a new on-the-fly BayesIM that will calculate conditional probabilities on the fly * from the given discrete data set, for the given Bayes PM. * * @param bayesPm the given Bayes PM, which specifies a directed acyclic graph for a Bayes net and * parametrization for the Bayes net, but not actual values for the parameters. * @param dataSet the discrete data set from which conditional probabilities should be estimated * on the fly. */ public OnTheFlyMarginalCalculator(BayesPm bayesPm, DataSet dataSet) throws IllegalArgumentException { if (bayesPm == null) { throw new NullPointerException(); } if (dataSet == null) { throw new NullPointerException(); } // Make sure all of the variables in the PM are in the data set; // otherwise, estimation is impossible. BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet); // DataUtils.ensureVariablesExist(bayesPm, dataSet); this.bayesPm = new BayesPm(bayesPm); // Get the nodes from the BayesPm. This fixes the order of the nodes // in the BayesIm, independently of any change to the BayesPm. // (This order must be maintained.) Graph graph = bayesPm.getDag(); this.nodes = graph.getNodes().toArray(new Node[0]); // Initialize. initialize(); // Create a subset of the data set with the variables of the IM, in // the order of the IM. List<Node> variables = getVariables(); this.dataSet = dataSet.subsetColumns(variables); // Create a tautologous proposition for evidence. this.evidence = new Evidence(Proposition.tautology(this)); }
/** * Estimates a Bayes IM using the variables, graph, and parameters in the given Bayes PM and the * data columns in the given data set. Each variable in the given Bayes PM must be equal to a * variable in the given data set. The Bayes IM so estimated is used as the initial Bayes net in * the iterative procedure implemented in the maximize method. */ private void estimateIM(BayesPm bayesPm, DataSet dataSet) { if (bayesPm == null) { throw new NullPointerException(); } if (dataSet == null) { throw new NullPointerException(); } // Make sure all of the variables in the PM are in the data set; // otherwise, estimation is impossible. // List pmvars = bayesPm.getVariables(); // List dsvars = dataSet.getVariables(); // List obsVars = observedIm.getBayesPm().getVariables(); // System.out.println("Bayes PM as received by estimateMixedIM: "); // System.out.println(bayesPm); // Graph g = bayesPm.getDag(); // System.out.println(g); // DEBUG Prints: // System.out.println("PM VARS " + pmvars); // System.out.println("DS VARS " + dsvars); // System.out.println("OBS IM Vars" + obsVars); BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet); // DataUtils.ensureVariablesExist(bayesPm, dataSet); // Create a new Bayes IM to store the estimated values. this.estimatedIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); int numNodes = estimatedIm.getNumNodes(); for (int node = 0; node < numNodes; node++) { int numRows = estimatedIm.getNumRows(node); int numCols = estimatedIm.getNumColumns(node); int[] parentVarIndices = estimatedIm.getParents(node); if (nodes[node].getNodeType() == NodeType.LATENT) { continue; } // int nodeObsIndex = estimatedIm.getCorrespondingNodeIndex(node, observedIm); // System.out.println("nodes[node] name = " + nodes[node].getName()); Node nodeObs = observedIm.getNode(nodes[node].getName()); // System.out.println("nodeObs name = " + nodeObs.getName()); int nodeObsIndex = observedIm.getNodeIndex(nodeObs); // int[] parentsObs = observedIm.getParents(nodeObsIndex); // System.out.println("For node " + nodes[node] + " parents are: "); boolean anyParentLatent = false; for (int parentVarIndice : parentVarIndices) { // System.out.println(nodes[parentVarIndices[p]]); if (nodes[parentVarIndice].getNodeType() == NodeType.LATENT) { anyParentLatent = true; break; } } if (anyParentLatent) { continue; } // At this point node is measured in bayesPm and so are its parents. for (int row = 0; row < numRows; row++) { // int[] parentValues = estimatedIm.getParentValues(node, row); // estimatedIm.randomizeRow(node, row); // if the node and all its parents are measured get the probs // from observedIm // loop: for (int col = 0; col < numCols; col++) { double p = observedIm.getProbability(nodeObsIndex, row, col); estimatedIm.setProbability(node, row, col, p); } } } }