@Test
  public void testSubSampleLayerNoneBackprop() throws Exception {
    Layer layer = getCNNConfig(nChannelsIn, depth, kernelSize, stride, padding);

    Pair<Gradient, INDArray> out = layer.backpropGradient(epsilon);
    assertEquals(epsilon.shape().length, out.getSecond().shape().length);
    assertEquals(nExamples, out.getSecond().size(1)); // depth retained
  }
  // note precision is off on this test but the numbers are close
  // investigation in a future release should determine how to resolve
  @Test
  public void testBackpropResultsContained() {
    Layer layer = getContainedConfig();
    INDArray input = getContainedData();
    INDArray col = getContainedCol();
    INDArray epsilon = Nd4j.ones(1, 2, 4, 4);

    INDArray expectedBiasGradient =
        Nd4j.create(new double[] {0.16608272, 0.16608272}, new int[] {1, 2});
    INDArray expectedWeightGradient =
        Nd4j.create(
            new double[] {
              0.17238397,
              0.17238397,
              0.33846668,
              0.33846668,
              0.17238397,
              0.17238397,
              0.33846668,
              0.33846668
            },
            new int[] {2, 1, 2, 2});
    INDArray expectedEpsilon =
        Nd4j.create(
            new double[] {
              0.00039383, 0.00039383, 0.00039383, 0.00039383, 0.00039383,
              0.00039383, 0., 0., 0.00039383, 0.00039383,
              0.00039383, 0.00039383, 0.00039383, 0.00039383, 0.,
              0., 0.02036651, 0.02036651, 0.02036651, 0.02036651,
              0.02036651, 0.02036651, 0., 0., 0.02036651,
              0.02036651, 0.02036651, 0.02036651, 0.02036651, 0.02036651,
              0., 0., 0.00039383, 0.00039383, 0.00039383,
              0.00039383, 0.00039383, 0.00039383, 0., 0.,
              0.00039383, 0.00039383, 0.00039383, 0.00039383, 0.00039383,
              0.00039383, 0., 0., 0., 0.,
              0., 0., 0., 0., 0.,
              0., 0., 0., 0., 0.,
              0., 0., 0., 0.
            },
            new int[] {1, 1, 8, 8});

    layer.setInput(input);
    org.deeplearning4j.nn.layers.convolution.ConvolutionLayer layer2 =
        (org.deeplearning4j.nn.layers.convolution.ConvolutionLayer) layer;
    layer2.setCol(col);
    Pair<Gradient, INDArray> pair = layer2.backpropGradient(epsilon);

    assertArrayEquals(expectedEpsilon.shape(), pair.getSecond().shape());
    assertArrayEquals(expectedWeightGradient.shape(), pair.getFirst().getGradientFor("W").shape());
    assertArrayEquals(expectedBiasGradient.shape(), pair.getFirst().getGradientFor("b").shape());
    assertEquals(expectedEpsilon, pair.getSecond());
    assertEquals(expectedWeightGradient, pair.getFirst().getGradientFor("W"));
    assertEquals(expectedBiasGradient, pair.getFirst().getGradientFor("b"));
  }
示例#3
0
  /**
   * Gibbs sampling step: hidden ---> visible ---> hidden
   *
   * @param h the hidden input
   * @return the expected values and samples of both the visible samples given the hidden and the
   *     new hidden input and expected values
   */
  public Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>> gibbhVh(INDArray h) {
    Pair<INDArray, INDArray> v1MeanAndSample = sampleVisibleGivenHidden(h);
    INDArray vSample = v1MeanAndSample.getSecond();

    Pair<INDArray, INDArray> h1MeanAndSample = sampleHiddenGivenVisible(vSample);
    return new Pair<>(v1MeanAndSample, h1MeanAndSample);
  }
  public static void main(String[] args) throws Exception {
    int iterations = 100;
    Nd4j.dtype = DataBuffer.Type.DOUBLE;
    Nd4j.factory().setDType(DataBuffer.Type.DOUBLE);
    List<String> cacheList = new ArrayList<>();

    log.info("Load & Vectorize data....");
    File wordFile = new ClassPathResource("words.txt").getFile();
    Pair<InMemoryLookupTable, VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile);
    VocabCache cache = vectors.getSecond();
    INDArray weights = vectors.getFirst().getSyn0();

    for (int i = 0; i < cache.numWords(); i++) cacheList.add(cache.wordAtIndex(i));

    log.info("Build model....");
    BarnesHutTsne tsne =
        new BarnesHutTsne.Builder()
            .setMaxIter(iterations)
            .theta(0.5)
            .normalize(false)
            .learningRate(500)
            .useAdaGrad(false)
            .usePca(false)
            .build();

    log.info("Store TSNE Coordinates for Plotting....");
    String outputFile = "target/archive-tmp/tsne-standard-coords.csv";
    (new File(outputFile)).getParentFile().mkdirs();
    tsne.plot(weights, 2, cacheList, outputFile);
  }
