/**
   * Creates an index of the top size words based on tf-idf metrics
   *
   * @param size the number of words in the vocab
   * @return the index of the words
   * @throws IOException
   */
  public Index createVocab(int size) {
    Index vocab = new Index();

    // bootstrapping
    calcWordFrequencies();

    // term frequency has every word
    for (String word : tf.keySet()) {
      double tfVal = MathUtils.tf((int) documentWordFrequencies.getCount(word));
      double idfVal = MathUtils.idf(numFiles, idf.getCount(word));
      double tfidfVal = MathUtils.tfidf(tfVal, idfVal);
      java.util.regex.Matcher m = punct.matcher(word);

      if (!stopWords.contains(word) && !m.matches()) tfidf.setCount(word, tfidfVal);
    }

    Counter<String> aggregate = tfidf;

    // keep top size keys via tfidf rankings
    aggregate.keepTopNKeys(size - 1);

    log.info("Created vocab of size " + aggregate.size());

    wordScores = aggregate;

    // add words that made it via rankings
    for (String word : aggregate.keySet()) {
      if (vocab.indexOf(word) < 0) vocab.add(word);
    }

    // cache the vocab
    currVocab = vocab;
    return vocab;
  }
  protected void addForDoc(File doc) {
    Set<String> encountered = new HashSet<String>();
    SentenceIterator iter = new LineSentenceIterator(doc);
    while (iter.hasNext()) {
      String line = iter.nextSentence();
      if (line == null) continue;
      Tokenizer tokenizer = tokenizerFactory.create(new InputHomogenization(line).transform());
      while (tokenizer.hasMoreTokens()) {
        String token = tokenizer.nextToken();
        java.util.regex.Matcher m = punct.matcher(token);
        if (validWord(token)) {
          documentWordFrequencies.incrementCount(token, doc.getAbsolutePath(), 1.0);
          tf.incrementCount(token, 1.0);
          if (!encountered.contains(token)) {
            idf.incrementCount(token, 1.0);
            encountered.add(token);
          }
        }
      }

      iter.finish();
    }
  }
  /**
   * Train on the corpus
   *
   * @param rdd the rdd to train
   * @return the vocab and weights
   */
  public Pair<VocabCache, GloveWeightLookupTable> train(JavaRDD<String> rdd) {
    TextPipeline pipeline = new TextPipeline(rdd);
    final Pair<VocabCache, Long> vocabAndNumWords = pipeline.process();
    SparkConf conf = rdd.context().getConf();
    JavaSparkContext sc = new JavaSparkContext(rdd.context());
    vocabCacheBroadcast = sc.broadcast(vocabAndNumWords.getFirst());

    final GloveWeightLookupTable gloveWeightLookupTable =
        new GloveWeightLookupTable.Builder()
            .cache(vocabAndNumWords.getFirst())
            .lr(conf.getDouble(GlovePerformer.ALPHA, 0.025))
            .maxCount(conf.getDouble(GlovePerformer.MAX_COUNT, 100))
            .vectorLength(conf.getInt(GlovePerformer.VECTOR_LENGTH, 300))
            .xMax(conf.getDouble(GlovePerformer.X_MAX, 0.75))
            .build();
    gloveWeightLookupTable.resetWeights();

    gloveWeightLookupTable.getBiasAdaGrad().historicalGradient =
        Nd4j.zeros(gloveWeightLookupTable.getSyn0().rows());
    gloveWeightLookupTable.getWeightAdaGrad().historicalGradient =
        Nd4j.create(gloveWeightLookupTable.getSyn0().shape());

    log.info(
        "Created lookup table of size "
            + Arrays.toString(gloveWeightLookupTable.getSyn0().shape()));
    CounterMap<String, String> coOccurrenceCounts =
        rdd.map(new TokenizerFunction(tokenizerFactoryClazz))
            .map(new CoOccurrenceCalculator(symmetric, vocabCacheBroadcast, windowSize))
            .fold(new CounterMap<String, String>(), new CoOccurrenceCounts());

    List<Triple<String, String, Double>> counts = new ArrayList<>();
    Iterator<Pair<String, String>> pairIter = coOccurrenceCounts.getPairIterator();
    while (pairIter.hasNext()) {
      Pair<String, String> pair = pairIter.next();
      counts.add(
          new Triple<>(
              pair.getFirst(),
              pair.getSecond(),
              coOccurrenceCounts.getCount(pair.getFirst(), pair.getSecond())));
    }

    log.info("Calculated co occurrences");

    JavaRDD<Triple<String, String, Double>> parallel = sc.parallelize(counts);
    JavaPairRDD<String, Tuple2<String, Double>> pairs =
        parallel.mapToPair(
            new PairFunction<Triple<String, String, Double>, String, Tuple2<String, Double>>() {
              @Override
              public Tuple2<String, Tuple2<String, Double>> call(
                  Triple<String, String, Double> stringStringDoubleTriple) throws Exception {
                return new Tuple2<>(
                    stringStringDoubleTriple.getFirst(),
                    new Tuple2<>(
                        stringStringDoubleTriple.getFirst(), stringStringDoubleTriple.getThird()));
              }
            });

    JavaPairRDD<VocabWord, Tuple2<VocabWord, Double>> pairsVocab =
        pairs.mapToPair(
            new PairFunction<
                Tuple2<String, Tuple2<String, Double>>, VocabWord, Tuple2<VocabWord, Double>>() {
              @Override
              public Tuple2<VocabWord, Tuple2<VocabWord, Double>> call(
                  Tuple2<String, Tuple2<String, Double>> stringTuple2Tuple2) throws Exception {
                return new Tuple2<>(
                    vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._1()),
                    new Tuple2<>(
                        vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._2()._1()),
                        stringTuple2Tuple2._2()._2()));
              }
            });

    for (int i = 0; i < iterations; i++) {

      JavaRDD<GloveChange> change =
          pairsVocab.map(
              new Function<Tuple2<VocabWord, Tuple2<VocabWord, Double>>, GloveChange>() {
                @Override
                public GloveChange call(
                    Tuple2<VocabWord, Tuple2<VocabWord, Double>> vocabWordTuple2Tuple2)
                    throws Exception {
                  VocabWord w1 = vocabWordTuple2Tuple2._1();
                  VocabWord w2 = vocabWordTuple2Tuple2._2()._1();
                  INDArray w1Vector = gloveWeightLookupTable.getSyn0().slice(w1.getIndex());
                  INDArray w2Vector = gloveWeightLookupTable.getSyn0().slice(w2.getIndex());
                  INDArray bias = gloveWeightLookupTable.getBias();
                  double score = vocabWordTuple2Tuple2._2()._2();
                  double xMax = gloveWeightLookupTable.getxMax();
                  double maxCount = gloveWeightLookupTable.getMaxCount();
                  // w1 * w2 + bias
                  double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
                  prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());

                  double weight = Math.pow(Math.min(1.0, (score / maxCount)), xMax);

                  double fDiff =
                      score > xMax ? prediction : weight * (prediction - Math.log(score));
                  if (Double.isNaN(fDiff)) fDiff = Nd4j.EPS_THRESHOLD;
                  // amount of change
                  double gradient = fDiff;
                  // update(w1,w1Vector,w2Vector,gradient);
                  // update(w2,w2Vector,w1Vector,gradient);

                  Pair<INDArray, Double> w1Update =
                      update(
                          gloveWeightLookupTable.getWeightAdaGrad(),
                          gloveWeightLookupTable.getBiasAdaGrad(),
                          gloveWeightLookupTable.getSyn0(),
                          gloveWeightLookupTable.getBias(),
                          w1,
                          w1Vector,
                          w2Vector,
                          gradient);
                  Pair<INDArray, Double> w2Update =
                      update(
                          gloveWeightLookupTable.getWeightAdaGrad(),
                          gloveWeightLookupTable.getBiasAdaGrad(),
                          gloveWeightLookupTable.getSyn0(),
                          gloveWeightLookupTable.getBias(),
                          w2,
                          w2Vector,
                          w1Vector,
                          gradient);
                  return new GloveChange(
                      w1,
                      w2,
                      w1Update.getFirst(),
                      w2Update.getFirst(),
                      w1Update.getSecond(),
                      w2Update.getSecond(),
                      fDiff);
                }
              });

      JavaRDD<Double> error =
          change.map(
              new Function<GloveChange, Double>() {
                @Override
                public Double call(GloveChange gloveChange) throws Exception {
                  gloveChange.apply(gloveWeightLookupTable);
                  return gloveChange.getError();
                }
              });

      final Accumulator<Double> d = sc.accumulator(0.0);
      error.foreach(
          new VoidFunction<Double>() {
            @Override
            public void call(Double aDouble) throws Exception {
              d.$plus$eq(aDouble);
            }
          });

      log.info("Error at iteration " + i + " was " + d.value());
    }

    return new Pair<>(vocabAndNumWords.getFirst(), gloveWeightLookupTable);
  }