/** * All words with frequency below threshold wii be removed * * @param threshold exclusive threshold for removal */ public void truncateVocabulary(int threshold) { logger.debug("Truncating vocabulary to minWordFrequency: [" + threshold + "]"); Set<String> keyset = vocabulary.keySet(); for (String word : keyset) { VocabularyWord vw = vocabulary.get(word); // please note: we're not applying threshold to SPECIAL words if (!vw.isSpecial() && vw.getCount() < threshold) { vocabulary.remove(word); if (vw.getHuffmanNode() != null) idxMap.remove(vw.getHuffmanNode().getIdx()); } } }
/** * This method is required for compatibility purposes. It just transfers vocabulary from * VocabHolder into VocabCache * * @param cache */ public void transferBackToVocabCache(VocabCache cache, boolean emptyHolder) { if (!(cache instanceof InMemoryLookupCache)) throw new IllegalStateException("Sorry, only InMemoryLookupCache use implemented."); // make sure that huffman codes are updated before transfer List<VocabularyWord> words = words(); // updateHuffmanCodes(); for (VocabularyWord word : words) { if (word.getWord().isEmpty()) continue; VocabWord vocabWord = new VocabWord(1, word.getWord()); // if we're transferring full model, it CAN contain HistoricalGradient for AdaptiveGradient // feature if (word.getHistoricalGradient() != null) { INDArray gradient = Nd4j.create(word.getHistoricalGradient()); vocabWord.setHistoricalGradient(gradient); } // put VocabWord into both Tokens and Vocabs maps ((InMemoryLookupCache) cache).getVocabs().put(word.getWord(), vocabWord); ((InMemoryLookupCache) cache).getTokens().put(word.getWord(), vocabWord); // update Huffman tree information if (word.getHuffmanNode() != null) { vocabWord.setIndex(word.getHuffmanNode().getIdx()); vocabWord.setCodeLength(word.getHuffmanNode().getLength()); vocabWord.setPoints( arrayToList(word.getHuffmanNode().getPoint(), word.getHuffmanNode().getLength())); vocabWord.setCodes( arrayToList(word.getHuffmanNode().getCode(), word.getHuffmanNode().getLength())); // put word into index cache.addWordToIndex(word.getHuffmanNode().getIdx(), word.getWord()); } // update vocabWord counter. substract 1, since its the base value for any token // >1 hack is required since VocabCache impl imples 1 as base word count, not 0 if (word.getCount() > 1) cache.incrementWordCount(word.getWord(), word.getCount() - 1); } // at this moment its pretty safe to nullify all vocabs. if (emptyHolder) { idxMap.clear(); vocabulary.clear(); } }
/** * build binary tree ordered by counter. * * <p>Based on original w2v by google */ public List<VocabularyWord> updateHuffmanCodes() { int min1i; int min2i; int b; int i; // get vocabulary as sorted list List<VocabularyWord> vocab = this.words(); int count[] = new int[vocab.size() * 2 + 1]; int parent_node[] = new int[vocab.size() * 2 + 1]; byte binary[] = new byte[vocab.size() * 2 + 1]; // at this point vocab is sorted, with descending order for (int a = 0; a < vocab.size(); a++) count[a] = vocab.get(a).getCount(); for (int a = vocab.size(); a < vocab.size() * 2; a++) count[a] = Integer.MAX_VALUE; int pos1 = vocab.size() - 1; int pos2 = vocab.size(); for (int a = 0; a < vocab.size(); a++) { // First, find two smallest nodes 'min1, min2' if (pos1 >= 0) { if (count[pos1] < count[pos2]) { min1i = pos1; pos1--; } else { min1i = pos2; pos2++; } } else { min1i = pos2; pos2++; } if (pos1 >= 0) { if (count[pos1] < count[pos2]) { min2i = pos1; pos1--; } else { min2i = pos2; pos2++; } } else { min2i = pos2; pos2++; } count[vocab.size() + a] = count[min1i] + count[min2i]; parent_node[min1i] = vocab.size() + a; parent_node[min2i] = vocab.size() + a; binary[min2i] = 1; } // Now assign binary code to each vocabulary word byte[] code = new byte[MAX_CODE_LENGTH]; int[] point = new int[MAX_CODE_LENGTH]; for (int a = 0; a < vocab.size(); a++) { b = a; i = 0; byte[] lcode = new byte[MAX_CODE_LENGTH]; int[] lpoint = new int[MAX_CODE_LENGTH]; while (true) { code[i] = binary[b]; point[i] = b; i++; b = parent_node[b]; if (b == vocab.size() * 2 - 2) break; } lpoint[0] = vocab.size() - 2; for (b = 0; b < i; b++) { lcode[i - b - 1] = code[b]; lpoint[i - b] = point[b] - vocab.size(); } vocab.get(a).setHuffmanNode(new HuffmanNode(lcode, lpoint, a, (byte) i)); } idxMap.clear(); for (VocabularyWord word : vocab) { idxMap.put(word.getHuffmanNode().getIdx(), word); } return vocab; }