示例#5
0
 /**
  * An individual iteration
  *
  * @param p the probabilities that certain points are near each other
  * @param i the iteration (primarily for debugging purposes)
  */
 public void step(INDArray p, int i) {
   Pair<Double, INDArray> costGradient = gradient(p);
   INDArray yIncs = costGradient.getSecond();
   log.info("Cost at iteration " + i + " was " + costGradient.getFirst());
   y.addi(yIncs);
   y.addi(yIncs).subiRowVector(y.mean(0));
   INDArray tiled = Nd4j.tile(y.mean(0), new int[] {y.rows(), y.columns()});
   y.subi(tiled);
 }
示例#6
0
  /**
   * Convert data to probability co-occurrences (aka calculating the kernel)
   *
   * @param d the data to convert
   * @param u the perplexity of the model
   * @return the probabilities of co-occurrence
   */
  public INDArray computeGaussianPerplexity(final INDArray d, double u) {
    int n = d.rows();
    final INDArray p = zeros(n, n);
    final INDArray beta = ones(n, 1);
    final double logU = Math.log(u);

    log.info("Calculating probabilities of data similarities..");
    for (int i = 0; i < n; i++) {
      if (i % 500 == 0 && i > 0) log.info("Handled " + i + " records");

      double betaMin = Double.NEGATIVE_INFINITY;
      double betaMax = Double.POSITIVE_INFINITY;
      int[] vals = Ints.concat(ArrayUtil.range(0, i), ArrayUtil.range(i + 1, d.columns()));
      INDArrayIndex[] range = new INDArrayIndex[] {new NDArrayIndex(vals)};

      INDArray row = d.slice(i).get(range);
      Pair<INDArray, INDArray> pair = hBeta(row, beta.getDouble(i));
      INDArray hDiff = pair.getFirst().sub(logU);
      int tries = 0;

      // while hdiff > tolerance
      while (BooleanIndexing.and(abs(hDiff), Conditions.greaterThan(tolerance)) && tries < 50) {
        // if hdiff > 0
        if (BooleanIndexing.and(hDiff, Conditions.greaterThan(0))) {
          if (Double.isInfinite(betaMax)) beta.putScalar(i, beta.getDouble(i) * 2.0);
          else beta.putScalar(i, (beta.getDouble(i) + betaMax) / 2.0);
          betaMin = beta.getDouble(i);
        } else {
          if (Double.isInfinite(betaMin)) beta.putScalar(i, beta.getDouble(i) / 2.0);
          else beta.putScalar(i, (beta.getDouble(i) + betaMin) / 2.0);
          betaMax = beta.getDouble(i);
        }

        pair = hBeta(row, beta.getDouble(i));
        hDiff = pair.getFirst().subi(logU);
        tries++;
      }

      p.slice(i).put(range, pair.getSecond());
    }

    // dont need data in memory after
    log.info("Mean value of sigma " + sqrt(beta.rdiv(1)).mean(Integer.MAX_VALUE));
    BooleanIndexing.applyWhere(p, Conditions.isNan(), new Value(realMin));

    // set 0 along the diagonal
    INDArray permute = p.transpose();

    INDArray pOut = p.add(permute);

    pOut.divi(pOut.sum(Integer.MAX_VALUE));
    BooleanIndexing.applyWhere(
        pOut, Conditions.lessThan(Nd4j.EPS_THRESHOLD), new Value(Nd4j.EPS_THRESHOLD));
    // ensure no nans
    return pOut;
  }
 @Test
 public void testFeedForwardActivationsAndDerivatives() {
   MultiLayerNetwork network = new MultiLayerNetwork(getConf());
   network.init();
   DataSet data = new IrisDataSetIterator(1, 150).next();
   network.fit(data);
   Pair result = network.feedForwardActivationsAndDerivatives();
   List<INDArray> first = (List) result.getFirst();
   List<INDArray> second = (List) result.getSecond();
   assertEquals(first.size(), second.size());
 }
