public AssociativeArray2D getWordProbabilitiesPerTopic() {
    AssociativeArray2D ptw = new AssociativeArray2D();

    ModelParameters modelParameters = knowledgeBase.getModelParameters();
    TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();

    // initialize a probability list for every topic
    int k = trainingParameters.getK();
    for (Integer topicId = 0; topicId < k; ++topicId) {
      ptw.put(topicId, new AssociativeArray());
    }

    int d = modelParameters.getD();
    double beta = trainingParameters.getBeta();

    Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts();
    Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts();
    for (Map.Entry<List<Object>, Integer> entry : topicWordCounts.entrySet()) {
      List<Object> tpk = entry.getKey();
      Integer topicId = (Integer) tpk.get(0);
      Object word = tpk.get(1);
      Integer njw = entry.getValue();

      Integer nj = topicCounts.get(topicId);

      double probability = (njw + beta) / (nj + beta * d);

      ptw.get(topicId).put(word, probability);
    }

    for (Integer topicId = 0; topicId < k; ++topicId) {
      ptw.put(topicId, MapFunctions.sortAssociativeArrayByValueDescending(ptw.get(topicId)));
    }

    return ptw;
  }
  protected Object getSelectedClassFromClassScores(AssociativeArray predictionScores) {
    Map.Entry<Object, Object> maxEntry = MapFunctions.selectMaxKeyValue(predictionScores);

    return maxEntry.getKey();
  }
  private ValidationMetrics predictAndValidate(Dataset newData) {
    // This method uses similar approach to the training but the most important
    // difference is that we do not wish to modify the original training params.
    // as a result we need to modify the code to use additional temporary
    // counts for the testing data and merge them with the parameters from the
    // training data in order to make a decision
    ModelParameters modelParameters = knowledgeBase.getModelParameters();
    TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();

    // create new validation metrics object
    ValidationMetrics validationMetrics = knowledgeBase.getEmptyValidationMetricsObject();

    String tmpPrefix = StorageConfiguration.getTmpPrefix();

    // get model parameters
    int n = modelParameters.getN();
    int d = modelParameters.getD();
    int k = trainingParameters.getK(); // number of topics

    Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts();
    Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts();

    BigDataStructureFactory.MapType mapType = knowledgeBase.getMemoryConfiguration().getMapType();
    int LRUsize = knowledgeBase.getMemoryConfiguration().getLRUsize();
    BigDataStructureFactory bdsf = knowledgeBase.getBdsf();

    // we create temporary maps for the prediction sets to avoid modifing the maps that we already
    // learned
    Map<List<Object>, Integer> tmp_topicAssignmentOfDocumentWord =
        bdsf.getMap(tmpPrefix + "topicAssignmentOfDocumentWord", mapType, LRUsize);
    Map<List<Integer>, Integer> tmp_documentTopicCounts =
        bdsf.getMap(tmpPrefix + "documentTopicCounts", mapType, LRUsize);
    Map<List<Object>, Integer> tmp_topicWordCounts =
        bdsf.getMap(tmpPrefix + "topicWordCounts", mapType, LRUsize);
    Map<Integer, Integer> tmp_topicCounts =
        bdsf.getMap(tmpPrefix + "topicCounts", mapType, LRUsize);

    // initialize topic assignments of each word randomly and update the counters
    for (Record r : newData) {
      Integer documentId = r.getId();

      for (Map.Entry<Object, Object> entry : r.getX().entrySet()) {
        Object wordPosition = entry.getKey();
        Object word = entry.getValue();

        // sample a topic
        Integer topic = PHPfunctions.mt_rand(0, k - 1);

        increase(tmp_topicCounts, topic);
        tmp_topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
        increase(tmp_documentTopicCounts, Arrays.asList(documentId, topic));
        increase(tmp_topicWordCounts, Arrays.asList(topic, word));
      }
    }

    double alpha = trainingParameters.getAlpha();
    double beta = trainingParameters.getBeta();

    int maxIterations = trainingParameters.getMaxIterations();

    double perplexity = Double.MAX_VALUE;
    for (int iteration = 0; iteration < maxIterations; ++iteration) {

      if (GeneralConfiguration.DEBUG) {
        System.out.println("Iteration " + iteration);
      }

      // collapsed gibbs sampler
      int changedCounter = 0;
      perplexity = 0.0;
      double totalDatasetWords = 0.0;
      for (Record r : newData) {
        Integer documentId = r.getId();

        AssociativeArray topicAssignments = new AssociativeArray();
        for (int j = 0; j < k; ++j) {
          topicAssignments.put(j, 0.0);
        }

        int totalDocumentWords = r.getX().size();
        totalDatasetWords += totalDocumentWords;
        for (Map.Entry<Object, Object> entry : r.getX().entrySet()) {
          Object wordPosition = entry.getKey();
          Object word = entry.getValue();

          // remove the word from the dataset
          Integer topic =
              tmp_topicAssignmentOfDocumentWord.get(Arrays.asList(documentId, wordPosition));
          decrease(tmp_topicCounts, topic);
          decrease(tmp_documentTopicCounts, Arrays.asList(documentId, topic));
          decrease(tmp_topicWordCounts, Arrays.asList(topic, word));

          int numberOfDocumentWords = r.getX().size() - 1;

          // compute the posteriors of the topics and sample from it
          AssociativeArray topicProbabilities = new AssociativeArray();
          for (int j = 0; j < k; ++j) {
            double enumerator = 0.0;

            // get the counts from the current testing data
            List<Object> topicWordKey = Arrays.asList(j, word);
            Integer njw = tmp_topicWordCounts.get(topicWordKey);
            if (njw != null) {
              enumerator = njw + beta;
            } else {
              enumerator = beta;
            }

            // get also the counts from the training data
            Integer njw_original = topicWordCounts.get(topicWordKey);
            if (njw_original != null) {
              enumerator += njw_original;
            }

            Integer njd = tmp_documentTopicCounts.get(Arrays.asList(documentId, j));
            if (njd != null) {
              enumerator *= (njd + alpha);
            } else {
              enumerator *= alpha;
            }

            // add the counts from testing data
            double denominator = tmp_topicCounts.get((Integer) j) + beta * d - 1;
            // and the ones from training data
            denominator += topicCounts.get((Integer) j);
            denominator *= numberOfDocumentWords + alpha * k;

            topicProbabilities.put(j, enumerator / denominator);
          }

          perplexity += Math.log(Descriptives.sum(topicProbabilities.toFlatDataCollection()));

          // normalize probabilities
          Descriptives.normalize(topicProbabilities);

          // sample from these probabilieis
          Integer newTopic =
              (Integer)
                  SRS.weightedProbabilitySampling(topicProbabilities, 1, true).iterator().next();
          topic = newTopic; // new topic assignment

          // add back the word in the dataset
          tmp_topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
          increase(tmp_topicCounts, topic);
          increase(tmp_documentTopicCounts, Arrays.asList(documentId, topic));
          increase(tmp_topicWordCounts, Arrays.asList(topic, word));

          topicAssignments.put(
              topic, Dataset.toDouble(topicAssignments.get(topic)) + 1.0 / totalDocumentWords);
        }

        Object mainTopic = MapFunctions.selectMaxKeyValue(topicAssignments).getKey();

        if (!mainTopic.equals(r.getYPredicted())) {
          ++changedCounter;
        }
        r.setYPredicted(mainTopic);
        r.setYPredictedProbabilities(topicAssignments);
      }

      perplexity = Math.exp(-perplexity / totalDatasetWords);

      if (GeneralConfiguration.DEBUG) {
        System.out.println("Reassigned Records " + changedCounter + " - Perplexity: " + perplexity);
      }

      if (changedCounter == 0) {
        break;
      }
    }

    // Drop the temporary Collection
    bdsf.dropTable(tmpPrefix + "topicAssignmentOfDocumentWord", tmp_topicAssignmentOfDocumentWord);
    bdsf.dropTable(tmpPrefix + "documentTopicCounts", tmp_documentTopicCounts);
    bdsf.dropTable(tmpPrefix + "topicWordCounts", tmp_topicWordCounts);
    bdsf.dropTable(tmpPrefix + "topicCounts", tmp_topicCounts);

    validationMetrics.setPerplexity(perplexity);

    return validationMetrics;
  }
  @Override
  protected void estimateModelParameters(Dataset trainingData) {
    int n = trainingData.size();
    int d = trainingData.getColumnSize();

    ModelParameters modelParameters = knowledgeBase.getModelParameters();
    TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();

    modelParameters.setN(n);
    modelParameters.setD(d);

    // get model parameters
    int k = trainingParameters.getK(); // number of topics
    Map<List<Object>, Integer> topicAssignmentOfDocumentWord =
        modelParameters.getTopicAssignmentOfDocumentWord();
    Map<List<Integer>, Integer> documentTopicCounts = modelParameters.getDocumentTopicCounts();
    Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts();
    Map<Integer, Integer> documentWordCounts = modelParameters.getDocumentWordCounts();
    Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts();

    // initialize topic assignments of each word randomly and update the counters
    for (Record r : trainingData) {
      Integer documentId = r.getId();

      documentWordCounts.put(documentId, r.getX().size());

      for (Map.Entry<Object, Object> entry : r.getX().entrySet()) {
        Object wordPosition = entry.getKey();
        Object word = entry.getValue();

        // sample a topic
        Integer topic = PHPfunctions.mt_rand(0, k - 1);

        increase(topicCounts, topic);
        topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
        increase(documentTopicCounts, Arrays.asList(documentId, topic));
        increase(topicWordCounts, Arrays.asList(topic, word));
      }
    }

    double alpha = trainingParameters.getAlpha();
    double beta = trainingParameters.getBeta();

    int maxIterations = trainingParameters.getMaxIterations();

    int iteration = 0;
    while (iteration < maxIterations) {

      if (GeneralConfiguration.DEBUG) {
        System.out.println("Iteration " + iteration);
      }

      int changedCounter = 0;
      // collapsed gibbs sampler
      for (Record r : trainingData) {
        Integer documentId = r.getId();

        AssociativeArray topicAssignments = new AssociativeArray();
        for (int j = 0; j < k; ++j) {
          topicAssignments.put(j, 0.0);
        }

        int totalWords = r.getX().size();

        for (Map.Entry<Object, Object> entry : r.getX().entrySet()) {
          Object wordPosition = entry.getKey();
          Object word = entry.getValue();

          // remove the word from the dataset
          Integer topic =
              topicAssignmentOfDocumentWord.get(Arrays.asList(documentId, wordPosition));
          // decrease(documentWordCounts, documentId); //slow
          decrease(topicCounts, topic);
          decrease(documentTopicCounts, Arrays.asList(documentId, topic));
          decrease(topicWordCounts, Arrays.asList(topic, word));

          // int numberOfDocumentWords = r.getX().size()-1; //fast - decreased by 1

          // compute the posteriors of the topics and sample from it
          AssociativeArray topicProbabilities = new AssociativeArray();
          for (int j = 0; j < k; ++j) {
            double enumerator = 0.0;
            Integer njw = topicWordCounts.get(Arrays.asList(j, word));
            if (njw != null) {
              enumerator = njw + beta;
            } else {
              enumerator = beta;
            }

            Integer njd = documentTopicCounts.get(Arrays.asList(documentId, j));
            if (njd != null) {
              enumerator *= (njd + alpha);
            } else {
              enumerator *= alpha;
            }

            double denominator = topicCounts.get((Integer) j) + beta * d;
            // denominator *= numberOfDocumentWords+alpha*k; //this is not necessary because it is
            // the same for all categories, so it can be omited

            topicProbabilities.put(j, enumerator / denominator);
          }

          // normalize probabilities
          Descriptives.normalize(topicProbabilities);

          // sample from these probabilieis
          Integer newTopic =
              (Integer)
                  SRS.weightedProbabilitySampling(topicProbabilities, 1, true).iterator().next();
          topic = newTopic; // new topic assigment

          // add back the word in the dataset
          topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
          // increase(documentWordCounts, documentId); //slow
          increase(topicCounts, topic);
          increase(documentTopicCounts, Arrays.asList(documentId, topic));
          increase(topicWordCounts, Arrays.asList(topic, word));

          topicAssignments.put(
              topic, Dataset.toDouble(topicAssignments.get(topic)) + 1.0 / totalWords);
        }

        Object mainTopic = MapFunctions.selectMaxKeyValue(topicAssignments).getKey();

        if (!mainTopic.equals(r.getYPredicted())) {
          ++changedCounter;
        }
        r.setYPredicted(mainTopic);
        r.setYPredictedProbabilities(topicAssignments);
      }
      ++iteration;

      if (GeneralConfiguration.DEBUG) {
        System.out.println("Reassigned Records " + changedCounter);
      }

      if (changedCounter == 0) {
        break;
      }
    }

    modelParameters.setTotalIterations(iteration);
  }