public class TestMinimizable extends TestCase { private static Logger logger = MalletLogger.getLogger(TestMinimizable.class.getName()); public TestMinimizable(String name) { super(name); } public static boolean testGetSetParameters(Minimizable minable) { System.out.println("TestMinimizable testGetSetParameters"); // Set all the parameters to unique values using setParameters() Matrix parameters = minable.getNewMatrix(); minable.getParameters(parameters); for (int i = 0; i < parameters.singleSize(); i++) parameters.setSingleValue(i, (double) i); minable.setParameters(parameters); // Test to make sure those parameters are there parameters.setAll(0.0); minable.getParameters(parameters); for (int i = 0; i < parameters.singleSize(); i++) assertTrue(parameters.singleValue(i) == (double) i); // Set all the parameters to unique values using setParameter() parameters.setAll(0.0); minable.setParameters(parameters); int[] indices = new int[parameters.getNumDimensions()]; for (int i = 0; i < parameters.singleSize(); i++) { parameters.singleToIndices(i, indices); minable.setParameter(indices, (double) i); } // Test to make sure those parameters are there parameters.setAll(0.0); minable.getParameters(parameters); for (int i = 0; i < parameters.singleSize(); i++) { // System.out.println ("Got "+parameters.getSingle(i)+", expecting "+((double)i)); assertTrue(parameters.singleValue(i) == (double) i); } // Test to make sure they are also there when we look individually for (int i = 0; i < parameters.singleSize(); i++) { parameters.singleToIndices(i, indices); assertTrue(minable.getParameter(indices) == (double) i); } return true; } public static double testCostAndGradientCurrentParameters(Minimizable.ByGradient minable) { Matrix parameters = minable.getParameters(minable.getNewMatrix()); double cost = minable.getCost(); // the gradient from the minimizable function Matrix analyticGradient = minable.getCostGradient(minable.getNewMatrix()); // the gradient calculate from the slope of the cost Matrix empiricalGradient = (Matrix) analyticGradient.cloneMatrix(); // This setting of epsilon should make the individual elements of // the analytical gradient and the empirical gradient equal. This // simplifies the comparison of the individual dimensions of the // gradient and thus makes debugging easier. double epsilon = 0.1 / analyticGradient.twoNorm(); double tolerance = epsilon * 5; System.out.println("epsilon = " + epsilon + " tolerance=" + tolerance); // Check each direction, perturb it, measure new cost, // and make sure it agrees with the gradient from minable.getCostGradient() for (int i = 0; i < parameters.singleSize(); i++) { double param = parameters.singleValue(i); parameters.setSingleValue(i, param + epsilon); // logger.fine ("Parameters:"); parameters.print(); minable.setParameters(parameters); double epsCost = minable.getCost(); double slope = (epsCost - cost) / epsilon; System.out.println( "cost=" + cost + " epsCost=" + epsCost + " slope[" + i + "] = " + slope + " gradient[]=" + analyticGradient.singleValue(i)); assert (!Double.isNaN(slope)); logger.fine( "TestMinimizable checking singleIndex " + i + ": gradient slope = " + analyticGradient.singleValue(i) + ", cost+epsilon slope = " + slope + ": slope difference = " + Math.abs(slope - analyticGradient.singleValue(i))); // No negative below because the gradient points in the direction // of maximizing the function. empiricalGradient.setSingleValue(i, slope); parameters.setSingleValue(i, param); } // Normalize the matrices to have the same L2 length System.out.println("empiricalGradient.twoNorm = " + empiricalGradient.twoNorm()); analyticGradient.timesEquals(1.0 / analyticGradient.twoNorm()); empiricalGradient.timesEquals(1.0 / empiricalGradient.twoNorm()); // logger.info ("AnalyticGradient:"); analyticGradient.print(); // logger.info ("EmpiricalGradient:"); empiricalGradient.print(); // Return the angle between the two vectors, in radians double angle = Math.acos(analyticGradient.dotProduct(empiricalGradient)); logger.info("TestMinimizable angle = " + angle); if (Math.abs(angle) > tolerance) throw new IllegalStateException("Gradient/Cost mismatch: angle=" + angle); if (Double.isNaN(angle)) throw new IllegalStateException("Gradient/Cost error: angle is NaN!"); return angle; } public static boolean testCostAndGradient(Minimizable.ByGradient minable) { Matrix parameters = minable.getNewMatrix(); parameters.setAll(0.0); minable.setParameters(parameters); testCostAndGradientCurrentParameters(minable); parameters.setAll(0.0); Matrix delta = minable.getNewMatrix(); minable.getCostGradient(delta); delta.timesEquals(-0.0001); parameters.plusEquals(delta); minable.setParameters(parameters); testCostAndGradientCurrentParameters(minable); return true; } public void testTestCostAndGradient() { testCostAndGradient(new Quadratic(10, 2, 3)); } public static Test suite() { return new TestSuite(TestMinimizable.class); } protected void setUp() {} public static void main(String[] args) { junit.textui.TestRunner.run(suite()); } }
/** * Created: May 12, 2004 * * @author <A HREF="mailto:[email protected]>[email protected]</A> * @version $Id: FieldF1Evaluator.java,v 1.1 2008-11-17 10:36:40 rja Exp $ */ public class FieldF1Evaluator extends TransducerEvaluator { private static final Logger logger = MalletLogger.getLogger(FieldF1Evaluator.class.getName()); String[] segmentTags; public FieldF1Evaluator(String[] segmentTags) { this.segmentTags = segmentTags; } // Pair that delineates boundaries of one segment in the sequence. private static class Segment { int start; int end; int tag; Segment(int t, int s, int e) { tag = t; start = s; end = e; } public int hashCode() { return start ^ end; } public boolean equals(Object o) { Segment seg = (Segment) o; return (start == seg.start) && (end == seg.end) && (tag == seg.tag); } } public void test( Transducer transducer, InstanceList data, String description, PrintStream viterbiOutputStream) { int[] ntrue = new int[segmentTags.length]; int[] npred = new int[segmentTags.length]; int[] ncorr = new int[segmentTags.length]; LabelAlphabet dict = (LabelAlphabet) transducer.getInputPipe().getTargetAlphabet(); for (int i = 0; i < data.size(); i++) { Instance instance = data.getInstance(i); Sequence input = (Sequence) instance.getData(); Sequence trueOutput = (Sequence) instance.getTarget(); assert (input.size() == trueOutput.size()); Sequence predOutput = transducer.viterbiPath(input).output(); assert (predOutput.size() == trueOutput.size()); List trueSegs = new ArrayList(); List predSegs = new ArrayList(); addSegs(trueSegs, trueOutput); addSegs(predSegs, predOutput); // System.out.println("FieldF1Evaluator instance "+instance.getName ()); // printSegs(dict, trueSegs, "True"); // printSegs(dict, predSegs, "Pred"); for (Iterator it = predSegs.iterator(); it.hasNext(); ) { Segment seg = (Segment) it.next(); npred[seg.tag]++; if (trueSegs.contains(seg)) { ncorr[seg.tag]++; } } for (Iterator it = trueSegs.iterator(); it.hasNext(); ) { Segment seg = (Segment) it.next(); ntrue[seg.tag]++; } } DecimalFormat f = new DecimalFormat("0.####"); logger.info(description + " per-field F1"); for (int tag = 0; tag < segmentTags.length; tag++) { double precision = ((double) ncorr[tag]) / npred[tag]; double recall = ((double) ncorr[tag]) / ntrue[tag]; double f1 = (2 * precision * recall) / (precision + recall); Label name = dict.lookupLabel(segmentTags[tag]); logger.info( " segments " + name + " true = " + ntrue[tag] + " pred = " + npred[tag] + " correct = " + ncorr[tag]); logger.info( " precision=" + f.format(precision) + " recall=" + f.format(recall) + " f1=" + f.format(f1)); } } private void addSegs(List segs, Sequence output) { int segtype = -1; int startidx = -1; for (int j = 0; j < output.size(); j++) { // System.out.println("addSegs j="+j); Object tag = output.get(j); segtype = ArrayUtils.indexOf(segmentTags, tag.toString()); if (segtype > -1) { // System.out.println("...found segment "+tag); // A new segment is starting startidx = j; while (j < output.size() - 1) { // System.out.println("...inner addSegs j="+j); j++; Object nextTag = output.get(j); if (!nextTag.equals(tag)) { j--; segs.add(new Segment(segtype, startidx, j)); segtype = startidx = -1; break; } } } } // Handle end-of-sequence if (startidx > -1) { segs.add(new Segment(segtype, startidx, output.size() - 1)); } } private void printSegs(LabelAlphabet dict, List segs, String desc) { System.out.println(desc + " segments:"); for (Iterator it = segs.iterator(); it.hasNext(); ) { Segment seg = (Segment) it.next(); Label lbl = dict.lookupLabel(segmentTags[seg.tag]); System.out.println(lbl + " [ " + seg.start + " " + seg.end + " ]"); } } }