/**
   * Convenience method to get the score of Chisquare.
   *
   * @param dataTable
   * @return
   */
  public static double getScoreValue(DataTable2D dataTable) {

    AssociativeArray result = getScore(dataTable);
    double score = result.getDouble("score");

    return score;
  }
  /**
   * Calculates the p-value of null Hypothesis
   *
   * @param dataTable
   * @return
   */
  public static double getPvalue(DataTable2D dataTable) {

    AssociativeArray result = getScore(dataTable);
    double score = result.getDouble("score");
    int n = result.getDouble("n").intValue();
    int k = result.getDouble("k").intValue();

    double pvalue = scoreToPvalue(score, n, k);

    return pvalue;
  }
 /** Test of getRanksFromValues method, of class Dataset. */
 @Test
 public void testGetRanksFromValues() {
   logger.info("getRanksFromValues");
   FlatDataList flatDataCollection =
       new FlatDataList(Arrays.asList(new Object[] {50, 10, 10, 30, 40}));
   FlatDataList expResult =
       new FlatDataList(Arrays.asList(new Object[] {5.0, 1.5, 1.5, 3.0, 4.0}));
   AssociativeArray expResult2 = new AssociativeArray(new ConcurrentSkipListMap<>());
   expResult2.put(10, 2);
   AssociativeArray tiesCounter = Ranks.getRanksFromValues(flatDataCollection);
   assertEquals(expResult, flatDataCollection);
   assertEquals(expResult2, tiesCounter);
 }
  /**
   * Calculates the p-value of null Hypothesis
   *
   * @param flatDataCollection
   * @param median
   * @return
   * @throws IllegalArgumentException
   */
  public static double getPvalue(FlatDataCollection flatDataCollection, double median)
      throws IllegalArgumentException {
    int n = 0;
    AssociativeArray Di = new AssociativeArray();
    Iterator<Double> it = flatDataCollection.iteratorDouble();
    while (it.hasNext()) {
      double delta = it.next() - median;

      if (delta == 0.0) {
        continue; // don't count it at all
      }

      String key = "+";
      if (delta < 0) {
        key = "-";
      }
      Di.put(key + String.valueOf(n), Math.abs(delta));
      ++n;
    }
    if (n <= 0) {
      throw new IllegalArgumentException();
    }

    // converts the values of the table with its Ranks
    Di.toFlatDataList();
    Ranks.getRanksFromValues(Di);
    double W = 0.0;
    for (Map.Entry<Object, Object> entry : Di.entrySet()) {
      String key = entry.getKey().toString();
      Double rank = TypeInference.toDouble(entry.getValue());

      if (key.charAt(0) == '+') {
        W += rank;
      }
    }

    double pvalue = scoreToPvalue(W, n);

    return pvalue;
  }
