/** * @param index * @param word */ @Override public synchronized void addWordToIndex(int index, String word) { if (word == null || word.isEmpty()) throw new IllegalArgumentException("Word can't be empty or null"); if (!tokens.containsKey(word)) { VocabWord token = new VocabWord(1.0, word); tokens.put(word, token); wordFrequencies.incrementCount(word, 1.0); } /* If we're speaking about adding any word to index directly, it means it's going to be vocab word, not token */ if (!vocabs.containsKey(word)) { VocabWord vw = tokenFor(word); vw.setIndex(index); vocabs.put(word, vw); vw.setIndex(index); } if (!wordFrequencies.containsKey(word)) wordFrequencies.incrementCount(word, 1); wordIndex.add(word, index); }
@Override public void updateWordsOccurencies() { totalWordOccurrences.set(0); for (VocabWord word : vocabWords()) { totalWordOccurrences.addAndGet((long) word.getElementFrequency()); } }
public InMemoryLookupCache(boolean addUnk) { if (addUnk) { VocabWord word = new VocabWord(1.0, Word2Vec.UNK); word.setIndex(0); addToken(word); addWordToIndex(0, Word2Vec.UNK); putVocabWord(Word2Vec.UNK); } }
/** * Increment the count for the given word by the amount increment * * @param word the word to increment the count for * @param increment the amount to increment by */ @Override public synchronized void incrementWordCount(String word, int increment) { if (word == null || word.isEmpty()) throw new IllegalArgumentException("Word can't be empty or null"); wordFrequencies.incrementCount(word, increment); if (hasToken(word)) { VocabWord token = tokenFor(word); token.increment(increment); } totalWordOccurrences.set(totalWordOccurrences.get() + increment); }
/** @param word */ @Override public synchronized void putVocabWord(String word) { if (word == null || word.isEmpty()) throw new IllegalArgumentException("Word can't be empty or null"); // STOP and UNK are not added as tokens if (word.equals("STOP") || word.equals("UNK")) return; VocabWord token = tokenFor(word); if (token == null) throw new IllegalStateException("Word " + word + " not found as token in vocab"); int ind = token.getIndex(); addWordToIndex(ind, word); if (!hasToken(word)) throw new IllegalStateException("Unable to add token " + word + " when not already a token"); vocabs.put(word, token); wordIndex.add(word, token.getIndex()); }
private Pair<INDArray, Double> update( AdaGrad weightAdaGrad, AdaGrad biasAdaGrad, INDArray syn0, INDArray bias, VocabWord w1, INDArray wordVector, INDArray contextVector, double gradient) { // gradient for word vectors INDArray grad1 = contextVector.mul(gradient); INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape()); double w1Bias = bias.getDouble(w1.getIndex()); double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape()); double update2 = w1Bias - biasGradient; return new Pair<>(update, update2); }
/** * Loads an in memory cache from the given path (sets syn0 and the vocab) * * @param vectorsFile the path of the file to load * @return * @throws FileNotFoundException */ public static Pair<InMemoryLookupTable, VocabCache> loadTxt(File vectorsFile) throws FileNotFoundException { BufferedReader write = new BufferedReader(new FileReader(vectorsFile)); VocabCache cache = new InMemoryLookupCache(); InMemoryLookupTable lookupTable; LineIterator iter = IOUtils.lineIterator(write); List<INDArray> arrays = new ArrayList<>(); while (iter.hasNext()) { String line = iter.nextLine(); String[] split = line.split(" "); String word = split[0]; VocabWord word1 = new VocabWord(1.0, word); cache.addToken(word1); cache.addWordToIndex(cache.numWords(), word); word1.setIndex(cache.numWords()); cache.putVocabWord(word); INDArray row = Nd4j.create(Nd4j.createBuffer(split.length - 1)); for (int i = 1; i < split.length; i++) { row.putScalar(i - 1, Float.parseFloat(split[i])); } arrays.add(row); } INDArray syn = Nd4j.create(new int[] {arrays.size(), arrays.get(0).columns()}); for (int i = 0; i < syn.rows(); i++) { syn.putRow(i, arrays.get(i)); } lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder() .vectorLength(arrays.get(0).columns()) .useAdaGrad(false) .cache(cache) .build(); Nd4j.clearNans(syn); lookupTable.setSyn0(syn); iter.close(); return new Pair<>(lookupTable, cache); }
private static void addTokenToVocabCache(InMemoryLookupCache vocab, String stringToken) { // Making string token into actual token if not already an actual token (vocabWord) VocabWord actualToken; if (vocab.hasToken(stringToken)) { actualToken = vocab.tokenFor(stringToken); } else { actualToken = new VocabWord(1, stringToken); } // Set the index of the actual token (vocabWord) // Put vocabWord into vocabs in InMemoryVocabCache boolean vocabContainsWord = vocab.containsWord(stringToken); if (!vocabContainsWord) { vocab.addToken(actualToken); int idx = vocab.numWords(); actualToken.setIndex(idx); vocab.putVocabWord(stringToken); } }
/** * Load a look up cache from an input stream delimited by \n * * @param from the input stream to read from * @return the in memory lookup cache */ public static InMemoryLookupCache load(InputStream from) { Reader inputStream = new InputStreamReader(from); LineIterator iter = IOUtils.lineIterator(inputStream); String line; InMemoryLookupCache ret = new InMemoryLookupCache(); int count = 0; while ((iter.hasNext())) { line = iter.nextLine(); if (line.isEmpty()) continue; ret.incrementWordCount(line); VocabWord word = new VocabWord(1.0, line); word.setIndex(count); ret.addToken(word); ret.addWordToIndex(count, line); ret.putVocabWord(line); count++; } return ret; }
@Override public void importVocabulary(VocabCache<VocabWord> vocabCache) { for (VocabWord word : vocabCache.vocabWords()) { if (vocabs.containsKey(word.getLabel())) { wordFrequencies.incrementCount(word.getLabel(), word.getElementFrequency()); } else { tokens.put(word.getLabel(), word); vocabs.put(word.getLabel(), word); wordFrequencies.incrementCount(word.getLabel(), word.getElementFrequency()); } totalWordOccurrences.addAndGet((long) word.getElementFrequency()); } }
private void addTokenToVocabCache(String stringToken, Double tokenCount) { // Making string token into actual token if not already an actual token (vocabWord) VocabWord actualToken; if (vocabCache.hasToken(stringToken)) { actualToken = vocabCache.tokenFor(stringToken); actualToken.increaseElementFrequency(tokenCount.intValue()); } else { actualToken = new VocabWord(tokenCount, stringToken); } // Set the index of the actual token (vocabWord) // Put vocabWord into vocabs in InMemoryVocabCache boolean vocabContainsWord = vocabCache.containsWord(stringToken); if (!vocabContainsWord) { vocabCache.addToken(actualToken); int idx = vocabCache.numWords(); actualToken.setIndex(idx); vocabCache.putVocabWord(stringToken); } }
/** * Builds VocabularyHolder from VocabCache. * * <p>Basically we just ignore tokens, and transfer VocabularyWords, supposing that it's already * truncated by minWordFrequency. * * <p>Huffman tree data is ignored and recalculated, due to suspectable flaw in dl4j huffman impl, * and it's exsessive memory usage. * * <p>This code is required for compatibility between dl4j w2v implementation, and standalone w2v * * @param cache */ protected VocabularyHolder(@NonNull VocabCache cache, boolean markAsSpecial) { this.vocabCache = cache; for (VocabWord word : cache.tokens()) { VocabularyWord vw = new VocabularyWord(word.getWord()); vw.setCount((int) word.getWordFrequency()); // since we're importing this word from external VocabCache, we'll assume that this word is // SPECIAL, and should NOT be affected by minWordFrequency vw.setSpecial(markAsSpecial); // please note: we don't transfer huffman data, since proper way is to recalculate it after // new words being added if (word.getPoints() != null && !word.getPoints().isEmpty()) { vw.setHuffmanNode( buildNode(word.getCodes(), word.getPoints(), word.getCodeLength(), word.getIndex())); } vocabulary.put(vw.getWord(), vw); } // there's no sense building huffman tree just for UNK word if (numWords() > 1) updateHuffmanCodes(); logger.info("Init from VocabCache is complete. " + numWords() + " word(s) were transferred."); }
/** * This method is required for compatibility purposes. It just transfers vocabulary from * VocabHolder into VocabCache * * @param cache */ public void transferBackToVocabCache(VocabCache cache, boolean emptyHolder) { if (!(cache instanceof InMemoryLookupCache)) throw new IllegalStateException("Sorry, only InMemoryLookupCache use implemented."); // make sure that huffman codes are updated before transfer List<VocabularyWord> words = words(); // updateHuffmanCodes(); for (VocabularyWord word : words) { if (word.getWord().isEmpty()) continue; VocabWord vocabWord = new VocabWord(1, word.getWord()); // if we're transferring full model, it CAN contain HistoricalGradient for AdaptiveGradient // feature if (word.getHistoricalGradient() != null) { INDArray gradient = Nd4j.create(word.getHistoricalGradient()); vocabWord.setHistoricalGradient(gradient); } // put VocabWord into both Tokens and Vocabs maps ((InMemoryLookupCache) cache).getVocabs().put(word.getWord(), vocabWord); ((InMemoryLookupCache) cache).getTokens().put(word.getWord(), vocabWord); // update Huffman tree information if (word.getHuffmanNode() != null) { vocabWord.setIndex(word.getHuffmanNode().getIdx()); vocabWord.setCodeLength(word.getHuffmanNode().getLength()); vocabWord.setPoints( arrayToList(word.getHuffmanNode().getPoint(), word.getHuffmanNode().getLength())); vocabWord.setCodes( arrayToList(word.getHuffmanNode().getCode(), word.getHuffmanNode().getLength())); // put word into index cache.addWordToIndex(word.getHuffmanNode().getIdx(), word.getWord()); } // update vocabWord counter. substract 1, since its the base value for any token // >1 hack is required since VocabCache impl imples 1 as base word count, not 0 if (word.getCount() > 1) cache.incrementWordCount(word.getWord(), word.getCount() - 1); } // at this moment its pretty safe to nullify all vocabs. if (emptyHolder) { idxMap.clear(); vocabulary.clear(); } }
@Override public synchronized void addToken(VocabWord word) { tokens.put(word.getWord(), word); }
public void iterateSample(VocabWord w1, VocabWord w2, double currentSentenceAlpha) { if (w1 == null || w2 == null || w2.getIndex() < 0 || w2.getIndex() == w1.getIndex()) return; final int currentWordIndex = w2.getIndex(); // error for current word and context INDArray neu1e = Nd4j.create(vectorLength); // First iteration Syn0 is random numbers INDArray l1 = null; if (indexSyn0VecMap.containsKey(vocab.elementAtIndex(currentWordIndex))) { l1 = indexSyn0VecMap.get(vocab.elementAtIndex(currentWordIndex)); } else { l1 = getRandomSyn0Vec(vectorLength, (long) currentWordIndex); } // for (int i = 0; i < w1.getCodeLength(); i++) { int code = w1.getCodes().get(i); int point = w1.getPoints().get(i); if (point < 0) throw new IllegalStateException("Illegal point " + point); // Point to INDArray syn1; if (pointSyn1VecMap.containsKey(point)) { syn1 = pointSyn1VecMap.get(point); } else { syn1 = Nd4j.zeros(1, vectorLength); // 1 row of vector length of zeros pointSyn1VecMap.put(point, syn1); } // Dot product of Syn0 and Syn1 vecs double dot = Nd4j.getBlasWrapper().level1().dot(vectorLength, 1.0, l1, syn1); if (dot < -maxExp || dot >= maxExp) continue; int idx = (int) ((dot + maxExp) * ((double) expTable.length / maxExp / 2.0)); if (idx > expTable.length) continue; // score double f = expTable[idx]; // gradient double g = (1 - code - f) * (useAdaGrad ? w1.getGradient(i, currentSentenceAlpha, currentSentenceAlpha) : currentSentenceAlpha); Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, syn1, neu1e); Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, l1, syn1); } int target = w1.getIndex(); int label; // negative sampling if (negative > 0) for (int d = 0; d < negative + 1; d++) { if (d == 0) label = 1; else { nextRandom.set(nextRandom.get() * 25214903917L + 11); int idx = Math.abs((int) (nextRandom.get() >> 16) % negativeHolder.getTable().length()); target = negativeHolder.getTable().getInt(idx); if (target <= 0) target = (int) nextRandom.get() % (vocab.numWords() - 1) + 1; if (target == w1.getIndex()) continue; label = 0; } if (target >= negativeHolder.getSyn1Neg().rows() || target < 0) continue; double f = Nd4j.getBlasWrapper().dot(l1, negativeHolder.getSyn1Neg().slice(target)); double g; if (f > maxExp) g = useAdaGrad ? w1.getGradient(target, (label - 1), alpha) : (label - 1) * alpha; else if (f < -maxExp) g = label * (useAdaGrad ? w1.getGradient(target, alpha, alpha) : alpha); else { int idx = (int) ((f + maxExp) * (expTable.length / maxExp / 2)); if (idx >= expTable.length) continue; g = useAdaGrad ? w1.getGradient(target, label - expTable[idx], alpha) : (label - expTable[idx]) * alpha; } Nd4j.getBlasWrapper().axpy((float) g, negativeHolder.getSyn1Neg().slice(target), neu1e); Nd4j.getBlasWrapper().axpy((float) g, l1, negativeHolder.getSyn1Neg().slice(target)); } // Updated the Syn0 vector based on gradient. Syn0 is not random anymore. Nd4j.getBlasWrapper().level1().axpy(vectorLength, 1.0f, neu1e, l1); if (aff.get() == 0) { synchronized (this) { cid.set(EnvironmentUtils.buildCId()); aff.set(EnvironmentUtils.buildEnvironment().getAvailableMemory()); } } VocabWord word = vocab.elementAtIndex(currentWordIndex); word.setVocabId(cid.get()); word.setAffinityId(aff.get()); indexSyn0VecMap.put(word, l1); }
@Override public void removeElement(VocabWord element) { removeElement(element.getLabel()); }