@Override
  protected void estimateModelParameters(Dataset trainingData) {
    int n = trainingData.size();
    int d = trainingData.getColumnSize();

    ModelParameters modelParameters = knowledgeBase.getModelParameters();
    TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();

    modelParameters.setN(n);
    modelParameters.setD(d);

    // get model parameters
    int k = trainingParameters.getK(); // number of topics
    Map<List<Object>, Integer> topicAssignmentOfDocumentWord =
        modelParameters.getTopicAssignmentOfDocumentWord();
    Map<List<Integer>, Integer> documentTopicCounts = modelParameters.getDocumentTopicCounts();
    Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts();
    Map<Integer, Integer> documentWordCounts = modelParameters.getDocumentWordCounts();
    Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts();

    // initialize topic assignments of each word randomly and update the counters
    for (Record r : trainingData) {
      Integer documentId = r.getId();

      documentWordCounts.put(documentId, r.getX().size());

      for (Map.Entry<Object, Object> entry : r.getX().entrySet()) {
        Object wordPosition = entry.getKey();
        Object word = entry.getValue();

        // sample a topic
        Integer topic = PHPfunctions.mt_rand(0, k - 1);

        increase(topicCounts, topic);
        topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
        increase(documentTopicCounts, Arrays.asList(documentId, topic));
        increase(topicWordCounts, Arrays.asList(topic, word));
      }
    }

    double alpha = trainingParameters.getAlpha();
    double beta = trainingParameters.getBeta();

    int maxIterations = trainingParameters.getMaxIterations();

    int iteration = 0;
    while (iteration < maxIterations) {

      if (GeneralConfiguration.DEBUG) {
        System.out.println("Iteration " + iteration);
      }

      int changedCounter = 0;
      // collapsed gibbs sampler
      for (Record r : trainingData) {
        Integer documentId = r.getId();

        AssociativeArray topicAssignments = new AssociativeArray();
        for (int j = 0; j < k; ++j) {
          topicAssignments.put(j, 0.0);
        }

        int totalWords = r.getX().size();

        for (Map.Entry<Object, Object> entry : r.getX().entrySet()) {
          Object wordPosition = entry.getKey();
          Object word = entry.getValue();

          // remove the word from the dataset
          Integer topic =
              topicAssignmentOfDocumentWord.get(Arrays.asList(documentId, wordPosition));
          // decrease(documentWordCounts, documentId); //slow
          decrease(topicCounts, topic);
          decrease(documentTopicCounts, Arrays.asList(documentId, topic));
          decrease(topicWordCounts, Arrays.asList(topic, word));

          // int numberOfDocumentWords = r.getX().size()-1; //fast - decreased by 1

          // compute the posteriors of the topics and sample from it
          AssociativeArray topicProbabilities = new AssociativeArray();
          for (int j = 0; j < k; ++j) {
            double enumerator = 0.0;
            Integer njw = topicWordCounts.get(Arrays.asList(j, word));
            if (njw != null) {
              enumerator = njw + beta;
            } else {
              enumerator = beta;
            }

            Integer njd = documentTopicCounts.get(Arrays.asList(documentId, j));
            if (njd != null) {
              enumerator *= (njd + alpha);
            } else {
              enumerator *= alpha;
            }

            double denominator = topicCounts.get((Integer) j) + beta * d;
            // denominator *= numberOfDocumentWords+alpha*k; //this is not necessary because it is
            // the same for all categories, so it can be omited

            topicProbabilities.put(j, enumerator / denominator);
          }

          // normalize probabilities
          Descriptives.normalize(topicProbabilities);

          // sample from these probabilieis
          Integer newTopic =
              (Integer)
                  SRS.weightedProbabilitySampling(topicProbabilities, 1, true).iterator().next();
          topic = newTopic; // new topic assigment

          // add back the word in the dataset
          topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
          // increase(documentWordCounts, documentId); //slow
          increase(topicCounts, topic);
          increase(documentTopicCounts, Arrays.asList(documentId, topic));
          increase(topicWordCounts, Arrays.asList(topic, word));

          topicAssignments.put(
              topic, Dataset.toDouble(topicAssignments.get(topic)) + 1.0 / totalWords);
        }

        Object mainTopic = MapFunctions.selectMaxKeyValue(topicAssignments).getKey();

        if (!mainTopic.equals(r.getYPredicted())) {
          ++changedCounter;
        }
        r.setYPredicted(mainTopic);
        r.setYPredictedProbabilities(topicAssignments);
      }
      ++iteration;

      if (GeneralConfiguration.DEBUG) {
        System.out.println("Reassigned Records " + changedCounter);
      }

      if (changedCounter == 0) {
        break;
      }
    }

    modelParameters.setTotalIterations(iteration);
  }