コード例 #1
0
  /**
   * @return
   * @throws NumberFormatException
   * @throws IOException
   * @throws FileNotFoundException
   */
  private static Word2Vec readBinaryModel(File modelFile)
      throws NumberFormatException, IOException {
    InMemoryLookupTable lookupTable;
    VocabCache cache;
    INDArray syn0;
    int words, size;
    try (BufferedInputStream bis =
            new BufferedInputStream(
                GzipUtils.isCompressedFilename(modelFile.getName())
                    ? new GZIPInputStream(new FileInputStream(modelFile))
                    : new FileInputStream(modelFile));
        DataInputStream dis = new DataInputStream(bis)) {
      words = Integer.parseInt(readString(dis));
      size = Integer.parseInt(readString(dis));
      syn0 = Nd4j.create(words, size);
      cache = new InMemoryLookupCache(false);
      lookupTable =
          (InMemoryLookupTable)
              new InMemoryLookupTable.Builder().cache(cache).vectorLength(size).build();

      String word;
      for (int i = 0; i < words; i++) {

        word = readString(dis);
        log.trace("Loading " + word + " with word " + i);
        if (word.isEmpty()) {
          continue;
        }

        float[] vector = new float[size];

        for (int j = 0; j < size; j++) {
          vector[j] = readFloat(dis);
        }

        syn0.putRow(i, Transforms.unitVec(Nd4j.create(vector)));

        cache.addWordToIndex(cache.numWords(), word);
        cache.addToken(new VocabWord(1, word));
        cache.putVocabWord(word);
      }
    }

    Word2Vec ret = new Word2Vec();

    lookupTable.setSyn0(syn0);
    ret.setVocab(cache);
    ret.setLookupTable(lookupTable);
    return ret;
  }
コード例 #2
0
  /**
   * @param modelFile
   * @return
   * @throws FileNotFoundException
   * @throws IOException
   * @throws NumberFormatException
   */
  private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
    InMemoryLookupTable lookupTable;
    VocabCache cache;
    INDArray syn0;
    BufferedReader reader = new BufferedReader(new FileReader(modelFile));
    String line = reader.readLine();
    String[] initial = line.split(" ");
    int words = Integer.parseInt(initial[0]);
    int layerSize = Integer.parseInt(initial[1]);
    syn0 = Nd4j.create(words, layerSize);

    cache = new InMemoryLookupCache();

    int currLine = 0;
    while ((line = reader.readLine()) != null) {
      String[] split = line.split(" ");
      String word = split[0];

      if (word.isEmpty()) {
        continue;
      }

      float[] vector = new float[split.length - 1];
      for (int i = 1; i < split.length; i++) {
        vector[i - 1] = Float.parseFloat(split[i]);
      }

      syn0.putRow(currLine, Transforms.unitVec(Nd4j.create(vector)));

      cache.addWordToIndex(cache.numWords(), word);
      cache.addToken(new VocabWord(1, word));
      cache.putVocabWord(word);
    }

    lookupTable =
        (InMemoryLookupTable)
            new InMemoryLookupTable.Builder().cache(cache).vectorLength(layerSize).build();
    lookupTable.setSyn0(syn0);

    Word2Vec ret = new Word2Vec();
    ret.setVocab(cache);
    ret.setLookupTable(lookupTable);

    reader.close();
    return ret;
  }
コード例 #3
0
  /**
   * Loads an in memory cache from the given path (sets syn0 and the vocab)
   *
   * @param vectorsFile the path of the file to load
   * @return
   * @throws FileNotFoundException
   */
  public static Pair<InMemoryLookupTable, VocabCache> loadTxt(File vectorsFile)
      throws FileNotFoundException {
    BufferedReader write = new BufferedReader(new FileReader(vectorsFile));
    VocabCache cache = new InMemoryLookupCache();

    InMemoryLookupTable lookupTable;

    LineIterator iter = IOUtils.lineIterator(write);
    List<INDArray> arrays = new ArrayList<>();
    while (iter.hasNext()) {
      String line = iter.nextLine();
      String[] split = line.split(" ");
      String word = split[0];
      VocabWord word1 = new VocabWord(1.0, word);
      cache.addToken(word1);
      cache.addWordToIndex(cache.numWords(), word);
      word1.setIndex(cache.numWords());
      cache.putVocabWord(word);
      INDArray row = Nd4j.create(Nd4j.createBuffer(split.length - 1));
      for (int i = 1; i < split.length; i++) {
        row.putScalar(i - 1, Float.parseFloat(split[i]));
      }
      arrays.add(row);
    }

    INDArray syn = Nd4j.create(new int[] {arrays.size(), arrays.get(0).columns()});
    for (int i = 0; i < syn.rows(); i++) {
      syn.putRow(i, arrays.get(i));
    }

    lookupTable =
        (InMemoryLookupTable)
            new InMemoryLookupTable.Builder()
                .vectorLength(arrays.get(0).columns())
                .useAdaGrad(false)
                .cache(cache)
                .build();
    Nd4j.clearNans(syn);
    lookupTable.setSyn0(syn);

    iter.close();

    return new Pair<>(lookupTable, cache);
  }