public FirstIterationFunction( Broadcast<Map<String, Object>> word2vecVarMapBroadcast, Broadcast<double[]> expTableBroadcast, Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast) { Map<String, Object> word2vecVarMap = word2vecVarMapBroadcast.getValue(); this.expTable = expTableBroadcast.getValue(); this.vectorLength = (int) word2vecVarMap.get("vectorLength"); this.useAdaGrad = (boolean) word2vecVarMap.get("useAdaGrad"); this.negative = (double) word2vecVarMap.get("negative"); this.window = (int) word2vecVarMap.get("window"); this.alpha = (double) word2vecVarMap.get("alpha"); this.minAlpha = (double) word2vecVarMap.get("minAlpha"); this.totalWordCount = (long) word2vecVarMap.get("totalWordCount"); this.seed = (long) word2vecVarMap.get("seed"); this.maxExp = (int) word2vecVarMap.get("maxExp"); this.iterations = (int) word2vecVarMap.get("iterations"); this.batchSize = (int) word2vecVarMap.get("batchSize"); this.indexSyn0VecMap = new HashMap<>(); this.pointSyn1VecMap = new HashMap<>(); this.vocab = vocabCacheBroadcast.getValue(); if (this.vocab == null) throw new RuntimeException("VocabCache is null"); if (negative > 0) { negativeHolder = NegativeHolder.getInstance(); negativeHolder.initHolder(vocab, expTable, this.vectorLength); } }
public void iterateSample(VocabWord w1, VocabWord w2, double currentSentenceAlpha) { if (w1 == null || w2 == null || w2.getIndex() < 0 || w2.getIndex() == w1.getIndex()) return; final int currentWordIndex = w2.getIndex(); // error for current word and context INDArray neu1e = Nd4j.create(vectorLength); // First iteration Syn0 is random numbers INDArray l1 = null; if (indexSyn0VecMap.containsKey(vocab.elementAtIndex(currentWordIndex))) { l1 = indexSyn0VecMap.get(vocab.elementAtIndex(currentWordIndex)); } else { l1 = getRandomSyn0Vec(vectorLength, (long) currentWordIndex); } // for (int i = 0; i < w1.getCodeLength(); i++) { int code = w1.getCodes().get(i); int point = w1.getPoints().get(i); if (point < 0) throw new IllegalStateException("Illegal point " + point); // Point to INDArray syn1; if (pointSyn1VecMap.containsKey(point)) { syn1 = pointSyn1VecMap.get(point); } else { syn1 = Nd4j.zeros(1, vectorLength); // 1 row of vector length of zeros pointSyn1VecMap.put(point, syn1); } // Dot product of Syn0 and Syn1 vecs double dot = Nd4j.getBlasWrapper().level1().dot(vectorLength, 1.0, l1, syn1); if (dot < -maxExp || dot >= maxExp) continue; int idx = (int) ((dot + maxExp) * ((double) expTable.length / maxExp / 2.0)); if (idx > expTable.length) continue; // score double f = expTable[idx]; // gradient double g = (1 - code - f) * (useAdaGrad ? w1.getGradient(i, currentSentenceAlpha, currentSentenceAlpha) : currentSentenceAlpha); Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, syn1, neu1e); Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, l1, syn1); } int target = w1.getIndex(); int label; // negative sampling if (negative > 0) for (int d = 0; d < negative + 1; d++) { if (d == 0) label = 1; else { nextRandom.set(nextRandom.get() * 25214903917L + 11); int idx = Math.abs((int) (nextRandom.get() >> 16) % negativeHolder.getTable().length()); target = negativeHolder.getTable().getInt(idx); if (target <= 0) target = (int) nextRandom.get() % (vocab.numWords() - 1) + 1; if (target == w1.getIndex()) continue; label = 0; } if (target >= negativeHolder.getSyn1Neg().rows() || target < 0) continue; double f = Nd4j.getBlasWrapper().dot(l1, negativeHolder.getSyn1Neg().slice(target)); double g; if (f > maxExp) g = useAdaGrad ? w1.getGradient(target, (label - 1), alpha) : (label - 1) * alpha; else if (f < -maxExp) g = label * (useAdaGrad ? w1.getGradient(target, alpha, alpha) : alpha); else { int idx = (int) ((f + maxExp) * (expTable.length / maxExp / 2)); if (idx >= expTable.length) continue; g = useAdaGrad ? w1.getGradient(target, label - expTable[idx], alpha) : (label - expTable[idx]) * alpha; } Nd4j.getBlasWrapper().axpy((float) g, negativeHolder.getSyn1Neg().slice(target), neu1e); Nd4j.getBlasWrapper().axpy((float) g, l1, negativeHolder.getSyn1Neg().slice(target)); } // Updated the Syn0 vector based on gradient. Syn0 is not random anymore. Nd4j.getBlasWrapper().level1().axpy(vectorLength, 1.0f, neu1e, l1); if (aff.get() == 0) { synchronized (this) { cid.set(EnvironmentUtils.buildCId()); aff.set(EnvironmentUtils.buildEnvironment().getAvailableMemory()); } } VocabWord word = vocab.elementAtIndex(currentWordIndex); word.setVocabId(cid.get()); word.setAffinityId(aff.get()); indexSyn0VecMap.put(word, l1); }