public static void main(String[] args) throws Exception { /* * First we load the test data from our ARFF file */ ArffLoader testLoader = new ArffLoader(); testLoader.setSource(new File("data/titanic/test.arff")); testLoader.setRetrieval(Loader.BATCH); Instances testDataSet = testLoader.getDataSet(); /* * Now we tell the data set which attribute we want to classify, in our * case, we want to classify the first column: survived */ Attribute testAttribute = testDataSet.attribute(0); testDataSet.setClass(testAttribute); testDataSet.deleteStringAttributes(); /* * Now we read in the serialized model from disk */ Classifier classifier = (Classifier) SerializationHelper.read("data/titanic/titanic.model"); /* * This part may be a little confusing. We load up the test data again * so we have a prediction data set to populate. As we iterate over the * first data set we also iterate over the second data set. After an * instance is classified, we set the value of the prediction data set * to be the value of the classification */ ArffLoader test1Loader = new ArffLoader(); test1Loader.setSource(new File("data/titanic/test.arff")); Instances test1DataSet = test1Loader.getDataSet(); Attribute test1Attribute = test1DataSet.attribute(0); test1DataSet.setClass(test1Attribute); /* * Now we iterate over the test data and classify each entry and set the * value of the 'survived' column to the result of the classification */ Enumeration testInstances = testDataSet.enumerateInstances(); Enumeration test1Instances = test1DataSet.enumerateInstances(); while (testInstances.hasMoreElements()) { Instance instance = (Instance) testInstances.nextElement(); Instance instance1 = (Instance) test1Instances.nextElement(); double classification = classifier.classifyInstance(instance); instance1.setClassValue(classification); } /* * Now we want to write out our predictions. The resulting file is in a * format suitable to submit to Kaggle. */ CSVSaver predictedCsvSaver = new CSVSaver(); predictedCsvSaver.setFile(new File("data/titanic/predict.csv")); predictedCsvSaver.setInstances(test1DataSet); predictedCsvSaver.writeBatch(); System.out.println("Prediciton saved to predict.csv"); }
private double calcNodeScorePlain(int nNode) { Instances instances = m_BayesNet.m_Instances; ParentSet oParentSet = m_BayesNet.getParentSet(nNode); // determine cardinality of parent set & reserve space for frequency counts int nCardinality = oParentSet.getCardinalityOfParents(); int numValues = instances.attribute(nNode).numValues(); int[] nCounts = new int[nCardinality * numValues]; // initialize (don't need this?) for (int iParent = 0; iParent < nCardinality * numValues; iParent++) { nCounts[iParent] = 0; } // estimate distributions Enumeration enumInsts = instances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = (Instance) enumInsts.nextElement(); // updateClassifier; double iCPT = 0; for (int iParent = 0; iParent < oParentSet.getNrOfParents(); iParent++) { int nParent = oParentSet.getParent(iParent); iCPT = iCPT * instances.attribute(nParent).numValues() + instance.value(nParent); } nCounts[numValues * ((int) iCPT) + (int) instance.value(nNode)]++; } return calcScoreOfCounts(nCounts, nCardinality, numValues, instances); } // CalcNodeScore
public double ExpectedClassificationError(Instances pool, int attr_i) { // initialize alpha's to one int alpha[][][]; int NumberOfFeatures = pool.numAttributes() - 1; int NumberOfLabels = pool.numClasses(); alpha = new int[NumberOfFeatures][NumberOfLabels][]; for (int i = 0; i < NumberOfFeatures; i++) for (int j = 0; j < NumberOfLabels; j++) alpha[i][j] = new int[pool.attribute(i).numValues()]; for (int i = 0; i < NumberOfFeatures; i++) for (int j = 0; j < NumberOfLabels; j++) for (int k = 0; k < alpha[i][j].length; k++) alpha[i][j][k] = 1; // construct alpha's for (int i = 0; i < NumberOfFeatures; i++) // for each attribute { if (i == pool.classIndex()) // skip the class attribute i++; for (Enumeration<Instance> e = pool.enumerateInstances(); e.hasMoreElements(); ) // for each instance { Instance inst = e.nextElement(); if (!inst.isMissing(i)) // if attribute i is not missing (i.e. its been bought) { int j = (int) inst.classValue(); int k = (int) inst.value(i); alpha[i][j][k]++; } } } return ExpectedClassificationError(alpha, attr_i); }
/** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if an error occurred during the prediction */ public double[] distributionForInstance(Instance instance) throws Exception { String debug = "(KStar.distributionForInstance) "; double transProb = 0.0, temp = 0.0; double[] classProbability = new double[m_NumClasses]; double[] predictedValue = new double[1]; // initialization ... for (int i = 0; i < classProbability.length; i++) { classProbability[i] = 0.0; } predictedValue[0] = 0.0; if (m_InitFlag == ON) { // need to compute them only once and will be used for all instances. // We are doing this because the evaluation module controls the calls. if (m_BlendMethod == B_ENTROPY) { generateRandomClassColomns(); } m_Cache = new KStarCache[m_NumAttributes]; for (int i = 0; i < m_NumAttributes; i++) { m_Cache[i] = new KStarCache(); } m_InitFlag = OFF; // System.out.println("Computing..."); } // init done. Instance trainInstance; Enumeration enu = m_Train.enumerateInstances(); while (enu.hasMoreElements()) { trainInstance = (Instance) enu.nextElement(); transProb = instanceTransformationProbability(instance, trainInstance); switch (m_ClassType) { case Attribute.NOMINAL: classProbability[(int) trainInstance.classValue()] += transProb; break; case Attribute.NUMERIC: predictedValue[0] += transProb * trainInstance.classValue(); temp += transProb; break; } } if (m_ClassType == Attribute.NOMINAL) { double sum = Utils.sum(classProbability); if (sum <= 0.0) for (int i = 0; i < classProbability.length; i++) classProbability[i] = (double) 1 / (double) m_NumClasses; else Utils.normalize(classProbability, sum); return classProbability; } else { predictedValue[0] = (temp != 0) ? predictedValue[0] / temp : 0.0; return predictedValue; } }
/** * Gets the subset of instances that apply to a particluar branch of the split. If the branch * index is -1, the subset will consist of those instances that don't apply to any branch. * * @param branch the index of the branch * @param sourceInstances the instances from which to find the subset * @return the set of instances that apply */ public ReferenceInstances instancesDownBranch(int branch, Instances instances) { ReferenceInstances filteredInstances = new ReferenceInstances(instances, 1); if (branch == -1) { for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { Instance inst = (Instance) e.nextElement(); if (inst.isMissing(attIndex)) filteredInstances.addReference(inst); } } else if (branch == 0) { for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { Instance inst = (Instance) e.nextElement(); if (!inst.isMissing(attIndex) && inst.value(attIndex) < splitPoint) filteredInstances.addReference(inst); } } else { for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { Instance inst = (Instance) e.nextElement(); if (!inst.isMissing(attIndex) && inst.value(attIndex) >= splitPoint) filteredInstances.addReference(inst); } } return filteredInstances; }
/** * Splits a dataset according to the values of a nominal attribute. * * @param data the data which is to be split * @param att the attribute to be used for splitting * @return the sets of instances produced by the split */ private Instances[] splitData(Instances data, Attribute att) { Instances[] splitData = new Instances[att.numValues()]; for (int j = 0; j < att.numValues(); j++) { splitData[j] = new Instances(data, data.numInstances()); } Enumeration instEnum = data.enumerateInstances(); while (instEnum.hasMoreElements()) { Instance inst = (Instance) instEnum.nextElement(); splitData[(int) inst.value(att)].add(inst); } for (int i = 0; i < splitData.length; i++) { splitData[i].compactify(); } return splitData; }
/** * Computes the entropy of a dataset. * * @param data the data for which entropy is to be computed * @return the entropy of the data's class distribution * @throws Exception if computation fails */ private double computeEntropy(Instances data) throws Exception { double[] classCounts = new double[data.numClasses()]; Enumeration instEnum = data.enumerateInstances(); while (instEnum.hasMoreElements()) { Instance inst = (Instance) instEnum.nextElement(); classCounts[(int) inst.classValue()]++; } double entropy = 0; for (int j = 0; j < data.numClasses(); j++) { if (classCounts[j] > 0) { entropy -= classCounts[j] * Utils.log2(classCounts[j]); } } entropy /= (double) data.numInstances(); return entropy + Utils.log2(data.numInstances()); }
/** * Sets split point to greatest value in given data smaller or equal to old split point. (C4.5 * does this for some strange reason). */ public final void setSplitPoint(Instances allInstances) { double newSplitPoint = -Double.MAX_VALUE; double tempValue; Instance instance; if ((!allInstances.attribute(m_attIndex).isNominal()) && (m_numSubsets > 1)) { Enumeration enu = allInstances.enumerateInstances(); while (enu.hasMoreElements()) { instance = (Instance) enu.nextElement(); if (!instance.isMissing(m_attIndex)) { tempValue = instance.value(m_attIndex); if (Utils.gr(tempValue, newSplitPoint) && Utils.smOrEq(tempValue, m_splitPoint)) newSplitPoint = tempValue; } } m_splitPoint = newSplitPoint; } }
/** * Creates split on enumerated attribute. * * @exception Exception if something goes wrong */ private void handleEnumeratedAttribute(Instances trainInstances) throws Exception { Distribution newDistribution, secondDistribution; int numAttValues; double currIG, currGR; Instance instance; int i; numAttValues = trainInstances.attribute(m_attIndex).numValues(); newDistribution = new Distribution(numAttValues, trainInstances.numClasses()); // Only Instances with known values are relevant. Enumeration enu = trainInstances.enumerateInstances(); while (enu.hasMoreElements()) { instance = (Instance) enu.nextElement(); if (!instance.isMissing(m_attIndex)) newDistribution.add((int) instance.value(m_attIndex), instance); } m_distribution = newDistribution; // For all values for (i = 0; i < numAttValues; i++) { if (Utils.grOrEq(newDistribution.perBag(i), m_minNoObj)) { secondDistribution = new Distribution(newDistribution, i); // Check if minimum number of Instances in the two // subsets. if (secondDistribution.check(m_minNoObj)) { m_numSubsets = 2; currIG = m_infoGainCrit.splitCritValue(secondDistribution, m_sumOfWeights); currGR = m_gainRatioCrit.splitCritValue(secondDistribution, m_sumOfWeights, currIG); if ((i == 0) || Utils.gr(currGR, m_gainRatio)) { m_gainRatio = currGR; m_infoGain = currIG; m_splitPoint = (double) i; m_distribution = secondDistribution; } } } } }
/** * Method for building an Id3 tree. * * @param data the training data * @exception Exception if decision tree can't be built successfully */ private void makeTree(Instances data) throws Exception { // Check if no instances have reached this node. if (data.numInstances() == 0) { m_Attribute = null; m_ClassValue = Utils.missingValue(); m_Distribution = new double[data.numClasses()]; return; } // Compute attribute with maximum information gain. double[] infoGains = new double[data.numAttributes()]; Enumeration attEnum = data.enumerateAttributes(); while (attEnum.hasMoreElements()) { Attribute att = (Attribute) attEnum.nextElement(); infoGains[att.index()] = computeInfoGain(data, att); } m_Attribute = data.attribute(Utils.maxIndex(infoGains)); // Make leaf if information gain is zero. // Otherwise create successors. if (Utils.eq(infoGains[m_Attribute.index()], 0)) { m_Attribute = null; m_Distribution = new double[data.numClasses()]; Enumeration instEnum = data.enumerateInstances(); while (instEnum.hasMoreElements()) { Instance inst = (Instance) instEnum.nextElement(); m_Distribution[(int) inst.classValue()]++; } Utils.normalize(m_Distribution); m_ClassValue = Utils.maxIndex(m_Distribution); m_ClassAttribute = data.classAttribute(); } else { Instances[] splitData = splitData(data, m_Attribute); m_Successors = new Id3[m_Attribute.numValues()]; for (int j = 0; j < m_Attribute.numValues(); j++) { m_Successors[j] = new Id3(); m_Successors[j].makeTree(splitData[j]); } } }
@Override protected void manipulateAttributes(Instances data) throws TaskException { String[] attribs = this.getParameterVal(ATTRIBS).split("\\s+"); for (int i = 0; i < attribs.length; ++i) { if (data.attribute(attribs[i]) == null) { Logger.getInstance() .message( "Attribute " + attribs[i] + " not found in data set " + data.relationName(), Logger.V_WARNING); continue; } Enumeration<Instance> insts = data.enumerateInstances(); int attrIndex = data.attribute(attribs[i]).index(); while (insts.hasMoreElements()) { Instance inst = insts.nextElement(); inst.setMissing(attrIndex); } } }
/** * Creates split on enumerated attribute. * * @exception Exception if something goes wrong */ private void handleEnumeratedAttribute(Instances trainInstances) throws Exception { Instance instance; m_distribution = new Distribution(m_complexityIndex, trainInstances.numClasses()); // Only Instances with known values are relevant. Enumeration<Instance> enu = trainInstances.enumerateInstances(); while (enu.hasMoreElements()) { instance = enu.nextElement(); if (!instance.isMissing(m_attIndex)) { m_distribution.add((int) instance.value(m_attIndex), instance); } } // Check if minimum number of Instances in at least two // subsets. if (m_distribution.check(m_minNoObj)) { m_numSubsets = m_complexityIndex; m_infoGain = infoGainCrit.splitCritValue(m_distribution, m_sumOfWeights); m_gainRatio = gainRatioCrit.splitCritValue(m_distribution, m_sumOfWeights, m_infoGain); } }
public static void main(String[] args) throws Exception { // NaiveBayesSimple nb = new NaiveBayesSimple(); // BufferedReader br_train = new BufferedReader(new FileReader("src/train.arff.txt")); // String s = null; // long st_time = System.currentTimeMillis(); // Instances inst_train = new Instances(br_train); // System.out.println(inst_train.numAttributes()); // inst_train.setClassIndex(inst_train.numAttributes()-1); // System.out.println("train time"+(System.currentTimeMillis()-st_time)); // NaiveBayes nb1 = new NaiveBayes(); // nb1.buildClassifier(inst_train); // br_train.close(); long st_time = System.currentTimeMillis(); st_time = System.currentTimeMillis(); Classifier classifier = (Classifier) SerializationHelper.read("NaiveBayes.model"); // BufferedReader br_test = new BufferedReader(new FileReader("src/test.arff.txt")); // Instances inst_test = new Instances(br_test); // inst_test.setClassIndex(inst_test.numAttributes()-1); // System.out.println("test time"+(System.currentTimeMillis()-st_time)); // ArffLoader testLoader = new ArffLoader(); testLoader.setSource(new File("src/test.arff")); testLoader.setRetrieval(Loader.BATCH); Instances testDataSet = testLoader.getDataSet(); Attribute testAttribute = testDataSet.attribute("class"); testDataSet.setClass(testAttribute); int correct = 0; int incorrect = 0; FastVector attInfo = new FastVector(); attInfo.addElement(new Attribute("Id")); attInfo.addElement(new Attribute("Category")); Instances outputInstances = new Instances("predict", attInfo, testDataSet.numInstances()); Enumeration testInstances = testDataSet.enumerateInstances(); int index = 1; while (testInstances.hasMoreElements()) { Instance instance = (Instance) testInstances.nextElement(); double classification = classifier.classifyInstance(instance); Instance predictInstance = new Instance(outputInstances.numAttributes()); predictInstance.setValue(0, index++); predictInstance.setValue(1, (int) classification + 1); outputInstances.add(predictInstance); } System.out.println("Correct Instance: " + correct); System.out.println("IncCorrect Instance: " + incorrect); double accuracy = (double) (correct) / (double) (correct + incorrect); System.out.println("Accuracy: " + accuracy); CSVSaver predictedCsvSaver = new CSVSaver(); predictedCsvSaver.setFile(new File("predict.csv")); predictedCsvSaver.setInstances(outputInstances); predictedCsvSaver.writeBatch(); System.out.println("Prediciton saved to predict.csv"); }
/** * Creates split on numeric attribute. * * @exception Exception if something goes wrong */ private void handleNumericAttribute(Instances trainInstances) throws Exception { int firstMiss; int next = 1; int last = 0; int index = 0; int splitIndex = -1; double currentInfoGain; double defaultEnt; double minSplit; Instance instance; int i; // Current attribute is a numeric attribute. m_distribution = new Distribution(2, trainInstances.numClasses()); // Only Instances with known values are relevant. Enumeration enu = trainInstances.enumerateInstances(); i = 0; while (enu.hasMoreElements()) { instance = (Instance) enu.nextElement(); if (instance.isMissing(m_attIndex)) break; m_distribution.add(1, instance); i++; } firstMiss = i; // Compute minimum number of Instances required in each // subset. minSplit = 0.1 * (m_distribution.total()) / ((double) trainInstances.numClasses()); if (Utils.smOrEq(minSplit, m_minNoObj)) minSplit = m_minNoObj; else if (Utils.gr(minSplit, 25)) minSplit = 25; // Enough Instances with known values? if (Utils.sm((double) firstMiss, 2 * minSplit)) return; // Compute values of criteria for all possible split // indices. defaultEnt = m_infoGainCrit.oldEnt(m_distribution); while (next < firstMiss) { if (trainInstances.instance(next - 1).value(m_attIndex) + 1e-5 < trainInstances.instance(next).value(m_attIndex)) { // Move class values for all Instances up to next // possible split point. m_distribution.shiftRange(1, 0, trainInstances, last, next); // Check if enough Instances in each subset and compute // values for criteria. if (Utils.grOrEq(m_distribution.perBag(0), minSplit) && Utils.grOrEq(m_distribution.perBag(1), minSplit)) { currentInfoGain = m_infoGainCrit.splitCritValue(m_distribution, m_sumOfWeights, defaultEnt); if (Utils.gr(currentInfoGain, m_infoGain)) { m_infoGain = currentInfoGain; splitIndex = next - 1; } index++; } last = next; } next++; } // Was there any useful split? if (index == 0) return; // Compute modified information gain for best split. if (m_useMDLcorrection) { m_infoGain = m_infoGain - (Utils.log2(index) / m_sumOfWeights); } if (Utils.smOrEq(m_infoGain, 0)) return; // Set instance variables' values to values for // best split. m_numSubsets = 2; m_splitPoint = (trainInstances.instance(splitIndex + 1).value(m_attIndex) + trainInstances.instance(splitIndex).value(m_attIndex)) / 2; // In case we have a numerical precision problem we need to choose the // smaller value if (m_splitPoint == trainInstances.instance(splitIndex + 1).value(m_attIndex)) { m_splitPoint = trainInstances.instance(splitIndex).value(m_attIndex); } // Restore distributioN for best split. m_distribution = new Distribution(2, trainInstances.numClasses()); m_distribution.addRange(0, trainInstances, 0, splitIndex + 1); m_distribution.addRange(1, trainInstances, splitIndex + 1, firstMiss); // Compute modified gain ratio for best split. m_gainRatio = m_gainRatioCrit.splitCritValue(m_distribution, m_sumOfWeights, m_infoGain); }
/** * add noise to the dataset * * <p>a given percentage of the instances are changed in the way, that a set of instances are * randomly selected using seed. The attribute given by its index is changed from its current * value to one of the other possibly ones, also randomly. This is done with leaving the apportion * the same. if m_UseMissing is true, missing value is used as a value of its own * * @param instances is the dataset * @param seed used for random function * @param percent percentage of instances that are changed * @param attIndex index of the attribute changed * @param useMissing if true missing values are treated as extra value */ public void addNoise( Instances instances, int seed, int percent, int attIndex, boolean useMissing) { int indexList[]; int partition_count[]; int partition_max[]; double splitPercent = (double) percent; // percentage used for splits // fill array with the indexes indexList = new int[instances.numInstances()]; for (int i = 0; i < instances.numInstances(); i++) { indexList[i] = i; } // randomize list of indexes Random random = new Random(seed); for (int i = instances.numInstances() - 1; i >= 0; i--) { int hValue = indexList[i]; int hIndex = (int) (random.nextDouble() * (double) i); indexList[i] = indexList[hIndex]; indexList[hIndex] = hValue; } // initialize arrays that are used to count instances // of each value and to keep the amount of instances of that value // that has to be changed // this is done for the missing values in the two variables // missing_count and missing_max int numValues = instances.attribute(attIndex).numValues(); partition_count = new int[numValues]; partition_max = new int[numValues]; int missing_count = 0; ; int missing_max = 0; ; for (int i = 0; i < numValues; i++) { partition_count[i] = 0; partition_max[i] = 0; } // go through the dataset and count all occurrences of values // and all missing values using temporarily .._max arrays and // variable missing_max for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { Instance instance = (Instance) e.nextElement(); if (instance.isMissing(attIndex)) { missing_max++; } else { int j = (int) instance.value(attIndex); partition_max[(int) instance.value(attIndex)]++; } } // use given percentage to calculate // how many have to be changed per split and // how many of the missing values if (!useMissing) { missing_max = missing_count; } else { missing_max = (int) (((double) missing_max / 100) * splitPercent + 0.5); } int sum_max = missing_max; for (int i = 0; i < numValues; i++) { partition_max[i] = (int) (((double) partition_max[i] / 100) * splitPercent + 0.5); sum_max = sum_max + partition_max[i]; } // initialize sum_count to zero, use this variable to see if // everything is done already int sum_count = 0; // add noise // using the randomized index-array // Random randomValue = new Random(seed); int numOfValues = instances.attribute(attIndex).numValues(); for (int i = 0; i < instances.numInstances(); i++) { if (sum_count >= sum_max) { break; } // finished Instance currInstance = instances.instance(indexList[i]); // if value is missing then... if (currInstance.isMissing(attIndex)) { if (missing_count < missing_max) { changeValueRandomly(randomValue, numOfValues, attIndex, currInstance, useMissing); missing_count++; sum_count++; } } else { int vIndex = (int) currInstance.value(attIndex); if (partition_count[vIndex] < partition_max[vIndex]) { changeValueRandomly(randomValue, numOfValues, attIndex, currInstance, useMissing); partition_count[vIndex]++; sum_count++; } } } }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ @Override public void buildClassifier(Instances instances) throws Exception { int attIndex = 0; double sum; // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); m_Instances = new Instances(instances, 0); // Reserve space m_Counts = new double[instances.numClasses()][instances.numAttributes() - 1][0]; m_Means = new double[instances.numClasses()][instances.numAttributes() - 1]; m_Devs = new double[instances.numClasses()][instances.numAttributes() - 1]; m_Priors = new double[instances.numClasses()]; Enumeration<Attribute> enu = instances.enumerateAttributes(); while (enu.hasMoreElements()) { Attribute attribute = enu.nextElement(); if (attribute.isNominal()) { for (int j = 0; j < instances.numClasses(); j++) { m_Counts[j][attIndex] = new double[attribute.numValues()]; } } else { for (int j = 0; j < instances.numClasses(); j++) { m_Counts[j][attIndex] = new double[1]; } } attIndex++; } // Compute counts and sums Enumeration<Instance> enumInsts = instances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = enumInsts.nextElement(); if (!instance.classIsMissing()) { Enumeration<Attribute> enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (!instance.isMissing(attribute)) { if (attribute.isNominal()) { m_Counts[(int) instance.classValue()][attIndex][(int) instance.value(attribute)]++; } else { m_Means[(int) instance.classValue()][attIndex] += instance.value(attribute); m_Counts[(int) instance.classValue()][attIndex][0]++; } } attIndex++; } m_Priors[(int) instance.classValue()]++; } } // Compute means Enumeration<Attribute> enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (attribute.isNumeric()) { for (int j = 0; j < instances.numClasses(); j++) { if (m_Counts[j][attIndex][0] < 2) { throw new Exception( "attribute " + attribute.name() + ": less than two values for class " + instances.classAttribute().value(j)); } m_Means[j][attIndex] /= m_Counts[j][attIndex][0]; } } attIndex++; } // Compute standard deviations enumInsts = instances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = enumInsts.nextElement(); if (!instance.classIsMissing()) { enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (!instance.isMissing(attribute)) { if (attribute.isNumeric()) { m_Devs[(int) instance.classValue()][attIndex] += (m_Means[(int) instance.classValue()][attIndex] - instance.value(attribute)) * (m_Means[(int) instance.classValue()][attIndex] - instance.value(attribute)); } } attIndex++; } } } enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (attribute.isNumeric()) { for (int j = 0; j < instances.numClasses(); j++) { if (m_Devs[j][attIndex] <= 0) { throw new Exception( "attribute " + attribute.name() + ": standard deviation is 0 for class " + instances.classAttribute().value(j)); } else { m_Devs[j][attIndex] /= m_Counts[j][attIndex][0] - 1; m_Devs[j][attIndex] = Math.sqrt(m_Devs[j][attIndex]); } } } attIndex++; } // Normalize counts enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (attribute.isNominal()) { for (int j = 0; j < instances.numClasses(); j++) { sum = Utils.sum(m_Counts[j][attIndex]); for (int i = 0; i < attribute.numValues(); i++) { m_Counts[j][attIndex][i] = (m_Counts[j][attIndex][i] + 1) / (sum + attribute.numValues()); } } } attIndex++; } // Normalize priors sum = Utils.sum(m_Priors); for (int j = 0; j < instances.numClasses(); j++) { m_Priors[j] = (m_Priors[j] + 1) / (sum + instances.numClasses()); } }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); m_headerInfo = new Instances(instances, 0); m_numClasses = instances.numClasses(); m_numAttributes = instances.numAttributes(); m_probOfWordGivenClass = new double[m_numClasses][]; /* initialising the matrix of word counts NOTE: Laplace estimator introduced in case a word that does not appear for a class in the training set does so for the test set */ for (int c = 0; c < m_numClasses; c++) { m_probOfWordGivenClass[c] = new double[m_numAttributes]; for (int att = 0; att < m_numAttributes; att++) { m_probOfWordGivenClass[c][att] = 1; } } // enumerate through the instances Instance instance; int classIndex; double numOccurences; double[] docsPerClass = new double[m_numClasses]; double[] wordsPerClass = new double[m_numClasses]; java.util.Enumeration enumInsts = instances.enumerateInstances(); while (enumInsts.hasMoreElements()) { instance = (Instance) enumInsts.nextElement(); classIndex = (int) instance.value(instance.classIndex()); docsPerClass[classIndex] += instance.weight(); for (int a = 0; a < instance.numValues(); a++) if (instance.index(a) != instance.classIndex()) { if (!instance.isMissing(a)) { numOccurences = instance.valueSparse(a) * instance.weight(); if (numOccurences < 0) throw new Exception("Numeric attribute values must all be greater or equal to zero."); wordsPerClass[classIndex] += numOccurences; m_probOfWordGivenClass[classIndex][instance.index(a)] += numOccurences; } } } /* normalising probOfWordGivenClass values and saving each value as the log of each value */ for (int c = 0; c < m_numClasses; c++) for (int v = 0; v < m_numAttributes; v++) m_probOfWordGivenClass[c][v] = Math.log(m_probOfWordGivenClass[c][v] / (wordsPerClass[c] + m_numAttributes - 1)); /* calculating Pr(H) NOTE: Laplace estimator introduced in case a class does not get mentioned in the set of training instances */ final double numDocs = instances.sumOfWeights() + m_numClasses; m_probOfClass = new double[m_numClasses]; for (int h = 0; h < m_numClasses; h++) m_probOfClass[h] = (double) (docsPerClass[h] + 1) / numDocs; }