/** * Writes the word vectors to the given path. Note that this assumes an in memory cache * * @param vec the word2vec to write * @param path the path to write * @throws IOException */ public static void writeWordVectors(Word2Vec vec, String path) throws IOException { BufferedWriter write = new BufferedWriter(new FileWriter(new File(path), false)); int words = 0; for (String word : vec.vocab().words()) { if (word == null) { continue; } StringBuilder sb = new StringBuilder(); sb.append(word.replaceAll(" ", "_")); sb.append(" "); INDArray wordVector = vec.getWordVectorMatrix(word); for (int j = 0; j < wordVector.length(); j++) { sb.append(wordVector.getDouble(j)); if (j < wordVector.length() - 1) { sb.append(" "); } } sb.append("\n"); write.write(sb.toString()); words++; } log.info("Wrote " + words + " with size of " + vec.lookupTable().layerSize()); write.flush(); write.close(); }
/** * Write the tsne format * * @param vec the word vectors to use for labeling * @param tsne the tsne array to write * @param csv the file to use * @throws Exception */ public static void writeTsneFormat(Word2Vec vec, INDArray tsne, File csv) throws Exception { BufferedWriter write = new BufferedWriter(new FileWriter(csv)); int words = 0; InMemoryLookupCache l = (InMemoryLookupCache) vec.vocab(); for (String word : vec.vocab().words()) { if (word == null) { continue; } StringBuilder sb = new StringBuilder(); INDArray wordVector = tsne.getRow(l.wordFor(word).getIndex()); for (int j = 0; j < wordVector.length(); j++) { sb.append(wordVector.getDouble(j)); if (j < wordVector.length() - 1) { sb.append(","); } } sb.append(","); sb.append(word); sb.append(" "); sb.append("\n"); write.write(sb.toString()); } log.info("Wrote " + words + " with size of " + vec.lookupTable().layerSize()); write.flush(); write.close(); }
@Test public void testWordsNearestFlat1() throws Exception { vec.setModelUtils(new FlatModelUtils<VocabWord>()); Collection<String> list = vec.wordsNearest("energy", 10); log.info("Flat model results:"); printWords("energy", list, vec); }
/** * @return * @throws NumberFormatException * @throws IOException * @throws FileNotFoundException */ private static Word2Vec readBinaryModel(File modelFile) throws NumberFormatException, IOException { InMemoryLookupTable lookupTable; VocabCache cache; INDArray syn0; int words, size; try (BufferedInputStream bis = new BufferedInputStream( GzipUtils.isCompressedFilename(modelFile.getName()) ? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile)); DataInputStream dis = new DataInputStream(bis)) { words = Integer.parseInt(readString(dis)); size = Integer.parseInt(readString(dis)); syn0 = Nd4j.create(words, size); cache = new InMemoryLookupCache(false); lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().cache(cache).vectorLength(size).build(); String word; for (int i = 0; i < words; i++) { word = readString(dis); log.trace("Loading " + word + " with word " + i); if (word.isEmpty()) { continue; } float[] vector = new float[size]; for (int j = 0; j < size; j++) { vector[j] = readFloat(dis); } syn0.putRow(i, Transforms.unitVec(Nd4j.create(vector))); cache.addWordToIndex(cache.numWords(), word); cache.addToken(new VocabWord(1, word)); cache.putVocabWord(word); } } Word2Vec ret = new Word2Vec(); lookupTable.setSyn0(syn0); ret.setVocab(cache); ret.setLookupTable(lookupTable); return ret; }
/** * @param modelFile * @return * @throws FileNotFoundException * @throws IOException * @throws NumberFormatException */ private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException { InMemoryLookupTable lookupTable; VocabCache cache; INDArray syn0; BufferedReader reader = new BufferedReader(new FileReader(modelFile)); String line = reader.readLine(); String[] initial = line.split(" "); int words = Integer.parseInt(initial[0]); int layerSize = Integer.parseInt(initial[1]); syn0 = Nd4j.create(words, layerSize); cache = new InMemoryLookupCache(); int currLine = 0; while ((line = reader.readLine()) != null) { String[] split = line.split(" "); String word = split[0]; if (word.isEmpty()) { continue; } float[] vector = new float[split.length - 1]; for (int i = 1; i < split.length; i++) { vector[i - 1] = Float.parseFloat(split[i]); } syn0.putRow(currLine, Transforms.unitVec(Nd4j.create(vector))); cache.addWordToIndex(cache.numWords(), word); cache.addToken(new VocabWord(1, word)); cache.putVocabWord(word); } lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().cache(cache).vectorLength(layerSize).build(); lookupTable.setSyn0(syn0); Word2Vec ret = new Word2Vec(); ret.setVocab(cache); ret.setLookupTable(lookupTable); reader.close(); return ret; }
@Before public void init() throws Exception { vectorizer = new TreeVectorizer(); tokenizerFactory = new UimaTokenizerFactory(false); sentenceIter = new CollectionSentenceIterator(Arrays.asList(sentence)); File vectors = new File("wordvectors.ser"); if (!vectors.exists()) { vec = new Word2Vec.Builder().iterate(sentenceIter).build(); vec.fit(); SerializationUtils.saveObject(vec, new File("wordvectors.ser")); } else { vec = SerializationUtils.readObject(vectors); vec.setCache(new InMemoryLookupCache(vec.getLayerSize())); } }
@Test public void testWordsNearestBasic1() throws Exception { // WordVectors vec = WordVectorSerializer.loadTxtVectors(new // File("/ext/Temp/Models/model.dat_trans")); vec.setModelUtils(new BasicModelUtils<VocabWord>()); String target = "energy"; INDArray arr1 = vec.getWordVectorMatrix(target).dup(); System.out.println("[-]: " + arr1); System.out.println("[+]: " + Transforms.unitVec(arr1)); Collection<String> list = vec.wordsNearest(target, 10); log.info("Transpose model results:"); printWords(target, list, vec); list = vec.wordsNearest(target, 10); log.info("Transpose model results 2:"); printWords(target, list, vec); list = vec.wordsNearest(target, 10); log.info("Transpose model results 3:"); printWords(target, list, vec); INDArray arr2 = vec.getWordVectorMatrix(target).dup(); assertEquals(arr1, arr2); }
public Builder setFeatureVectors(Word2Vec vec) { vocabCache = vec.vocab(); return setFeatureVectors(vec.lookupTable()); }