예제 #1
0
 public FloatMatrix getParameters() {
   return MatrixUtil.toFlattenedFloat(
       getNumParameters(),
       binaryTransform.values().iterator(),
       binaryClassification.values().iterator(),
       binaryFloatTensors.values().iterator(),
       unaryClassification.values().iterator(),
       featureVectors.values().iterator());
 }
예제 #2
0
  public FloatMatrix getValueGradient(int iterations) {

    // We use TreeMap for each of these so that they stay in a
    // canonical sorted order
    // TODO: factor out the initialization routines
    // binaryTD stands for Transform Derivatives
    final MultiDimensionalMap<String, String, FloatMatrix> binaryTD =
        MultiDimensionalMap.newTreeBackedMap();
    // the derivatives of the FloatTensors for the binary nodes
    final MultiDimensionalMap<String, String, FloatTensor> binaryFloatTensorTD =
        MultiDimensionalMap.newTreeBackedMap();
    // binaryCD stands for Classification Derivatives
    final MultiDimensionalMap<String, String, FloatMatrix> binaryCD =
        MultiDimensionalMap.newTreeBackedMap();

    // unaryCD stands for Classification Derivatives
    final Map<String, FloatMatrix> unaryCD = new TreeMap<>();

    // word vector derivatives
    final Map<String, FloatMatrix> wordVectorD = new TreeMap<>();

    for (MultiDimensionalMap.Entry<String, String, FloatMatrix> entry :
        binaryTransform.entrySet()) {
      int numRows = entry.getValue().rows;
      int numCols = entry.getValue().columns;

      binaryTD.put(entry.getFirstKey(), entry.getSecondKey(), new FloatMatrix(numRows, numCols));
    }

    if (!combineClassification) {
      for (MultiDimensionalMap.Entry<String, String, FloatMatrix> entry :
          binaryClassification.entrySet()) {
        int numRows = entry.getValue().rows;
        int numCols = entry.getValue().columns;

        binaryCD.put(entry.getFirstKey(), entry.getSecondKey(), new FloatMatrix(numRows, numCols));
      }
    }

    if (useFloatTensors) {
      for (MultiDimensionalMap.Entry<String, String, FloatTensor> entry :
          binaryFloatTensors.entrySet()) {
        int numRows = entry.getValue().rows();
        int numCols = entry.getValue().columns;
        int numSlices = entry.getValue().slices();

        binaryFloatTensorTD.put(
            entry.getFirstKey(),
            entry.getSecondKey(),
            new FloatTensor(numRows, numCols, numSlices));
      }
    }

    for (Map.Entry<String, FloatMatrix> entry : unaryClassification.entrySet()) {
      int numRows = entry.getValue().rows;
      int numCols = entry.getValue().columns;
      unaryCD.put(entry.getKey(), new FloatMatrix(numRows, numCols));
    }
    for (Map.Entry<String, FloatMatrix> entry : featureVectors.entrySet()) {
      int numRows = entry.getValue().rows;
      int numCols = entry.getValue().columns;
      wordVectorD.put(entry.getKey(), new FloatMatrix(numRows, numCols));
    }

    final List<Tree> forwardPropTrees = new CopyOnWriteArrayList<>();
    Parallelization.iterateInParallel(
        trainingTrees,
        new Parallelization.RunnableWithParams<Tree>() {

          public void run(Tree currentItem, Object[] args) {
            Tree trainingTree = new Tree(currentItem);
            trainingTree.connect(new ArrayList<>(currentItem.children()));
            // this will attach the error vectors and the node vectors
            // to each node in the tree
            forwardPropagateTree(trainingTree);
            forwardPropTrees.add(trainingTree);
          }
        },
        rnTnActorSystem);

    // TODO: we may find a big speedup by separating the derivatives and then summing
    final AtomicDouble error = new AtomicDouble(0);
    Parallelization.iterateInParallel(
        forwardPropTrees,
        new Parallelization.RunnableWithParams<Tree>() {

          public void run(Tree currentItem, Object[] args) {
            backpropDerivativesAndError(
                currentItem, binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD);
            error.addAndGet(currentItem.errorSum());
          }
        },
        new Parallelization.RunnableWithParams<Tree>() {

          public void run(Tree currentItem, Object[] args) {}
        },
        rnTnActorSystem,
        new Object[] {binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD});

    // scale the error by the number of sentences so that the
    // regularization isn't drowned out for large training batchs
    float scale = (1.0f / trainingTrees.size());
    value = error.floatValue() * scale;

    value += scaleAndRegularize(binaryTD, binaryTransform, scale, regTransformMatrix);
    value += scaleAndRegularize(binaryCD, binaryClassification, scale, regClassification);
    value +=
        scaleAndRegularizeFloatTensor(
            binaryFloatTensorTD, binaryFloatTensors, scale, regTransformFloatTensor);
    value += scaleAndRegularize(unaryCD, unaryClassification, scale, regClassification);
    value += scaleAndRegularize(wordVectorD, featureVectors, scale, regWordVector);

    FloatMatrix derivative =
        MatrixUtil.toFlattenedFloat(
            getNumParameters(),
            binaryTD.values().iterator(),
            binaryCD.values().iterator(),
            binaryFloatTensorTD.values().iterator(),
            unaryCD.values().iterator(),
            wordVectorD.values().iterator());

    if (paramAdaGrad == null) paramAdaGrad = new AdaGradFloat(1, derivative.columns);

    derivative.muli(paramAdaGrad.getLearningRates(derivative));

    return derivative;
  }