Example #5
0
  /**
   * Calculates the p-value of null Hypothesis.
   *
   * @param transposeDataCollection
   * @return
   * @throws IllegalArgumentException
   */
  public static double getPvalue(TransposeDataCollection transposeDataCollection)
      throws IllegalArgumentException {
    if (transposeDataCollection.size() != 2) {
      throw new IllegalArgumentException();
    }

    Object[] keys = transposeDataCollection.keySet().toArray();

    // counter of uncencored internalData in each group
    Map<Object, Integer> n = new HashMap<>();
    n.put(keys[0], 0);
    n.put(keys[1], 0);

    Queue<Double> censoredData = new PriorityQueue<>();
    Queue<Double> uncensoredData = new PriorityQueue<>();
    for (Map.Entry<Object, FlatDataCollection> entry : transposeDataCollection.entrySet()) {
      Object j = entry.getKey();
      FlatDataCollection flatDataCollection = entry.getValue();

      for (Object value : flatDataCollection) {
        String str = value.toString();
        if (str.endsWith(CensoredDescriptives.CENSORED_NUMBER_POSTFIX)) {
          // censored internalData encoded as 4.3+ or -4.3+
          censoredData.add(
              Double.valueOf(
                  str.substring(
                      0,
                      str.length()
                          - CensoredDescriptives.CENSORED_NUMBER_POSTFIX
                              .length()))); // remove the trailing char and convert it to double
        } else {
          // uncensored internalData
          uncensoredData.add(TypeInference.toDouble(value)); // convert it to double
        }
        n.put(j, n.get(j) + 1);
      }
    }

    Double currentCensored = null;
    Double currentUncensored = null;
    AssociativeArray2D testTable = new AssociativeArray2D();

    do {
      if (currentCensored == null) {
        currentCensored = censoredData.poll();
      }
      if (currentUncensored == null) {
        currentUncensored = uncensoredData.poll();
      }

      Double ti;
      String key;
      if (currentUncensored == null) {
        key = currentCensored.toString().concat((CensoredDescriptives.CENSORED_NUMBER_POSTFIX));
        ti = currentCensored;
        currentCensored = null;
      } else if (currentCensored == null) {
        key = currentUncensored.toString();
        ti = currentUncensored;
        currentUncensored = null;
      } else if (currentCensored
          < currentUncensored) { // NOT EQUAL! Uncensored internalData of the same value are always
                                 // larger
        key = currentCensored.toString().concat(CensoredDescriptives.CENSORED_NUMBER_POSTFIX);
        ti = currentCensored;
        currentCensored = null;
      } else {
        key = currentUncensored.toString();
        ti = currentUncensored;
        currentUncensored = null;
      }

      Object value = testTable.get2d(key, "mi");
      if (value == null) {
        testTable.put2d(key, "mi", 1);
        testTable.put2d(key, "rti", 0);
      } else {
        testTable.put2d(key, "mi", ((Integer) value) + 1);
        continue; // continue in order not to count twice the r*ti below
      }

      for (Map.Entry<Object, FlatDataCollection> entry : transposeDataCollection.entrySet()) {
        Object j = entry.getKey();
        FlatDataCollection flatDataCollection = entry.getValue();

        for (Object value2 : flatDataCollection) {
          double v;
          String str = value2.toString();
          if (str.endsWith(CensoredDescriptives.CENSORED_NUMBER_POSTFIX)) {
            // censored internalData encoded as 4.3+ or -4.3+
            v =
                Double.valueOf(
                    str.substring(
                        0,
                        str.length()
                            - CensoredDescriptives.CENSORED_NUMBER_POSTFIX
                                .length())); // remove the trailing char and convert it to double
          } else {
            // uncensored internalData
            v = TypeInference.toDouble(value2); // convert it to double
          }

          if (v >= ti) {
            testTable.put2d(key, "rti", (Integer) testTable.get2d(key, "rti") + 1);
          }
        }
      }

    } while (currentCensored != null
        || currentUncensored != null
        || !censoredData.isEmpty()
        || !uncensoredData.isEmpty());

    censoredData = null;
    uncensoredData = null;

    double VarS = 0.0;

    Object previousUncencoredKey = null;
    for (Map.Entry<Object, AssociativeArray> entry : testTable.entrySet()) {
      Object ti = entry.getKey();
      AssociativeArray testRow = entry.getValue();

      double previousUncencoredValue = 0;

      Object tmp = testTable.get2d(previousUncencoredKey, "eti");
      if (tmp != null) {
        previousUncencoredValue = TypeInference.toDouble(tmp);
      }

      if (!ti.toString().endsWith(CensoredDescriptives.CENSORED_NUMBER_POSTFIX)) { // uncensored
        double mi = testRow.getDouble("mi");
        double rti = testRow.getDouble("rti");
        double eti = previousUncencoredValue + mi / rti;

        testRow.put("eti", eti);
        testRow.put("wi", 1 - eti);
        previousUncencoredKey = ti;
      } else { // censored
        testRow.put("wi", -previousUncencoredValue);
      }

      double wi = testRow.getDouble("wi");
      VarS += testRow.getDouble("mi") * wi * wi;
    }

    double S = 0.0;
    for (Object value : transposeDataCollection.get(keys[0])) { // if ti belongs to the first group
      Object
          key; // we must first convert the number into to double and then append the + if
               // necessary. This is why it's converted like this.
      String str = value.toString();
      if (str.endsWith(CensoredDescriptives.CENSORED_NUMBER_POSTFIX)) {
        // censored internalData encoded as 4.3+ or -4.3+
        Double v =
            Double.valueOf(
                str.substring(
                    0,
                    str.length()
                        - CensoredDescriptives.CENSORED_NUMBER_POSTFIX
                            .length())); // remove the trailing char and convert it to double
        key = v.toString() + CensoredDescriptives.CENSORED_NUMBER_POSTFIX;
      } else {
        // uncensored internalData
        Double v = TypeInference.toDouble(value); // convert it to double
        key = v.toString();
      }
      double wi = TypeInference.toDouble(testTable.get2d(key, "wi"));
      S += wi;
    }
    testTable = null;

    double n0 = n.get(keys[0]).doubleValue();
    double n1 = n.get(keys[1]).doubleValue();

    VarS *= n0 * n1 / ((n0 + n1) * (n0 + n1 - 1.0));

    double Z = S / Math.sqrt(VarS);

    double pvalue = scoreToPvalue(Z);

    return pvalue;
  }
  /**
   * Estimates the efficiency of the records by running DEA
   *
   * @param id2DeaRecordMapDatabase AssociativeArray with the DeaRecords
   * @param id2DeaRecordMapEvaluation
   * @return Map with the scores of the records
   */
  public AssociativeArray estimateEfficiency(
      Map<Object, DeaRecord> id2DeaRecordMapDatabase,
      Map<Object, DeaRecord> id2DeaRecordMapEvaluation) {
    AssociativeArray evaluatedResults = new AssociativeArray();

    List<LPSolver.LPConstraint> constraints = new ArrayList<>();

    // initialize the constraints list
    Integer totalColumns = null;
    boolean hasInput = false;
    for (Map.Entry<Object, DeaRecord> entry : id2DeaRecordMapDatabase.entrySet()) {
      DeaRecord currentRecord = entry.getValue();
      int currentColumns = currentRecord.getInput().length; // add the size of input array
      boolean currentHasInput = (currentColumns > 0); // check if the input is defined

      currentColumns += currentRecord.getOutput().length; // add the size of output array

      if (totalColumns == null) { // if totalColumns is not set, then set them
        // the totalColumns is the sum of input and output columns
        totalColumns = currentColumns;
        hasInput = currentHasInput;
      } else { // if totalColumns is initialized, validate that the record has exactly this amount
               // of columns
        if (totalColumns != currentColumns) {
          throw new IllegalArgumentException(
              "The input and output columns do not match in all records.");
        }
        if (hasInput != currentHasInput) {
          throw new IllegalArgumentException("The input should be used in all records or in none.");
        }
      }

      // We have two cases. Either an input is defined for the records or not.
      // The mathematical model is formulated differently depending the case
      if (hasInput == false) {
        // if no input then change the way that the linear problem formulates
        constraints.add(
            new LPSolver.LPConstraint(currentRecord.getOutput(), LpSolve.LE, 1.0)); // less than 1
      } else {
        // create a double[] with size both of the input and output
        double[] currentConstraintBody = new double[totalColumns];

        // set the values of output first on the new array
        double[] conOutput = currentRecord.getOutput();
        for (int i = 0; i < conOutput.length; ++i) {
          currentConstraintBody[i] = conOutput[i];
        }

        // now set the input by negatiting the values
        double[] conInput = currentRecord.getInput();
        for (int i = 0; i < conInput.length; ++i) {
          currentConstraintBody[conOutput.length + i] = -conInput[i];
        }
        conOutput = null;
        conInput = null;

        // add the constrain on the list
        constraints.add(
            new LPSolver.LPConstraint(currentConstraintBody, LpSolve.LE, 0.0)); // less than 0
      }
    }

    for (Map.Entry<Object, DeaRecord> entry : id2DeaRecordMapEvaluation.entrySet()) {
      Object currentRecordId = entry.getKey();
      DeaRecord currentRecord = entry.getValue();

      double[] objectiveFunction;
      if (hasInput == false) {
        // set the Objection function equal to the output of the record
        objectiveFunction = currentRecord.getOutput();
      } else {
        // create a double[] with size both of the input and output
        objectiveFunction = new double[totalColumns];
        double[] denominatorConstraintBody = new double[totalColumns];

        // set the values of output first on the new array
        double[] conOutput = currentRecord.getOutput();
        for (int i = 0; i < conOutput.length; ++i) {
          objectiveFunction[i] = conOutput[i]; // set the output to the objective function
          denominatorConstraintBody[i] = 0.0; // set zero to the constraint
        }

        // set the values of input first on the new array
        double[] conInput = currentRecord.getInput();
        for (int i = 0; i < conInput.length; ++i) {
          objectiveFunction[conOutput.length + i] =
              0.0; // set zeros on objective function for input
          denominatorConstraintBody[conOutput.length + i] =
              conInput[i]; // set the input to the constraint
        }
        conInput = null;
        conOutput = null;

        // set the denominator equal to 1
        constraints.add(new LPSolver.LPConstraint(denominatorConstraintBody, LpSolve.EQ, 1.0));
      }

      double[] lowBoundsOfVariables = null;
      double[] upBoundsOfVariables = null;
      boolean[] strictlyIntegerVariables = null;

      /*
      lowBoundsOfVariables = new double[totalColumns];
      upBoundsOfVariables = new double[totalColumns];
      strictlyIntegerVariables = new boolean[totalColumns];
      for(int i =0; i<totalColumns;++i) {
          lowBoundsOfVariables[i]=0;
          upBoundsOfVariables[i]=Double.MAX_VALUE;
          strictlyIntegerVariables[i]=false;
      }
      */

      Integer scalingMode = LpSolve.SCALE_GEOMETRIC;

      // RUN SOLVE
      Double objectiveValue = null;
      try {
        LPSolver.LPResult result =
            LPSolver.solve(
                objectiveFunction,
                constraints,
                lowBoundsOfVariables,
                upBoundsOfVariables,
                strictlyIntegerVariables,
                scalingMode);
        objectiveValue = result.getObjectiveValue();
      } catch (LpSolveException ex) {
        throw new RuntimeException(ex);
      }

      if (hasInput) {
        constraints.remove(constraints.size() - 1); // remove the last constraint that you put it
      }

      evaluatedResults.put(currentRecordId, objectiveValue);
    }

    return evaluatedResults;
  }
  private ValidationMetrics predictAndValidate(Dataset newData) {
    // This method uses similar approach to the training but the most important
    // difference is that we do not wish to modify the original training params.
    // as a result we need to modify the code to use additional temporary
    // counts for the testing data and merge them with the parameters from the
    // training data in order to make a decision
    ModelParameters modelParameters = knowledgeBase.getModelParameters();
    TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();

    // create new validation metrics object
    ValidationMetrics validationMetrics = knowledgeBase.getEmptyValidationMetricsObject();

    String tmpPrefix = StorageConfiguration.getTmpPrefix();

    // get model parameters
    int n = modelParameters.getN();
    int d = modelParameters.getD();
    int k = trainingParameters.getK(); // number of topics

    Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts();
    Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts();

    BigDataStructureFactory.MapType mapType = knowledgeBase.getMemoryConfiguration().getMapType();
    int LRUsize = knowledgeBase.getMemoryConfiguration().getLRUsize();
    BigDataStructureFactory bdsf = knowledgeBase.getBdsf();

    // we create temporary maps for the prediction sets to avoid modifing the maps that we already
    // learned
    Map<List<Object>, Integer> tmp_topicAssignmentOfDocumentWord =
        bdsf.getMap(tmpPrefix + "topicAssignmentOfDocumentWord", mapType, LRUsize);
    Map<List<Integer>, Integer> tmp_documentTopicCounts =
        bdsf.getMap(tmpPrefix + "documentTopicCounts", mapType, LRUsize);
    Map<List<Object>, Integer> tmp_topicWordCounts =
        bdsf.getMap(tmpPrefix + "topicWordCounts", mapType, LRUsize);
    Map<Integer, Integer> tmp_topicCounts =
        bdsf.getMap(tmpPrefix + "topicCounts", mapType, LRUsize);

    // initialize topic assignments of each word randomly and update the counters
    for (Record r : newData) {
      Integer documentId = r.getId();

      for (Map.Entry<Object, Object> entry : r.getX().entrySet()) {
        Object wordPosition = entry.getKey();
        Object word = entry.getValue();

        // sample a topic
        Integer topic = PHPfunctions.mt_rand(0, k - 1);

        increase(tmp_topicCounts, topic);
        tmp_topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
        increase(tmp_documentTopicCounts, Arrays.asList(documentId, topic));
        increase(tmp_topicWordCounts, Arrays.asList(topic, word));
      }
    }

    double alpha = trainingParameters.getAlpha();
    double beta = trainingParameters.getBeta();

    int maxIterations = trainingParameters.getMaxIterations();

    double perplexity = Double.MAX_VALUE;
    for (int iteration = 0; iteration < maxIterations; ++iteration) {

      if (GeneralConfiguration.DEBUG) {
        System.out.println("Iteration " + iteration);
      }

      // collapsed gibbs sampler
      int changedCounter = 0;
      perplexity = 0.0;
      double totalDatasetWords = 0.0;
      for (Record r : newData) {
        Integer documentId = r.getId();

        AssociativeArray topicAssignments = new AssociativeArray();
        for (int j = 0; j < k; ++j) {
          topicAssignments.put(j, 0.0);
        }

        int totalDocumentWords = r.getX().size();
        totalDatasetWords += totalDocumentWords;
        for (Map.Entry<Object, Object> entry : r.getX().entrySet()) {
          Object wordPosition = entry.getKey();
          Object word = entry.getValue();

          // remove the word from the dataset
          Integer topic =
              tmp_topicAssignmentOfDocumentWord.get(Arrays.asList(documentId, wordPosition));
          decrease(tmp_topicCounts, topic);
          decrease(tmp_documentTopicCounts, Arrays.asList(documentId, topic));
          decrease(tmp_topicWordCounts, Arrays.asList(topic, word));

          int numberOfDocumentWords = r.getX().size() - 1;

          // compute the posteriors of the topics and sample from it
          AssociativeArray topicProbabilities = new AssociativeArray();
          for (int j = 0; j < k; ++j) {
            double enumerator = 0.0;

            // get the counts from the current testing data
            List<Object> topicWordKey = Arrays.asList(j, word);
            Integer njw = tmp_topicWordCounts.get(topicWordKey);
            if (njw != null) {
              enumerator = njw + beta;
            } else {
              enumerator = beta;
            }

            // get also the counts from the training data
            Integer njw_original = topicWordCounts.get(topicWordKey);
            if (njw_original != null) {
              enumerator += njw_original;
            }

            Integer njd = tmp_documentTopicCounts.get(Arrays.asList(documentId, j));
            if (njd != null) {
              enumerator *= (njd + alpha);
            } else {
              enumerator *= alpha;
            }

            // add the counts from testing data
            double denominator = tmp_topicCounts.get((Integer) j) + beta * d - 1;
            // and the ones from training data
            denominator += topicCounts.get((Integer) j);
            denominator *= numberOfDocumentWords + alpha * k;

            topicProbabilities.put(j, enumerator / denominator);
          }

          perplexity += Math.log(Descriptives.sum(topicProbabilities.toFlatDataCollection()));

          // normalize probabilities
          Descriptives.normalize(topicProbabilities);

          // sample from these probabilieis
          Integer newTopic =
              (Integer)
                  SRS.weightedProbabilitySampling(topicProbabilities, 1, true).iterator().next();
          topic = newTopic; // new topic assignment

          // add back the word in the dataset
          tmp_topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
          increase(tmp_topicCounts, topic);
          increase(tmp_documentTopicCounts, Arrays.asList(documentId, topic));
          increase(tmp_topicWordCounts, Arrays.asList(topic, word));

          topicAssignments.put(
              topic, Dataset.toDouble(topicAssignments.get(topic)) + 1.0 / totalDocumentWords);
        }

        Object mainTopic = MapFunctions.selectMaxKeyValue(topicAssignments).getKey();

        if (!mainTopic.equals(r.getYPredicted())) {
          ++changedCounter;
        }
        r.setYPredicted(mainTopic);
        r.setYPredictedProbabilities(topicAssignments);
      }

      perplexity = Math.exp(-perplexity / totalDatasetWords);

      if (GeneralConfiguration.DEBUG) {
        System.out.println("Reassigned Records " + changedCounter + " - Perplexity: " + perplexity);
      }

      if (changedCounter == 0) {
        break;
      }
    }

    // Drop the temporary Collection
    bdsf.dropTable(tmpPrefix + "topicAssignmentOfDocumentWord", tmp_topicAssignmentOfDocumentWord);
    bdsf.dropTable(tmpPrefix + "documentTopicCounts", tmp_documentTopicCounts);
    bdsf.dropTable(tmpPrefix + "topicWordCounts", tmp_topicWordCounts);
    bdsf.dropTable(tmpPrefix + "topicCounts", tmp_topicCounts);

    validationMetrics.setPerplexity(perplexity);

    return validationMetrics;
  }
  @Override
  protected void estimateModelParameters(Dataset trainingData) {
    int n = trainingData.size();
    int d = trainingData.getColumnSize();

    ModelParameters modelParameters = knowledgeBase.getModelParameters();
    TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();

    modelParameters.setN(n);
    modelParameters.setD(d);

    // get model parameters
    int k = trainingParameters.getK(); // number of topics
    Map<List<Object>, Integer> topicAssignmentOfDocumentWord =
        modelParameters.getTopicAssignmentOfDocumentWord();
    Map<List<Integer>, Integer> documentTopicCounts = modelParameters.getDocumentTopicCounts();
    Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts();
    Map<Integer, Integer> documentWordCounts = modelParameters.getDocumentWordCounts();
    Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts();

    // initialize topic assignments of each word randomly and update the counters
    for (Record r : trainingData) {
      Integer documentId = r.getId();

      documentWordCounts.put(documentId, r.getX().size());

      for (Map.Entry<Object, Object> entry : r.getX().entrySet()) {
        Object wordPosition = entry.getKey();
        Object word = entry.getValue();

        // sample a topic
        Integer topic = PHPfunctions.mt_rand(0, k - 1);

        increase(topicCounts, topic);
        topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
        increase(documentTopicCounts, Arrays.asList(documentId, topic));
        increase(topicWordCounts, Arrays.asList(topic, word));
      }
    }

    double alpha = trainingParameters.getAlpha();
    double beta = trainingParameters.getBeta();

    int maxIterations = trainingParameters.getMaxIterations();

    int iteration = 0;
    while (iteration < maxIterations) {

      if (GeneralConfiguration.DEBUG) {
        System.out.println("Iteration " + iteration);
      }

      int changedCounter = 0;
      // collapsed gibbs sampler
      for (Record r : trainingData) {
        Integer documentId = r.getId();

        AssociativeArray topicAssignments = new AssociativeArray();
        for (int j = 0; j < k; ++j) {
          topicAssignments.put(j, 0.0);
        }

        int totalWords = r.getX().size();

        for (Map.Entry<Object, Object> entry : r.getX().entrySet()) {
          Object wordPosition = entry.getKey();
          Object word = entry.getValue();

          // remove the word from the dataset
          Integer topic =
              topicAssignmentOfDocumentWord.get(Arrays.asList(documentId, wordPosition));
          // decrease(documentWordCounts, documentId); //slow
          decrease(topicCounts, topic);
          decrease(documentTopicCounts, Arrays.asList(documentId, topic));
          decrease(topicWordCounts, Arrays.asList(topic, word));

          // int numberOfDocumentWords = r.getX().size()-1; //fast - decreased by 1

          // compute the posteriors of the topics and sample from it
          AssociativeArray topicProbabilities = new AssociativeArray();
          for (int j = 0; j < k; ++j) {
            double enumerator = 0.0;
            Integer njw = topicWordCounts.get(Arrays.asList(j, word));
            if (njw != null) {
              enumerator = njw + beta;
            } else {
              enumerator = beta;
            }

            Integer njd = documentTopicCounts.get(Arrays.asList(documentId, j));
            if (njd != null) {
              enumerator *= (njd + alpha);
            } else {
              enumerator *= alpha;
            }

            double denominator = topicCounts.get((Integer) j) + beta * d;
            // denominator *= numberOfDocumentWords+alpha*k; //this is not necessary because it is
            // the same for all categories, so it can be omited

            topicProbabilities.put(j, enumerator / denominator);
          }

          // normalize probabilities
          Descriptives.normalize(topicProbabilities);

          // sample from these probabilieis
          Integer newTopic =
              (Integer)
                  SRS.weightedProbabilitySampling(topicProbabilities, 1, true).iterator().next();
          topic = newTopic; // new topic assigment

          // add back the word in the dataset
          topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
          // increase(documentWordCounts, documentId); //slow
          increase(topicCounts, topic);
          increase(documentTopicCounts, Arrays.asList(documentId, topic));
          increase(topicWordCounts, Arrays.asList(topic, word));

          topicAssignments.put(
              topic, Dataset.toDouble(topicAssignments.get(topic)) + 1.0 / totalWords);
        }

        Object mainTopic = MapFunctions.selectMaxKeyValue(topicAssignments).getKey();

        if (!mainTopic.equals(r.getYPredicted())) {
          ++changedCounter;
        }
        r.setYPredicted(mainTopic);
        r.setYPredictedProbabilities(topicAssignments);
      }
      ++iteration;

      if (GeneralConfiguration.DEBUG) {
        System.out.println("Reassigned Records " + changedCounter);
      }

      if (changedCounter == 0) {
        break;
      }
    }

    modelParameters.setTotalIterations(iteration);
  }
  /**
   * Calculates the score of Chisquare test.
   *
   * @param dataTable
   * @return
   * @throws IllegalArgumentException
   */
  public static AssociativeArray getScore(DataTable2D dataTable) throws IllegalArgumentException {
    if (dataTable.isValid() == false) {
      throw new IllegalArgumentException();
    }
    // Estimate marginal scores and sum
    Map<Object, Double> XdotJ = new HashMap<>();
    Map<Object, Double> XIdot = new HashMap<>();
    double Xdotdot = 0.0;

    for (Map.Entry<Object, AssociativeArray> entry1 : dataTable.entrySet()) {
      Object i = entry1.getKey();
      AssociativeArray row = entry1.getValue();

      for (Map.Entry<Object, Object> entry2 : row.entrySet()) {
        Object j = entry2.getKey();
        Object value = entry2.getValue();

        double v = Dataset.toDouble(value);

        // Summing the columns
        if (XdotJ.containsKey(j) == false) {
          XdotJ.put(j, v);
        } else {
          XdotJ.put(j, XdotJ.get(j) + v);
        }

        // Summing the rows
        if (XIdot.containsKey(i) == false) {
          XIdot.put(i, v);
        } else {
          XIdot.put(i, XIdot.get(i) + v);
        }

        Xdotdot += v;
      }
    }

    int k = XdotJ.size();
    int n = XIdot.size();

    // Calculating Chisquare score
    double ChisquareScore = 0.0;
    if (k == 2
        && n
            == 2) { // if 2x2 then perform the Yates correction. Make this check outside the loops
                    // to make it faster
      for (Map.Entry<Object, AssociativeArray> entry1 : dataTable.entrySet()) {
        Object i = entry1.getKey();
        AssociativeArray row = entry1.getValue();

        for (Map.Entry<Object, Object> entry2 : row.entrySet()) {
          Object j = entry2.getKey();
          Object value = entry2.getValue();

          double v = Dataset.toDouble(value);

          // expected value under null hypothesis
          double eij = XIdot.get(i) * XdotJ.get(j) / Xdotdot;
          if (eij == 0) {
            continue;
          }
          ChisquareScore += Math.pow((Math.abs(v - eij) - 0.5), 2) / eij;
        }
      }
    } else {
      for (Map.Entry<Object, AssociativeArray> entry1 : dataTable.entrySet()) {
        Object i = entry1.getKey();
        AssociativeArray row = entry1.getValue();

        for (Map.Entry<Object, Object> entry2 : row.entrySet()) {
          Object j = entry2.getKey();
          Object value = entry2.getValue();

          double v = Dataset.toDouble(value);

          // expected value under null hypothesis
          double eij = XIdot.get(i) * XdotJ.get(j) / Xdotdot;

          ChisquareScore += Math.pow((v - eij), 2) / eij;
        }
      }
    }
    XdotJ = null;
    XIdot = null;

    AssociativeArray result = new AssociativeArray();
    result.put("k", k);
    result.put("n", n);
    result.put("score", ChisquareScore);

    return result;
  }