public void trainDocuments(double testFraction) {
   long start = System.nanoTime();
   modelTrainer.start();
   for (int docId = 0; docId < corpusWeights.numRows(); docId++) {
     if (testFraction == 0 || docId % (1 / testFraction) != 0) {
       Vector docTopics =
           new DenseVector(numTopics).assign(1.0 / numTopics); // docTopicCounts.getRow(docId)
       modelTrainer.trainSync(corpusWeights.viewRow(docId), docTopics, true, 10);
     }
   }
   modelTrainer.stop();
   logTime("train documents", System.nanoTime() - start);
 }
 public double iterateUntilConvergence(
     double minFractionalErrorChange, int maxIterations, int minIter, double testFraction) {
   int iter = 0;
   double oldPerplexity = 0;
   while (iter < minIter) {
     trainDocuments(testFraction);
     if (verbose) {
       log.info("model after: " + iter + ": " + modelTrainer.getReadModel().toString());
     }
     log.info("iteration " + iter + " complete");
     oldPerplexity = modelTrainer.calculatePerplexity(corpusWeights, docTopicCounts, testFraction);
     log.info(oldPerplexity + " = perplexity");
     iter++;
   }
   double newPerplexity = 0;
   double fractionalChange = Double.MAX_VALUE;
   while (iter < maxIterations && fractionalChange > minFractionalErrorChange) {
     trainDocuments();
     if (verbose) {
       log.info("model after: " + iter + ": " + modelTrainer.getReadModel().toString());
     }
     newPerplexity = modelTrainer.calculatePerplexity(corpusWeights, docTopicCounts, testFraction);
     log.info(newPerplexity + " = perplexity");
     iter++;
     fractionalChange = Math.abs(newPerplexity - oldPerplexity) / oldPerplexity;
     log.info(fractionalChange + " = fractionalChange");
     oldPerplexity = newPerplexity;
   }
   if (iter < maxIterations) {
     log.info(
         String.format(
             "Converged! fractional error change: %f, error %f", fractionalChange, newPerplexity));
   } else {
     log.info(
         String.format(
             "Reached max iteration count (%d), fractional error change: %f, error: %f",
             maxIterations, fractionalChange, newPerplexity));
   }
   return newPerplexity;
 }
예제 #3
0
 public void registerWithTrainer(ModelTrainer modelTrainer) {
   modelTrainer.registerModel(delegate);
 }
 public void writeModel(Path outputPath) throws IOException {
   modelTrainer.persist(outputPath);
 }