/**
   * @param index
   * @param word
   */
  @Override
  public synchronized void addWordToIndex(int index, String word) {
    if (word == null || word.isEmpty())
      throw new IllegalArgumentException("Word can't be empty or null");

    if (!tokens.containsKey(word)) {
      VocabWord token = new VocabWord(1.0, word);
      tokens.put(word, token);
      wordFrequencies.incrementCount(word, 1.0);
    }

    /*
       If we're speaking about adding any word to index directly, it means it's going to be vocab word, not token
    */
    if (!vocabs.containsKey(word)) {
      VocabWord vw = tokenFor(word);
      vw.setIndex(index);
      vocabs.put(word, vw);
      vw.setIndex(index);
    }

    if (!wordFrequencies.containsKey(word)) wordFrequencies.incrementCount(word, 1);

    wordIndex.add(word, index);
  }
 @Override
 public void updateWordsOccurencies() {
   totalWordOccurrences.set(0);
   for (VocabWord word : vocabWords()) {
     totalWordOccurrences.addAndGet((long) word.getElementFrequency());
   }
 }
 public InMemoryLookupCache(boolean addUnk) {
   if (addUnk) {
     VocabWord word = new VocabWord(1.0, Word2Vec.UNK);
     word.setIndex(0);
     addToken(word);
     addWordToIndex(0, Word2Vec.UNK);
     putVocabWord(Word2Vec.UNK);
   }
 }
  /**
   * Increment the count for the given word by the amount increment
   *
   * @param word the word to increment the count for
   * @param increment the amount to increment by
   */
  @Override
  public synchronized void incrementWordCount(String word, int increment) {
    if (word == null || word.isEmpty())
      throw new IllegalArgumentException("Word can't be empty or null");
    wordFrequencies.incrementCount(word, increment);

    if (hasToken(word)) {
      VocabWord token = tokenFor(word);
      token.increment(increment);
    }
    totalWordOccurrences.set(totalWordOccurrences.get() + increment);
  }
 /** @param word */
 @Override
 public synchronized void putVocabWord(String word) {
   if (word == null || word.isEmpty())
     throw new IllegalArgumentException("Word can't be empty or null");
   // STOP and UNK are not added as tokens
   if (word.equals("STOP") || word.equals("UNK")) return;
   VocabWord token = tokenFor(word);
   if (token == null)
     throw new IllegalStateException("Word " + word + " not found as token in vocab");
   int ind = token.getIndex();
   addWordToIndex(ind, word);
   if (!hasToken(word))
     throw new IllegalStateException("Unable to add token " + word + " when not already a token");
   vocabs.put(word, token);
   wordIndex.add(word, token.getIndex());
 }
  private Pair<INDArray, Double> update(
      AdaGrad weightAdaGrad,
      AdaGrad biasAdaGrad,
      INDArray syn0,
      INDArray bias,
      VocabWord w1,
      INDArray wordVector,
      INDArray contextVector,
      double gradient) {
    // gradient for word vectors
    INDArray grad1 = contextVector.mul(gradient);
    INDArray update = weightAdaGrad.getGradient(grad1, w1.getIndex(), syn0.shape());

    double w1Bias = bias.getDouble(w1.getIndex());
    double biasGradient = biasAdaGrad.getGradient(gradient, w1.getIndex(), bias.shape());
    double update2 = w1Bias - biasGradient;
    return new Pair<>(update, update2);
  }
  /**
   * Loads an in memory cache from the given path (sets syn0 and the vocab)
   *
   * @param vectorsFile the path of the file to load
   * @return
   * @throws FileNotFoundException
   */
  public static Pair<InMemoryLookupTable, VocabCache> loadTxt(File vectorsFile)
      throws FileNotFoundException {
    BufferedReader write = new BufferedReader(new FileReader(vectorsFile));
    VocabCache cache = new InMemoryLookupCache();

    InMemoryLookupTable lookupTable;

    LineIterator iter = IOUtils.lineIterator(write);
    List<INDArray> arrays = new ArrayList<>();
    while (iter.hasNext()) {
      String line = iter.nextLine();
      String[] split = line.split(" ");
      String word = split[0];
      VocabWord word1 = new VocabWord(1.0, word);
      cache.addToken(word1);
      cache.addWordToIndex(cache.numWords(), word);
      word1.setIndex(cache.numWords());
      cache.putVocabWord(word);
      INDArray row = Nd4j.create(Nd4j.createBuffer(split.length - 1));
      for (int i = 1; i < split.length; i++) {
        row.putScalar(i - 1, Float.parseFloat(split[i]));
      }
      arrays.add(row);
    }

    INDArray syn = Nd4j.create(new int[] {arrays.size(), arrays.get(0).columns()});
    for (int i = 0; i < syn.rows(); i++) {
      syn.putRow(i, arrays.get(i));
    }

    lookupTable =
        (InMemoryLookupTable)
            new InMemoryLookupTable.Builder()
                .vectorLength(arrays.get(0).columns())
                .useAdaGrad(false)
                .cache(cache)
                .build();
    Nd4j.clearNans(syn);
    lookupTable.setSyn0(syn);

    iter.close();

    return new Pair<>(lookupTable, cache);
  }
