private void computeDir(double[] dir, double[] fg) throws SQNMinimizer.SurpriseConvergence {
    System.arraycopy(fg, 0, dir, 0, fg.length);

    int mmm = sList.size();
    double[] as = new double[mmm];
    double[] factors = new double[dir.length];

    for (int i = mmm - 1; i >= 0; i--) {
      as[i] = roList.get(i) * ArrayMath.innerProduct(sList.get(i), dir);
      plusAndConstMult(dir, yList.get(i), -as[i], dir);
    }

    // multiply by hessian approximation
    if (mmm != 0) {
      double[] y = yList.get(mmm - 1);
      double yDotY = ArrayMath.innerProduct(y, y);
      if (yDotY == 0) {
        throw new SQNMinimizer.SurpriseConvergence("Y is 0!!");
      }
      double gamma = ArrayMath.innerProduct(sList.get(mmm - 1), y) / yDotY;
      ArrayMath.multiplyInPlace(dir, gamma);
    } else if (mmm == 0) {
      // This is a safety feature preventing too large of an initial step (see Yu Schraudolph
      // Gunter)
      ArrayMath.multiplyInPlace(dir, epsilon);
    }

    for (int i = 0; i < mmm; i++) {
      double b = roList.get(i) * ArrayMath.innerProduct(yList.get(i), dir);
      plusAndConstMult(dir, sList.get(i), cPosDef * as[i] - b, dir);
      plusAndConstMult(ArrayMath.pairwiseMultiply(yList.get(i), sList.get(i)), factors, 1, factors);
    }

    ArrayMath.multiplyInPlace(dir, -1);
  }
