/**
   * Words nearest based on positive and negative words * @param top the top n words
   *
   * @return the words nearest the mean of the words
   */
  @Override
  public Collection<String> wordsNearest(INDArray words, int top) {
    if (lookupTable instanceof InMemoryLookupTable) {
      InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;

      INDArray syn0 = l.getSyn0();

      if (!normalized) {
        synchronized (this) {
          if (!normalized) {
            syn0.diviColumnVector(syn0.norm1(1));
            normalized = true;
          }
        }
      }

      INDArray similarity = Transforms.unitVec(words).mmul(syn0.transpose());

      List<Double> highToLowSimList = getTopN(similarity, top + 20);

      List<WordSimilarity> result = new ArrayList<>();

      for (int i = 0; i < highToLowSimList.size(); i++) {
        String word = vocabCache.wordAtIndex(highToLowSimList.get(i).intValue());
        if (word != null && !word.equals("UNK") && !word.equals("STOP")) {
          INDArray otherVec = lookupTable.vector(word);
          double sim = Transforms.cosineSim(words, otherVec);

          result.add(new WordSimilarity(word, sim));
        }
      }

      Collections.sort(result, new SimilarityComparator());

      return getLabels(result, top);
    }

    Counter<String> distances = new Counter<>();

    for (String s : vocabCache.words()) {
      INDArray otherVec = lookupTable.vector(s);
      double sim = Transforms.cosineSim(words, otherVec);
      distances.incrementCount(s, sim);
    }

    distances.keepTopNKeys(top);
    return distances.keySet();
  }
  /**
   * Returns the similarity of 2 words. Result value will be in range [-1,1], where -1.0 is exact
   * opposite similarity, i.e. NO similarity, and 1.0 is total match of two word vectors. However,
   * most of time you'll see values in range [0,1], but that's something depends of training corpus.
   *
   * <p>Returns NaN if any of labels not exists in vocab, or any label is null
   *
   * @param label1 the first word
   * @param label2 the second word
   * @return a normalized similarity (cosine similarity)
   */
  @Override
  public double similarity(String label1, String label2) {
    if (label1 == null || label2 == null) {
      log.debug(
          "LABELS: "
              + label1
              + ": "
              + (label1 == null ? "null" : EXISTS)
              + ";"
              + label2
              + " vec2:"
              + (label2 == null ? "null" : EXISTS));
      return Double.NaN;
    }

    INDArray vec1 = lookupTable.vector(label1).dup();
    INDArray vec2 = lookupTable.vector(label2).dup();

    if (vec1 == null || vec2 == null) {
      log.debug(
          label1
              + ": "
              + (vec1 == null ? "null" : EXISTS)
              + ";"
              + label2
              + " vec2:"
              + (vec2 == null ? "null" : EXISTS));
      return Double.NaN;
    }

    if (label1.equals(label2)) return 1.0;

    return Transforms.cosineSim(vec1, vec2);
  }
  /**
   * Words nearest based on positive and negative words
   *
   * @param positive the positive words
   * @param negative the negative words
   * @param top the top n words
   * @return the words nearest the mean of the words
   */
  public Collection<String> wordsNearestSum(
      Collection<String> positive, Collection<String> negative, int top) {
    INDArray words = Nd4j.create(lookupTable().layerSize());
    Set<String> union = SetUtils.union(new HashSet<>(positive), new HashSet<>(negative));
    for (String s : positive) words.addi(lookupTable().vector(s));

    for (String s : negative) words.addi(lookupTable.vector(s).mul(-1));

    if (lookupTable() instanceof InMemoryLookupTable) {
      InMemoryLookupTable l = (InMemoryLookupTable) lookupTable();
      INDArray syn0 = l.getSyn0();
      INDArray weights = syn0.norm2(0).rdivi(1).muli(words);
      INDArray distances = syn0.mulRowVector(weights).sum(1);
      INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false);
      INDArray sort = sorted[0];
      List<String> ret = new ArrayList<>();
      if (top > sort.length()) top = sort.length();
      // there will be a redundant word
      int end = top;
      for (int i = 0; i < end; i++) {
        String word = vocab.wordAtIndex(sort.getInt(i));
        if (union.contains(word)) {
          end++;
          if (end >= sort.length()) break;
          continue;
        }

        String add = vocab().wordAtIndex(sort.getInt(i));
        if (add == null || add.equals("UNK") || add.equals("STOP")) {
          end++;
          if (end >= sort.length()) break;
          continue;
        }

        ret.add(vocab().wordAtIndex(sort.getInt(i)));
      }

      return ret;
    }

    Counter<String> distances = new Counter<>();

    for (String s : vocab().words()) {
      INDArray otherVec = getWordVectorMatrix(s);
      double sim = Transforms.cosineSim(words, otherVec);
      distances.incrementCount(s, sim);
    }

    distances.keepTopNKeys(top);
    return distances.keySet();
  }
  /**
   * Get the top n words most similar to the given word
   *
   * @param word the word to compare
   * @param n the n to get
   * @return the top n words
   */
  public Collection<String> wordsNearestSum(String word, int n) {
    INDArray vec = Transforms.unitVec(this.getWordVectorMatrix(word));

    if (lookupTable() instanceof InMemoryLookupTable) {
      InMemoryLookupTable l = (InMemoryLookupTable) lookupTable();
      INDArray syn0 = l.getSyn0();
      INDArray weights = syn0.norm2(0).rdivi(1).muli(vec);
      INDArray distances = syn0.mulRowVector(weights).sum(1);
      INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false);
      INDArray sort = sorted[0];
      List<String> ret = new ArrayList<>();
      SequenceElement word2 = vocab().wordFor(word);
      if (n > sort.length()) n = sort.length();
      // there will be a redundant word
      for (int i = 0; i < n + 1; i++) {
        if (sort.getInt(i) == word2.getIndex()) continue;
        String add = vocab().wordAtIndex(sort.getInt(i));
        if (add == null || add.equals("UNK") || add.equals("STOP")) {
          continue;
        }

        ret.add(vocab().wordAtIndex(sort.getInt(i)));
      }

      return ret;
    }

    if (vec == null) return new ArrayList<>();
    Counter<String> distances = new Counter<>();

    for (String s : vocab().words()) {
      if (s.equals(word)) continue;
      INDArray otherVec = getWordVectorMatrix(s);
      double sim = Transforms.cosineSim(vec, otherVec);
      distances.incrementCount(s, sim);
    }

    distances.keepTopNKeys(n);
    return distances.keySet();
  }
  /**
   * Words nearest based on positive and negative words * @param top the top n words
   *
   * @return the words nearest the mean of the words
   */
  @Override
  public Collection<String> wordsNearest(INDArray words, int top) {
    if (lookupTable() instanceof InMemoryLookupTable) {
      InMemoryLookupTable l = (InMemoryLookupTable) lookupTable();
      INDArray syn0 = l.getSyn0();
      INDArray weights = syn0.norm2(0).rdivi(1).muli(words);
      INDArray distances = syn0.mulRowVector(weights).mean(1);
      INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false);
      INDArray sort = sorted[0];
      List<String> ret = new ArrayList<>();
      if (top > sort.length()) top = sort.length();
      // there will be a redundant word
      int end = top;
      for (int i = 0; i < end; i++) {
        VocabCache vocabCache = vocab();
        int s = sort.getInt(0, i);
        String add = vocabCache.wordAtIndex(s);
        if (add == null || add.equals("UNK") || add.equals("STOP")) {
          end++;
          if (end >= sort.length()) break;
          continue;
        }

        ret.add(vocabCache.wordAtIndex(s));
      }

      return ret;
    }

    Counter<String> distances = new Counter<>();

    for (String s : vocab().words()) {
      INDArray otherVec = getWordVectorMatrix(s);
      double sim = Transforms.cosineSim(words, otherVec);
      distances.incrementCount(s, sim);
    }

    distances.keepTopNKeys(top);
    return distances.keySet();
  }
  /**
   * Get the top n words most similar to the given word
   *
   * @param word the word to compare
   * @param n the n to get
   * @return the top n words
   */
  public Collection<String> wordsNearest(String word, int n) {
    /*
       TODO: This is temporary solution and we should get rid of flat array scan. Probably, after VPTree implementation gets fixed
    */
    if (!vocab.hasToken(word)) return new ArrayList<>();

    INDArray mean = getWordVectorMatrix(word);

    Counter<String> distances = new Counter<>();

    for (String s : vocab().words()) {
      if (s.equals(word)) continue;

      INDArray otherVec = getWordVectorMatrix(s);
      double sim = Transforms.cosineSim(mean, otherVec);
      distances.incrementCount(s, sim);
    }

    distances.keepTopNKeys(n - 1);
    return distances.keySet();
    //        return wordsNearest(Arrays.asList(word),new ArrayList<String>(),n);
  }
  /**
   * Returns the similarity of 2 words. Result value will be in range [-1,1], where -1.0 is exact
   * opposite similarity, i.e. NO similarity, and 1.0 is total match of two word vectors. However,
   * most of time you'll see values in range [0,1], but that's something depends of training corpus.
   *
   * @param word the first word
   * @param word2 the second word
   * @return a normalized similarity (cosine similarity)
   */
  public double similarity(String word, String word2) {
    if (word.equals(word2)) return 1.0;

    if (getWordVectorMatrix(word) == null || getWordVectorMatrix(word2) == null) return -1;
    return Transforms.cosineSim(getWordVectorMatrix(word), getWordVectorMatrix(word2));
  }
  /**
   * Words nearest based on positive and negative words
   *
   * @param positive the positive words
   * @param negative the negative words
   * @param top the top n words
   * @return the words nearest the mean of the words
   */
  @Override
  public Collection<String> wordsNearest(
      Collection<String> positive, Collection<String> negative, int top) {
    // Check every word is in the model
    for (String p : SetUtils.union(new HashSet<>(positive), new HashSet<>(negative))) {
      if (!vocab().containsWord(p)) {
        return new ArrayList<>();
      }
    }

    WeightLookupTable weightLookupTable = lookupTable();
    INDArray words = Nd4j.create(positive.size() + negative.size(), weightLookupTable.layerSize());
    int row = 0;
    Set<String> union = SetUtils.union(new HashSet<>(positive), new HashSet<>(negative));
    for (String s : positive) {
      words.putRow(row++, weightLookupTable.vector(s));
    }

    for (String s : negative) {
      words.putRow(row++, weightLookupTable.vector(s).mul(-1));
    }

    INDArray mean = words.isMatrix() ? words.mean(0) : words;
    // TODO this should probably be replaced with wordsNearest(mean, top)
    if (weightLookupTable instanceof InMemoryLookupTable) {
      InMemoryLookupTable l = (InMemoryLookupTable) weightLookupTable;

      INDArray syn0 = l.getSyn0();
      syn0.diviRowVector(syn0.norm2(0));

      INDArray similarity = Transforms.unitVec(mean).mmul(syn0.transpose());
      // We assume that syn0 is normalized.
      // Hence, the following division is not needed anymore.
      // distances.diviRowVector(distances.norm2(1));
      // INDArray[] sorted = Nd4j.sortWithIndices(distances,0,false);
      List<Double> highToLowSimList = getTopN(similarity, top + union.size());
      List<String> ret = new ArrayList<>();

      for (int i = 0; i < highToLowSimList.size(); i++) {
        String word = vocab().wordAtIndex(highToLowSimList.get(i).intValue());
        if (word != null && !word.equals("UNK") && !word.equals("STOP") && !union.contains(word)) {
          ret.add(word);
          if (ret.size() >= top) {
            break;
          }
        }
      }

      return ret;
    }

    Counter<String> distances = new Counter<>();

    for (String s : vocab().words()) {
      INDArray otherVec = getWordVectorMatrix(s);
      double sim = Transforms.cosineSim(mean, otherVec);
      distances.incrementCount(s, sim);
    }

    distances.keepTopNKeys(top);
    return distances.keySet();
  }