/**
   * Writes the word vectors to the given path. Note that this assumes an in memory cache
   *
   * @param vec the word2vec to write
   * @param path the path to write
   * @throws IOException
   */
  public static void writeWordVectors(Word2Vec vec, String path) throws IOException {
    BufferedWriter write = new BufferedWriter(new FileWriter(new File(path), false));
    int words = 0;
    for (String word : vec.vocab().words()) {
      if (word == null) {
        continue;
      }
      StringBuilder sb = new StringBuilder();
      sb.append(word.replaceAll(" ", "_"));
      sb.append(" ");
      INDArray wordVector = vec.getWordVectorMatrix(word);
      for (int j = 0; j < wordVector.length(); j++) {
        sb.append(wordVector.getDouble(j));
        if (j < wordVector.length() - 1) {
          sb.append(" ");
        }
      }
      sb.append("\n");
      write.write(sb.toString());
      words++;
    }

    log.info("Wrote " + words + " with size of " + vec.lookupTable().layerSize());
    write.flush();
    write.close();
  }
  /**
   * Write the tsne format
   *
   * @param vec the word vectors to use for labeling
   * @param tsne the tsne array to write
   * @param csv the file to use
   * @throws Exception
   */
  public static void writeTsneFormat(Word2Vec vec, INDArray tsne, File csv) throws Exception {
    BufferedWriter write = new BufferedWriter(new FileWriter(csv));
    int words = 0;
    InMemoryLookupCache l = (InMemoryLookupCache) vec.vocab();
    for (String word : vec.vocab().words()) {
      if (word == null) {
        continue;
      }
      StringBuilder sb = new StringBuilder();
      INDArray wordVector = tsne.getRow(l.wordFor(word).getIndex());
      for (int j = 0; j < wordVector.length(); j++) {
        sb.append(wordVector.getDouble(j));
        if (j < wordVector.length() - 1) {
          sb.append(",");
        }
      }
      sb.append(",");
      sb.append(word);
      sb.append(" ");

      sb.append("\n");
      write.write(sb.toString());
    }

    log.info("Wrote " + words + " with size of " + vec.lookupTable().layerSize());
    write.flush();
    write.close();
  }
  @Test
  public void testWordsNearestFlat1() throws Exception {
    vec.setModelUtils(new FlatModelUtils<VocabWord>());

    Collection<String> list = vec.wordsNearest("energy", 10);
    log.info("Flat model results:");
    printWords("energy", list, vec);
  }
  /**
   * @return
   * @throws NumberFormatException
   * @throws IOException
   * @throws FileNotFoundException
   */
  private static Word2Vec readBinaryModel(File modelFile)
      throws NumberFormatException, IOException {
    InMemoryLookupTable lookupTable;
    VocabCache cache;
    INDArray syn0;
    int words, size;
    try (BufferedInputStream bis =
            new BufferedInputStream(
                GzipUtils.isCompressedFilename(modelFile.getName())
                    ? new GZIPInputStream(new FileInputStream(modelFile))
                    : new FileInputStream(modelFile));
        DataInputStream dis = new DataInputStream(bis)) {
      words = Integer.parseInt(readString(dis));
      size = Integer.parseInt(readString(dis));
      syn0 = Nd4j.create(words, size);
      cache = new InMemoryLookupCache(false);
      lookupTable =
          (InMemoryLookupTable)
              new InMemoryLookupTable.Builder().cache(cache).vectorLength(size).build();

      String word;
      for (int i = 0; i < words; i++) {

        word = readString(dis);
        log.trace("Loading " + word + " with word " + i);
        if (word.isEmpty()) {
          continue;
        }

        float[] vector = new float[size];

        for (int j = 0; j < size; j++) {
          vector[j] = readFloat(dis);
        }

        syn0.putRow(i, Transforms.unitVec(Nd4j.create(vector)));

        cache.addWordToIndex(cache.numWords(), word);
        cache.addToken(new VocabWord(1, word));
        cache.putVocabWord(word);
      }
    }

    Word2Vec ret = new Word2Vec();

    lookupTable.setSyn0(syn0);
    ret.setVocab(cache);
    ret.setLookupTable(lookupTable);
    return ret;
  }
  /**
   * @param modelFile
   * @return
   * @throws FileNotFoundException
   * @throws IOException
   * @throws NumberFormatException
   */
  private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
    InMemoryLookupTable lookupTable;
    VocabCache cache;
    INDArray syn0;
    BufferedReader reader = new BufferedReader(new FileReader(modelFile));
    String line = reader.readLine();
    String[] initial = line.split(" ");
    int words = Integer.parseInt(initial[0]);
    int layerSize = Integer.parseInt(initial[1]);
    syn0 = Nd4j.create(words, layerSize);

    cache = new InMemoryLookupCache();

    int currLine = 0;
    while ((line = reader.readLine()) != null) {
      String[] split = line.split(" ");
      String word = split[0];

      if (word.isEmpty()) {
        continue;
      }

      float[] vector = new float[split.length - 1];
      for (int i = 1; i < split.length; i++) {
        vector[i - 1] = Float.parseFloat(split[i]);
      }

      syn0.putRow(currLine, Transforms.unitVec(Nd4j.create(vector)));

      cache.addWordToIndex(cache.numWords(), word);
      cache.addToken(new VocabWord(1, word));
      cache.putVocabWord(word);
    }

    lookupTable =
        (InMemoryLookupTable)
            new InMemoryLookupTable.Builder().cache(cache).vectorLength(layerSize).build();
    lookupTable.setSyn0(syn0);

    Word2Vec ret = new Word2Vec();
    ret.setVocab(cache);
    ret.setLookupTable(lookupTable);

    reader.close();
    return ret;
  }
  @Before
  public void init() throws Exception {
    vectorizer = new TreeVectorizer();
    tokenizerFactory = new UimaTokenizerFactory(false);
    sentenceIter = new CollectionSentenceIterator(Arrays.asList(sentence));
    File vectors = new File("wordvectors.ser");
    if (!vectors.exists()) {
      vec = new Word2Vec.Builder().iterate(sentenceIter).build();
      vec.fit();

      SerializationUtils.saveObject(vec, new File("wordvectors.ser"));

    } else {
      vec = SerializationUtils.readObject(vectors);
      vec.setCache(new InMemoryLookupCache(vec.getLayerSize()));
    }
  }
  @Test
  public void testWordsNearestBasic1() throws Exception {

    // WordVectors vec = WordVectorSerializer.loadTxtVectors(new
    // File("/ext/Temp/Models/model.dat_trans"));
    vec.setModelUtils(new BasicModelUtils<VocabWord>());

    String target = "energy";

    INDArray arr1 = vec.getWordVectorMatrix(target).dup();

    System.out.println("[-]: " + arr1);
    System.out.println("[+]: " + Transforms.unitVec(arr1));

    Collection<String> list = vec.wordsNearest(target, 10);
    log.info("Transpose model results:");
    printWords(target, list, vec);

    list = vec.wordsNearest(target, 10);
    log.info("Transpose model results 2:");
    printWords(target, list, vec);

    list = vec.wordsNearest(target, 10);
    log.info("Transpose model results 3:");
    printWords(target, list, vec);

    INDArray arr2 = vec.getWordVectorMatrix(target).dup();

    assertEquals(arr1, arr2);
  }
Esempio n. 8
0
 public Builder setFeatureVectors(Word2Vec vec) {
   vocabCache = vec.vocab();
   return setFeatureVectors(vec.lookupTable());
 }