/** * Writes the word vectors to the given path. Note that this assumes an in memory cache * * @param lookupTable * @param cache * @param path the path to write * @throws IOException */ public static void writeWordVectors( InMemoryLookupTable lookupTable, InMemoryLookupCache cache, String path) throws IOException { BufferedWriter write = new BufferedWriter(new FileWriter(new File(path), false)); for (int i = 0; i < lookupTable.getSyn0().rows(); i++) { String word = cache.wordAtIndex(i); if (word == null) { continue; } StringBuilder sb = new StringBuilder(); sb.append(word.replaceAll(" ", "_")); sb.append(" "); INDArray wordVector = lookupTable.vector(word); for (int j = 0; j < wordVector.length(); j++) { sb.append(wordVector.getDouble(j)); if (j < wordVector.length() - 1) { sb.append(" "); } } sb.append("\n"); write.write(sb.toString()); } write.flush(); write.close(); }
/** * Words nearest based on positive and negative words * * @param positive the positive words * @param negative the negative words * @param top the top n words * @return the words nearest the mean of the words */ public Collection<String> wordsNearestSum( Collection<String> positive, Collection<String> negative, int top) { INDArray words = Nd4j.create(lookupTable().layerSize()); Set<String> union = SetUtils.union(new HashSet<>(positive), new HashSet<>(negative)); for (String s : positive) words.addi(lookupTable().vector(s)); for (String s : negative) words.addi(lookupTable.vector(s).mul(-1)); if (lookupTable() instanceof InMemoryLookupTable) { InMemoryLookupTable l = (InMemoryLookupTable) lookupTable(); INDArray syn0 = l.getSyn0(); INDArray weights = syn0.norm2(0).rdivi(1).muli(words); INDArray distances = syn0.mulRowVector(weights).sum(1); INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false); INDArray sort = sorted[0]; List<String> ret = new ArrayList<>(); if (top > sort.length()) top = sort.length(); // there will be a redundant word int end = top; for (int i = 0; i < end; i++) { String word = vocab.wordAtIndex(sort.getInt(i)); if (union.contains(word)) { end++; if (end >= sort.length()) break; continue; } String add = vocab().wordAtIndex(sort.getInt(i)); if (add == null || add.equals("UNK") || add.equals("STOP")) { end++; if (end >= sort.length()) break; continue; } ret.add(vocab().wordAtIndex(sort.getInt(i))); } return ret; } Counter<String> distances = new Counter<>(); for (String s : vocab().words()) { INDArray otherVec = getWordVectorMatrix(s); double sim = Transforms.cosineSim(words, otherVec); distances.incrementCount(s, sim); } distances.keepTopNKeys(top); return distances.keySet(); }
/** * @return * @throws NumberFormatException * @throws IOException * @throws FileNotFoundException */ private static Word2Vec readBinaryModel(File modelFile) throws NumberFormatException, IOException { InMemoryLookupTable lookupTable; VocabCache cache; INDArray syn0; int words, size; try (BufferedInputStream bis = new BufferedInputStream( GzipUtils.isCompressedFilename(modelFile.getName()) ? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile)); DataInputStream dis = new DataInputStream(bis)) { words = Integer.parseInt(readString(dis)); size = Integer.parseInt(readString(dis)); syn0 = Nd4j.create(words, size); cache = new InMemoryLookupCache(false); lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().cache(cache).vectorLength(size).build(); String word; for (int i = 0; i < words; i++) { word = readString(dis); log.trace("Loading " + word + " with word " + i); if (word.isEmpty()) { continue; } float[] vector = new float[size]; for (int j = 0; j < size; j++) { vector[j] = readFloat(dis); } syn0.putRow(i, Transforms.unitVec(Nd4j.create(vector))); cache.addWordToIndex(cache.numWords(), word); cache.addToken(new VocabWord(1, word)); cache.putVocabWord(word); } } Word2Vec ret = new Word2Vec(); lookupTable.setSyn0(syn0); ret.setVocab(cache); ret.setLookupTable(lookupTable); return ret; }
/** * Words nearest based on positive and negative words * @param top the top n words * * @return the words nearest the mean of the words */ @Override public Collection<String> wordsNearest(INDArray words, int top) { if (lookupTable instanceof InMemoryLookupTable) { InMemoryLookupTable l = (InMemoryLookupTable) lookupTable; INDArray syn0 = l.getSyn0(); if (!normalized) { synchronized (this) { if (!normalized) { syn0.diviColumnVector(syn0.norm1(1)); normalized = true; } } } INDArray similarity = Transforms.unitVec(words).mmul(syn0.transpose()); List<Double> highToLowSimList = getTopN(similarity, top + 20); List<WordSimilarity> result = new ArrayList<>(); for (int i = 0; i < highToLowSimList.size(); i++) { String word = vocabCache.wordAtIndex(highToLowSimList.get(i).intValue()); if (word != null && !word.equals("UNK") && !word.equals("STOP")) { INDArray otherVec = lookupTable.vector(word); double sim = Transforms.cosineSim(words, otherVec); result.add(new WordSimilarity(word, sim)); } } Collections.sort(result, new SimilarityComparator()); return getLabels(result, top); } Counter<String> distances = new Counter<>(); for (String s : vocabCache.words()) { INDArray otherVec = lookupTable.vector(s); double sim = Transforms.cosineSim(words, otherVec); distances.incrementCount(s, sim); } distances.keepTopNKeys(top); return distances.keySet(); }
/** * @param modelFile * @return * @throws FileNotFoundException * @throws IOException * @throws NumberFormatException */ private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException { InMemoryLookupTable lookupTable; VocabCache cache; INDArray syn0; BufferedReader reader = new BufferedReader(new FileReader(modelFile)); String line = reader.readLine(); String[] initial = line.split(" "); int words = Integer.parseInt(initial[0]); int layerSize = Integer.parseInt(initial[1]); syn0 = Nd4j.create(words, layerSize); cache = new InMemoryLookupCache(); int currLine = 0; while ((line = reader.readLine()) != null) { String[] split = line.split(" "); String word = split[0]; if (word.isEmpty()) { continue; } float[] vector = new float[split.length - 1]; for (int i = 1; i < split.length; i++) { vector[i - 1] = Float.parseFloat(split[i]); } syn0.putRow(currLine, Transforms.unitVec(Nd4j.create(vector))); cache.addWordToIndex(cache.numWords(), word); cache.addToken(new VocabWord(1, word)); cache.putVocabWord(word); } lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().cache(cache).vectorLength(layerSize).build(); lookupTable.setSyn0(syn0); Word2Vec ret = new Word2Vec(); ret.setVocab(cache); ret.setLookupTable(lookupTable); reader.close(); return ret; }
/** * Loads an in memory cache from the given path (sets syn0 and the vocab) * * @param vectorsFile the path of the file to load * @return * @throws FileNotFoundException */ public static Pair<InMemoryLookupTable, VocabCache> loadTxt(File vectorsFile) throws FileNotFoundException { BufferedReader write = new BufferedReader(new FileReader(vectorsFile)); VocabCache cache = new InMemoryLookupCache(); InMemoryLookupTable lookupTable; LineIterator iter = IOUtils.lineIterator(write); List<INDArray> arrays = new ArrayList<>(); while (iter.hasNext()) { String line = iter.nextLine(); String[] split = line.split(" "); String word = split[0]; VocabWord word1 = new VocabWord(1.0, word); cache.addToken(word1); cache.addWordToIndex(cache.numWords(), word); word1.setIndex(cache.numWords()); cache.putVocabWord(word); INDArray row = Nd4j.create(Nd4j.createBuffer(split.length - 1)); for (int i = 1; i < split.length; i++) { row.putScalar(i - 1, Float.parseFloat(split[i])); } arrays.add(row); } INDArray syn = Nd4j.create(new int[] {arrays.size(), arrays.get(0).columns()}); for (int i = 0; i < syn.rows(); i++) { syn.putRow(i, arrays.get(i)); } lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder() .vectorLength(arrays.get(0).columns()) .useAdaGrad(false) .cache(cache) .build(); Nd4j.clearNans(syn); lookupTable.setSyn0(syn); iter.close(); return new Pair<>(lookupTable, cache); }
/** * Get the top n words most similar to the given word * * @param word the word to compare * @param n the n to get * @return the top n words */ public Collection<String> wordsNearestSum(String word, int n) { INDArray vec = Transforms.unitVec(this.getWordVectorMatrix(word)); if (lookupTable() instanceof InMemoryLookupTable) { InMemoryLookupTable l = (InMemoryLookupTable) lookupTable(); INDArray syn0 = l.getSyn0(); INDArray weights = syn0.norm2(0).rdivi(1).muli(vec); INDArray distances = syn0.mulRowVector(weights).sum(1); INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false); INDArray sort = sorted[0]; List<String> ret = new ArrayList<>(); SequenceElement word2 = vocab().wordFor(word); if (n > sort.length()) n = sort.length(); // there will be a redundant word for (int i = 0; i < n + 1; i++) { if (sort.getInt(i) == word2.getIndex()) continue; String add = vocab().wordAtIndex(sort.getInt(i)); if (add == null || add.equals("UNK") || add.equals("STOP")) { continue; } ret.add(vocab().wordAtIndex(sort.getInt(i))); } return ret; } if (vec == null) return new ArrayList<>(); Counter<String> distances = new Counter<>(); for (String s : vocab().words()) { if (s.equals(word)) continue; INDArray otherVec = getWordVectorMatrix(s); double sim = Transforms.cosineSim(vec, otherVec); distances.incrementCount(s, sim); } distances.keepTopNKeys(n); return distances.keySet(); }
/** * Words nearest based on positive and negative words * @param top the top n words * * @return the words nearest the mean of the words */ @Override public Collection<String> wordsNearest(INDArray words, int top) { if (lookupTable() instanceof InMemoryLookupTable) { InMemoryLookupTable l = (InMemoryLookupTable) lookupTable(); INDArray syn0 = l.getSyn0(); INDArray weights = syn0.norm2(0).rdivi(1).muli(words); INDArray distances = syn0.mulRowVector(weights).mean(1); INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false); INDArray sort = sorted[0]; List<String> ret = new ArrayList<>(); if (top > sort.length()) top = sort.length(); // there will be a redundant word int end = top; for (int i = 0; i < end; i++) { VocabCache vocabCache = vocab(); int s = sort.getInt(0, i); String add = vocabCache.wordAtIndex(s); if (add == null || add.equals("UNK") || add.equals("STOP")) { end++; if (end >= sort.length()) break; continue; } ret.add(vocabCache.wordAtIndex(s)); } return ret; } Counter<String> distances = new Counter<>(); for (String s : vocab().words()) { INDArray otherVec = getWordVectorMatrix(s); double sim = Transforms.cosineSim(words, otherVec); distances.incrementCount(s, sim); } distances.keepTopNKeys(top); return distances.keySet(); }
/** * Words nearest based on positive and negative words * * @param positive the positive words * @param negative the negative words * @param top the top n words * @return the words nearest the mean of the words */ @Override public Collection<String> wordsNearest( Collection<String> positive, Collection<String> negative, int top) { // Check every word is in the model for (String p : SetUtils.union(new HashSet<>(positive), new HashSet<>(negative))) { if (!vocab().containsWord(p)) { return new ArrayList<>(); } } WeightLookupTable weightLookupTable = lookupTable(); INDArray words = Nd4j.create(positive.size() + negative.size(), weightLookupTable.layerSize()); int row = 0; Set<String> union = SetUtils.union(new HashSet<>(positive), new HashSet<>(negative)); for (String s : positive) { words.putRow(row++, weightLookupTable.vector(s)); } for (String s : negative) { words.putRow(row++, weightLookupTable.vector(s).mul(-1)); } INDArray mean = words.isMatrix() ? words.mean(0) : words; // TODO this should probably be replaced with wordsNearest(mean, top) if (weightLookupTable instanceof InMemoryLookupTable) { InMemoryLookupTable l = (InMemoryLookupTable) weightLookupTable; INDArray syn0 = l.getSyn0(); syn0.diviRowVector(syn0.norm2(0)); INDArray similarity = Transforms.unitVec(mean).mmul(syn0.transpose()); // We assume that syn0 is normalized. // Hence, the following division is not needed anymore. // distances.diviRowVector(distances.norm2(1)); // INDArray[] sorted = Nd4j.sortWithIndices(distances,0,false); List<Double> highToLowSimList = getTopN(similarity, top + union.size()); List<String> ret = new ArrayList<>(); for (int i = 0; i < highToLowSimList.size(); i++) { String word = vocab().wordAtIndex(highToLowSimList.get(i).intValue()); if (word != null && !word.equals("UNK") && !word.equals("STOP") && !union.contains(word)) { ret.add(word); if (ret.size() >= top) { break; } } } return ret; } Counter<String> distances = new Counter<>(); for (String s : vocab().words()) { INDArray otherVec = getWordVectorMatrix(s); double sim = Transforms.cosineSim(mean, otherVec); distances.incrementCount(s, sim); } distances.keepTopNKeys(top); return distances.keySet(); }