コード例 #1
0
public class TestMinimizable extends TestCase {
  private static Logger logger = MalletLogger.getLogger(TestMinimizable.class.getName());

  public TestMinimizable(String name) {
    super(name);
  }

  public static boolean testGetSetParameters(Minimizable minable) {
    System.out.println("TestMinimizable testGetSetParameters");
    // Set all the parameters to unique values using setParameters()
    Matrix parameters = minable.getNewMatrix();
    minable.getParameters(parameters);
    for (int i = 0; i < parameters.singleSize(); i++) parameters.setSingleValue(i, (double) i);
    minable.setParameters(parameters);

    // Test to make sure those parameters are there
    parameters.setAll(0.0);
    minable.getParameters(parameters);
    for (int i = 0; i < parameters.singleSize(); i++)
      assertTrue(parameters.singleValue(i) == (double) i);

    // Set all the parameters to unique values using setParameter()
    parameters.setAll(0.0);
    minable.setParameters(parameters);
    int[] indices = new int[parameters.getNumDimensions()];
    for (int i = 0; i < parameters.singleSize(); i++) {
      parameters.singleToIndices(i, indices);
      minable.setParameter(indices, (double) i);
    }

    // Test to make sure those parameters are there
    parameters.setAll(0.0);
    minable.getParameters(parameters);
    for (int i = 0; i < parameters.singleSize(); i++) {
      // System.out.println ("Got "+parameters.getSingle(i)+", expecting "+((double)i));
      assertTrue(parameters.singleValue(i) == (double) i);
    }

    // Test to make sure they are also there when we look individually
    for (int i = 0; i < parameters.singleSize(); i++) {
      parameters.singleToIndices(i, indices);
      assertTrue(minable.getParameter(indices) == (double) i);
    }

    return true;
  }

  public static double testCostAndGradientCurrentParameters(Minimizable.ByGradient minable) {
    Matrix parameters = minable.getParameters(minable.getNewMatrix());
    double cost = minable.getCost();
    // the gradient from the minimizable function
    Matrix analyticGradient = minable.getCostGradient(minable.getNewMatrix());
    // the gradient calculate from the slope of the cost
    Matrix empiricalGradient = (Matrix) analyticGradient.cloneMatrix();
    // This setting of epsilon should make the individual elements of
    // the analytical gradient and the empirical gradient equal.  This
    // simplifies the comparison of the individual dimensions of the
    // gradient and thus makes debugging easier.
    double epsilon = 0.1 / analyticGradient.twoNorm();
    double tolerance = epsilon * 5;
    System.out.println("epsilon = " + epsilon + " tolerance=" + tolerance);

    // Check each direction, perturb it, measure new cost,
    // and make sure it agrees with the gradient from minable.getCostGradient()
    for (int i = 0; i < parameters.singleSize(); i++) {
      double param = parameters.singleValue(i);
      parameters.setSingleValue(i, param + epsilon);
      // logger.fine ("Parameters:"); parameters.print();
      minable.setParameters(parameters);
      double epsCost = minable.getCost();
      double slope = (epsCost - cost) / epsilon;
      System.out.println(
          "cost="
              + cost
              + " epsCost="
              + epsCost
              + " slope["
              + i
              + "] = "
              + slope
              + " gradient[]="
              + analyticGradient.singleValue(i));
      assert (!Double.isNaN(slope));
      logger.fine(
          "TestMinimizable checking singleIndex "
              + i
              + ": gradient slope = "
              + analyticGradient.singleValue(i)
              + ", cost+epsilon slope = "
              + slope
              + ": slope difference = "
              + Math.abs(slope - analyticGradient.singleValue(i)));
      // No negative below because the gradient points in the direction
      // of maximizing the function.
      empiricalGradient.setSingleValue(i, slope);
      parameters.setSingleValue(i, param);
    }
    // Normalize the matrices to have the same L2 length
    System.out.println("empiricalGradient.twoNorm = " + empiricalGradient.twoNorm());
    analyticGradient.timesEquals(1.0 / analyticGradient.twoNorm());
    empiricalGradient.timesEquals(1.0 / empiricalGradient.twoNorm());
    // logger.info ("AnalyticGradient:"); analyticGradient.print();
    // logger.info ("EmpiricalGradient:"); empiricalGradient.print();
    // Return the angle between the two vectors, in radians
    double angle = Math.acos(analyticGradient.dotProduct(empiricalGradient));
    logger.info("TestMinimizable angle = " + angle);
    if (Math.abs(angle) > tolerance)
      throw new IllegalStateException("Gradient/Cost mismatch: angle=" + angle);
    if (Double.isNaN(angle)) throw new IllegalStateException("Gradient/Cost error: angle is NaN!");
    return angle;
  }

  public static boolean testCostAndGradient(Minimizable.ByGradient minable) {
    Matrix parameters = minable.getNewMatrix();
    parameters.setAll(0.0);
    minable.setParameters(parameters);
    testCostAndGradientCurrentParameters(minable);
    parameters.setAll(0.0);
    Matrix delta = minable.getNewMatrix();
    minable.getCostGradient(delta);
    delta.timesEquals(-0.0001);
    parameters.plusEquals(delta);
    minable.setParameters(parameters);
    testCostAndGradientCurrentParameters(minable);
    return true;
  }

