/**
   * Compares the baseline with the inferred result for a given predicate. Checks the baseline
   * database for atoms
   *
   * @param Predicate p : The predicate to compare
   * @param int maxAtoms : Defines the maximum number of base atoms that can be found for the given
   *     predicate. (This will vary, depending on the predicate and the problem.)
   */
  @Override
  public DiscretePredictionStatistics compare(Predicate p, int maxBaseAtoms) {
    countResultDBStats(p);

    Iterator<GroundAtom> res = resultFilter.filter(Queries.getAllAtoms(baseline, p).iterator());
    double expected;
    while (res.hasNext()) {
      GroundAtom baselineAtom = res.next();

      if (!errors.containsKey(baselineAtom) && !correctAtoms.contains(baselineAtom)) {
        // Missed result
        expected = (baselineAtom.getValue() >= threshold) ? 1.0 : 0.0;

        if (expected != 0.0) {
          errors.put(
              result.getAtom(baselineAtom.getPredicate(), baselineAtom.getArguments()), expected);
          fn++;
        }
      }
    }

    tn = maxBaseAtoms - tp - fp - fn;
    return new DiscretePredictionStatistics(tp, fp, tn, fn, threshold, errors, correctAtoms);
  }
  /**
   * Subroutine used by both compare methods for counting statistics from atoms stored in result
   * database
   *
   * @param p Predicate to compare against baseline database
   */
  private void countResultDBStats(Predicate p) {
    tp = 0;
    fn = 0;
    tn = 0;
    fp = 0;

    errors = new HashMap<GroundAtom, Double>();
    correctAtoms = new HashSet<GroundAtom>();

    GroundAtom resultAtom, baselineAtom;
    GroundTerm[] args;
    boolean actual, expected;

    Iterator<GroundAtom> iter = resultFilter.filter(Queries.getAllAtoms(result, p).iterator());

    while (iter.hasNext()) {
      resultAtom = iter.next();
      args = new GroundTerm[resultAtom.getArity()];
      for (int i = 0; i < args.length; i++) args[i] = (GroundTerm) resultAtom.getArguments()[i];
      baselineAtom = baseline.getAtom(resultAtom.getPredicate(), args);

      if (baselineAtom instanceof ObservedAtom) {
        actual = (resultAtom.getValue() >= threshold);
        expected = (baselineAtom.getValue() >= threshold);
        if (actual && expected || !actual && !expected) {
          // True negative
          if (!actual) tn++;
          // True positive
          else tp++;
          correctAtoms.add(resultAtom);
        }
        // False negative
        else if (!actual) {
          fn++;
          errors.put(resultAtom, -1.0);
        }
        // False positive
        else {
          fp++;
          errors.put(resultAtom, 1.0);
        }
      }
    }
  }
예제 #3
0
 @Override
 public String toString() {
   return atom.toString();
 }
예제 #4
0
 @Override
 public int hashCode() {
   return atom.hashCode() + 97;
 }
예제 #5
0
 @Override
 public double getValue() {
   return atom.getValue();
 }
예제 #6
0
 @Override
 public double getConfidence() {
   return atom.getConfidenceValue();
 }