Exemplo n.º 1
0
  public static void demonstrateSerialization() throws IOException, ClassNotFoundException {
    System.out.println("Demonstrating working with a serialized classifier");
    ColumnDataClassifier cdc = new ColumnDataClassifier("examples/cheese2007.prop");
    Classifier<String, String> cl =
        cdc.makeClassifier(cdc.readTrainingExamples("examples/cheeseDisease.train"));

    // Exhibit serialization and deserialization working. Serialized to bytes in memory for
    // simplicity
    System.out.println();
    System.out.println();
    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    ObjectOutputStream oos = new ObjectOutputStream(baos);
    oos.writeObject(cl);
    oos.close();
    byte[] object = baos.toByteArray();
    ByteArrayInputStream bais = new ByteArrayInputStream(object);
    ObjectInputStream ois = new ObjectInputStream(bais);
    LinearClassifier<String, String> lc = ErasureUtils.uncheckedCast(ois.readObject());
    ois.close();
    ColumnDataClassifier cdc2 = new ColumnDataClassifier("examples/cheese2007.prop");

    // We compare the output of the deserialized classifier lc versus the original one cl
    // For both we use a ColumnDataClassifier to convert text lines to examples
    for (String line : ObjectBank.getLineIterator("examples/cheeseDisease.test", "utf-8")) {
      Datum<String, String> d = cdc.makeDatumFromLine(line);
      Datum<String, String> d2 = cdc2.makeDatumFromLine(line);
      System.out.println(line + "  =origi=>  " + cl.classOf(d));
      System.out.println(line + "  =deser=>  " + lc.classOf(d2));
    }
  }
Exemplo n.º 2
0
 /**
  * A helper function for dumping the accuracy of the trained classifier.
  *
  * @param classifier The classifier to evaluate.
  * @param dataset The dataset to evaluate the classifier on.
  */
 public static void dumpAccuracy(
     Classifier<ClauseSplitter.ClauseClassifierLabel, String> classifier,
     GeneralDataset<ClauseSplitter.ClauseClassifierLabel, String> dataset) {
   DecimalFormat df = new DecimalFormat("0.00%");
   log("size:         " + dataset.size());
   log(
       "split count:  "
           + StreamSupport.stream(dataset.spliterator(), false)
               .filter(x -> x.label() == ClauseSplitter.ClauseClassifierLabel.CLAUSE_SPLIT)
               .collect(Collectors.toList())
               .size());
   log(
       "interm count: "
           + StreamSupport.stream(dataset.spliterator(), false)
               .filter(x -> x.label() == ClauseSplitter.ClauseClassifierLabel.CLAUSE_INTERM)
               .collect(Collectors.toList())
               .size());
   Pair<Double, Double> pr =
       classifier.evaluatePrecisionAndRecall(
           dataset, ClauseSplitter.ClauseClassifierLabel.CLAUSE_SPLIT);
   log("p  (split):   " + df.format(pr.first));
   log("r  (split):   " + df.format(pr.second));
   log("f1 (split):   " + df.format(2 * pr.first * pr.second / (pr.first + pr.second)));
   pr =
       classifier.evaluatePrecisionAndRecall(
           dataset, ClauseSplitter.ClauseClassifierLabel.CLAUSE_INTERM);
   log("p  (interm):  " + df.format(pr.first));
   log("r  (interm):  " + df.format(pr.second));
   log("f1 (interm):  " + df.format(2 * pr.first * pr.second / (pr.first + pr.second)));
 }