示例#8
0
  @Override
  public Pair<Gradient, INDArray> backpropGradient(
      INDArray epsilon, Gradient nextGradient, Layer layer) {
    Pair<Gradient, INDArray> pair =
        getGradientsAndDelta(
            output(input)); // Returns Gradient and delta^(this), not Gradient and epsilon^(this-1)
    INDArray delta = pair.getSecond();

    INDArray epsilonNext =
        params.get(DefaultParamInitializer.WEIGHT_KEY).mmul(delta.transpose()).transpose();
    return new Pair<>(pair.getFirst(), epsilonNext);
  }
  @Test
  public void testBackpropResults() {
    Layer layer = getContainedConfig();
    INDArray col = getContainedCol();

    INDArray expectedWeightGradient =
        Nd4j.create(
            new double[] {-1440., -1440., -1984., -1984., -1440., -1440., -1984., -1984.},
            new int[] {2, 1, 2, 2});
    INDArray expectedBiasGradient =
        Nd4j.create(
            new double[] {-544., -544.},
            new int[] {
              2,
            });
    INDArray expectedEpsilon =
        Nd4j.create(
            new double[] {
              -12., -12., -12., -12., -12., -12., -12., -12., -12., -12., -12.,
              -12., -12., -12., -12., -12., -56., -56., -56., -56., -56., -56.,
              -56., -56., -56., -56., -56., -56., -56., -56., -56., -56., -12.,
              -12., -12., -12., -12., -12., -12., -12., -12., -12., -12., -12.,
              -12., -12., -12., -12., -56., -56., -56., -56., -56., -56., -56.,
              -56., -56., -56., -56., -56., -56., -56., -56., -56.
            },
            new int[] {1, 1, 8, 8});

    org.deeplearning4j.nn.layers.convolution.ConvolutionLayer layer2 =
        (org.deeplearning4j.nn.layers.convolution.ConvolutionLayer) layer;
    layer2.setCol(col);
    Pair<Gradient, INDArray> pair = layer2.backpropGradient(epsilon);

    assertEquals(expectedEpsilon.shape(), pair.getSecond().shape());
    assertEquals(expectedWeightGradient.shape(), pair.getFirst().getGradientFor("W").shape());
    assertEquals(expectedBiasGradient.shape(), pair.getFirst().getGradientFor("b").shape());
    assertEquals(expectedEpsilon, pair.getSecond());
    assertEquals(expectedWeightGradient, pair.getFirst().getGradientFor("W"));
    assertEquals(expectedBiasGradient, pair.getFirst().getGradientFor("b"));
  }
  public static void main(String[] args) throws Exception {
    // STEP 1: Initialization
    int iterations = 100;
    // create an n-dimensional array of doubles
    DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE);
    List<String> cacheList =
        new ArrayList<>(); // cacheList is a dynamic array of strings used to hold all words

    // STEP 2: Turn text input into a list of words
    log.info("Load & Vectorize data....");
    File wordFile = new ClassPathResource("words.txt").getFile(); // Open the file
    // Get the data of all unique word vectors
    Pair<InMemoryLookupTable, VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile);
    VocabCache cache = vectors.getSecond();
    INDArray weights =
        vectors.getFirst().getSyn0(); // seperate weights of unique words into their own list

    for (int i = 0; i < cache.numWords(); i++) // seperate strings of words into their own list
    cacheList.add(cache.wordAtIndex(i));

    // STEP 3: build a dual-tree tsne to use later
    log.info("Build model....");
    BarnesHutTsne tsne =
        new BarnesHutTsne.Builder()
            .setMaxIter(iterations)
            .theta(0.5)
            .normalize(false)
            .learningRate(500)
            .useAdaGrad(false)
            //                .usePca(false)
            .build();

    // STEP 4: establish the tsne values and save them to a file
    log.info("Store TSNE Coordinates for Plotting....");
    String outputFile = "target/archive-tmp/tsne-standard-coords.csv";
    (new File(outputFile)).getParentFile().mkdirs();
    tsne.plot(weights, 2, cacheList, outputFile);
    // This tsne will use the weights of the vectors as its matrix, have two dimensions, use the
    // words strings as
    // labels, and be written to the outputFile created on the previous line

    // !!! Possible error: plot was recently deprecated. Might need to re-do the last line
  }
