/**
   * Update this model using the differences between the correct output and the predicted output,
   * both given as arguments.
   *
   * @param input
   * @param outputCorrect
   * @param outputPredicted
   * @param learningRate
   * @return
   */
  private double update(
      DPInput input, DPOutput outputCorrect, DPOutput outputPredicted, double learningRate) {
    /*
     * The root token (zero) must always be ignored during the inference and
     * so its always correctly classified (its head must always be pointing
     * to itself).
     */
    assert outputCorrect.getHead(0) == outputPredicted.getHead(0);

    // Per-token loss value for this example.
    double loss = 0d;
    for (int idxTkn = 1; idxTkn < input.getNumberOfTokens(); ++idxTkn) {
      int idxCorrectHead = outputCorrect.getHead(idxTkn);
      int idxPredictedHead = outputPredicted.getHead(idxTkn);
      if (idxCorrectHead == idxPredictedHead)
        // Correctly predicted head.
        continue;

      if (idxCorrectHead == -1)
        /*
         * Skip tokens with missing CORRECT edge (this is due to prune
         * preprocessing).
         */
        continue;

      // Misclassified head. Increment missed edges weights.
      int[] correctFeatures = input.getFeatures(idxCorrectHead, idxTkn);
      for (int idxFtr = 0; idxFtr < correctFeatures.length; ++idxFtr) {
        int ftr = correctFeatures[idxFtr];
        AveragedParameter param = getFeatureWeightOrCreate(ftr);
        param.update(learningRate);
        updatedWeights.add(param);
      }

      if (idxPredictedHead == -1) {
        LOG.warn("Predicted head is -1");
        continue;
      }

      // Decrement mispredicted edges weights.
      int[] predictedFeatures = input.getFeatures(idxPredictedHead, idxTkn);
      for (int idxFtr = 0; idxFtr < predictedFeatures.length; ++idxFtr) {
        int ftr = predictedFeatures[idxFtr];
        AveragedParameter param = getFeatureWeightOrCreate(ftr);
        param.update(-learningRate);
        updatedWeights.add(param);
      }

      // Increment (per-token) loss value.
      loss += 1d;
    }

    return loss;
  }
 @Override
 public double getEdgeScore(DPInput input, int idxHead, int idxDependent) {
   double score = 0d;
   // Accumulate the parameter in the edge score.
   int[] ftrs = input.getFeatures(idxHead, idxDependent);
   if (ftrs == null)
     // Edge does not exist.
     return Double.NaN;
   for (int ftr : ftrs) {
     AveragedParameter param = getFeatureWeight(ftr);
     if (param != null) score += param.get();
   }
   return score;
 }