  public void testTestCostAndGradient() {
    testCostAndGradient(new Quadratic(10, 2, 3));
  }

  public static Test suite() {
    return new TestSuite(TestMinimizable.class);
  }

  protected void setUp() {}

  public static void main(String[] args) {
    junit.textui.TestRunner.run(suite());
  }
}
コード例 #2
0
/**
 * Created: May 12, 2004
 *
 * @author <A HREF="mailto:[email protected]>[email protected]</A>
 * @version $Id: FieldF1Evaluator.java,v 1.1 2008-11-17 10:36:40 rja Exp $
 */
public class FieldF1Evaluator extends TransducerEvaluator {

  private static final Logger logger = MalletLogger.getLogger(FieldF1Evaluator.class.getName());

  String[] segmentTags;

  public FieldF1Evaluator(String[] segmentTags) {
    this.segmentTags = segmentTags;
  }

  // Pair that delineates boundaries of one segment in the sequence.
  private static class Segment {
    int start;
    int end;
    int tag;

    Segment(int t, int s, int e) {
      tag = t;
      start = s;
      end = e;
    }

    public int hashCode() {
      return start ^ end;
    }

    public boolean equals(Object o) {
      Segment seg = (Segment) o;
      return (start == seg.start) && (end == seg.end) && (tag == seg.tag);
    }
  }

  public void test(
      Transducer transducer,
      InstanceList data,
      String description,
      PrintStream viterbiOutputStream) {
    int[] ntrue = new int[segmentTags.length];
    int[] npred = new int[segmentTags.length];
    int[] ncorr = new int[segmentTags.length];

    LabelAlphabet dict = (LabelAlphabet) transducer.getInputPipe().getTargetAlphabet();

    for (int i = 0; i < data.size(); i++) {
      Instance instance = data.getInstance(i);
      Sequence input = (Sequence) instance.getData();
      Sequence trueOutput = (Sequence) instance.getTarget();
      assert (input.size() == trueOutput.size());
      Sequence predOutput = transducer.viterbiPath(input).output();
      assert (predOutput.size() == trueOutput.size());

      List trueSegs = new ArrayList();
      List predSegs = new ArrayList();

      addSegs(trueSegs, trueOutput);
      addSegs(predSegs, predOutput);

      //      System.out.println("FieldF1Evaluator instance "+instance.getName ());
      //      printSegs(dict, trueSegs, "True");
      //      printSegs(dict, predSegs, "Pred");

      for (Iterator it = predSegs.iterator(); it.hasNext(); ) {
        Segment seg = (Segment) it.next();
        npred[seg.tag]++;
        if (trueSegs.contains(seg)) {
          ncorr[seg.tag]++;
        }
      }

      for (Iterator it = trueSegs.iterator(); it.hasNext(); ) {
        Segment seg = (Segment) it.next();
        ntrue[seg.tag]++;
      }
    }

    DecimalFormat f = new DecimalFormat("0.####");
    logger.info(description + " per-field F1");
    for (int tag = 0; tag < segmentTags.length; tag++) {
      double precision = ((double) ncorr[tag]) / npred[tag];
      double recall = ((double) ncorr[tag]) / ntrue[tag];
      double f1 = (2 * precision * recall) / (precision + recall);
      Label name = dict.lookupLabel(segmentTags[tag]);
      logger.info(
          " segments "
              + name
              + "  true = "
              + ntrue[tag]
              + "  pred = "
              + npred[tag]
              + "  correct = "
              + ncorr[tag]);
      logger.info(
          " precision="
              + f.format(precision)
              + " recall="
              + f.format(recall)
              + " f1="
              + f.format(f1));
    }
  }

  private void addSegs(List segs, Sequence output) {
    int segtype = -1;
    int startidx = -1;

    for (int j = 0; j < output.size(); j++) {
      //      System.out.println("addSegs j="+j);
      Object tag = output.get(j);
      segtype = ArrayUtils.indexOf(segmentTags, tag.toString());

      if (segtype > -1) {
        //        System.out.println("...found segment "+tag);
        // A new segment is starting
        startidx = j;
        while (j < output.size() - 1) {
          //          System.out.println("...inner addSegs j="+j);
          j++;
          Object nextTag = output.get(j);
          if (!nextTag.equals(tag)) {
            j--;
            segs.add(new Segment(segtype, startidx, j));
            segtype = startidx = -1;
            break;
          }
        }
      }
    }

    // Handle end-of-sequence
    if (startidx > -1) {
      segs.add(new Segment(segtype, startidx, output.size() - 1));
    }
  }

  private void printSegs(LabelAlphabet dict, List segs, String desc) {
    System.out.println(desc + " segments:");
    for (Iterator it = segs.iterator(); it.hasNext(); ) {
      Segment seg = (Segment) it.next();
      Label lbl = dict.lookupLabel(segmentTags[seg.tag]);
      System.out.println(lbl + " [ " + seg.start + " " + seg.end + " ]");
    }
  }
}