/** Copies the info out of the old proposition into a new proposition for the new BayesIm. */
  public Proposition(VariableSource variableSource, Proposition proposition) {
    this(variableSource);

    if (proposition == null) {
      throw new NullPointerException();
    }

    List<Node> variables = variableSource.getVariables();
    List<Node> oldVariables = proposition.getVariableSource().getVariables();

    for (int i = 0; i < variables.size(); i++) {
      DiscreteVariable variable = (DiscreteVariable) variables.get(i);
      int oldIndex = -1;

      for (int j = 0; j < oldVariables.size(); j++) {
        DiscreteVariable _variable = (DiscreteVariable) oldVariables.get(j);
        if (variable.equals(_variable)) {
          oldIndex = j;
          break;
        }
      }

      if (oldIndex != -1) {
        for (int j = 0; j < allowedCategories[i].length; j++) {
          allowedCategories[i][j] = proposition.isAllowed(oldIndex, j);
        }
      }
    }
  }
  private static int nextValue(Proposition proposition, int variable, int currentIndex) {
    for (int i = currentIndex + 1; i < proposition.getNumCategories(variable); i++) {
      if (proposition.isAllowed(variable, i)) {
        return i;
      }
    }

    return -1;
  }
  /**
   * Restricts this proposition to the categories for the given variable that are true in the given
   * proposition.
   */
  public void restrictToProposition(Proposition proposition, int variable) {
    if (proposition.getVariableSource() != this.variableSource) {
      throw new IllegalArgumentException(
          "Can only restrict to " + "propositions for the same variable source.");
    }

    for (int j = 0; j < allowedCategories[variable].length; j++) {
      if (!proposition.isAllowed(variable, j)) {
        removeCategory(variable, j);
      }
    }
  }
  /**
   * Constructs a new on-the-fly BayesIM that will calculate conditional probabilities on the fly
   * from the given discrete data set, for the given Bayes PM.
   *
   * @param bayesPm the given Bayes PM, which specifies a directed acyclic graph for a Bayes net and
   *     parametrization for the Bayes net, but not actual values for the parameters.
   * @param dataSet the discrete data set from which conditional probabilities should be estimated
   *     on the fly.
   */
  public OnTheFlyMarginalCalculator(BayesPm bayesPm, DataSet dataSet)
      throws IllegalArgumentException {
    if (bayesPm == null) {
      throw new NullPointerException();
    }

    if (dataSet == null) {
      throw new NullPointerException();
    }

    // Make sure all of the variables in the PM are in the data set;
    // otherwise, estimation is impossible.
    BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet);
    //        DataUtils.ensureVariablesExist(bayesPm, dataSet);

    this.bayesPm = new BayesPm(bayesPm);

    // Get the nodes from the BayesPm. This fixes the order of the nodes
    // in the BayesIm, independently of any change to the BayesPm.
    // (This order must be maintained.)
    Graph graph = bayesPm.getDag();
    this.nodes = graph.getNodes().toArray(new Node[0]);

    // Initialize.
    initialize();

    // Create a subset of the data set with the variables of the IM, in
    // the order of the IM.
    List<Node> variables = getVariables();
    this.dataSet = dataSet.subsetColumns(variables);

    // Create a tautologous proposition for evidence.
    this.evidence = new Evidence(Proposition.tautology(this));
  }
示例#5
0
  /**
   * @return the estimated conditional probability for the given assertion conditional on the given
   *     condition.
   */
  public double getConditionalProb(Proposition assertion, Proposition condition) {
    if (assertion.getVariableSource() != condition.getVariableSource()) {
      throw new IllegalArgumentException(
          "Assertion and condition must be " + "for the same Bayes IM.");
    }

    List<Node> assertionVars = assertion.getVariableSource().getVariables();
    List<Node> dataVars = dataSet.getVariables();

    assertionVars = GraphUtils.replaceNodes(assertionVars, dataVars);

    if (!new HashSet<Node>(assertionVars).equals(new HashSet<Node>(dataVars))) {
      throw new IllegalArgumentException(
          "Assertion variable and data variables"
              + " are either different or in a different order: "
              + "\n\tAssertion vars: "
              + assertionVars
              + "\n\tData vars: "
              + dataVars);
    }

    int[] point = new int[dims.length];
    int count1 = 0;
    int count2 = 0;
    this.missingValueCaseFound = false;

    point:
    for (int i = 0; i < numRows; i++) {
      for (int j = 0; j < dims.length; j++) {
        point[j] = dataSet.getInt(i, j);

        if (point[j] == DiscreteVariable.MISSING_VALUE) {
          continue point;
        }
      }

      if (condition.isPermissibleCombination(point)) {
        count1++;

        if (assertion.isPermissibleCombination(point)) {
          count2++;
        }
      }
    }

    return count2 / (double) count1;
  }
  /**
   * Restricts this proposition to the categories for each variable that are true in the given
   * proposition.
   */
  public void restrictToProposition(Proposition proposition) {
    if (proposition.getVariableSource() != this.variableSource) {
      throw new IllegalArgumentException(
          "Can only restrict to " + "propositions for the same variable source.");
    }

    for (int i = 0; i < allowedCategories.length; i++) {
      for (int j = 0; j < allowedCategories[i].length; j++) {
        if (!proposition.allowedCategories[i][j]) {
          this.allowedCategories[i][j] = false;
        }
      }
    }
  }
示例#7
0
  /** @return the estimated probability of the given proposition. */
  public double getProb(Proposition assertion) {
    int[] point = new int[dims.length];
    int count = 0;

    this.missingValueCaseFound = false;

    point:
    for (int i = 0; i < numRows; i++) {
      for (int j = 0; j < dims.length; j++) {
        point[j] = dataSet.getInt(i, j);

        if (point[j] == DiscreteVariable.MISSING_VALUE) {
          this.missingValueCaseFound = true;
          continue point;
        }
      }

      if (assertion.isPermissibleCombination(point)) {
        count++;
      }
    }

    return count / (double) this.numRows;
  }
  private double getUpdatedMarginalFromModel(int variable, int category) {
    Proposition evidence = getEvidence().getProposition();
    int[] variableValues = new int[evidence.getNumVariables()];

    for (int i = 0; i < evidence.getNumVariables(); i++) {
      variableValues[i] = nextValue(evidence, i, -1);
    }

    variableValues[variableValues.length - 1] = -1;
    double sum = 0.0;

    loop:
    while (true) {
      for (int i = evidence.getNumVariables() - 1; i >= 0; i--) {
        if (hasNextValue(evidence, i, variableValues[i])) {
          variableValues[i] = nextValue(evidence, i, variableValues[i]);

          for (int j = i + 1; j < evidence.getNumVariables(); j++) {
            if (hasNextValue(evidence, j, -1)) {
              variableValues[j] = nextValue(evidence, j, -1);
            } else {
              break loop;
            }
          }

          double product = 1.0;

          for (int m = 0; m < getNumNodes(); m++) {
            Proposition assertion = Proposition.tautology(this);
            assertion.setCategory(variable, category);

            Proposition condition = new Proposition(evidence);
            int[] parents = getParents(m);

            for (int parent : parents) {
              condition.disallowComplement(parent, variableValues[parent]);
            }

            if (condition.existsCombination()) {
              product *= getDiscreteProbs().getConditionalProb(assertion, condition);
            }
          }

          sum += product;
          continue loop;
        }
      }

      break;
    }

    return sum;
  }