/** * Estimates a Bayes IM using the variables, graph, and parameters in the given Bayes PM and the * data columns in the given data set. Each variable in the given Bayes PM must be equal to a * variable in the given data set. The Bayes IM so estimated is used as the initial Bayes net in * the iterative procedure implemented in the maximize method. */ private void estimateIM(BayesPm bayesPm, DataSet dataSet) { if (bayesPm == null) { throw new NullPointerException(); } if (dataSet == null) { throw new NullPointerException(); } // Make sure all of the variables in the PM are in the data set; // otherwise, estimation is impossible. // List pmvars = bayesPm.getVariables(); // List dsvars = dataSet.getVariables(); // List obsVars = observedIm.getBayesPm().getVariables(); // System.out.println("Bayes PM as received by estimateMixedIM: "); // System.out.println(bayesPm); // Graph g = bayesPm.getDag(); // System.out.println(g); // DEBUG Prints: // System.out.println("PM VARS " + pmvars); // System.out.println("DS VARS " + dsvars); // System.out.println("OBS IM Vars" + obsVars); BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet); // DataUtils.ensureVariablesExist(bayesPm, dataSet); // Create a new Bayes IM to store the estimated values. this.estimatedIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); int numNodes = estimatedIm.getNumNodes(); for (int node = 0; node < numNodes; node++) { int numRows = estimatedIm.getNumRows(node); int numCols = estimatedIm.getNumColumns(node); int[] parentVarIndices = estimatedIm.getParents(node); if (nodes[node].getNodeType() == NodeType.LATENT) { continue; } // int nodeObsIndex = estimatedIm.getCorrespondingNodeIndex(node, observedIm); // System.out.println("nodes[node] name = " + nodes[node].getName()); Node nodeObs = observedIm.getNode(nodes[node].getName()); // System.out.println("nodeObs name = " + nodeObs.getName()); int nodeObsIndex = observedIm.getNodeIndex(nodeObs); // int[] parentsObs = observedIm.getParents(nodeObsIndex); // System.out.println("For node " + nodes[node] + " parents are: "); boolean anyParentLatent = false; for (int parentVarIndice : parentVarIndices) { // System.out.println(nodes[parentVarIndices[p]]); if (nodes[parentVarIndice].getNodeType() == NodeType.LATENT) { anyParentLatent = true; break; } } if (anyParentLatent) { continue; } // At this point node is measured in bayesPm and so are its parents. for (int row = 0; row < numRows; row++) { // int[] parentValues = estimatedIm.getParentValues(node, row); // estimatedIm.randomizeRow(node, row); // if the node and all its parents are measured get the probs // from observedIm // loop: for (int col = 0; col < numCols; col++) { double p = observedIm.getProbability(nodeObsIndex, row, col); estimatedIm.setProbability(node, row, col, p); } } } }
private void initialize() { DirichletBayesIm prior = DirichletBayesIm.symmetricDirichletIm(bayesPmObs, 0.5); observedIm = DirichletEstimator.estimate(prior, dataSet); // MLBayesEstimator dirichEst = new MLBayesEstimator(); // observedIm = dirichEst.estimate(bayesPmObs, dataSet); // System.out.println("Estimated Bayes IM for Measured Variables: "); // System.out.println(observedIm); // mixedData should be ddsNm with new columns for the latent variables. // Each such column should contain missing data for each case. int numFullCases = dataSet.getNumRows(); List<Node> variables = new LinkedList<Node>(); for (Node node : nodes) { if (node.getNodeType() == NodeType.LATENT) { int numCategories = bayesPm.getNumCategories(node); DiscreteVariable latentVar = new DiscreteVariable(node.getName(), numCategories); variables.add(latentVar); } else { String name = bayesPm.getVariable(node).getName(); Node variable = dataSet.getVariable(name); variables.add(variable); } } DataSet dsMixed = new ColtDataSet(numFullCases, variables); for (int j = 0; j < nodes.length; j++) { if (nodes[j].getNodeType() == NodeType.LATENT) { for (int i = 0; i < numFullCases; i++) { dsMixed.setInt(i, j, -99); } } else { String name = bayesPm.getVariable(nodes[j]).getName(); Node variable = dataSet.getVariable(name); int index = dataSet.getColumn(variable); for (int i = 0; i < numFullCases; i++) { dsMixed.setInt(i, j, dataSet.getInt(i, index)); } } } // System.out.println(dsMixed); mixedData = dsMixed; allVariables = mixedData.getVariables(); // Find the bayes net which is parameterized using mixedData or set randomly when that's // not possible. estimateIM(bayesPm, mixedData); // The following DEBUG section tests a case specified by P. Spirtes // DEBUG TAIL: For use with embayes_l1x1x2x3V3.dat /* Node l1Node = graph.getNode("L1"); //int l1Index = bayesImMixed.getNodeIndex(l1Node); int l1index = estimatedIm.getNodeIndex(l1Node); Node x1Node = graph.getNode("X1"); //int x1Index = bayesImMixed.getNodeIndex(x1Node); int x1Index = estimatedIm.getNodeIndex(x1Node); Node x2Node = graph.getNode("X2"); //int x2Index = bayesImMixed.getNodeIndex(x2Node); int x2Index = estimatedIm.getNodeIndex(x2Node); Node x3Node = graph.getNode("X3"); //int x3Index = bayesImMixed.getNodeIndex(x3Node); int x3Index = estimatedIm.getNodeIndex(x3Node); estimatedIm.setProbability(l1index, 0, 0, 0.5); estimatedIm.setProbability(l1index, 0, 1, 0.5); //bayesImMixed.setProbability(x1Index, 0, 0, 0.33333); //bayesImMixed.setProbability(x1Index, 0, 1, 0.66667); estimatedIm.setProbability(x1Index, 0, 0, 0.6); //p(x1 = 0 | l1 = 0) estimatedIm.setProbability(x1Index, 0, 1, 0.4); //p(x1 = 1 | l1 = 0) estimatedIm.setProbability(x1Index, 1, 0, 0.4); //p(x1 = 0 | l1 = 1) estimatedIm.setProbability(x1Index, 1, 1, 0.6); //p(x1 = 1 | l1 = 1) //bayesImMixed.setProbability(x2Index, 1, 0, 0.66667); //bayesImMixed.setProbability(x2Index, 1, 1, 0.33333); estimatedIm.setProbability(x2Index, 1, 0, 0.4); //p(x2 = 0 | l1 = 1) estimatedIm.setProbability(x2Index, 1, 1, 0.6); //p(x2 = 1 | l1 = 1) estimatedIm.setProbability(x2Index, 0, 0, 0.6); //p(x2 = 0 | l1 = 0) estimatedIm.setProbability(x2Index, 0, 1, 0.4); //p(x2 = 1 | l1 = 0) //bayesImMixed.setProbability(x3Index, 1, 0, 0.66667); //bayesImMixed.setProbability(x3Index, 1, 1, 0.33333); estimatedIm.setProbability(x3Index, 1, 0, 0.4); //p(x3 = 0 | l1 = 1) estimatedIm.setProbability(x3Index, 1, 1, 0.6); //p(x3 = 1 | l1 = 1) estimatedIm.setProbability(x3Index, 0, 0, 0.6); //p(x3 = 0 | l1 = 0) estimatedIm.setProbability(x3Index, 0, 1, 0.4); //p(x3 = 1 | l1 = 0) */ // END of TAIL // System.out.println("bayes IM estimated by estimateIM"); // System.out.println(bayesImMixed); // System.out.println(estimatedIm); estimatedCounts = new double[nodes.length][][]; estimatedCountsDenom = new double[nodes.length][]; condProbs = new double[nodes.length][][]; for (int i = 0; i < nodes.length; i++) { // int numRows = bayesImMixed.getNumRows(i); int numRows = estimatedIm.getNumRows(i); estimatedCounts[i] = new double[numRows][]; estimatedCountsDenom[i] = new double[numRows]; condProbs[i] = new double[numRows][]; // for(int j = 0; j < bayesImMixed.getNumRows(i); j++) { for (int j = 0; j < estimatedIm.getNumRows(i); j++) { // int numCols = bayesImMixed.getNumColumns(i); int numCols = estimatedIm.getNumColumns(i); estimatedCounts[i][j] = new double[numCols]; condProbs[i][j] = new double[numCols]; } } }
/** * This method takes an instantiated Bayes net (BayesIm) whose graph include all the variables * (observed and latent) and computes estimated counts using the data in the DataSet mixedData. * The counts that are estimated correspond to cells in the conditional probability tables of the * Bayes net. The outermost loop (indexed by j) is over the set of variables. If the variable has * no parents, each case in the dataset is examined and the count for the observed value of the * variables is increased by 1.0; if the value of the variable is missing the marginal * probabilities its values given the values of the variables that are available for that case are * used to increment the corresponding estimated counts. If a variable has parents then there is a * loop which steps through all possible sets of values of its parents. This loop is indexed by * the variable "row". Each case in the dataset is examined. It the variable and all its parents * have values in the case the corresponding estimated counts are incremented by 1.0. If the * variable or any of its parents have missing values, the joint marginal is computed for the * variable and the set of values of its parents corresponding to "row" and the corresponding * estimated counts are incremented by the appropriate probability. The estimated counts are * stored in the double[][][] array estimatedCounts. The count (possibly fractional) of the number * of times each combination of parent values occurs is stored in the double[][] array * estimatedCountsDenom. These two arrays are used to compute the estimated conditional * probabilities of the output Bayes net. */ private BayesIm expectation(BayesIm inputBayesIm) { // System.out.println("Entered method expectation."); int numCases = mixedData.getNumRows(); // StoredCellEstCounts estCounts = new StoredCellEstCounts(variables); int numVariables = allVariables.size(); RowSummingExactUpdater rseu = new RowSummingExactUpdater(inputBayesIm); for (int j = 0; j < numVariables; j++) { DiscreteVariable var = (DiscreteVariable) allVariables.get(j); String varName = var.getName(); Node varNode = graph.getNode(varName); int varIndex = inputBayesIm.getNodeIndex(varNode); int[] parentVarIndices = inputBayesIm.getParents(varIndex); // System.out.println("graph = " + graph); // for(int col = 0; col < var.getNumSplits(); col++) // System.out.println("Category " + col + " = " + var.getCategory(col)); // System.out.println("Updating estimated counts for node " + varName); // This segment is for variables with no parents: if (parentVarIndices.length == 0) { // System.out.println("No parents"); for (int col = 0; col < var.getNumCategories(); col++) { estimatedCounts[j][0][col] = 0.0; } for (int i = 0; i < numCases; i++) { // System.out.println("Case " + i); // If this case has a value for var if (mixedData.getInt(i, j) != -99) { estimatedCounts[j][0][mixedData.getInt(i, j)] += 1.0; // System.out.println("Adding 1.0 to " + varName + // " row 0 category " + mixedData[j][i]); } else { // find marginal probability, given obs data in this case, p(v=0) Evidence evidenceThisCase = Evidence.tautology(inputBayesIm); boolean existsEvidence = false; // Define evidence for updating by using the values of the other vars. for (int k = 0; k < numVariables; k++) { if (k == j) { continue; } Node otherVar = allVariables.get(k); if (mixedData.getInt(i, k) == -99) { continue; } existsEvidence = true; String otherVarName = otherVar.getName(); Node otherNode = graph.getNode(otherVarName); int otherIndex = inputBayesIm.getNodeIndex(otherNode); evidenceThisCase.getProposition().setCategory(otherIndex, mixedData.getInt(i, k)); } if (!existsEvidence) { continue; // No other variable contained useful data } rseu.setEvidence(evidenceThisCase); for (int m = 0; m < var.getNumCategories(); m++) { estimatedCounts[j][0][m] += rseu.getMarginal(varIndex, m); // System.out.println("Adding " + p + " to " + varName + // " row 0 category " + m); // find marginal probability, given obs data in this case, p(v=1) // estimatedCounts[j][0][1] += 0.5; } } } // Print estimated counts: // System.out.println("Estimated counts: "); // Print counts for each value of this variable with no parents. // for(int m = 0; m < var.getNumSplits(); m++) // System.out.print(" " + m + " " + estimatedCounts[j][0][m]); // System.out.println(); } else { // For variables with parents: int numRows = inputBayesIm.getNumRows(varIndex); for (int row = 0; row < numRows; row++) { int[] parValues = inputBayesIm.getParentValues(varIndex, row); estimatedCountsDenom[varIndex][row] = 0.0; for (int col = 0; col < var.getNumCategories(); col++) { estimatedCounts[varIndex][row][col] = 0.0; } for (int i = 0; i < numCases; i++) { // for a case where the parent values = parValues increment the estCount boolean parentMatch = true; for (int p = 0; p < parentVarIndices.length; p++) { if (parValues[p] != mixedData.getInt(i, parentVarIndices[p]) && mixedData.getInt(i, parentVarIndices[p]) != -99) { parentMatch = false; break; } } if (!parentMatch) { continue; // Not a matching case; go to next. } boolean parentMissing = false; for (int parentVarIndice : parentVarIndices) { if (mixedData.getInt(i, parentVarIndice) == -99) { parentMissing = true; break; } } if (mixedData.getInt(i, j) != -99 && !parentMissing) { estimatedCounts[j][row][mixedData.getInt(i, j)] += 1.0; estimatedCountsDenom[j][row] += 1.0; continue; // Next case } // for a case with missing data (either var or one of its parents) // compute the joint marginal // distribution for var & this combination of values of its parents // and update the estCounts accordingly // To compute marginals create the evidence boolean existsEvidence = false; Evidence evidenceThisCase = Evidence.tautology(inputBayesIm); // "evidenceVars" not used. // List<String> evidenceVars = new LinkedList<String>(); // for (int k = 0; k < numVariables; k++) { // //if(k == j) continue; // Variable otherVar = allVariables.get(k); // if (mixedData.getInt(i, k) == -99) { // continue; // } // existsEvidence = true; // String otherVarName = otherVar.getName(); // Node otherNode = graph.getNode(otherVarName); // int otherIndex = inputBayesIm.getNodeIndex( // otherNode); // evidenceThisCase.getProposition().setCategory( // otherIndex, mixedData.getInt(i, k)); // evidenceVars.add(otherVarName); // } if (!existsEvidence) { continue; } rseu.setEvidence(evidenceThisCase); estimatedCountsDenom[j][row] += rseu.getJointMarginal(parentVarIndices, parValues); int[] parPlusChildIndices = new int[parentVarIndices.length + 1]; int[] parPlusChildValues = new int[parentVarIndices.length + 1]; parPlusChildIndices[0] = varIndex; for (int pc = 1; pc < parPlusChildIndices.length; pc++) { parPlusChildIndices[pc] = parentVarIndices[pc - 1]; parPlusChildValues[pc] = parValues[pc - 1]; } for (int m = 0; m < var.getNumCategories(); m++) { parPlusChildValues[0] = m; /* if(varName.equals("X1") && i == 0 ) { System.out.println("Calling getJointMarginal with parvalues"); for(int k = 0; k < parPlusChildIndices.length; k++) { int pIndex = parPlusChildIndices[k]; Node pNode = inputBayesIm.getNode(pIndex); String pName = pNode.getName(); System.out.println(pName + " " + parPlusChildValues[k]); } } */ /* if(varName.equals("X1") && i == 0 ) { System.out.println("Evidence = " + evidenceThisCase); //int[] vars = {l1Index, x1Index}; Node nodex1 = inputBayesIm.getNode("X1"); int x1Index = inputBayesIm.getNodeIndex(nodex1); Node nodel1 = inputBayesIm.getNode("L1"); int l1Index = inputBayesIm.getNodeIndex(nodel1); int[] vars = {l1Index, x1Index}; int[] vals = {0, 0}; double ptest = rseu.getJointMarginal(vars, vals); System.out.println("Joint marginal (X1=0, L1 = 0) = " + p); } */ estimatedCounts[j][row][m] += rseu.getJointMarginal(parPlusChildIndices, parPlusChildValues); // System.out.println("Case " + i + " parent values "); // for (int pp = 0; pp < parentVarIndices.length; pp++) { // Variable par = (Variable) allVariables.get(parentVarIndices[pp]); // System.out.print(" " + par.getName() + " " + parValues[pp]); // } // System.out.println(); // System.out.println("Adding " + p + " to " + varName + // " row " + row + " category " + m); } // } } // Print estimated counts: // System.out.println("Estimated counts: "); // System.out.println(" Parent values: "); // for (int i = 0; i < parentVarIndices.length; i++) { // Variable par = (Variable) allVariables.get(parentVarIndices[i]); // System.out.print(" " + par.getName() + " " + parValues[i] + " "); // } // System.out.println(); // for(int m = 0; m < var.getNumSplits(); m++) // System.out.print(" " + m + " " + estimatedCounts[j][row][m]); // System.out.println(); } } // else } // j < numVariables BayesIm outputBayesIm = new MlBayesIm(bayesPm); for (int j = 0; j < nodes.length; j++) { DiscreteVariable var = (DiscreteVariable) allVariables.get(j); String varName = var.getName(); Node varNode = graph.getNode(varName); int varIndex = inputBayesIm.getNodeIndex(varNode); // int[] parentVarIndices = inputBayesIm.getParents(varIndex); int numRows = inputBayesIm.getNumRows(j); // System.out.println("Conditional probabilities for variable " + varName); int numCols = inputBayesIm.getNumColumns(j); if (numRows == 1) { double sum = 0.0; for (int m = 0; m < numCols; m++) { sum += estimatedCounts[j][0][m]; } for (int m = 0; m < numCols; m++) { condProbs[j][0][m] = estimatedCounts[j][0][m] / sum; // System.out.print(" " + condProbs[j][0][m]); outputBayesIm.setProbability(varIndex, 0, m, condProbs[j][0][m]); } // System.out.println(); } else { for (int row = 0; row < numRows; row++) { // int[] parValues = inputBayesIm.getParentValues(varIndex, // row); // int numCols = inputBayesIm.getNumColumns(j); // for (int p = 0; p < parentVarIndices.length; p++) { // Variable par = (Variable) allVariables.get(parentVarIndices[p]); // System.out.print(" " + par.getName() + " " + parValues[p]); // } // double sum = 0.0; // for(int m = 0; m < numCols; m++) // sum += estimatedCounts[j][row][m]; for (int m = 0; m < numCols; m++) { if (estimatedCountsDenom[j][row] != 0.0) { condProbs[j][row][m] = estimatedCounts[j][row][m] / estimatedCountsDenom[j][row]; } else { condProbs[j][row][m] = Double.NaN; } // System.out.print(" " + condProbs[j][row][m]); outputBayesIm.setProbability(varIndex, row, m, condProbs[j][row][m]); } // System.out.println(); } } } return outputBayesIm; }
public EmBayesEstimator(BayesIm inputBayesIm, DataSet dataSet) { this(inputBayesIm.getBayesPm(), dataSet); // this.inputBayesIm = inputBayesIm; }