Ejemplo n.º 1
0
  /**
   * Given single matrices and sets of options, create the corresponding SentimentModel. Useful for
   * creating a Java version of a model trained in some other manner, such as using the original
   * Matlab code.
   */
  static SentimentModel modelFromMatrices(
      SimpleMatrix W,
      SimpleMatrix Wcat,
      SimpleTensor Wt,
      Map<String, SimpleMatrix> wordVectors,
      RNNOptions op) {
    if (!op.combineClassification || !op.simplifiedModel) {
      throw new IllegalArgumentException(
          "Can only create a model using this method if combineClassification and simplifiedModel are turned on");
    }
    TwoDimensionalMap<String, String, SimpleMatrix> binaryTransform = TwoDimensionalMap.treeMap();
    binaryTransform.put("", "", W);

    TwoDimensionalMap<String, String, SimpleTensor> binaryTensors = TwoDimensionalMap.treeMap();
    binaryTensors.put("", "", Wt);

    TwoDimensionalMap<String, String, SimpleMatrix> binaryClassification =
        TwoDimensionalMap.treeMap();

    Map<String, SimpleMatrix> unaryClassification = Generics.newTreeMap();
    unaryClassification.put("", Wcat);

    return new SentimentModel(
        binaryTransform, binaryTensors, binaryClassification, unaryClassification, wordVectors, op);
  }
