/** * @return * @throws NumberFormatException * @throws IOException * @throws FileNotFoundException */ private static Word2Vec readBinaryModel(File modelFile) throws NumberFormatException, IOException { InMemoryLookupTable lookupTable; VocabCache cache; INDArray syn0; int words, size; try (BufferedInputStream bis = new BufferedInputStream( GzipUtils.isCompressedFilename(modelFile.getName()) ? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile)); DataInputStream dis = new DataInputStream(bis)) { words = Integer.parseInt(readString(dis)); size = Integer.parseInt(readString(dis)); syn0 = Nd4j.create(words, size); cache = new InMemoryLookupCache(false); lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().cache(cache).vectorLength(size).build(); String word; for (int i = 0; i < words; i++) { word = readString(dis); log.trace("Loading " + word + " with word " + i); if (word.isEmpty()) { continue; } float[] vector = new float[size]; for (int j = 0; j < size; j++) { vector[j] = readFloat(dis); } syn0.putRow(i, Transforms.unitVec(Nd4j.create(vector))); cache.addWordToIndex(cache.numWords(), word); cache.addToken(new VocabWord(1, word)); cache.putVocabWord(word); } } Word2Vec ret = new Word2Vec(); lookupTable.setSyn0(syn0); ret.setVocab(cache); ret.setLookupTable(lookupTable); return ret; }
/** * @param modelFile * @return * @throws FileNotFoundException * @throws IOException * @throws NumberFormatException */ private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException { InMemoryLookupTable lookupTable; VocabCache cache; INDArray syn0; BufferedReader reader = new BufferedReader(new FileReader(modelFile)); String line = reader.readLine(); String[] initial = line.split(" "); int words = Integer.parseInt(initial[0]); int layerSize = Integer.parseInt(initial[1]); syn0 = Nd4j.create(words, layerSize); cache = new InMemoryLookupCache(); int currLine = 0; while ((line = reader.readLine()) != null) { String[] split = line.split(" "); String word = split[0]; if (word.isEmpty()) { continue; } float[] vector = new float[split.length - 1]; for (int i = 1; i < split.length; i++) { vector[i - 1] = Float.parseFloat(split[i]); } syn0.putRow(currLine, Transforms.unitVec(Nd4j.create(vector))); cache.addWordToIndex(cache.numWords(), word); cache.addToken(new VocabWord(1, word)); cache.putVocabWord(word); } lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().cache(cache).vectorLength(layerSize).build(); lookupTable.setSyn0(syn0); Word2Vec ret = new Word2Vec(); ret.setVocab(cache); ret.setLookupTable(lookupTable); reader.close(); return ret; }
/** * Loads an in memory cache from the given path (sets syn0 and the vocab) * * @param vectorsFile the path of the file to load * @return * @throws FileNotFoundException */ public static Pair<InMemoryLookupTable, VocabCache> loadTxt(File vectorsFile) throws FileNotFoundException { BufferedReader write = new BufferedReader(new FileReader(vectorsFile)); VocabCache cache = new InMemoryLookupCache(); InMemoryLookupTable lookupTable; LineIterator iter = IOUtils.lineIterator(write); List<INDArray> arrays = new ArrayList<>(); while (iter.hasNext()) { String line = iter.nextLine(); String[] split = line.split(" "); String word = split[0]; VocabWord word1 = new VocabWord(1.0, word); cache.addToken(word1); cache.addWordToIndex(cache.numWords(), word); word1.setIndex(cache.numWords()); cache.putVocabWord(word); INDArray row = Nd4j.create(Nd4j.createBuffer(split.length - 1)); for (int i = 1; i < split.length; i++) { row.putScalar(i - 1, Float.parseFloat(split[i])); } arrays.add(row); } INDArray syn = Nd4j.create(new int[] {arrays.size(), arrays.get(0).columns()}); for (int i = 0; i < syn.rows(); i++) { syn.putRow(i, arrays.get(i)); } lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder() .vectorLength(arrays.get(0).columns()) .useAdaGrad(false) .cache(cache) .build(); Nd4j.clearNans(syn); lookupTable.setSyn0(syn); iter.close(); return new Pair<>(lookupTable, cache); }