Exemple #2
0
 @Override
 public double[] derivativeAt(double[] flatCoefs) {
   double[] g = new double[model.flatIDsize()];
   model.setCoefsFromFlat(flatCoefs);
   for (ModelSentence s : mSentences) {
     model.computeGradient(s, g);
   }
   ArrayMath.multiplyInPlace(g, -1);
   addL2regularizerGradient(g, flatCoefs);
   return g;
 }
  // fill value & derivative
  public void calculate(double[] theta) {
    dvModel.vectorToParams(theta);

    double localValue = 0.0;
    double[] localDerivative = new double[theta.length];

    TwoDimensionalMap<String, String, SimpleMatrix> binaryW_dfsG, binaryW_dfsB;
    binaryW_dfsG = TwoDimensionalMap.treeMap();
    binaryW_dfsB = TwoDimensionalMap.treeMap();
    TwoDimensionalMap<String, String, SimpleMatrix> binaryScoreDerivativesG,
        binaryScoreDerivativesB;
    binaryScoreDerivativesG = TwoDimensionalMap.treeMap();
    binaryScoreDerivativesB = TwoDimensionalMap.treeMap();
    Map<String, SimpleMatrix> unaryW_dfsG, unaryW_dfsB;
    unaryW_dfsG = new TreeMap<>();
    unaryW_dfsB = new TreeMap<>();
    Map<String, SimpleMatrix> unaryScoreDerivativesG, unaryScoreDerivativesB;
    unaryScoreDerivativesG = new TreeMap<>();
    unaryScoreDerivativesB = new TreeMap<>();

    Map<String, SimpleMatrix> wordVectorDerivativesG = new TreeMap<>();
    Map<String, SimpleMatrix> wordVectorDerivativesB = new TreeMap<>();

    for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : dvModel.binaryTransform) {
      int numRows = entry.getValue().numRows();
      int numCols = entry.getValue().numCols();
      binaryW_dfsG.put(
          entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
      binaryW_dfsB.put(
          entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
      binaryScoreDerivativesG.put(
          entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(1, numRows));
      binaryScoreDerivativesB.put(
          entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(1, numRows));
    }
    for (Map.Entry<String, SimpleMatrix> entry : dvModel.unaryTransform.entrySet()) {
      int numRows = entry.getValue().numRows();
      int numCols = entry.getValue().numCols();
      unaryW_dfsG.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
      unaryW_dfsB.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
      unaryScoreDerivativesG.put(entry.getKey(), new SimpleMatrix(1, numRows));
      unaryScoreDerivativesB.put(entry.getKey(), new SimpleMatrix(1, numRows));
    }
    if (op.trainOptions.trainWordVectors) {
      for (Map.Entry<String, SimpleMatrix> entry : dvModel.wordVectors.entrySet()) {
        int numRows = entry.getValue().numRows();
        int numCols = entry.getValue().numCols();
        wordVectorDerivativesG.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
        wordVectorDerivativesB.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
      }
    }

    // Some optimization methods prints out a line without an end, so our
    // debugging statements are misaligned
    Timing scoreTiming = new Timing();
    scoreTiming.doing("Scoring trees");
    int treeNum = 0;
    MulticoreWrapper<Tree, Pair<DeepTree, DeepTree>> wrapper =
        new MulticoreWrapper<>(op.trainOptions.trainingThreads, new ScoringProcessor());
    for (Tree tree : trainingBatch) {
      wrapper.put(tree);
    }
    wrapper.join();
    scoreTiming.done();
    while (wrapper.peek()) {
      Pair<DeepTree, DeepTree> result = wrapper.poll();
      DeepTree goldTree = result.first;
      DeepTree bestTree = result.second;

      StringBuilder treeDebugLine = new StringBuilder();
      Formatter formatter = new Formatter(treeDebugLine);
      boolean isDone =
          (Math.abs(bestTree.getScore() - goldTree.getScore()) <= 0.00001
              || goldTree.getScore() > bestTree.getScore());
      String done = isDone ? "done" : "";
      formatter.format(
          "Tree %6d Highest tree: %12.4f Correct tree: %12.4f %s",
          treeNum, bestTree.getScore(), goldTree.getScore(), done);
      System.err.println(treeDebugLine.toString());
      if (!isDone) {
        // if the gold tree is better than the best hypothesis tree by
        // a large enough margin, then the score difference will be 0
        // and we ignore the tree

        double valueDelta = bestTree.getScore() - goldTree.getScore();
        // double valueDelta = Math.max(0.0, - scoreGold + bestScore);
        localValue += valueDelta;

        // get the context words for this tree - should be the same
        // for either goldTree or bestTree
        List<String> words = getContextWords(goldTree.getTree());

        // The derivatives affected by this tree are only based on the
        // nodes present in this tree, eg not all matrix derivatives
        // will be affected by this tree
        backpropDerivative(
            goldTree.getTree(),
            words,
            goldTree.getVectors(),
            binaryW_dfsG,
            unaryW_dfsG,
            binaryScoreDerivativesG,
            unaryScoreDerivativesG,
            wordVectorDerivativesG);

        backpropDerivative(
            bestTree.getTree(),
            words,
            bestTree.getVectors(),
            binaryW_dfsB,
            unaryW_dfsB,
            binaryScoreDerivativesB,
            unaryScoreDerivativesB,
            wordVectorDerivativesB);
      }
      ++treeNum;
    }

    double[] localDerivativeGood;
    double[] localDerivativeB;
    if (op.trainOptions.trainWordVectors) {
      localDerivativeGood =
          NeuralUtils.paramsToVector(
              theta.length,
              binaryW_dfsG.valueIterator(),
              unaryW_dfsG.values().iterator(),
              binaryScoreDerivativesG.valueIterator(),
              unaryScoreDerivativesG.values().iterator(),
              wordVectorDerivativesG.values().iterator());

      localDerivativeB =
          NeuralUtils.paramsToVector(
              theta.length,
              binaryW_dfsB.valueIterator(),
              unaryW_dfsB.values().iterator(),
              binaryScoreDerivativesB.valueIterator(),
              unaryScoreDerivativesB.values().iterator(),
              wordVectorDerivativesB.values().iterator());
    } else {
      localDerivativeGood =
          NeuralUtils.paramsToVector(
              theta.length,
              binaryW_dfsG.valueIterator(),
              unaryW_dfsG.values().iterator(),
              binaryScoreDerivativesG.valueIterator(),
              unaryScoreDerivativesG.values().iterator());

      localDerivativeB =
          NeuralUtils.paramsToVector(
              theta.length,
              binaryW_dfsB.valueIterator(),
              unaryW_dfsB.values().iterator(),
              binaryScoreDerivativesB.valueIterator(),
              unaryScoreDerivativesB.values().iterator());
    }

    // correct - highest
    for (int i = 0; i < localDerivativeGood.length; i++) {
      localDerivative[i] = localDerivativeB[i] - localDerivativeGood[i];
    }

    // TODO: this is where we would combine multiple costs if we had parallelized the calculation
    value = localValue;
    derivative = localDerivative;

    // normalizing by training batch size
    value = (1.0 / trainingBatch.size()) * value;
    ArrayMath.multiplyInPlace(derivative, (1.0 / trainingBatch.size()));

    // add regularization to cost:
    double[] currentParams = dvModel.paramsToVector();
    double regCost = 0;
    for (double currentParam : currentParams) {
      regCost += currentParam * currentParam;
    }
    regCost = op.trainOptions.regCost * 0.5 * regCost;
    value += regCost;
    // add regularization to gradient
    ArrayMath.multiplyInPlace(currentParams, op.trainOptions.regCost);
    ArrayMath.pairwiseAddInPlace(derivative, currentParams);
  }