/** * Write the tsne format * * @param vec the word vectors to use for labeling * @param tsne the tsne array to write * @param csv the file to use * @throws Exception */ public static void writeTsneFormat(Word2Vec vec, INDArray tsne, File csv) throws Exception { BufferedWriter write = new BufferedWriter(new FileWriter(csv)); int words = 0; InMemoryLookupCache l = (InMemoryLookupCache) vec.vocab(); for (String word : vec.vocab().words()) { if (word == null) { continue; } StringBuilder sb = new StringBuilder(); INDArray wordVector = tsne.getRow(l.wordFor(word).getIndex()); for (int j = 0; j < wordVector.length(); j++) { sb.append(wordVector.getDouble(j)); if (j < wordVector.length() - 1) { sb.append(","); } } sb.append(","); sb.append(word); sb.append(" "); sb.append("\n"); write.write(sb.toString()); } log.info("Wrote " + words + " with size of " + vec.lookupTable().layerSize()); write.flush(); write.close(); }
public static INDArray getWordVectorMatrix( INDArray syn0, InMemoryLookupCache vocab, String word, int k, int K) { if (word == null || k > K) return null; int idx = vocab.indexOf(word); if (idx < 0) idx = vocab.indexOf(org.deeplearning4j.models.word2vec.Word2Vec.UNK); return syn0.getRow(vocab.numWords() * k + idx); }
/** * Writes the word vectors to the given path. Note that this assumes an in memory cache * * @param lookupTable * @param cache * @param path the path to write * @throws IOException */ public static void writeWordVectors( InMemoryLookupTable lookupTable, InMemoryLookupCache cache, String path) throws IOException { BufferedWriter write = new BufferedWriter(new FileWriter(new File(path), false)); for (int i = 0; i < lookupTable.getSyn0().rows(); i++) { String word = cache.wordAtIndex(i); if (word == null) { continue; } StringBuilder sb = new StringBuilder(); sb.append(word.replaceAll(" ", "_")); sb.append(" "); INDArray wordVector = lookupTable.vector(word); for (int j = 0; j < wordVector.length(); j++) { sb.append(wordVector.getDouble(j)); if (j < wordVector.length() - 1) { sb.append(" "); } } sb.append("\n"); write.write(sb.toString()); } write.flush(); write.close(); }
/** * This method is required for compatibility purposes. It just transfers vocabulary from * VocabHolder into VocabCache * * @param cache */ public void transferBackToVocabCache(VocabCache cache, boolean emptyHolder) { if (!(cache instanceof InMemoryLookupCache)) throw new IllegalStateException("Sorry, only InMemoryLookupCache use implemented."); // make sure that huffman codes are updated before transfer List<VocabularyWord> words = words(); // updateHuffmanCodes(); for (VocabularyWord word : words) { if (word.getWord().isEmpty()) continue; VocabWord vocabWord = new VocabWord(1, word.getWord()); // if we're transferring full model, it CAN contain HistoricalGradient for AdaptiveGradient // feature if (word.getHistoricalGradient() != null) { INDArray gradient = Nd4j.create(word.getHistoricalGradient()); vocabWord.setHistoricalGradient(gradient); } // put VocabWord into both Tokens and Vocabs maps ((InMemoryLookupCache) cache).getVocabs().put(word.getWord(), vocabWord); ((InMemoryLookupCache) cache).getTokens().put(word.getWord(), vocabWord); // update Huffman tree information if (word.getHuffmanNode() != null) { vocabWord.setIndex(word.getHuffmanNode().getIdx()); vocabWord.setCodeLength(word.getHuffmanNode().getLength()); vocabWord.setPoints( arrayToList(word.getHuffmanNode().getPoint(), word.getHuffmanNode().getLength())); vocabWord.setCodes( arrayToList(word.getHuffmanNode().getCode(), word.getHuffmanNode().getLength())); // put word into index cache.addWordToIndex(word.getHuffmanNode().getIdx(), word.getWord()); } // update vocabWord counter. substract 1, since its the base value for any token // >1 hack is required since VocabCache impl imples 1 as base word count, not 0 if (word.getCount() > 1) cache.incrementWordCount(word.getWord(), word.getCount() - 1); } // at this moment its pretty safe to nullify all vocabs. if (emptyHolder) { idxMap.clear(); vocabulary.clear(); } }
@Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; InMemoryLookupCache that = (InMemoryLookupCache) o; if (numDocs != that.numDocs) return false; if (wordIndex != null ? !wordIndex.equals(that.wordIndex) : that.wordIndex != null) return false; if (wordFrequencies != null ? !wordFrequencies.equals(that.wordFrequencies) : that.wordFrequencies != null) return false; if (docFrequencies != null ? !docFrequencies.equals(that.docFrequencies) : that.docFrequencies != null) return false; if (vocabWords().equals(that.vocabWords())) return true; return true; }
/** * Load a look up cache from an input stream delimited by \n * * @param from the input stream to read from * @return the in memory lookup cache */ public static InMemoryLookupCache load(InputStream from) { Reader inputStream = new InputStreamReader(from); LineIterator iter = IOUtils.lineIterator(inputStream); String line; InMemoryLookupCache ret = new InMemoryLookupCache(); int count = 0; while ((iter.hasNext())) { line = iter.nextLine(); if (line.isEmpty()) continue; ret.incrementWordCount(line); VocabWord word = new VocabWord(1.0, line); word.setIndex(count); ret.addToken(word); ret.addWordToIndex(count, line); ret.putVocabWord(line); count++; } return ret; }
public static Collection<String> wordsNearest( INDArray syn0, InMemoryLookupCache vocab, String word, int k, int n, int K) { INDArray vector = Transforms.unitVec(getWordVectorMatrix(syn0, vocab, word, k, K)); INDArray similarity = vector.mmul(syn0.transpose()); List<Double> highToLowSimList = getTopN(similarity, n); List<String> ret = new ArrayList(); for (int i = 1; i < highToLowSimList.size(); i++) { word = vocab.wordAtIndex(highToLowSimList.get(i).intValue() % vocab.numWords()) + "(" + highToLowSimList.get(i).intValue() / vocab.numWords() + ")"; if (word != null && !word.equals("UNK") && !word.equals("STOP")) { ret.add(word); if (ret.size() >= n) { break; } } } return ret; }
private static void addTokenToVocabCache(InMemoryLookupCache vocab, String stringToken) { // Making string token into actual token if not already an actual token (vocabWord) VocabWord actualToken; if (vocab.hasToken(stringToken)) { actualToken = vocab.tokenFor(stringToken); } else { actualToken = new VocabWord(1, stringToken); } // Set the index of the actual token (vocabWord) // Put vocabWord into vocabs in InMemoryVocabCache boolean vocabContainsWord = vocab.containsWord(stringToken); if (!vocabContainsWord) { vocab.addToken(actualToken); int idx = vocab.numWords(); actualToken.setIndex(idx); vocab.putVocabWord(stringToken); } }