public static INDArray getWordVectorMatrix( INDArray syn0, InMemoryLookupCache vocab, String word, int k, int K) { if (word == null || k > K) return null; int idx = vocab.indexOf(word); if (idx < 0) idx = vocab.indexOf(org.deeplearning4j.models.word2vec.Word2Vec.UNK); return syn0.getRow(vocab.numWords() * k + idx); }
public static Collection<String> wordsNearest( INDArray syn0, InMemoryLookupCache vocab, String word, int k, int n, int K) { INDArray vector = Transforms.unitVec(getWordVectorMatrix(syn0, vocab, word, k, K)); INDArray similarity = vector.mmul(syn0.transpose()); List<Double> highToLowSimList = getTopN(similarity, n); List<String> ret = new ArrayList(); for (int i = 1; i < highToLowSimList.size(); i++) { word = vocab.wordAtIndex(highToLowSimList.get(i).intValue() % vocab.numWords()) + "(" + highToLowSimList.get(i).intValue() / vocab.numWords() + ")"; if (word != null && !word.equals("UNK") && !word.equals("STOP")) { ret.add(word); if (ret.size() >= n) { break; } } } return ret; }
private static void addTokenToVocabCache(InMemoryLookupCache vocab, String stringToken) { // Making string token into actual token if not already an actual token (vocabWord) VocabWord actualToken; if (vocab.hasToken(stringToken)) { actualToken = vocab.tokenFor(stringToken); } else { actualToken = new VocabWord(1, stringToken); } // Set the index of the actual token (vocabWord) // Put vocabWord into vocabs in InMemoryVocabCache boolean vocabContainsWord = vocab.containsWord(stringToken); if (!vocabContainsWord) { vocab.addToken(actualToken); int idx = vocab.numWords(); actualToken.setIndex(idx); vocab.putVocabWord(stringToken); } }