示例#11
0
  private void init() {

    if (rng == null) rng = new MersenneTwister(123);

    MultiDimensionalSet<String, String> binaryProductions = MultiDimensionalSet.hashSet();
    if (simplifiedModel) {
      binaryProductions.add("", "");
    } else {
      // TODO
      // figure out what binary productions we have in these trees
      // Note: the current sentiment training data does not actually
      // have any constituent labels
      throw new UnsupportedOperationException("Not yet implemented");
    }

    Set<String> unaryProductions = new HashSet<>();

    if (simplifiedModel) {
      unaryProductions.add("");
    } else {
      // TODO
      // figure out what unary productions we have in these trees (preterminals only, after the
      // collapsing)
      throw new UnsupportedOperationException("Not yet implemented");
    }

    identity = FloatMatrix.eye(numHidden);

    binaryTransform = MultiDimensionalMap.newTreeBackedMap();
    binaryFloatTensors = MultiDimensionalMap.newTreeBackedMap();
    binaryClassification = MultiDimensionalMap.newTreeBackedMap();

    // When making a flat model (no semantic untying) the
    // basicCategory function will return the same basic category for
    // all labels, so all entries will map to the same matrix
    for (Pair<String, String> binary : binaryProductions) {
      String left = basicCategory(binary.getFirst());
      String right = basicCategory(binary.getSecond());
      if (binaryTransform.contains(left, right)) {
        continue;
      }

      binaryTransform.put(left, right, randomTransformMatrix());
      if (useFloatTensors) {
        binaryFloatTensors.put(left, right, randomBinaryFloatTensor());
      }

      if (!combineClassification) {
        binaryClassification.put(left, right, randomClassificationMatrix());
      }
    }

    numBinaryMatrices = binaryTransform.size();
    binaryTransformSize = numHidden * (2 * numHidden + 1);

    if (useFloatTensors) {
      binaryFloatTensorSize = numHidden * numHidden * numHidden * 4;
    } else {
      binaryFloatTensorSize = 0;
    }

    binaryClassificationSize = (combineClassification) ? 0 : numOuts * (numHidden + 1);

    unaryClassification = new TreeMap<>();

    // When making a flat model (no semantic untying) the
    // basicCategory function will return the same basic category for
    // all labels, so all entries will map to the same matrix

    for (String unary : unaryProductions) {
      unary = basicCategory(unary);
      if (unaryClassification.containsKey(unary)) {
        continue;
      }
      unaryClassification.put(unary, randomClassificationMatrix());
    }

    binaryClassificationSize = (combineClassification) ? 0 : numOuts * (numHidden + 1);

    numUnaryMatrices = unaryClassification.size();
    unaryClassificationSize = numOuts * (numHidden + 1);

    featureVectors.put(UNKNOWN_FEATURE, randomWordVector());
    numUnaryMatrices = unaryClassification.size();
    unaryClassificationSize = numOuts * (numHidden + 1);
    classWeights = new HashMap<>();
  }
