/**
   * Writes the word vectors to the given path. Note that this assumes an in memory cache
   *
   * @param lookupTable
   * @param cache
   * @param path the path to write
   * @throws IOException
   */
  public static void writeWordVectors(
      InMemoryLookupTable lookupTable, InMemoryLookupCache cache, String path) throws IOException {
    BufferedWriter write = new BufferedWriter(new FileWriter(new File(path), false));
    for (int i = 0; i < lookupTable.getSyn0().rows(); i++) {
      String word = cache.wordAtIndex(i);
      if (word == null) {
        continue;
      }
      StringBuilder sb = new StringBuilder();
      sb.append(word.replaceAll(" ", "_"));
      sb.append(" ");
      INDArray wordVector = lookupTable.vector(word);
      for (int j = 0; j < wordVector.length(); j++) {
        sb.append(wordVector.getDouble(j));
        if (j < wordVector.length() - 1) {
          sb.append(" ");
        }
      }
      sb.append("\n");
      write.write(sb.toString());
    }

    write.flush();
    write.close();
  }
Esempio n. 2
0
  public static Collection<String> wordsNearest(
      INDArray syn0, InMemoryLookupCache vocab, String word, int k, int n, int K) {

    INDArray vector = Transforms.unitVec(getWordVectorMatrix(syn0, vocab, word, k, K));
    INDArray similarity = vector.mmul(syn0.transpose());
    List<Double> highToLowSimList = getTopN(similarity, n);
    List<String> ret = new ArrayList();

    for (int i = 1; i < highToLowSimList.size(); i++) {
      word =
          vocab.wordAtIndex(highToLowSimList.get(i).intValue() % vocab.numWords())
              + "("
              + highToLowSimList.get(i).intValue() / vocab.numWords()
              + ")";
      if (word != null && !word.equals("UNK") && !word.equals("STOP")) {
        ret.add(word);
        if (ret.size() >= n) {
          break;
        }
      }
    }

    return ret;
  }