Exemple #8
0
  private static void addTokenToVocabCache(InMemoryLookupCache vocab, String stringToken) {
    // Making string token into actual token if not already an actual token (vocabWord)
    VocabWord actualToken;
    if (vocab.hasToken(stringToken)) {
      actualToken = vocab.tokenFor(stringToken);
    } else {
      actualToken = new VocabWord(1, stringToken);
    }

    // Set the index of the actual token (vocabWord)
    // Put vocabWord into vocabs in InMemoryVocabCache
    boolean vocabContainsWord = vocab.containsWord(stringToken);
    if (!vocabContainsWord) {
      vocab.addToken(actualToken);
      int idx = vocab.numWords();
      actualToken.setIndex(idx);
      vocab.putVocabWord(stringToken);
    }
  }
  /**
   * Load a look up cache from an input stream delimited by \n
   *
   * @param from the input stream to read from
   * @return the in memory lookup cache
   */
  public static InMemoryLookupCache load(InputStream from) {
    Reader inputStream = new InputStreamReader(from);
    LineIterator iter = IOUtils.lineIterator(inputStream);
    String line;
    InMemoryLookupCache ret = new InMemoryLookupCache();
    int count = 0;
    while ((iter.hasNext())) {
      line = iter.nextLine();
      if (line.isEmpty()) continue;
      ret.incrementWordCount(line);
      VocabWord word = new VocabWord(1.0, line);
      word.setIndex(count);
      ret.addToken(word);
      ret.addWordToIndex(count, line);
      ret.putVocabWord(line);
      count++;
    }

    return ret;
  }
 @Override
 public void importVocabulary(VocabCache<VocabWord> vocabCache) {
   for (VocabWord word : vocabCache.vocabWords()) {
     if (vocabs.containsKey(word.getLabel())) {
       wordFrequencies.incrementCount(word.getLabel(), word.getElementFrequency());
     } else {
       tokens.put(word.getLabel(), word);
       vocabs.put(word.getLabel(), word);
       wordFrequencies.incrementCount(word.getLabel(), word.getElementFrequency());
     }
     totalWordOccurrences.addAndGet((long) word.getElementFrequency());
   }
 }
  private void addTokenToVocabCache(String stringToken, Double tokenCount) {
    // Making string token into actual token if not already an actual token (vocabWord)
    VocabWord actualToken;
    if (vocabCache.hasToken(stringToken)) {
      actualToken = vocabCache.tokenFor(stringToken);
      actualToken.increaseElementFrequency(tokenCount.intValue());
    } else {
      actualToken = new VocabWord(tokenCount, stringToken);
    }

    // Set the index of the actual token (vocabWord)
    // Put vocabWord into vocabs in InMemoryVocabCache
    boolean vocabContainsWord = vocabCache.containsWord(stringToken);
    if (!vocabContainsWord) {
      vocabCache.addToken(actualToken);

      int idx = vocabCache.numWords();
      actualToken.setIndex(idx);
      vocabCache.putVocabWord(stringToken);
    }
  }
  /**
   * Builds VocabularyHolder from VocabCache.
   *
   * <p>Basically we just ignore tokens, and transfer VocabularyWords, supposing that it's already
   * truncated by minWordFrequency.
   *
   * <p>Huffman tree data is ignored and recalculated, due to suspectable flaw in dl4j huffman impl,
   * and it's exsessive memory usage.
   *
   * <p>This code is required for compatibility between dl4j w2v implementation, and standalone w2v
   *
   * @param cache
   */
  protected VocabularyHolder(@NonNull VocabCache cache, boolean markAsSpecial) {
    this.vocabCache = cache;
    for (VocabWord word : cache.tokens()) {
      VocabularyWord vw = new VocabularyWord(word.getWord());
      vw.setCount((int) word.getWordFrequency());

      // since we're importing this word from external VocabCache, we'll assume that this word is
      // SPECIAL, and should NOT be affected by minWordFrequency
      vw.setSpecial(markAsSpecial);

      // please note: we don't transfer huffman data, since proper way is  to recalculate it after
      // new words being added
      if (word.getPoints() != null && !word.getPoints().isEmpty()) {
        vw.setHuffmanNode(
            buildNode(word.getCodes(), word.getPoints(), word.getCodeLength(), word.getIndex()));
      }

      vocabulary.put(vw.getWord(), vw);
    }

    // there's no sense building huffman tree just for UNK word
    if (numWords() > 1) updateHuffmanCodes();
    logger.info("Init from VocabCache is complete. " + numWords() + " word(s) were transferred.");
  }
  /**
   * This method is required for compatibility purposes. It just transfers vocabulary from
   * VocabHolder into VocabCache
   *
   * @param cache
   */
  public void transferBackToVocabCache(VocabCache cache, boolean emptyHolder) {
    if (!(cache instanceof InMemoryLookupCache))
      throw new IllegalStateException("Sorry, only InMemoryLookupCache use implemented.");

    // make sure that huffman codes are updated before transfer
    List<VocabularyWord> words = words(); // updateHuffmanCodes();

    for (VocabularyWord word : words) {
      if (word.getWord().isEmpty()) continue;
      VocabWord vocabWord = new VocabWord(1, word.getWord());

      // if we're transferring full model, it CAN contain HistoricalGradient for AdaptiveGradient
      // feature
      if (word.getHistoricalGradient() != null) {
        INDArray gradient = Nd4j.create(word.getHistoricalGradient());
        vocabWord.setHistoricalGradient(gradient);
      }

      // put VocabWord into both Tokens and Vocabs maps
      ((InMemoryLookupCache) cache).getVocabs().put(word.getWord(), vocabWord);
      ((InMemoryLookupCache) cache).getTokens().put(word.getWord(), vocabWord);

      // update Huffman tree information
      if (word.getHuffmanNode() != null) {
        vocabWord.setIndex(word.getHuffmanNode().getIdx());
        vocabWord.setCodeLength(word.getHuffmanNode().getLength());
        vocabWord.setPoints(
            arrayToList(word.getHuffmanNode().getPoint(), word.getHuffmanNode().getLength()));
        vocabWord.setCodes(
            arrayToList(word.getHuffmanNode().getCode(), word.getHuffmanNode().getLength()));

        // put word into index
        cache.addWordToIndex(word.getHuffmanNode().getIdx(), word.getWord());
      }

      // update vocabWord counter. substract 1, since its the base value for any token
      // >1 hack is required since VocabCache impl imples 1 as base word count, not 0
      if (word.getCount() > 1) cache.incrementWordCount(word.getWord(), word.getCount() - 1);
    }

    // at this moment its pretty safe to nullify all vocabs.
    if (emptyHolder) {
      idxMap.clear();
      vocabulary.clear();
    }
  }
 @Override
 public synchronized void addToken(VocabWord word) {
   tokens.put(word.getWord(), word);
 }
  public void iterateSample(VocabWord w1, VocabWord w2, double currentSentenceAlpha) {

    if (w1 == null || w2 == null || w2.getIndex() < 0 || w2.getIndex() == w1.getIndex()) return;
    final int currentWordIndex = w2.getIndex();

    // error for current word and context
    INDArray neu1e = Nd4j.create(vectorLength);

    // First iteration Syn0 is random numbers
    INDArray l1 = null;
    if (indexSyn0VecMap.containsKey(vocab.elementAtIndex(currentWordIndex))) {
      l1 = indexSyn0VecMap.get(vocab.elementAtIndex(currentWordIndex));
    } else {
      l1 = getRandomSyn0Vec(vectorLength, (long) currentWordIndex);
    }

    //
    for (int i = 0; i < w1.getCodeLength(); i++) {
      int code = w1.getCodes().get(i);
      int point = w1.getPoints().get(i);
      if (point < 0) throw new IllegalStateException("Illegal point " + point);
      // Point to
      INDArray syn1;
      if (pointSyn1VecMap.containsKey(point)) {
        syn1 = pointSyn1VecMap.get(point);
      } else {
        syn1 = Nd4j.zeros(1, vectorLength); // 1 row of vector length of zeros
        pointSyn1VecMap.put(point, syn1);
      }

      // Dot product of Syn0 and Syn1 vecs
      double dot = Nd4j.getBlasWrapper().level1().dot(vectorLength, 1.0, l1, syn1);

      if (dot < -maxExp || dot >= maxExp) continue;

      int idx = (int) ((dot + maxExp) * ((double) expTable.length / maxExp / 2.0));

      if (idx > expTable.length) continue;

      // score
      double f = expTable[idx];
      // gradient
      double g =
          (1 - code - f)
              * (useAdaGrad
                  ? w1.getGradient(i, currentSentenceAlpha, currentSentenceAlpha)
                  : currentSentenceAlpha);

      Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, syn1, neu1e);
      Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, l1, syn1);
    }

    int target = w1.getIndex();
    int label;
    // negative sampling
    if (negative > 0)
      for (int d = 0; d < negative + 1; d++) {
        if (d == 0) label = 1;
        else {
          nextRandom.set(nextRandom.get() * 25214903917L + 11);
          int idx = Math.abs((int) (nextRandom.get() >> 16) % negativeHolder.getTable().length());

          target = negativeHolder.getTable().getInt(idx);
          if (target <= 0) target = (int) nextRandom.get() % (vocab.numWords() - 1) + 1;

          if (target == w1.getIndex()) continue;
          label = 0;
        }

        if (target >= negativeHolder.getSyn1Neg().rows() || target < 0) continue;

        double f = Nd4j.getBlasWrapper().dot(l1, negativeHolder.getSyn1Neg().slice(target));
        double g;
        if (f > maxExp)
          g = useAdaGrad ? w1.getGradient(target, (label - 1), alpha) : (label - 1) * alpha;
        else if (f < -maxExp)
          g = label * (useAdaGrad ? w1.getGradient(target, alpha, alpha) : alpha);
        else {
          int idx = (int) ((f + maxExp) * (expTable.length / maxExp / 2));
          if (idx >= expTable.length) continue;

          g =
              useAdaGrad
                  ? w1.getGradient(target, label - expTable[idx], alpha)
                  : (label - expTable[idx]) * alpha;
        }

        Nd4j.getBlasWrapper().axpy((float) g, negativeHolder.getSyn1Neg().slice(target), neu1e);

        Nd4j.getBlasWrapper().axpy((float) g, l1, negativeHolder.getSyn1Neg().slice(target));
      }

    // Updated the Syn0 vector based on gradient. Syn0 is not random anymore.
    Nd4j.getBlasWrapper().level1().axpy(vectorLength, 1.0f, neu1e, l1);

    if (aff.get() == 0) {
      synchronized (this) {
        cid.set(EnvironmentUtils.buildCId());
        aff.set(EnvironmentUtils.buildEnvironment().getAvailableMemory());
      }
    }

    VocabWord word = vocab.elementAtIndex(currentWordIndex);
    word.setVocabId(cid.get());
    word.setAffinityId(aff.get());
    indexSyn0VecMap.put(word, l1);
  }
 @Override
 public void removeElement(VocabWord element) {
   removeElement(element.getLabel());
 }