Ejemplo n.º 2
0
  // fill value & derivative
  public void calculate(double[] theta) {
    dvModel.vectorToParams(theta);

    double localValue = 0.0;
    double[] localDerivative = new double[theta.length];

    TwoDimensionalMap<String, String, SimpleMatrix> binaryW_dfsG, binaryW_dfsB;
    binaryW_dfsG = TwoDimensionalMap.treeMap();
    binaryW_dfsB = TwoDimensionalMap.treeMap();
    TwoDimensionalMap<String, String, SimpleMatrix> binaryScoreDerivativesG,
        binaryScoreDerivativesB;
    binaryScoreDerivativesG = TwoDimensionalMap.treeMap();
    binaryScoreDerivativesB = TwoDimensionalMap.treeMap();
    Map<String, SimpleMatrix> unaryW_dfsG, unaryW_dfsB;
    unaryW_dfsG = new TreeMap<>();
    unaryW_dfsB = new TreeMap<>();
    Map<String, SimpleMatrix> unaryScoreDerivativesG, unaryScoreDerivativesB;
    unaryScoreDerivativesG = new TreeMap<>();
    unaryScoreDerivativesB = new TreeMap<>();

    Map<String, SimpleMatrix> wordVectorDerivativesG = new TreeMap<>();
    Map<String, SimpleMatrix> wordVectorDerivativesB = new TreeMap<>();

    for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : dvModel.binaryTransform) {
      int numRows = entry.getValue().numRows();
      int numCols = entry.getValue().numCols();
      binaryW_dfsG.put(
          entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
      binaryW_dfsB.put(
          entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
      binaryScoreDerivativesG.put(
          entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(1, numRows));
      binaryScoreDerivativesB.put(
          entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(1, numRows));
    }
    for (Map.Entry<String, SimpleMatrix> entry : dvModel.unaryTransform.entrySet()) {
      int numRows = entry.getValue().numRows();
      int numCols = entry.getValue().numCols();
      unaryW_dfsG.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
      unaryW_dfsB.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
      unaryScoreDerivativesG.put(entry.getKey(), new SimpleMatrix(1, numRows));
      unaryScoreDerivativesB.put(entry.getKey(), new SimpleMatrix(1, numRows));
    }
    if (op.trainOptions.trainWordVectors) {
      for (Map.Entry<String, SimpleMatrix> entry : dvModel.wordVectors.entrySet()) {
        int numRows = entry.getValue().numRows();
        int numCols = entry.getValue().numCols();
        wordVectorDerivativesG.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
        wordVectorDerivativesB.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
      }
    }

    // Some optimization methods prints out a line without an end, so our
    // debugging statements are misaligned
    Timing scoreTiming = new Timing();
    scoreTiming.doing("Scoring trees");
    int treeNum = 0;
    MulticoreWrapper<Tree, Pair<DeepTree, DeepTree>> wrapper =
        new MulticoreWrapper<>(op.trainOptions.trainingThreads, new ScoringProcessor());
    for (Tree tree : trainingBatch) {
      wrapper.put(tree);
    }
    wrapper.join();
    scoreTiming.done();
    while (wrapper.peek()) {
      Pair<DeepTree, DeepTree> result = wrapper.poll();
      DeepTree goldTree = result.first;
      DeepTree bestTree = result.second;

      StringBuilder treeDebugLine = new StringBuilder();
      Formatter formatter = new Formatter(treeDebugLine);
      boolean isDone =
          (Math.abs(bestTree.getScore() - goldTree.getScore()) <= 0.00001
              || goldTree.getScore() > bestTree.getScore());
      String done = isDone ? "done" : "";
      formatter.format(
          "Tree %6d Highest tree: %12.4f Correct tree: %12.4f %s",
          treeNum, bestTree.getScore(), goldTree.getScore(), done);
      System.err.println(treeDebugLine.toString());
      if (!isDone) {
        // if the gold tree is better than the best hypothesis tree by
        // a large enough margin, then the score difference will be 0
        // and we ignore the tree

        double valueDelta = bestTree.getScore() - goldTree.getScore();
        // double valueDelta = Math.max(0.0, - scoreGold + bestScore);
        localValue += valueDelta;

        // get the context words for this tree - should be the same
        // for either goldTree or bestTree
        List<String> words = getContextWords(goldTree.getTree());

        // The derivatives affected by this tree are only based on the
        // nodes present in this tree, eg not all matrix derivatives
        // will be affected by this tree
        backpropDerivative(
            goldTree.getTree(),
            words,
            goldTree.getVectors(),
            binaryW_dfsG,
            unaryW_dfsG,
            binaryScoreDerivativesG,
            unaryScoreDerivativesG,
            wordVectorDerivativesG);

        backpropDerivative(
            bestTree.getTree(),
            words,
            bestTree.getVectors(),
            binaryW_dfsB,
            unaryW_dfsB,
            binaryScoreDerivativesB,
            unaryScoreDerivativesB,
            wordVectorDerivativesB);
      }
      ++treeNum;
    }

    double[] localDerivativeGood;
    double[] localDerivativeB;
    if (op.trainOptions.trainWordVectors) {
      localDerivativeGood =
          NeuralUtils.paramsToVector(
              theta.length,
              binaryW_dfsG.valueIterator(),
              unaryW_dfsG.values().iterator(),
              binaryScoreDerivativesG.valueIterator(),
              unaryScoreDerivativesG.values().iterator(),
              wordVectorDerivativesG.values().iterator());

      localDerivativeB =
          NeuralUtils.paramsToVector(
              theta.length,
              binaryW_dfsB.valueIterator(),
              unaryW_dfsB.values().iterator(),
              binaryScoreDerivativesB.valueIterator(),
              unaryScoreDerivativesB.values().iterator(),
              wordVectorDerivativesB.values().iterator());
    } else {
      localDerivativeGood =
          NeuralUtils.paramsToVector(
              theta.length,
              binaryW_dfsG.valueIterator(),
              unaryW_dfsG.values().iterator(),
              binaryScoreDerivativesG.valueIterator(),
              unaryScoreDerivativesG.values().iterator());

      localDerivativeB =
          NeuralUtils.paramsToVector(
              theta.length,
              binaryW_dfsB.valueIterator(),
              unaryW_dfsB.values().iterator(),
              binaryScoreDerivativesB.valueIterator(),
              unaryScoreDerivativesB.values().iterator());
    }

    // correct - highest
    for (int i = 0; i < localDerivativeGood.length; i++) {
      localDerivative[i] = localDerivativeB[i] - localDerivativeGood[i];
    }

    // TODO: this is where we would combine multiple costs if we had parallelized the calculation
    value = localValue;
    derivative = localDerivative;

    // normalizing by training batch size
    value = (1.0 / trainingBatch.size()) * value;
    ArrayMath.multiplyInPlace(derivative, (1.0 / trainingBatch.size()));

    // add regularization to cost:
    double[] currentParams = dvModel.paramsToVector();
    double regCost = 0;
    for (double currentParam : currentParams) {
      regCost += currentParam * currentParam;
    }
    regCost = op.trainOptions.regCost * 0.5 * regCost;
    value += regCost;
    // add regularization to gradient
    ArrayMath.multiplyInPlace(currentParams, op.trainOptions.regCost);
    ArrayMath.pairwiseAddInPlace(derivative, currentParams);
  }
Ejemplo n.º 3
0
  /** The traditional way of initializing an empty model suitable for training. */
  public SentimentModel(RNNOptions op, List<Tree> trainingTrees) {
    this.op = op;
    rand = new Random(op.randomSeed);

    if (op.randomWordVectors) {
      initRandomWordVectors(trainingTrees);
    } else {
      readWordVectors();
    }
    if (op.numHid > 0) {
      this.numHid = op.numHid;
    } else {
      int size = 0;
      for (SimpleMatrix vector : wordVectors.values()) {
        size = vector.getNumElements();
        break;
      }
      this.numHid = size;
    }

    TwoDimensionalSet<String, String> binaryProductions = TwoDimensionalSet.hashSet();
    if (op.simplifiedModel) {
      binaryProductions.add("", "");
    } else {
      // TODO
      // figure out what binary productions we have in these trees
      // Note: the current sentiment training data does not actually
      // have any constituent labels
      throw new UnsupportedOperationException("Not yet implemented");
    }

    Set<String> unaryProductions = Generics.newHashSet();
    if (op.simplifiedModel) {
      unaryProductions.add("");
    } else {
      // TODO
      // figure out what unary productions we have in these trees (preterminals only, after the
      // collapsing)
      throw new UnsupportedOperationException("Not yet implemented");
    }

    this.numClasses = op.numClasses;

    identity = SimpleMatrix.identity(numHid);

    binaryTransform = TwoDimensionalMap.treeMap();
    binaryTensors = TwoDimensionalMap.treeMap();
    binaryClassification = TwoDimensionalMap.treeMap();

    // When making a flat model (no symantic untying) the
    // basicCategory function will return the same basic category for
    // all labels, so all entries will map to the same matrix
    for (Pair<String, String> binary : binaryProductions) {
      String left = basicCategory(binary.first);
      String right = basicCategory(binary.second);
      if (binaryTransform.contains(left, right)) {
        continue;
      }
      binaryTransform.put(left, right, randomTransformMatrix());
      if (op.useTensors) {
        binaryTensors.put(left, right, randomBinaryTensor());
      }
      if (!op.combineClassification) {
        binaryClassification.put(left, right, randomClassificationMatrix());
      }
    }
    numBinaryMatrices = binaryTransform.size();
    binaryTransformSize = numHid * (2 * numHid + 1);
    if (op.useTensors) {
      binaryTensorSize = numHid * numHid * numHid * 4;
    } else {
      binaryTensorSize = 0;
    }
    binaryClassificationSize = (op.combineClassification) ? 0 : numClasses * (numHid + 1);

    unaryClassification = Generics.newTreeMap();

    // When making a flat model (no symantic untying) the
    // basicCategory function will return the same basic category for
    // all labels, so all entries will map to the same matrix
    for (String unary : unaryProductions) {
      unary = basicCategory(unary);
      if (unaryClassification.containsKey(unary)) {
        continue;
      }
      unaryClassification.put(unary, randomClassificationMatrix());
    }
    numUnaryMatrices = unaryClassification.size();
    unaryClassificationSize = numClasses * (numHid + 1);

    // System.err.println("Binary transform matrices:");
    // System.err.println(binaryTransform);
    // System.err.println("Binary classification matrices:");
    // System.err.println(binaryClassification);
    // System.err.println("Unary classification matrices:");
    // System.err.println(unaryClassification);
  }