示例#12
0
  @Override
  public void computeGradientAndScore() {
    int k = layerConf().getK();

    // POSITIVE PHASE
    Pair<INDArray, INDArray> probHidden = sampleHiddenGivenVisible(input());

    /*
     * Start the gibbs sampling.
     */
    INDArray chainStart = probHidden.getSecond();

    /*
     * Note that at a later date, we can explore alternative methods of
     * storing the chain transitions for different kinds of sampling
     * and exploring the search space.
     */
    Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>> matrices;
    // negative visible means or expected values
    INDArray nvMeans = null;
    // negative value samples
    INDArray nvSamples = null;
    // negative hidden means or expected values
    INDArray nhMeans = null;
    // negative hidden samples
    INDArray nhSamples = null;

    /*
     * K steps of gibbs sampling. This is the positive phase of contrastive divergence.
     *
     * There are 4 matrices being computed for each gibbs sampling.
     * The samples from both the positive and negative phases and their expected values
     * or averages.
     *
     */

    for (int i = 0; i < k; i++) {

      // NEGATIVE PHASE
      if (i == 0) matrices = gibbhVh(chainStart);
      else matrices = gibbhVh(nhSamples);

      // get the cost updates for sampling in the chain after k iterations
      nvMeans = matrices.getFirst().getFirst();
      nvSamples = matrices.getFirst().getSecond();
      nhMeans = matrices.getSecond().getFirst();
      nhSamples = matrices.getSecond().getSecond();
    }

    /*
     * Update gradient parameters
     */
    INDArray wGradient =
        input().transposei().mmul(probHidden.getSecond()).subi(nvSamples.transpose().mmul(nhMeans));

    INDArray hBiasGradient;

    if (layerConf().getSparsity() != 0)
      // all hidden units must stay around this number
      hBiasGradient = probHidden.getSecond().rsub(layerConf().getSparsity()).sum(0);
    else
      // update rule: the expected values of the hidden input - the negative hidden  means adjusted
      // by the learning rate
      hBiasGradient = probHidden.getSecond().sub(nhMeans).sum(0);

    // update rule: the expected values of the input - the negative samples adjusted by the learning
    // rate
    INDArray delta = input.sub(nvSamples);
    INDArray vBiasGradient = delta.sum(0);

    Gradient ret = new DefaultGradient();
    ret.gradientForVariable().put(PretrainParamInitializer.VISIBLE_BIAS_KEY, vBiasGradient);
    ret.gradientForVariable().put(PretrainParamInitializer.BIAS_KEY, hBiasGradient);
    ret.gradientForVariable().put(PretrainParamInitializer.WEIGHT_KEY, wGradient);
    gradient = ret;
    setScoreWithZ(delta);
  }
 /**
  * Load word vectors from the given pair
  *
  * @param pair the given pair
  * @return a read only word vectors impl based on the given lookup table and vocab
  */
 public static WordVectors fromPair(Pair<InMemoryLookupTable, VocabCache> pair) {
   WordVectorsImpl vectors = new WordVectorsImpl();
   vectors.setLookupTable(pair.getFirst());
   vectors.setVocab(pair.getSecond());
   return vectors;
 }
