/** Copies the info out of the old proposition into a new proposition for the new BayesIm. */ public Proposition(VariableSource variableSource, Proposition proposition) { this(variableSource); if (proposition == null) { throw new NullPointerException(); } List<Node> variables = variableSource.getVariables(); List<Node> oldVariables = proposition.getVariableSource().getVariables(); for (int i = 0; i < variables.size(); i++) { DiscreteVariable variable = (DiscreteVariable) variables.get(i); int oldIndex = -1; for (int j = 0; j < oldVariables.size(); j++) { DiscreteVariable _variable = (DiscreteVariable) oldVariables.get(j); if (variable.equals(_variable)) { oldIndex = j; break; } } if (oldIndex != -1) { for (int j = 0; j < allowedCategories[i].length; j++) { allowedCategories[i][j] = proposition.isAllowed(oldIndex, j); } } } }
private static int nextValue(Proposition proposition, int variable, int currentIndex) { for (int i = currentIndex + 1; i < proposition.getNumCategories(variable); i++) { if (proposition.isAllowed(variable, i)) { return i; } } return -1; }
/** * Restricts this proposition to the categories for the given variable that are true in the given * proposition. */ public void restrictToProposition(Proposition proposition, int variable) { if (proposition.getVariableSource() != this.variableSource) { throw new IllegalArgumentException( "Can only restrict to " + "propositions for the same variable source."); } for (int j = 0; j < allowedCategories[variable].length; j++) { if (!proposition.isAllowed(variable, j)) { removeCategory(variable, j); } } }
/** * Constructs a new on-the-fly BayesIM that will calculate conditional probabilities on the fly * from the given discrete data set, for the given Bayes PM. * * @param bayesPm the given Bayes PM, which specifies a directed acyclic graph for a Bayes net and * parametrization for the Bayes net, but not actual values for the parameters. * @param dataSet the discrete data set from which conditional probabilities should be estimated * on the fly. */ public OnTheFlyMarginalCalculator(BayesPm bayesPm, DataSet dataSet) throws IllegalArgumentException { if (bayesPm == null) { throw new NullPointerException(); } if (dataSet == null) { throw new NullPointerException(); } // Make sure all of the variables in the PM are in the data set; // otherwise, estimation is impossible. BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet); // DataUtils.ensureVariablesExist(bayesPm, dataSet); this.bayesPm = new BayesPm(bayesPm); // Get the nodes from the BayesPm. This fixes the order of the nodes // in the BayesIm, independently of any change to the BayesPm. // (This order must be maintained.) Graph graph = bayesPm.getDag(); this.nodes = graph.getNodes().toArray(new Node[0]); // Initialize. initialize(); // Create a subset of the data set with the variables of the IM, in // the order of the IM. List<Node> variables = getVariables(); this.dataSet = dataSet.subsetColumns(variables); // Create a tautologous proposition for evidence. this.evidence = new Evidence(Proposition.tautology(this)); }
/** * @return the estimated conditional probability for the given assertion conditional on the given * condition. */ public double getConditionalProb(Proposition assertion, Proposition condition) { if (assertion.getVariableSource() != condition.getVariableSource()) { throw new IllegalArgumentException( "Assertion and condition must be " + "for the same Bayes IM."); } List<Node> assertionVars = assertion.getVariableSource().getVariables(); List<Node> dataVars = dataSet.getVariables(); assertionVars = GraphUtils.replaceNodes(assertionVars, dataVars); if (!new HashSet<Node>(assertionVars).equals(new HashSet<Node>(dataVars))) { throw new IllegalArgumentException( "Assertion variable and data variables" + " are either different or in a different order: " + "\n\tAssertion vars: " + assertionVars + "\n\tData vars: " + dataVars); } int[] point = new int[dims.length]; int count1 = 0; int count2 = 0; this.missingValueCaseFound = false; point: for (int i = 0; i < numRows; i++) { for (int j = 0; j < dims.length; j++) { point[j] = dataSet.getInt(i, j); if (point[j] == DiscreteVariable.MISSING_VALUE) { continue point; } } if (condition.isPermissibleCombination(point)) { count1++; if (assertion.isPermissibleCombination(point)) { count2++; } } } return count2 / (double) count1; }
/** * Restricts this proposition to the categories for each variable that are true in the given * proposition. */ public void restrictToProposition(Proposition proposition) { if (proposition.getVariableSource() != this.variableSource) { throw new IllegalArgumentException( "Can only restrict to " + "propositions for the same variable source."); } for (int i = 0; i < allowedCategories.length; i++) { for (int j = 0; j < allowedCategories[i].length; j++) { if (!proposition.allowedCategories[i][j]) { this.allowedCategories[i][j] = false; } } } }
/** @return the estimated probability of the given proposition. */ public double getProb(Proposition assertion) { int[] point = new int[dims.length]; int count = 0; this.missingValueCaseFound = false; point: for (int i = 0; i < numRows; i++) { for (int j = 0; j < dims.length; j++) { point[j] = dataSet.getInt(i, j); if (point[j] == DiscreteVariable.MISSING_VALUE) { this.missingValueCaseFound = true; continue point; } } if (assertion.isPermissibleCombination(point)) { count++; } } return count / (double) this.numRows; }
private double getUpdatedMarginalFromModel(int variable, int category) { Proposition evidence = getEvidence().getProposition(); int[] variableValues = new int[evidence.getNumVariables()]; for (int i = 0; i < evidence.getNumVariables(); i++) { variableValues[i] = nextValue(evidence, i, -1); } variableValues[variableValues.length - 1] = -1; double sum = 0.0; loop: while (true) { for (int i = evidence.getNumVariables() - 1; i >= 0; i--) { if (hasNextValue(evidence, i, variableValues[i])) { variableValues[i] = nextValue(evidence, i, variableValues[i]); for (int j = i + 1; j < evidence.getNumVariables(); j++) { if (hasNextValue(evidence, j, -1)) { variableValues[j] = nextValue(evidence, j, -1); } else { break loop; } } double product = 1.0; for (int m = 0; m < getNumNodes(); m++) { Proposition assertion = Proposition.tautology(this); assertion.setCategory(variable, category); Proposition condition = new Proposition(evidence); int[] parents = getParents(m); for (int parent : parents) { condition.disallowComplement(parent, variableValues[parent]); } if (condition.existsCombination()) { product *= getDiscreteProbs().getConditionalProb(assertion, condition); } } sum += product; continue loop; } } break; } return sum; }