@Override protected void _transform(Dataset data, boolean trainingMode) { // handle non-numeric types, extract columns to dummy variables Map<Object, Dataset.ColumnType> newColumns = new HashMap<>(); int n = data.size(); Iterator<Map.Entry<Object, Dataset.ColumnType>> it = data.getColumns().entrySet().iterator(); while (it.hasNext()) { Map.Entry<Object, Dataset.ColumnType> entry = it.next(); Object column = entry.getKey(); Dataset.ColumnType columnType = entry.getValue(); if (columnType == Dataset.ColumnType.CATEGORICAL || columnType == Dataset.ColumnType .ORDINAL) { // ordinal and categorical are converted into dummyvars // Remove the old column from the column map it.remove(); // create dummy variables for all the levels for (Record r : data) { if (!r.getX().containsKey(column)) { continue; // does not contain column } Object value = r.getX().get(column); // remove the column from data r.getX().remove(column); List<Object> newColumn = Arrays.<Object>asList(column, value); // add a new boolean feature with combination of column and value r.getX().put(newColumn, true); // add the new column in the list for insertion newColumns.put(newColumn, Dataset.ColumnType.DUMMYVAR); } } } // add the new columns in the dataset column map if (!newColumns.isEmpty()) { data.getColumns().putAll(newColumns); } }
@Override protected VM validateModel(Dataset validationData) { predictDataset(validationData); Set<Object> classesSet = knowledgeBase.getModelParameters().getClasses(); // create new validation metrics object VM validationMetrics = knowledgeBase.getEmptyValidationMetricsObject(); // short notation Map<List<Object>, Double> ctMap = validationMetrics.getContingencyTable(); for (Object theClass : classesSet) { ctMap.put(Arrays.<Object>asList(theClass, SensitivityRates.TP), 0.0); // true possitive ctMap.put(Arrays.<Object>asList(theClass, SensitivityRates.FP), 0.0); // false possitive ctMap.put(Arrays.<Object>asList(theClass, SensitivityRates.TN), 0.0); // true negative ctMap.put(Arrays.<Object>asList(theClass, SensitivityRates.FN), 0.0); // false negative } int n = validationData.size(); int c = classesSet.size(); int correctCount = 0; for (Record r : validationData) { if (r.getYPredicted().equals(r.getY())) { ++correctCount; for (Object cl : classesSet) { if (cl.equals(r.getYPredicted())) { List<Object> tpk = Arrays.<Object>asList(cl, SensitivityRates.TP); ctMap.put(tpk, ctMap.get(tpk) + 1.0); } else { List<Object> tpk = Arrays.<Object>asList(cl, SensitivityRates.TN); ctMap.put(tpk, ctMap.get(tpk) + 1.0); } } } else { for (Object cl : classesSet) { if (cl.equals(r.getYPredicted())) { List<Object> tpk = Arrays.<Object>asList(cl, SensitivityRates.FP); ctMap.put(tpk, ctMap.get(tpk) + 1.0); } else if (cl.equals(r.getY())) { List<Object> tpk = Arrays.<Object>asList(cl, SensitivityRates.FN); ctMap.put(tpk, ctMap.get(tpk) + 1.0); } else { List<Object> tpk = Arrays.<Object>asList(cl, SensitivityRates.TN); ctMap.put(tpk, ctMap.get(tpk) + 1.0); } } } } validationMetrics.setAccuracy(correctCount / (double) n); // Average Precision, Recall and F1: // http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.104.8244&rep=rep1&type=pdf for (Object theClass : classesSet) { double tp = ctMap.get(Arrays.<Object>asList(theClass, SensitivityRates.TP)); double fp = ctMap.get(Arrays.<Object>asList(theClass, SensitivityRates.FP)); double fn = ctMap.get(Arrays.<Object>asList(theClass, SensitivityRates.FN)); double classPrecision = 0.0; double classRecall = 0.0; double classF1 = 0.0; if (tp > 0.0) { classPrecision = tp / (tp + fp); classRecall = tp / (tp + fn); classF1 = 2.0 * classPrecision * classRecall / (classPrecision + classRecall); } else if (tp == 0.0 && fp == 0.0 && fn == 0.0) { // if this category did not appear in the dataset then set the metrics to 1 classPrecision = 1.0; classRecall = 1.0; classF1 = 1.0; } validationMetrics.getMicroPrecision().put(theClass, classPrecision); validationMetrics.getMicroRecall().put(theClass, classRecall); validationMetrics.getMicroF1().put(theClass, classF1); validationMetrics.setMacroPrecision( validationMetrics.getMacroPrecision() + classPrecision / c); validationMetrics.setMacroRecall(validationMetrics.getMacroRecall() + classRecall / c); validationMetrics.setMacroF1(validationMetrics.getMacroF1() + classF1 / c); } return validationMetrics; }
@Override protected void estimateModelParameters(Dataset trainingData) { int n = trainingData.size(); int d = trainingData.getColumnSize(); ModelParameters modelParameters = knowledgeBase.getModelParameters(); TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters(); modelParameters.setN(n); modelParameters.setD(d); // get model parameters int k = trainingParameters.getK(); // number of topics Map<List<Object>, Integer> topicAssignmentOfDocumentWord = modelParameters.getTopicAssignmentOfDocumentWord(); Map<List<Integer>, Integer> documentTopicCounts = modelParameters.getDocumentTopicCounts(); Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts(); Map<Integer, Integer> documentWordCounts = modelParameters.getDocumentWordCounts(); Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts(); // initialize topic assignments of each word randomly and update the counters for (Record r : trainingData) { Integer documentId = r.getId(); documentWordCounts.put(documentId, r.getX().size()); for (Map.Entry<Object, Object> entry : r.getX().entrySet()) { Object wordPosition = entry.getKey(); Object word = entry.getValue(); // sample a topic Integer topic = PHPfunctions.mt_rand(0, k - 1); increase(topicCounts, topic); topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic); increase(documentTopicCounts, Arrays.asList(documentId, topic)); increase(topicWordCounts, Arrays.asList(topic, word)); } } double alpha = trainingParameters.getAlpha(); double beta = trainingParameters.getBeta(); int maxIterations = trainingParameters.getMaxIterations(); int iteration = 0; while (iteration < maxIterations) { if (GeneralConfiguration.DEBUG) { System.out.println("Iteration " + iteration); } int changedCounter = 0; // collapsed gibbs sampler for (Record r : trainingData) { Integer documentId = r.getId(); AssociativeArray topicAssignments = new AssociativeArray(); for (int j = 0; j < k; ++j) { topicAssignments.put(j, 0.0); } int totalWords = r.getX().size(); for (Map.Entry<Object, Object> entry : r.getX().entrySet()) { Object wordPosition = entry.getKey(); Object word = entry.getValue(); // remove the word from the dataset Integer topic = topicAssignmentOfDocumentWord.get(Arrays.asList(documentId, wordPosition)); // decrease(documentWordCounts, documentId); //slow decrease(topicCounts, topic); decrease(documentTopicCounts, Arrays.asList(documentId, topic)); decrease(topicWordCounts, Arrays.asList(topic, word)); // int numberOfDocumentWords = r.getX().size()-1; //fast - decreased by 1 // compute the posteriors of the topics and sample from it AssociativeArray topicProbabilities = new AssociativeArray(); for (int j = 0; j < k; ++j) { double enumerator = 0.0; Integer njw = topicWordCounts.get(Arrays.asList(j, word)); if (njw != null) { enumerator = njw + beta; } else { enumerator = beta; } Integer njd = documentTopicCounts.get(Arrays.asList(documentId, j)); if (njd != null) { enumerator *= (njd + alpha); } else { enumerator *= alpha; } double denominator = topicCounts.get((Integer) j) + beta * d; // denominator *= numberOfDocumentWords+alpha*k; //this is not necessary because it is // the same for all categories, so it can be omited topicProbabilities.put(j, enumerator / denominator); } // normalize probabilities Descriptives.normalize(topicProbabilities); // sample from these probabilieis Integer newTopic = (Integer) SRS.weightedProbabilitySampling(topicProbabilities, 1, true).iterator().next(); topic = newTopic; // new topic assigment // add back the word in the dataset topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic); // increase(documentWordCounts, documentId); //slow increase(topicCounts, topic); increase(documentTopicCounts, Arrays.asList(documentId, topic)); increase(topicWordCounts, Arrays.asList(topic, word)); topicAssignments.put( topic, Dataset.toDouble(topicAssignments.get(topic)) + 1.0 / totalWords); } Object mainTopic = MapFunctions.selectMaxKeyValue(topicAssignments).getKey(); if (!mainTopic.equals(r.getYPredicted())) { ++changedCounter; } r.setYPredicted(mainTopic); r.setYPredictedProbabilities(topicAssignments); } ++iteration; if (GeneralConfiguration.DEBUG) { System.out.println("Reassigned Records " + changedCounter); } if (changedCounter == 0) { break; } } modelParameters.setTotalIterations(iteration); }