示例#14
0
  /**
   * Train on the corpus
   *
   * @param rdd the rdd to train
   * @return the vocab and weights
   */
  public Pair<VocabCache, GloveWeightLookupTable> train(JavaRDD<String> rdd) {
    TextPipeline pipeline = new TextPipeline(rdd);
    final Pair<VocabCache, Long> vocabAndNumWords = pipeline.process();
    SparkConf conf = rdd.context().getConf();
    JavaSparkContext sc = new JavaSparkContext(rdd.context());
    vocabCacheBroadcast = sc.broadcast(vocabAndNumWords.getFirst());

    final GloveWeightLookupTable gloveWeightLookupTable =
        new GloveWeightLookupTable.Builder()
            .cache(vocabAndNumWords.getFirst())
            .lr(conf.getDouble(GlovePerformer.ALPHA, 0.025))
            .maxCount(conf.getDouble(GlovePerformer.MAX_COUNT, 100))
            .vectorLength(conf.getInt(GlovePerformer.VECTOR_LENGTH, 300))
            .xMax(conf.getDouble(GlovePerformer.X_MAX, 0.75))
            .build();
    gloveWeightLookupTable.resetWeights();

    gloveWeightLookupTable.getBiasAdaGrad().historicalGradient =
        Nd4j.zeros(gloveWeightLookupTable.getSyn0().rows());
    gloveWeightLookupTable.getWeightAdaGrad().historicalGradient =
        Nd4j.create(gloveWeightLookupTable.getSyn0().shape());

    log.info(
        "Created lookup table of size "
            + Arrays.toString(gloveWeightLookupTable.getSyn0().shape()));
    CounterMap<String, String> coOccurrenceCounts =
        rdd.map(new TokenizerFunction(tokenizerFactoryClazz))
            .map(new CoOccurrenceCalculator(symmetric, vocabCacheBroadcast, windowSize))
            .fold(new CounterMap<String, String>(), new CoOccurrenceCounts());

    List<Triple<String, String, Double>> counts = new ArrayList<>();
    Iterator<Pair<String, String>> pairIter = coOccurrenceCounts.getPairIterator();
    while (pairIter.hasNext()) {
      Pair<String, String> pair = pairIter.next();
      counts.add(
          new Triple<>(
              pair.getFirst(),
              pair.getSecond(),
              coOccurrenceCounts.getCount(pair.getFirst(), pair.getSecond())));
    }

    log.info("Calculated co occurrences");

    JavaRDD<Triple<String, String, Double>> parallel = sc.parallelize(counts);
    JavaPairRDD<String, Tuple2<String, Double>> pairs =
        parallel.mapToPair(
            new PairFunction<Triple<String, String, Double>, String, Tuple2<String, Double>>() {
              @Override
              public Tuple2<String, Tuple2<String, Double>> call(
                  Triple<String, String, Double> stringStringDoubleTriple) throws Exception {
                return new Tuple2<>(
                    stringStringDoubleTriple.getFirst(),
                    new Tuple2<>(
                        stringStringDoubleTriple.getFirst(), stringStringDoubleTriple.getThird()));
              }
            });

    JavaPairRDD<VocabWord, Tuple2<VocabWord, Double>> pairsVocab =
        pairs.mapToPair(
            new PairFunction<
                Tuple2<String, Tuple2<String, Double>>, VocabWord, Tuple2<VocabWord, Double>>() {
              @Override
              public Tuple2<VocabWord, Tuple2<VocabWord, Double>> call(
                  Tuple2<String, Tuple2<String, Double>> stringTuple2Tuple2) throws Exception {
                return new Tuple2<>(
                    vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._1()),
                    new Tuple2<>(
                        vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._2()._1()),
                        stringTuple2Tuple2._2()._2()));
              }
            });

    for (int i = 0; i < iterations; i++) {

      JavaRDD<GloveChange> change =
          pairsVocab.map(
              new Function<Tuple2<VocabWord, Tuple2<VocabWord, Double>>, GloveChange>() {
                @Override
                public GloveChange call(
                    Tuple2<VocabWord, Tuple2<VocabWord, Double>> vocabWordTuple2Tuple2)
                    throws Exception {
                  VocabWord w1 = vocabWordTuple2Tuple2._1();
                  VocabWord w2 = vocabWordTuple2Tuple2._2()._1();
                  INDArray w1Vector = gloveWeightLookupTable.getSyn0().slice(w1.getIndex());
                  INDArray w2Vector = gloveWeightLookupTable.getSyn0().slice(w2.getIndex());
                  INDArray bias = gloveWeightLookupTable.getBias();
                  double score = vocabWordTuple2Tuple2._2()._2();
                  double xMax = gloveWeightLookupTable.getxMax();
                  double maxCount = gloveWeightLookupTable.getMaxCount();
                  // w1 * w2 + bias
                  double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
                  prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());

                  double weight = Math.pow(Math.min(1.0, (score / maxCount)), xMax);

                  double fDiff =
                      score > xMax ? prediction : weight * (prediction - Math.log(score));
                  if (Double.isNaN(fDiff)) fDiff = Nd4j.EPS_THRESHOLD;
                  // amount of change
                  double gradient = fDiff;
                  // update(w1,w1Vector,w2Vector,gradient);
                  // update(w2,w2Vector,w1Vector,gradient);

                  Pair<INDArray, Double> w1Update =
                      update(
                          gloveWeightLookupTable.getWeightAdaGrad(),
                          gloveWeightLookupTable.getBiasAdaGrad(),
                          gloveWeightLookupTable.getSyn0(),
                          gloveWeightLookupTable.getBias(),
                          w1,
                          w1Vector,
                          w2Vector,
                          gradient);
                  Pair<INDArray, Double> w2Update =
                      update(
                          gloveWeightLookupTable.getWeightAdaGrad(),
                          gloveWeightLookupTable.getBiasAdaGrad(),
                          gloveWeightLookupTable.getSyn0(),
                          gloveWeightLookupTable.getBias(),
                          w2,
                          w2Vector,
                          w1Vector,
                          gradient);
                  return new GloveChange(
                      w1,
                      w2,
                      w1Update.getFirst(),
                      w2Update.getFirst(),
                      w1Update.getSecond(),
                      w2Update.getSecond(),
                      fDiff);
                }
              });

      JavaRDD<Double> error =
          change.map(
              new Function<GloveChange, Double>() {
                @Override
                public Double call(GloveChange gloveChange) throws Exception {
                  gloveChange.apply(gloveWeightLookupTable);
                  return gloveChange.getError();
                }
              });

      final Accumulator<Double> d = sc.accumulator(0.0);
      error.foreach(
          new VoidFunction<Double>() {
            @Override
            public void call(Double aDouble) throws Exception {
              d.$plus$eq(aDouble);
            }
          });

      log.info("Error at iteration " + i + " was " + d.value());
    }

    return new Pair<>(vocabAndNumWords.getFirst(), gloveWeightLookupTable);
  }
 public FloatDataSet(Pair<FloatMatrix, FloatMatrix> pair) {
   this(pair.getFirst(), pair.getSecond());
 }