コード例 #1
0
  /**
   * Update this model using the differences between the correct output and the predicted output,
   * both given as arguments.
   *
   * @param input
   * @param outputCorrect
   * @param outputPredicted
   * @param learningRate
   * @return
   */
  private double update(
      DPInput input, DPOutput outputCorrect, DPOutput outputPredicted, double learningRate) {
    /*
     * The root token (zero) must always be ignored during the inference and
     * so its always correctly classified (its head must always be pointing
     * to itself).
     */
    assert outputCorrect.getHead(0) == outputPredicted.getHead(0);

    // Per-token loss value for this example.
    double loss = 0d;
    for (int idxTkn = 1; idxTkn < input.getNumberOfTokens(); ++idxTkn) {
      int idxCorrectHead = outputCorrect.getHead(idxTkn);
      int idxPredictedHead = outputPredicted.getHead(idxTkn);
      if (idxCorrectHead == idxPredictedHead)
        // Correctly predicted head.
        continue;

      if (idxCorrectHead == -1)
        /*
         * Skip tokens with missing CORRECT edge (this is due to prune
         * preprocessing).
         */
        continue;

      // Misclassified head. Increment missed edges weights.
      int[] correctFeatures = input.getFeatures(idxCorrectHead, idxTkn);
      for (int idxFtr = 0; idxFtr < correctFeatures.length; ++idxFtr) {
        int ftr = correctFeatures[idxFtr];
        AveragedParameter param = getFeatureWeightOrCreate(ftr);
        param.update(learningRate);
        updatedWeights.add(param);
      }

      if (idxPredictedHead == -1) {
        LOG.warn("Predicted head is -1");
        continue;
      }

      // Decrement mispredicted edges weights.
      int[] predictedFeatures = input.getFeatures(idxPredictedHead, idxTkn);
      for (int idxFtr = 0; idxFtr < predictedFeatures.length; ++idxFtr) {
        int ftr = predictedFeatures[idxFtr];
        AveragedParameter param = getFeatureWeightOrCreate(ftr);
        param.update(-learningRate);
        updatedWeights.add(param);
      }

      // Increment (per-token) loss value.
      loss += 1d;
    }

    return loss;
  }
  /**
   * Recover the parameter associated with the given feature.
   *
   * <p>If the parameter has not been initialized yet, then create it. If the inverted index is
   * activated and the parameter has not been initialized yet, then update the active features lists
   * for each edge where the feature occurs.
   *
   * @param ftr
   * @param value
   * @return
   */
  protected void updateFeatureParam(int code, double value) {
    AveragedParameter param = parameters.get(code);
    if (param == null) {
      // Create a new parameter.
      param = new AveragedParameter();
      parameters.put(code, param);
    }

    // Update parameter value.
    param.update(value);

    // Keep track of updated parameter within this example.
    updatedParameters.add(param);
  }