@Override public void bigDataStructureInitializer( BigDataStructureFactory bdsf, MemoryConfiguration memoryConfiguration) { super.bigDataStructureInitializer(bdsf, memoryConfiguration); BigDataStructureFactory.MapType mapType = memoryConfiguration.getMapType(); int LRUsize = memoryConfiguration.getLRUsize(); topicAssignmentOfDocumentWord = bdsf.getMap("topicAssignmentOfDocumentWord", mapType, LRUsize); documentTopicCounts = bdsf.getMap("documentTopicCounts", mapType, LRUsize); topicWordCounts = bdsf.getMap("topicWordCounts", mapType, LRUsize); documentWordCounts = bdsf.getMap("documentWordCounts", mapType, LRUsize); topicCounts = bdsf.getMap("topicCounts", mapType, LRUsize); }
private ValidationMetrics predictAndValidate(Dataset newData) { // This method uses similar approach to the training but the most important // difference is that we do not wish to modify the original training params. // as a result we need to modify the code to use additional temporary // counts for the testing data and merge them with the parameters from the // training data in order to make a decision ModelParameters modelParameters = knowledgeBase.getModelParameters(); TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters(); // create new validation metrics object ValidationMetrics validationMetrics = knowledgeBase.getEmptyValidationMetricsObject(); String tmpPrefix = StorageConfiguration.getTmpPrefix(); // get model parameters int n = modelParameters.getN(); int d = modelParameters.getD(); int k = trainingParameters.getK(); // number of topics Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts(); Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts(); BigDataStructureFactory.MapType mapType = knowledgeBase.getMemoryConfiguration().getMapType(); int LRUsize = knowledgeBase.getMemoryConfiguration().getLRUsize(); BigDataStructureFactory bdsf = knowledgeBase.getBdsf(); // we create temporary maps for the prediction sets to avoid modifing the maps that we already // learned Map<List<Object>, Integer> tmp_topicAssignmentOfDocumentWord = bdsf.getMap(tmpPrefix + "topicAssignmentOfDocumentWord", mapType, LRUsize); Map<List<Integer>, Integer> tmp_documentTopicCounts = bdsf.getMap(tmpPrefix + "documentTopicCounts", mapType, LRUsize); Map<List<Object>, Integer> tmp_topicWordCounts = bdsf.getMap(tmpPrefix + "topicWordCounts", mapType, LRUsize); Map<Integer, Integer> tmp_topicCounts = bdsf.getMap(tmpPrefix + "topicCounts", mapType, LRUsize); // initialize topic assignments of each word randomly and update the counters for (Record r : newData) { Integer documentId = r.getId(); for (Map.Entry<Object, Object> entry : r.getX().entrySet()) { Object wordPosition = entry.getKey(); Object word = entry.getValue(); // sample a topic Integer topic = PHPfunctions.mt_rand(0, k - 1); increase(tmp_topicCounts, topic); tmp_topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic); increase(tmp_documentTopicCounts, Arrays.asList(documentId, topic)); increase(tmp_topicWordCounts, Arrays.asList(topic, word)); } } double alpha = trainingParameters.getAlpha(); double beta = trainingParameters.getBeta(); int maxIterations = trainingParameters.getMaxIterations(); double perplexity = Double.MAX_VALUE; for (int iteration = 0; iteration < maxIterations; ++iteration) { if (GeneralConfiguration.DEBUG) { System.out.println("Iteration " + iteration); } // collapsed gibbs sampler int changedCounter = 0; perplexity = 0.0; double totalDatasetWords = 0.0; for (Record r : newData) { Integer documentId = r.getId(); AssociativeArray topicAssignments = new AssociativeArray(); for (int j = 0; j < k; ++j) { topicAssignments.put(j, 0.0); } int totalDocumentWords = r.getX().size(); totalDatasetWords += totalDocumentWords; for (Map.Entry<Object, Object> entry : r.getX().entrySet()) { Object wordPosition = entry.getKey(); Object word = entry.getValue(); // remove the word from the dataset Integer topic = tmp_topicAssignmentOfDocumentWord.get(Arrays.asList(documentId, wordPosition)); decrease(tmp_topicCounts, topic); decrease(tmp_documentTopicCounts, Arrays.asList(documentId, topic)); decrease(tmp_topicWordCounts, Arrays.asList(topic, word)); int numberOfDocumentWords = r.getX().size() - 1; // compute the posteriors of the topics and sample from it AssociativeArray topicProbabilities = new AssociativeArray(); for (int j = 0; j < k; ++j) { double enumerator = 0.0; // get the counts from the current testing data List<Object> topicWordKey = Arrays.asList(j, word); Integer njw = tmp_topicWordCounts.get(topicWordKey); if (njw != null) { enumerator = njw + beta; } else { enumerator = beta; } // get also the counts from the training data Integer njw_original = topicWordCounts.get(topicWordKey); if (njw_original != null) { enumerator += njw_original; } Integer njd = tmp_documentTopicCounts.get(Arrays.asList(documentId, j)); if (njd != null) { enumerator *= (njd + alpha); } else { enumerator *= alpha; } // add the counts from testing data double denominator = tmp_topicCounts.get((Integer) j) + beta * d - 1; // and the ones from training data denominator += topicCounts.get((Integer) j); denominator *= numberOfDocumentWords + alpha * k; topicProbabilities.put(j, enumerator / denominator); } perplexity += Math.log(Descriptives.sum(topicProbabilities.toFlatDataCollection())); // normalize probabilities Descriptives.normalize(topicProbabilities); // sample from these probabilieis Integer newTopic = (Integer) SRS.weightedProbabilitySampling(topicProbabilities, 1, true).iterator().next(); topic = newTopic; // new topic assignment // add back the word in the dataset tmp_topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic); increase(tmp_topicCounts, topic); increase(tmp_documentTopicCounts, Arrays.asList(documentId, topic)); increase(tmp_topicWordCounts, Arrays.asList(topic, word)); topicAssignments.put( topic, Dataset.toDouble(topicAssignments.get(topic)) + 1.0 / totalDocumentWords); } Object mainTopic = MapFunctions.selectMaxKeyValue(topicAssignments).getKey(); if (!mainTopic.equals(r.getYPredicted())) { ++changedCounter; } r.setYPredicted(mainTopic); r.setYPredictedProbabilities(topicAssignments); } perplexity = Math.exp(-perplexity / totalDatasetWords); if (GeneralConfiguration.DEBUG) { System.out.println("Reassigned Records " + changedCounter + " - Perplexity: " + perplexity); } if (changedCounter == 0) { break; } } // Drop the temporary Collection bdsf.dropTable(tmpPrefix + "topicAssignmentOfDocumentWord", tmp_topicAssignmentOfDocumentWord); bdsf.dropTable(tmpPrefix + "documentTopicCounts", tmp_documentTopicCounts); bdsf.dropTable(tmpPrefix + "topicWordCounts", tmp_topicWordCounts); bdsf.dropTable(tmpPrefix + "topicCounts", tmp_topicCounts); validationMetrics.setPerplexity(perplexity); return validationMetrics; }