Exemplo n.º 3
0
  public static void main(String[] args) throws Exception {
    ColumnDataClassifier cdc = new ColumnDataClassifier("examples/cheese2007.prop");
    Classifier<String, String> cl =
        cdc.makeClassifier(cdc.readTrainingExamples("examples/cheeseDisease.train"));
    for (String line : ObjectBank.getLineIterator("examples/cheeseDisease.test", "utf-8")) {
      // instead of the method in the line below, if you have the individual elements
      // already you can use cdc.makeDatumFromStrings(String[])
      Datum<String, String> d = cdc.makeDatumFromLine(line);
      System.out.println(line + "  ==>  " + cl.classOf(d));
    }

    demonstrateSerialization();
  }
  public <F> double score(Classifier<L, F> classifier, GeneralDataset<L, F> data) {

    List<L> guesses = new ArrayList<L>();
    List<L> labels = new ArrayList<L>();

    for (int i = 0; i < data.size(); i++) {
      Datum<L, F> d = data.getRVFDatum(i);
      L guess = classifier.classOf(d);
      guesses.add(guess);
    }

    int[] labelsArr = data.getLabelsArray();
    labelIndex = data.labelIndex;
    for (int i = 0; i < data.size(); i++) {
      labels.add(labelIndex.get(labelsArr[i]));
    }

    labelIndex = new HashIndex<L>();
    labelIndex.addAll(data.labelIndex().objectsList());
    labelIndex.addAll(classifier.labels());

    int numClasses = labelIndex.size();
    tpCount = new int[numClasses];
    fpCount = new int[numClasses];
    fnCount = new int[numClasses];

    negIndex = labelIndex.indexOf(negLabel);

    for (int i = 0; i < guesses.size(); ++i) {
      L guess = guesses.get(i);
      int guessIndex = labelIndex.indexOf(guess);
      L label = labels.get(i);
      int trueIndex = labelIndex.indexOf(label);

      if (guessIndex == trueIndex) {
        if (guessIndex != negIndex) {
          tpCount[guessIndex]++;
        }
      } else {
        if (guessIndex != negIndex) {
          fpCount[guessIndex]++;
        }
        if (trueIndex != negIndex) {
          fnCount[trueIndex]++;
        }
      }
    }

    return getFMeasure();
  }
  public <F> double score(Classifier<L, F> classifier, GeneralDataset<L, F> data) {
    labelIndex = new HashIndex<L>();
    labelIndex.addAll(classifier.labels());
    labelIndex.addAll(data.labelIndex.objectsList());
    clearCounts();
    int[] labelsArr = data.getLabelsArray();
    for (int i = 0; i < data.size(); i++) {
      Datum<L, F> d = data.getRVFDatum(i);
      L guess = classifier.classOf(d);
      addGuess(guess, labelIndex.get(labelsArr[i]));
    }
    finalizeCounts();

    return getFMeasure();
  }
  /**
   * The core implementation of the search.
   *
   * @param root The root word to search from. Traditionally, this is the root of the sentence.
   * @param candidateFragments The callback for the resulting sentence fragments. This is a
   *     predicate of a triple of values. The return value of the predicate determines whether we
   *     should continue searching. The triple is a triple of
   *     <ol>
   *       <li>The log probability of the sentence fragment, according to the featurizer and the
   *           weights
   *       <li>The features along the path to this fragment. The last element of this is the
   *           features from the most recent step.
   *       <li>The sentence fragment. Because it is relatively expensive to compute the resulting
   *           tree, this is returned as a lazy {@link Supplier}.
   *     </ol>
   *
   * @param classifier The classifier for whether an arc should be on the path to a clause split, a
   *     clause split itself, or neither.
   * @param featurizer The featurizer to use. Make sure this matches the weights!
   * @param actionSpace The action space we are allowed to take. Each action defines a means of
   *     splitting a clause on a dependency boundary.
   */
  protected void search(
      // The root to search from
      IndexedWord root,
      // The output specs
      final Predicate<Triple<Double, List<Counter<String>>, Supplier<SentenceFragment>>>
          candidateFragments,
      // The learning specs
      final Classifier<ClauseSplitter.ClauseClassifierLabel, String> classifier,
      Map<String, ? extends List<String>> hardCodedSplits,
      final Function<Triple<State, Action, State>, Counter<String>> featurizer,
      final Collection<Action> actionSpace,
      final int maxTicks) {
    // (the fringe)
    PriorityQueue<Pair<State, List<Counter<String>>>> fringe = new FixedPrioritiesPriorityQueue<>();
    // (avoid duplicate work)
    Set<IndexedWord> seenWords = new HashSet<>();

    State firstState =
        new State(null, null, -9000, null, x -> {}, true); // First state is implicitly "done"
    fringe.add(Pair.makePair(firstState, new ArrayList<>(0)), -0.0);
    int ticks = 0;

    while (!fringe.isEmpty()) {
      if (++ticks > maxTicks) {
        //        System.err.println("WARNING! Timed out on search with " + ticks + " ticks");
        return;
      }
      // Useful variables
      double logProbSoFar = fringe.getPriority();
      assert logProbSoFar <= 0.0;
      Pair<State, List<Counter<String>>> lastStatePair = fringe.removeFirst();
      State lastState = lastStatePair.first;
      List<Counter<String>> featuresSoFar = lastStatePair.second;
      IndexedWord rootWord = lastState.edge == null ? root : lastState.edge.getDependent();

      // Register thunk
      if (lastState.isDone) {
        if (!candidateFragments.test(
            Triple.makeTriple(
                logProbSoFar,
                featuresSoFar,
                () -> {
                  SemanticGraph copy = new SemanticGraph(tree);
                  lastState
                      .thunk
                      .andThen(
                          x -> {
                            // Add the extra edges back in, if they don't break the tree-ness of the
                            // extraction
                            for (IndexedWord newTreeRoot : x.getRoots()) {
                              if (newTreeRoot != null) { // what a strange thing to have happen...
                                for (SemanticGraphEdge extraEdge :
                                    extraEdgesByGovernor.get(newTreeRoot)) {
                                  assert Util.isTree(x);
                                  //noinspection unchecked
                                  addSubtree(
                                      x,
                                      newTreeRoot,
                                      extraEdge.getRelation().toString(),
                                      tree,
                                      extraEdge.getDependent(),
                                      tree.getIncomingEdgesSorted(newTreeRoot));
                                  assert Util.isTree(x);
                                }
                              }
                            }
                          })
                      .accept(copy);
                  return new SentenceFragment(copy, assumedTruth, false);
                }))) {
          break;
        }
      }

      // Find relevant auxilliary terms
      SemanticGraphEdge subjOrNull = null;
      SemanticGraphEdge objOrNull = null;
      for (SemanticGraphEdge auxEdge : tree.outgoingEdgeIterable(rootWord)) {
        String relString = auxEdge.getRelation().toString();
        if (relString.contains("obj")) {
          objOrNull = auxEdge;
        } else if (relString.contains("subj")) {
          subjOrNull = auxEdge;
        }
      }

      // Iterate over children
      // For each outgoing edge...
      for (SemanticGraphEdge outgoingEdge : tree.outgoingEdgeIterable(rootWord)) {
        // Prohibit indirect speech verbs from splitting off clauses
        // (e.g., 'said', 'think')
        // This fires if the governor is an indirect speech verb, and the outgoing edge is a ccomp
        if (outgoingEdge.getRelation().toString().equals("ccomp")
            && ((outgoingEdge.getGovernor().lemma() != null
                    && INDIRECT_SPEECH_LEMMAS.contains(outgoingEdge.getGovernor().lemma()))
                || INDIRECT_SPEECH_LEMMAS.contains(outgoingEdge.getGovernor().word()))) {
          continue;
        }
        // Get some variables
        String outgoingEdgeRelation = outgoingEdge.getRelation().toString();
        List<String> forcedArcOrder = hardCodedSplits.get(outgoingEdgeRelation);
        if (forcedArcOrder == null && outgoingEdgeRelation.contains(":")) {
          forcedArcOrder =
              hardCodedSplits.get(
                  outgoingEdgeRelation.substring(0, outgoingEdgeRelation.indexOf(":")) + ":*");
        }
        boolean doneForcedArc = false;
        // For each action...
        for (Action action :
            (forcedArcOrder == null ? actionSpace : orderActions(actionSpace, forcedArcOrder))) {
          // Check the prerequisite
          if (!action.prerequisitesMet(tree, outgoingEdge)) {
            continue;
          }
          if (forcedArcOrder != null && doneForcedArc) {
            break;
          }
          // 1. Compute the child state
          Optional<State> candidate =
              action.applyTo(tree, lastState, outgoingEdge, subjOrNull, objOrNull);
          if (candidate.isPresent()) {
            double logProbability;
            ClauseClassifierLabel bestLabel;
            Counter<String> features =
                featurizer.apply(Triple.makeTriple(lastState, action, candidate.get()));
            if (forcedArcOrder != null && !doneForcedArc) {
              logProbability = 0.0;
              bestLabel = ClauseClassifierLabel.CLAUSE_SPLIT;
              doneForcedArc = true;
            } else if (features.containsKey("__undocumented_junit_no_classifier")) {
              logProbability = Double.NEGATIVE_INFINITY;
              bestLabel = ClauseClassifierLabel.CLAUSE_INTERM;
            } else {
              Counter<ClauseClassifierLabel> scores = classifier.scoresOf(new RVFDatum<>(features));
              if (scores.size() > 0) {
                Counters.logNormalizeInPlace(scores);
              }
              String rel = outgoingEdge.getRelation().toString();
              if ("nsubj".equals(rel) || "dobj".equals(rel)) {
                scores.remove(
                    ClauseClassifierLabel.NOT_A_CLAUSE); // Always at least yield on nsubj and dobj
              }
              logProbability = Counters.max(scores, Double.NEGATIVE_INFINITY);
              bestLabel = Counters.argmax(scores, (x, y) -> 0, ClauseClassifierLabel.CLAUSE_SPLIT);
            }

            if (bestLabel != ClauseClassifierLabel.NOT_A_CLAUSE) {
              Pair<State, List<Counter<String>>> childState =
                  Pair.makePair(
                      candidate.get().withIsDone(bestLabel),
                      new ArrayList<Counter<String>>(featuresSoFar) {
                        {
                          add(features);
                        }
                      });
              // 2. Register the child state
              if (!seenWords.contains(childState.first.edge.getDependent())) {
                //            System.err.println("  pushing " + action.signature() + " with " +
                // argmax.first.edge);
                fringe.add(childState, logProbability);
              }
            }
          }
        }
      }

      seenWords.add(rootWord);
    }
    //    System.err.println("Search finished in " + ticks + " ticks and " + classifierEvals + "
    // classifier evaluations.");
  }