public FirstIterationFunction(
      Broadcast<Map<String, Object>> word2vecVarMapBroadcast,
      Broadcast<double[]> expTableBroadcast,
      Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast) {

    Map<String, Object> word2vecVarMap = word2vecVarMapBroadcast.getValue();
    this.expTable = expTableBroadcast.getValue();
    this.vectorLength = (int) word2vecVarMap.get("vectorLength");
    this.useAdaGrad = (boolean) word2vecVarMap.get("useAdaGrad");
    this.negative = (double) word2vecVarMap.get("negative");
    this.window = (int) word2vecVarMap.get("window");
    this.alpha = (double) word2vecVarMap.get("alpha");
    this.minAlpha = (double) word2vecVarMap.get("minAlpha");
    this.totalWordCount = (long) word2vecVarMap.get("totalWordCount");
    this.seed = (long) word2vecVarMap.get("seed");
    this.maxExp = (int) word2vecVarMap.get("maxExp");
    this.iterations = (int) word2vecVarMap.get("iterations");
    this.batchSize = (int) word2vecVarMap.get("batchSize");
    this.indexSyn0VecMap = new HashMap<>();
    this.pointSyn1VecMap = new HashMap<>();
    this.vocab = vocabCacheBroadcast.getValue();

    if (this.vocab == null) throw new RuntimeException("VocabCache is null");

    if (negative > 0) {
      negativeHolder = NegativeHolder.getInstance();
      negativeHolder.initHolder(vocab, expTable, this.vectorLength);
    }
  }
  public void iterateSample(VocabWord w1, VocabWord w2, double currentSentenceAlpha) {

    if (w1 == null || w2 == null || w2.getIndex() < 0 || w2.getIndex() == w1.getIndex()) return;
    final int currentWordIndex = w2.getIndex();

    // error for current word and context
    INDArray neu1e = Nd4j.create(vectorLength);

    // First iteration Syn0 is random numbers
    INDArray l1 = null;
    if (indexSyn0VecMap.containsKey(vocab.elementAtIndex(currentWordIndex))) {
      l1 = indexSyn0VecMap.get(vocab.elementAtIndex(currentWordIndex));
    } else {
      l1 = getRandomSyn0Vec(vectorLength, (long) currentWordIndex);
    }

    //
    for (int i = 0; i < w1.getCodeLength(); i++) {
      int code = w1.getCodes().get(i);
      int point = w1.getPoints().get(i);
      if (point < 0) throw new IllegalStateException("Illegal point " + point);
      // Point to
      INDArray syn1;
      if (pointSyn1VecMap.containsKey(point)) {
        syn1 = pointSyn1VecMap.get(point);
      } else {
        syn1 = Nd4j.zeros(1, vectorLength); // 1 row of vector length of zeros
        pointSyn1VecMap.put(point, syn1);
      }

      // Dot product of Syn0 and Syn1 vecs
      double dot = Nd4j.getBlasWrapper().level1().dot(vectorLength, 1.0, l1, syn1);

      if (dot < -maxExp || dot >= maxExp) continue;

      int idx = (int) ((dot + maxExp) * ((double) expTable.length / maxExp / 2.0));

      if (idx > expTable.length) continue;

      // score
      double f = expTable[idx];
      // gradient
      double g =
          (1 - code - f)
              * (useAdaGrad
                  ? w1.getGradient(i, currentSentenceAlpha, currentSentenceAlpha)
                  : currentSentenceAlpha);

      Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, syn1, neu1e);
      Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, l1, syn1);
    }

    int target = w1.getIndex();
    int label;
    // negative sampling
    if (negative > 0)
      for (int d = 0; d < negative + 1; d++) {
        if (d == 0) label = 1;
        else {
          nextRandom.set(nextRandom.get() * 25214903917L + 11);
          int idx = Math.abs((int) (nextRandom.get() >> 16) % negativeHolder.getTable().length());

          target = negativeHolder.getTable().getInt(idx);
          if (target <= 0) target = (int) nextRandom.get() % (vocab.numWords() - 1) + 1;

          if (target == w1.getIndex()) continue;
          label = 0;
        }

        if (target >= negativeHolder.getSyn1Neg().rows() || target < 0) continue;

        double f = Nd4j.getBlasWrapper().dot(l1, negativeHolder.getSyn1Neg().slice(target));
        double g;
        if (f > maxExp)
          g = useAdaGrad ? w1.getGradient(target, (label - 1), alpha) : (label - 1) * alpha;
        else if (f < -maxExp)
          g = label * (useAdaGrad ? w1.getGradient(target, alpha, alpha) : alpha);
        else {
          int idx = (int) ((f + maxExp) * (expTable.length / maxExp / 2));
          if (idx >= expTable.length) continue;

          g =
              useAdaGrad
                  ? w1.getGradient(target, label - expTable[idx], alpha)
                  : (label - expTable[idx]) * alpha;
        }

        Nd4j.getBlasWrapper().axpy((float) g, negativeHolder.getSyn1Neg().slice(target), neu1e);

        Nd4j.getBlasWrapper().axpy((float) g, l1, negativeHolder.getSyn1Neg().slice(target));
      }

    // Updated the Syn0 vector based on gradient. Syn0 is not random anymore.
    Nd4j.getBlasWrapper().level1().axpy(vectorLength, 1.0f, neu1e, l1);

    if (aff.get() == 0) {
      synchronized (this) {
        cid.set(EnvironmentUtils.buildCId());
        aff.set(EnvironmentUtils.buildEnvironment().getAvailableMemory());
      }
    }

    VocabWord word = vocab.elementAtIndex(currentWordIndex);
    word.setVocabId(cid.get());
    word.setAffinityId(aff.get());
    indexSyn0VecMap.put(word, l1);
  }