/** * Returns a textual description of this classifier. * * @return a textual description of this classifier. */ @Override public String toString() { if (m_probOfClass == null) { return "NaiveBayesMultinomialText: No model built yet.\n"; } StringBuffer result = new StringBuffer(); // build a master dictionary over all classes HashSet<String> master = new HashSet<String>(); for (int i = 0; i < m_data.numClasses(); i++) { LinkedHashMap<String, Count> classDict = m_probOfWordGivenClass.get(i); for (String key : classDict.keySet()) { master.add(key); } } result.append("Dictionary size: " + master.size()).append("\n\n"); result.append("The independent frequency of a class\n"); result.append("--------------------------------------\n"); for (int i = 0; i < m_data.numClasses(); i++) { result .append(m_data.classAttribute().value(i)) .append("\t") .append(Double.toString(m_probOfClass[i])) .append("\n"); } result.append("\nThe frequency of a word given the class\n"); result.append("-----------------------------------------\n"); for (int i = 0; i < m_data.numClasses(); i++) { result.append(Utils.padLeft(m_data.classAttribute().value(i), 11)).append("\t"); } result.append("\n"); Iterator<String> masterIter = master.iterator(); while (masterIter.hasNext()) { String word = masterIter.next(); for (int i = 0; i < m_data.numClasses(); i++) { LinkedHashMap<String, Count> classDict = m_probOfWordGivenClass.get(i); Count c = classDict.get(word); if (c == null) { result.append("<laplace=1>\t"); } else { result.append(Utils.padLeft(Double.toString(c.m_count), 11)).append("\t"); } } result.append(word); result.append("\n"); } return result.toString(); }
/** * Generates the classifier. * * @param data set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ @Override public void buildClassifier(Instances data) throws Exception { reset(); // can classifier handle the data? getCapabilities().testWithFail(data); m_data = new Instances(data, 0); data = new Instances(data); m_wordsPerClass = new double[data.numClasses()]; m_probOfClass = new double[data.numClasses()]; m_probOfWordGivenClass = new HashMap<Integer, LinkedHashMap<String, Count>>(); double laplace = 1.0; for (int i = 0; i < data.numClasses(); i++) { LinkedHashMap<String, Count> dict = new LinkedHashMap<String, Count>(10000 / data.numClasses()); m_probOfWordGivenClass.put(i, dict); m_probOfClass[i] = laplace; // this needs to be updated for laplace correction every time we see a new // word (attribute) m_wordsPerClass[i] = 0; } for (int i = 0; i < data.numInstances(); i++) { updateClassifier(data.instance(i)); } }
public Vector<int[]> getMultiKmodesResults( Instances data, Instances dataforcluster, int ensemblesize) throws Exception { Vector<int[]> cls = new Vector<int[]>(); int k = data.numClasses(); for (int i = 0; i < ensemblesize; ++i) { int tmpk = Rnd.nextInt(ensemblesize * k); tmpk = tmpk <= k ? (tmpk + k) : tmpk; SimpleKMeans km = new SimpleKMeans(); km.setMaxIterations(100); km.setNumClusters(tmpk); km.setDontReplaceMissingValues(true); km.setSeed(Rnd.nextInt()); km.setPreserveInstancesOrder(true); km.buildClusterer(dataforcluster); SquaredError[i] = km.getSquaredError(); /*EM em = new EM(); em.setMaxIterations(100); em.setNumClusters(k); //em.setSeed(Rnd.nextInt()); em.buildClusterer(dataforcluster); int[] res2 = new int[dataforcluster.numInstances()]; for(int r=0;r<dataforcluster.numInstances();++r){ res2[r]=em.clusterInstance(dataforcluster.instance(r)); }*/ // System.out.println(Arrays.toString(km.getAssignments())); cls.add(km.getAssignments()); // cls.add(res2); } return cls; }
/** * Builds the model of the base learner. * * @param data the training data * @throws Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); if (m_Classifier == null) { throw new Exception("No base classifier has been set!"); } if (m_MatrixSource == MATRIX_ON_DEMAND) { String costName = data.relationName() + CostMatrix.FILE_EXTENSION; File costFile = new File(getOnDemandDirectory(), costName); if (!costFile.exists()) { throw new Exception("On-demand cost file doesn't exist: " + costFile); } setCostMatrix(new CostMatrix(new BufferedReader(new FileReader(costFile)))); } else if (m_CostMatrix == null) { // try loading an old format cost file m_CostMatrix = new CostMatrix(data.numClasses()); m_CostMatrix.readOldFormat(new BufferedReader(new FileReader(m_CostFile))); } if (!m_MinimizeExpectedCost) { Random random = null; if (!(m_Classifier instanceof WeightedInstancesHandler)) { random = new Random(m_Seed); } data = m_CostMatrix.applyCostMatrix(data, random); } m_Classifier.buildClassifier(data); }
@Override public void buildClassifier(Instances data) throws Exception { trainingData = data; Attribute classAttribute = data.classAttribute(); prototypes = new ArrayList<>(); classedData = new HashMap<String, ArrayList<Sequence>>(); indexClassedDataInFullData = new HashMap<String, ArrayList<Integer>>(); for (int c = 0; c < data.numClasses(); c++) { classedData.put(data.classAttribute().value(c), new ArrayList<Sequence>()); indexClassedDataInFullData.put(data.classAttribute().value(c), new ArrayList<Integer>()); } sequences = new Sequence[data.numInstances()]; classMap = new String[sequences.length]; for (int i = 0; i < sequences.length; i++) { Instance sample = data.instance(i); MonoDoubleItemSet[] sequence = new MonoDoubleItemSet[sample.numAttributes() - 1]; int shift = (sample.classIndex() == 0) ? 1 : 0; for (int t = 0; t < sequence.length; t++) { sequence[t] = new MonoDoubleItemSet(sample.value(t + shift)); } sequences[i] = new Sequence(sequence); String clas = sample.stringValue(classAttribute); classMap[i] = clas; classedData.get(clas).add(sequences[i]); indexClassedDataInFullData.get(clas).add(i); // System.out.println("Element "+i+" of train is classed "+clas+" and went to element // "+(indexClassedDataInFullData.get(clas).size()-1)); } buildSpecificClassifier(data); }
private static void writePredictedDistributions( Classifier c, Instances data, int idIndex, Writer out) throws Exception { // header out.write("id"); for (int i = 0; i < data.numClasses(); i++) { out.write(",\""); out.write(data.classAttribute().value(i).replaceAll("[\"\\\\]", "_")); out.write("\""); } out.write("\n"); // data for (int i = 0; i < data.numInstances(); i++) { final String id = data.instance(i).stringValue(idIndex); double[] distribution = c.distributionForInstance(data.instance(i)); // final String label = data.attribute(classIndex).value(); out.write(id); for (double probability : distribution) { out.write(","); out.write(String.valueOf(probability > 1e-5 ? (float) probability : 0f)); } out.write("\n"); } }
/** * Stratify the given data into the given number of bags based on the class values. It differs * from the <code>Instances.stratify(int fold)</code> that before stratification it sorts the * instances according to the class order in the header file. It assumes no missing values in the * class. * * @param data the given data * @param folds the given number of folds * @param rand the random object used to randomize the instances * @return the stratified instances */ public static final Instances stratify(Instances data, int folds, Random rand) { if (!data.classAttribute().isNominal()) return data; Instances result = new Instances(data, 0); Instances[] bagsByClasses = new Instances[data.numClasses()]; for (int i = 0; i < bagsByClasses.length; i++) bagsByClasses[i] = new Instances(data, 0); // Sort by class for (int j = 0; j < data.numInstances(); j++) { Instance datum = data.instance(j); bagsByClasses[(int) datum.classValue()].add(datum); } // Randomize each class for (int j = 0; j < bagsByClasses.length; j++) bagsByClasses[j].randomize(rand); for (int k = 0; k < folds; k++) { int offset = k, bag = 0; oneFold: while (true) { while (offset >= bagsByClasses[bag].numInstances()) { offset -= bagsByClasses[bag].numInstances(); if (++bag >= bagsByClasses.length) // Next bag break oneFold; } result.add(bagsByClasses[bag].instance(offset)); offset += folds; } } return result; }
/** * @param args * @throws Exception */ public static void main(String[] args) throws Exception { // TODO Auto-generated method stub oneAlgorithm oneAlg = new oneAlgorithm(); oneAlg.category = xCategory.RSandFCBFalg; oneAlg.style = xStyle.fuzzySU; oneAlg.flag = false; oneAlg.alpha = 2.0; // String fn = "C:/Users/Eric/Desktop/2011秋冬/Code/Xreducer/data/Data/wine.arff"; // String fn = "C:/Users/Eric/Desktop/2011秋冬/Code/Xreducer/data/Data/wdbc.arff"; String fn = "C:/Users/Eric/Desktop/2011秋冬/Code/Xreducer/data/Data/glass.arff"; // String fn = "C:/Users/Eric/Desktop/2011秋冬/Code/Xreducer/data/shen/wine-shen.arff"; // String fn = "C:/Users/Eric/Desktop/2011秋冬/Code/Xreducer/data/fuzzy/fuzzy-ex.arff"; // String fn = "C:/Users/Eric/Desktop/2011秋冬/Code/Xreducer/data/derm.arff"; oneFile onef = new oneFile(new File(fn)); Instances dataset = new Instances(new FileReader(fn)); dataset.setClassIndex(dataset.numAttributes() - 1); onef.ins = dataset.numInstances(); onef.att = dataset.numAttributes(); onef.cla = dataset.numClasses(); RSandFCBFReduceMethod rs = new RSandFCBFReduceMethod(onef, oneAlg); boolean[] B = new boolean[rs.NumAttr]; boolean[] rq = rs.getOneReduction(B); System.out.println(Arrays.toString(Utils.boolean2select(rq))); }
public double ExpectedClassificationError(Instances pool, int attr_i) { // initialize alpha's to one int alpha[][][]; int NumberOfFeatures = pool.numAttributes() - 1; int NumberOfLabels = pool.numClasses(); alpha = new int[NumberOfFeatures][NumberOfLabels][]; for (int i = 0; i < NumberOfFeatures; i++) for (int j = 0; j < NumberOfLabels; j++) alpha[i][j] = new int[pool.attribute(i).numValues()]; for (int i = 0; i < NumberOfFeatures; i++) for (int j = 0; j < NumberOfLabels; j++) for (int k = 0; k < alpha[i][j].length; k++) alpha[i][j][k] = 1; // construct alpha's for (int i = 0; i < NumberOfFeatures; i++) // for each attribute { if (i == pool.classIndex()) // skip the class attribute i++; for (Enumeration<Instance> e = pool.enumerateInstances(); e.hasMoreElements(); ) // for each instance { Instance inst = e.nextElement(); if (!inst.isMissing(i)) // if attribute i is not missing (i.e. its been bought) { int j = (int) inst.classValue(); int k = (int) inst.value(i); alpha[i][j][k]++; } } } return ExpectedClassificationError(alpha, attr_i); }
/** * Computes the entropy of a dataset. * * @param data the data for which entropy is to be computed * @return the entropy of the data's class distribution * @throws Exception if computation fails */ private double computeEntropy(Instances data) throws Exception { double[] classCounts = new double[data.numClasses()]; Enumeration instEnum = data.enumerateInstances(); while (instEnum.hasMoreElements()) { Instance inst = (Instance) instEnum.nextElement(); classCounts[(int) inst.classValue()]++; } double entropy = 0; for (int j = 0; j < data.numClasses(); j++) { if (classCounts[j] > 0) { entropy -= classCounts[j] * Utils.log2(classCounts[j]); } } entropy /= (double) data.numInstances(); return entropy + Utils.log2(data.numInstances()); }
/** Initializes the m_Attributes of the class. */ private void init_m_Attributes() { try { m_NumInstances = m_Train.numInstances(); m_NumClasses = m_Train.numClasses(); m_NumAttributes = m_Train.numAttributes(); m_ClassType = m_Train.classAttribute().type(); m_InitFlag = ON; } catch (Exception e) { e.printStackTrace(); } }
/** * Method for building an Id3 tree. * * @param data the training data * @exception Exception if decision tree can't be built successfully */ private void makeTree(Instances data) throws Exception { // Check if no instances have reached this node. if (data.numInstances() == 0) { m_Attribute = null; m_ClassValue = Utils.missingValue(); m_Distribution = new double[data.numClasses()]; return; } // Compute attribute with maximum information gain. double[] infoGains = new double[data.numAttributes()]; Enumeration attEnum = data.enumerateAttributes(); while (attEnum.hasMoreElements()) { Attribute att = (Attribute) attEnum.nextElement(); infoGains[att.index()] = computeInfoGain(data, att); } m_Attribute = data.attribute(Utils.maxIndex(infoGains)); // Make leaf if information gain is zero. // Otherwise create successors. if (Utils.eq(infoGains[m_Attribute.index()], 0)) { m_Attribute = null; m_Distribution = new double[data.numClasses()]; Enumeration instEnum = data.enumerateInstances(); while (instEnum.hasMoreElements()) { Instance inst = (Instance) instEnum.nextElement(); m_Distribution[(int) inst.classValue()]++; } Utils.normalize(m_Distribution); m_ClassValue = Utils.maxIndex(m_Distribution); m_ClassAttribute = data.classAttribute(); } else { Instances[] splitData = splitData(data, m_Attribute); m_Successors = new Id3[m_Attribute.numValues()]; for (int j = 0; j < m_Attribute.numValues(); j++) { m_Successors[j] = new Id3(); m_Successors[j].makeTree(splitData[j]); } } }
/** * Returns a description of the classifier. * * @return a description of the classifier as a string. */ @Override public String toString() { if (m_Instances == null) { return "Naive Bayes (simple): No model built yet."; } try { StringBuffer text = new StringBuffer("Naive Bayes (simple)"); int attIndex; for (int i = 0; i < m_Instances.numClasses(); i++) { text.append( "\n\nClass " + m_Instances.classAttribute().value(i) + ": P(C) = " + Utils.doubleToString(m_Priors[i], 10, 8) + "\n\n"); Enumeration<Attribute> enumAtts = m_Instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); text.append("Attribute " + attribute.name() + "\n"); if (attribute.isNominal()) { for (int j = 0; j < attribute.numValues(); j++) { text.append(attribute.value(j) + "\t"); } text.append("\n"); for (int j = 0; j < attribute.numValues(); j++) { text.append(Utils.doubleToString(m_Counts[i][attIndex][j], 10, 8) + "\t"); } } else { text.append("Mean: " + Utils.doubleToString(m_Means[i][attIndex], 10, 8) + "\t"); text.append("Standard Deviation: " + Utils.doubleToString(m_Devs[i][attIndex], 10, 8)); } text.append("\n\n"); attIndex++; } } return text.toString(); } catch (Exception e) { return "Can't print Naive Bayes classifier!"; } }
private static void analyse(Instances train, Instances datapredict) { String mpOptions = "-L 0.3 -M 0.2 -N 500 -V 0 -S 0 -E 20 -H a"; try { train.setClassIndex(train.numAttributes() - 1); train.deleteAttributeAt(0); int numClasses = train.numClasses(); for (int i = 0; i < numClasses; i++) { System.out.println("class value [" + i + "]=" + train.classAttribute().value(i) + ""); } // Instance of NN MultilayerPerceptron mlp = new MultilayerPerceptron(); mlp.setOptions(weka.core.Utils.splitOptions(mpOptions)); mlp.buildClassifier(train); datapredict.setClassIndex(datapredict.numAttributes() - 1); datapredict.deleteAttributeAt(0); // Instances predicteddata = new Instances(datapredict); for (int i = 0; i < datapredict.numInstances(); i++) { Instance newInst = datapredict.instance(i); double pred = mlp.classifyInstance(newInst); int predInt = (int) pred; // Math.round(pred); String predString = train.classAttribute().value(predInt); System.out.println( "cliente[" + i + "] pred[" + pred + "] predInt[" + predInt + "] desertor[" + predString + "]"); } } catch (Exception e) { e.printStackTrace(); } }
private Vector<int[]> getMultiKmodesResultswithRandomSelectFeature( Instances data, Instances dataforcluster, int ensemblesize) throws Exception { // TODO Auto-generated method stub Vector<int[]> cls = new Vector<int[]>(); for (int i = 0; i < ensemblesize; ++i) { SimpleKMeans km = new SimpleKMeans(); km.setMaxIterations(100); km.setNumClusters(data.numClasses()); km.setDontReplaceMissingValues(true); km.setSeed(Rnd.nextInt()); Instances newData = getNewRandomData(dataforcluster, Rnd.nextInt()); km.setPreserveInstancesOrder(true); km.buildClusterer(newData); SquaredError[i] = km.getSquaredError(); cls.add(km.getAssignments()); } return cls; }
public int SelectRow_L2Norm( Instances pool, Classifier myEstimator, int desiredAttr, int desiredLabel) { // for each instance with unbought desiredAttr and label = desiredLabel // measure distance from uniform // choose (row) that is closest to uniform as your instance to buy from double leastDistance = Double.MAX_VALUE; int leastIndex = -1; Instance inst; int n = pool.numClasses(); double[] uniform; double[] probs; uniform = new double[n]; for (int i = 0; i < n; i++) uniform[i] = 1.0 / (double) n; for (int i = 0; i < pool.numInstances(); i++) { inst = pool.instance(i); // System.out.println("currentlabel="+(int)inst.classValue()+" // isMissing="+inst.isMissing(desiredAttr)); if ((int) inst.classValue() == desiredLabel && inst.isMissing(desiredAttr)) { // valid instance // measure the distance from uniform: // sqrt{ sum_i (a_i - b_i)^2 } probs = new double[n]; try { probs = myEstimator.distributionForInstance(inst); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } double distance = 0.0; for (int j = 0; j < n; j++) distance += (probs[j] - uniform[j]) * (probs[j] - uniform[j]); distance = Math.sqrt(distance); // System.out.println("current distance="+distance); if (distance < leastDistance) { leastDistance = distance; leastIndex = i; } } } return leastIndex; }
/** * Creates split on enumerated attribute. * * @exception Exception if something goes wrong */ private void handleEnumeratedAttribute(Instances trainInstances) throws Exception { Distribution newDistribution, secondDistribution; int numAttValues; double currIG, currGR; Instance instance; int i; numAttValues = trainInstances.attribute(m_attIndex).numValues(); newDistribution = new Distribution(numAttValues, trainInstances.numClasses()); // Only Instances with known values are relevant. Enumeration enu = trainInstances.enumerateInstances(); while (enu.hasMoreElements()) { instance = (Instance) enu.nextElement(); if (!instance.isMissing(m_attIndex)) newDistribution.add((int) instance.value(m_attIndex), instance); } m_distribution = newDistribution; // For all values for (i = 0; i < numAttValues; i++) { if (Utils.grOrEq(newDistribution.perBag(i), m_minNoObj)) { secondDistribution = new Distribution(newDistribution, i); // Check if minimum number of Instances in the two // subsets. if (secondDistribution.check(m_minNoObj)) { m_numSubsets = 2; currIG = m_infoGainCrit.splitCritValue(secondDistribution, m_sumOfWeights); currGR = m_gainRatioCrit.splitCritValue(secondDistribution, m_sumOfWeights, currIG); if ((i == 0) || Utils.gr(currGR, m_gainRatio)) { m_gainRatio = currGR; m_infoGain = currIG; m_splitPoint = (double) i; m_distribution = secondDistribution; } } } } }
/** * Creates split on enumerated attribute. * * @exception Exception if something goes wrong */ private void handleEnumeratedAttribute(Instances trainInstances) throws Exception { Instance instance; m_distribution = new Distribution(m_complexityIndex, trainInstances.numClasses()); // Only Instances with known values are relevant. Enumeration<Instance> enu = trainInstances.enumerateInstances(); while (enu.hasMoreElements()) { instance = enu.nextElement(); if (!instance.isMissing(m_attIndex)) { m_distribution.add((int) instance.value(m_attIndex), instance); } } // Check if minimum number of Instances in at least two // subsets. if (m_distribution.check(m_minNoObj)) { m_numSubsets = m_complexityIndex; m_infoGain = infoGainCrit.splitCritValue(m_distribution, m_sumOfWeights); m_gainRatio = gainRatioCrit.splitCritValue(m_distribution, m_sumOfWeights, m_infoGain); } }
public Map<FV, Collection<FV>> extractValuesFromData(Instances inst) { Multimap<FV, FV> fv_list = ArrayListMultimap.create(); // Instances outFormat = getOutputFormat(); for (int i = 0; i < inst.numInstances(); i++) { Instance ins = inst.instance(i); // Skip the class label for (int x = 0; x < ins.numAttributes() - 1; x++) { Object value = null; try { value = ins.stringValue(x); } catch (Exception e) { value = ins.value(x); } FV fv = new FV(x, value, ins.classValue()); fv.setNumLabels(inst.numClasses()); if (!fv_list.put(fv, fv)) { System.err.println("Couldn't put duplicates: " + fv); } } } Map<FV, Collection<FV>> original_map = fv_list.asMap(); return original_map; }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); m_NumClasses = instances.numClasses(); m_ClassType = instances.classAttribute().type(); m_Train = new Instances(instances, 0, instances.numInstances()); // Throw away initial instances until within the specified window size if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) { m_Train = new Instances(m_Train, m_Train.numInstances() - m_WindowSize, m_WindowSize); } m_NumAttributesUsed = 0.0; for (int i = 0; i < m_Train.numAttributes(); i++) { if ((i != m_Train.classIndex()) && (m_Train.attribute(i).isNominal() || m_Train.attribute(i).isNumeric())) { m_NumAttributesUsed += 1.0; } } m_NNSearch.setInstances(m_Train); // Invalidate any currently cross-validation selected k m_kNNValid = false; m_defaultModel = new ZeroR(); m_defaultModel.buildClassifier(instances); m_defaultModel.setOptions(getOptions()); // System.out.println("hello world"); }
/** * Selects cutpoints for sorted subset. * * @param instances * @param attIndex * @param first * @param lastPlusOne * @return */ private double[] cutPointsForSubset( Instances instances, int attIndex, int first, int lastPlusOne) { double[][] counts, bestCounts; double[] priorCounts, left, right, cutPoints; double currentCutPoint = -Double.MAX_VALUE, bestCutPoint = -1, currentEntropy, bestEntropy, priorEntropy, gain; int bestIndex = -1, numCutPoints = 0; double numInstances = 0; // Compute number of instances in set if ((lastPlusOne - first) < 2) { return null; } // Compute class counts. counts = new double[2][instances.numClasses()]; for (int i = first; i < lastPlusOne; i++) { numInstances += instances.instance(i).weight(); counts[1][(int) instances.instance(i).classValue()] += instances.instance(i).weight(); } // Save prior counts priorCounts = new double[instances.numClasses()]; System.arraycopy(counts[1], 0, priorCounts, 0, instances.numClasses()); // Entropy of the full set priorEntropy = ContingencyTables.entropy(priorCounts); bestEntropy = priorEntropy; // Find best entropy. bestCounts = new double[2][instances.numClasses()]; for (int i = first; i < (lastPlusOne - 1); i++) { counts[0][(int) instances.instance(i).classValue()] += instances.instance(i).weight(); counts[1][(int) instances.instance(i).classValue()] -= instances.instance(i).weight(); if (instances.instance(i).value(attIndex) < instances.instance(i + 1).value(attIndex)) { currentCutPoint = (instances.instance(i).value(attIndex) + instances.instance(i + 1).value(attIndex)) / 2.0; currentEntropy = ContingencyTables.entropyConditionedOnRows(counts); if (currentEntropy < bestEntropy) { bestCutPoint = currentCutPoint; bestEntropy = currentEntropy; bestIndex = i; System.arraycopy(counts[0], 0, bestCounts[0], 0, instances.numClasses()); System.arraycopy(counts[1], 0, bestCounts[1], 0, instances.numClasses()); } numCutPoints++; } } // Use worse encoding? if (!m_UseBetterEncoding) { numCutPoints = (lastPlusOne - first) - 1; } // Checks if gain is zero gain = priorEntropy - bestEntropy; if (gain <= 0) { return null; } // Check if split is to be accepted if ((m_UseKononenko && KononenkosMDL(priorCounts, bestCounts, numInstances, numCutPoints)) || (!m_UseKononenko && FayyadAndIranisMDL(priorCounts, bestCounts, numInstances, numCutPoints))) { // Select split points for the left and right subsets left = cutPointsForSubset(instances, attIndex, first, bestIndex + 1); right = cutPointsForSubset(instances, attIndex, bestIndex + 1, lastPlusOne); // Merge cutpoints and return them if ((left == null) && (right) == null) { cutPoints = new double[1]; cutPoints[0] = bestCutPoint; } else if (right == null) { cutPoints = new double[left.length + 1]; System.arraycopy(left, 0, cutPoints, 0, left.length); cutPoints[left.length] = bestCutPoint; } else if (left == null) { cutPoints = new double[1 + right.length]; cutPoints[0] = bestCutPoint; System.arraycopy(right, 0, cutPoints, 1, right.length); } else { cutPoints = new double[left.length + right.length + 1]; System.arraycopy(left, 0, cutPoints, 0, left.length); cutPoints[left.length] = bestCutPoint; System.arraycopy(right, 0, cutPoints, left.length + 1, right.length); } return cutPoints; } else { return null; } }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); m_headerInfo = new Instances(instances, 0); m_numClasses = instances.numClasses(); m_numAttributes = instances.numAttributes(); m_probOfWordGivenClass = new double[m_numClasses][]; /* initialising the matrix of word counts NOTE: Laplace estimator introduced in case a word that does not appear for a class in the training set does so for the test set */ for (int c = 0; c < m_numClasses; c++) { m_probOfWordGivenClass[c] = new double[m_numAttributes]; for (int att = 0; att < m_numAttributes; att++) { m_probOfWordGivenClass[c][att] = 1; } } // enumerate through the instances Instance instance; int classIndex; double numOccurences; double[] docsPerClass = new double[m_numClasses]; double[] wordsPerClass = new double[m_numClasses]; java.util.Enumeration enumInsts = instances.enumerateInstances(); while (enumInsts.hasMoreElements()) { instance = (Instance) enumInsts.nextElement(); classIndex = (int) instance.value(instance.classIndex()); docsPerClass[classIndex] += instance.weight(); for (int a = 0; a < instance.numValues(); a++) if (instance.index(a) != instance.classIndex()) { if (!instance.isMissing(a)) { numOccurences = instance.valueSparse(a) * instance.weight(); if (numOccurences < 0) throw new Exception("Numeric attribute values must all be greater or equal to zero."); wordsPerClass[classIndex] += numOccurences; m_probOfWordGivenClass[classIndex][instance.index(a)] += numOccurences; } } } /* normalising probOfWordGivenClass values and saving each value as the log of each value */ for (int c = 0; c < m_numClasses; c++) for (int v = 0; v < m_numAttributes; v++) m_probOfWordGivenClass[c][v] = Math.log(m_probOfWordGivenClass[c][v] / (wordsPerClass[c] + m_numAttributes - 1)); /* calculating Pr(H) NOTE: Laplace estimator introduced in case a class does not get mentioned in the set of training instances */ final double numDocs = instances.sumOfWeights() + m_numClasses; m_probOfClass = new double[m_numClasses]; for (int h = 0; h < m_numClasses; h++) m_probOfClass[h] = (double) (docsPerClass[h] + 1) / numDocs; }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { if (!m_weightByConfidence) { TINY = 0.0; } // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); m_ClassIndex = instances.classIndex(); m_NumClasses = instances.numClasses(); m_globalCounts = new double[m_NumClasses]; m_maxEntrop = Math.log(m_NumClasses) / Math.log(2); m_Instances = new Instances(instances, 0); // Copy the structure for ref m_intervalBounds = new double[instances.numAttributes()][2 + (2 * m_NumClasses)]; for (int j = 0; j < instances.numAttributes(); j++) { boolean alt = false; for (int i = 0; i < m_NumClasses * 2 + 2; i++) { if (i == 0) { m_intervalBounds[j][i] = Double.NEGATIVE_INFINITY; } else if (i == m_NumClasses * 2 + 1) { m_intervalBounds[j][i] = Double.POSITIVE_INFINITY; } else { if (alt) { m_intervalBounds[j][i] = Double.NEGATIVE_INFINITY; alt = false; } else { m_intervalBounds[j][i] = Double.POSITIVE_INFINITY; alt = true; } } } } // find upper and lower bounds for numeric attributes for (int j = 0; j < instances.numAttributes(); j++) { if (j != m_ClassIndex && instances.attribute(j).isNumeric()) { for (int i = 0; i < instances.numInstances(); i++) { Instance inst = instances.instance(i); if (!inst.isMissing(j)) { if (inst.value(j) < m_intervalBounds[j][((int) inst.classValue() * 2 + 1)]) { m_intervalBounds[j][((int) inst.classValue() * 2 + 1)] = inst.value(j); } if (inst.value(j) > m_intervalBounds[j][((int) inst.classValue() * 2 + 2)]) { m_intervalBounds[j][((int) inst.classValue() * 2 + 2)] = inst.value(j); } } } } } m_counts = new double[instances.numAttributes()][][]; // sort intervals for (int i = 0; i < instances.numAttributes(); i++) { if (instances.attribute(i).isNumeric()) { int[] sortedIntervals = Utils.sort(m_intervalBounds[i]); // remove any duplicate bounds int count = 1; for (int j = 1; j < sortedIntervals.length; j++) { if (m_intervalBounds[i][sortedIntervals[j]] != m_intervalBounds[i][sortedIntervals[j - 1]]) { count++; } } double[] reordered = new double[count]; count = 1; reordered[0] = m_intervalBounds[i][sortedIntervals[0]]; for (int j = 1; j < sortedIntervals.length; j++) { if (m_intervalBounds[i][sortedIntervals[j]] != m_intervalBounds[i][sortedIntervals[j - 1]]) { reordered[count] = m_intervalBounds[i][sortedIntervals[j]]; count++; } } m_intervalBounds[i] = reordered; m_counts[i] = new double[count][m_NumClasses]; } else if (i != m_ClassIndex) { // nominal attribute m_counts[i] = new double[instances.attribute(i).numValues()][m_NumClasses]; } } // collect class counts for (int i = 0; i < instances.numInstances(); i++) { Instance inst = instances.instance(i); m_globalCounts[(int) instances.instance(i).classValue()] += inst.weight(); for (int j = 0; j < instances.numAttributes(); j++) { if (!inst.isMissing(j) && j != m_ClassIndex) { if (instances.attribute(j).isNumeric()) { double val = inst.value(j); int k; for (k = m_intervalBounds[j].length - 1; k >= 0; k--) { if (val > m_intervalBounds[j][k]) { m_counts[j][k][(int) inst.classValue()] += inst.weight(); break; } else if (val == m_intervalBounds[j][k]) { m_counts[j][k][(int) inst.classValue()] += (inst.weight() / 2.0); m_counts[j][k - 1][(int) inst.classValue()] += (inst.weight() / 2.0); ; break; } } } else { // nominal attribute m_counts[j][(int) inst.value(j)][(int) inst.classValue()] += inst.weight(); ; } } } } }
public static void analyze_accuracy_NHBS(int rng_seed) throws Exception { HashMap<String, Object> population_params = load_defaults(null); RawLoader rl = new RawLoader(population_params, true, false, rng_seed); List<DrugUser> learningData = rl.getLearningData(); Instances nhbs_data = new Instances("learning_instances", DrugUser.getAttInfo(), learningData.size()); for (DrugUser du : learningData) { nhbs_data.add(du.getInstance()); } System.out.println(nhbs_data.toSummaryString()); nhbs_data.setClass(DrugUser.getAttribMap().get("hcv_state")); // wishlist: remove infrequent values // weka.filters.unsupervised.instance.RemoveFrequentValues() Filter f1 = new RemoveUseless(); f1.setInputFormat(nhbs_data); nhbs_data = Filter.useFilter(nhbs_data, f1); System.out.println("NHBS IDU 2009 Dataset"); System.out.println("Summary of input:"); // System.out.printlnnhbs_data.toSummaryString()); System.out.println(" Num of classes: " + nhbs_data.numClasses()); System.out.println(" Num of attributes: " + nhbs_data.numAttributes()); for (int idx = 0; idx < nhbs_data.numAttributes(); ++idx) { Attribute attr = nhbs_data.attribute(idx); System.out.println("" + idx + ": " + attr.toString()); System.out.println(" distinct values:" + nhbs_data.numDistinctValues(idx)); // System.out.println("" + attr.enumerateValues()); } ArrayList<String> options = new ArrayList<String>(); options.add("-Q"); options.add("" + rng_seed); // System.exit(0); // nhbs_data.deleteAttributeAt(0); //response ID // nhbs_data.deleteAttributeAt(16); //zip // Classifier classifier = new NNge(); //best nearest-neighbor classifier: 40.00 // ROC=0.60 // Classifier classifier = new MINND(); // Classifier classifier = new CitationKNN(); // Classifier classifier = new LibSVM(); //requires LibSVM classes. only gets 37.7% // Classifier classifier = new SMOreg(); Classifier classifier = new Logistic(); // ROC=0.686 // Classifier classifier = new LinearNNSearch(); // LinearRegression: Cannot handle multi-valued nominal class! // Classifier classifier = new LinearRegression(); // Classifier classifier = new RandomForest(); // String[] options = {"-I", "100", "-K", "4"}; //-I trees, -K features per tree. generally, // might want to optimize (or not // https://cwiki.apache.org/confluence/display/MAHOUT/Random+Forests) // options.add("-I"); options.add("100"); options.add("-K"); options.add("4"); // ROC=0.673 // KStar classifier = new KStar(); // classifier.setGlobalBlend(20); //the amount of not greedy, in percent // ROC=0.633 // Classifier classifier = new AdaBoostM1(); // ROC=0.66 // Classifier classifier = new MultiBoostAB(); // ROC=0.67 // Classifier classifier = new Stacking(); // ROC=0.495 // J48 classifier = new J48(); // new instance of tree //building a C45 tree classifier // ROC=0.585 // String[] options = new String[1]; // options[0] = "-U"; // unpruned tree // classifier.setOptions(options); // set the options classifier.setOptions((String[]) options.toArray(new String[0])); // not needed before CV: http://weka.wikispaces.com/Use+WEKA+in+your+Java+code // classifier.buildClassifier(nhbs_data); // build classifier // evaluation Evaluation eval = new Evaluation(nhbs_data); eval.crossValidateModel(classifier, nhbs_data, 10, new Random(1)); // 10-fold cross validation System.out.println(eval.toSummaryString("\nResults\n\n", false)); System.out.println(eval.toClassDetailsString()); // System.out.println(eval.toCumulativeMarginDistributionString()); }
public static void test_NHBS_old() throws Exception { // load the data CSVLoader loader = new CSVLoader(); // these must come before the getDataSet() // loader.setEnclosureCharacters(",\'\"S"); // loader.setNominalAttributes("16,71"); //zip code, drug name // loader.setStringAttributes(""); // loader.setDateAttributes("0,1"); // loader.setSource(new File("hcv/data/NHBS/IDU2_HCV_model_012913_cleaned_for_weka.csv")); loader.setSource(new File("/home/sasha/hcv/code/data/IDU2_HCV_model_012913_cleaned.csv")); Instances nhbs_data = loader.getDataSet(); loader.setMissingValue("NOVALUE"); // loader.setMissingValue(""); nhbs_data.deleteAttributeAt(12); // zip code nhbs_data.deleteAttributeAt(1); // date - redundant with age nhbs_data.deleteAttributeAt(0); // date System.out.println("classifying attribute:"); nhbs_data.setClassIndex(1); // new index 3->2->1 nhbs_data.attribute(1).getMetadata().toString(); // HCVEIARSLT1 // wishlist: perhaps it would be smarter to throw out unclassified instance? they interfere // with the scoring nhbs_data.deleteWithMissingClass(); // nhbs_data.setClass(new Attribute("HIVRSLT"));//.setClassIndex(1); //2nd column. all are // mostly negative // nhbs_data.setClass(new Attribute("HCVEIARSLT1"));//.setClassIndex(2); //3rd column // #14, i.e. rds_fem, should be made numeric System.out.println("NHBS IDU 2009 Dataset"); System.out.println("Summary of input:"); // System.out.printlnnhbs_data.toSummaryString()); System.out.println(" Num of classes: " + nhbs_data.numClasses()); System.out.println(" Num of attributes: " + nhbs_data.numAttributes()); for (int idx = 0; idx < nhbs_data.numAttributes(); ++idx) { Attribute attr = nhbs_data.attribute(idx); System.out.println("" + idx + ": " + attr.toString()); System.out.println(" distinct values:" + nhbs_data.numDistinctValues(idx)); // System.out.println("" + attr.enumerateValues()); } // System.exit(0); // nhbs_data.deleteAttributeAt(0); //response ID // nhbs_data.deleteAttributeAt(16); //zip // Classifier classifier = new NNge(); //best nearest-neighbor classifier: 40.00 // Classifier classifier = new MINND(); // Classifier classifier = new CitationKNN(); // Classifier classifier = new LibSVM(); //requires LibSVM classes. only gets 37.7% // Classifier classifier = new SMOreg(); // Classifier classifier = new LinearNNSearch(); // LinearRegression: Cannot handle multi-valued nominal class! // Classifier classifier = new LinearRegression(); Classifier classifier = new RandomForest(); String[] options = { "-I", "100", "-K", "4" }; // -I trees, -K features per tree. generally, might want to optimize (or not // https://cwiki.apache.org/confluence/display/MAHOUT/Random+Forests) classifier.setOptions(options); // Classifier classifier = new Logistic(); // KStar classifier = new KStar(); // classifier.setGlobalBlend(20); //the amount of not greedy, in percent // does poorly // Classifier classifier = new AdaBoostM1(); // Classifier classifier = new MultiBoostAB(); // Classifier classifier = new Stacking(); // building a C45 tree classifier // J48 classifier = new J48(); // new instance of tree // String[] options = new String[1]; // options[0] = "-U"; // unpruned tree // classifier.setOptions(options); // set the options // classifier.buildClassifier(nhbs_data); // build classifier // wishlist: remove infrequent values // weka.filters.unsupervised.instance.RemoveFrequentValues() Filter f1 = new RemoveUseless(); f1.setInputFormat(nhbs_data); nhbs_data = Filter.useFilter(nhbs_data, f1); // evaluation Evaluation eval = new Evaluation(nhbs_data); eval.crossValidateModel(classifier, nhbs_data, 10, new Random(1)); System.out.println(eval.toSummaryString("\nResults\n\n", false)); System.out.println(eval.toClassDetailsString()); // System.out.println(eval.toCumulativeMarginDistributionString()); }
/** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if there is a problem generating the prediction */ @Override public double[] distributionForInstance(Instance instance) throws Exception { tokenizeInstance(instance, false); double[] probOfClassGivenDoc = new double[m_data.numClasses()]; double[] logDocGivenClass = new double[m_data.numClasses()]; for (int i = 0; i < m_data.numClasses(); i++) { logDocGivenClass[i] += Math.log(m_probOfClass[i]); LinkedHashMap<String, Count> dictForClass = m_probOfWordGivenClass.get(i); int allWords = 0; // for document normalization (if in use) double iNorm = 0; double fv = 0; if (m_normalize) { for (Map.Entry<String, Count> feature : m_inputVector.entrySet()) { String word = feature.getKey(); Count c = feature.getValue(); // check the word against all the dictionaries (all classes) boolean ok = false; for (int clss = 0; clss < m_data.numClasses(); clss++) { if (m_probOfWordGivenClass.get(clss).get(word) != null) { ok = true; break; } } // only normalize with respect to those words that we've seen during // training // (i.e. dictionary over all classes) if (ok) { // word counts or bag-of-words? fv = (m_wordFrequencies) ? c.m_count : 1.0; iNorm += Math.pow(Math.abs(fv), m_lnorm); } } iNorm = Math.pow(iNorm, 1.0 / m_lnorm); } // System.out.println("---- " + m_inputVector.size()); for (Map.Entry<String, Count> feature : m_inputVector.entrySet()) { String word = feature.getKey(); Count dictCount = dictForClass.get(word); // System.out.print(word + " "); /* * if (dictCount != null) { System.out.println(dictCount.m_count); } * else { System.out.println("*1"); } */ // check the word against all the dictionaries (all classes) boolean ok = false; for (int clss = 0; clss < m_data.numClasses(); clss++) { if (m_probOfWordGivenClass.get(clss).get(word) != null) { ok = true; break; } } // ignore words we haven't seen in the training data if (ok) { double freq = (m_wordFrequencies) ? feature.getValue().m_count : 1.0; // double freq = (feature.getValue().m_count / iNorm * m_norm); if (m_normalize) { freq /= iNorm * m_norm; } allWords += freq; if (dictCount != null) { logDocGivenClass[i] += freq * Math.log(dictCount.m_count); } else { // leplace for zero frequency logDocGivenClass[i] += freq * Math.log(m_leplace); } } } if (m_wordsPerClass[i] > 0) { logDocGivenClass[i] -= allWords * Math.log(m_wordsPerClass[i]); } } double max = logDocGivenClass[Utils.maxIndex(logDocGivenClass)]; for (int i = 0; i < m_data.numClasses(); i++) { probOfClassGivenDoc[i] = Math.exp(logDocGivenClass[i] - max); } Utils.normalize(probOfClassGivenDoc); return probOfClassGivenDoc; }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // only class? -> build ZeroR model if (instances.numAttributes() == 1) { System.err.println( "Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!"); m_ZeroR = new weka.classifiers.rules.ZeroR(); m_ZeroR.buildClassifier(instances); return; } else { m_ZeroR = null; } // reset variable m_NumClasses = instances.numClasses(); m_ClassIndex = instances.classIndex(); m_NumAttributes = instances.numAttributes(); m_NumInstances = instances.numInstances(); m_TotalAttValues = 0; // allocate space for attribute reference arrays m_StartAttIndex = new int[m_NumAttributes]; m_NumAttValues = new int[m_NumAttributes]; // set the starting index of each attribute and the number of values for // each attribute and the total number of values for all attributes (not including class). for (int i = 0; i < m_NumAttributes; i++) { if (i != m_ClassIndex) { m_StartAttIndex[i] = m_TotalAttValues; m_NumAttValues[i] = instances.attribute(i).numValues(); m_TotalAttValues += m_NumAttValues[i]; } else { m_StartAttIndex[i] = -1; m_NumAttValues[i] = m_NumClasses; } } // allocate space for counts and frequencies m_ClassCounts = new double[m_NumClasses]; m_AttCounts = new double[m_TotalAttValues]; m_AttAttCounts = new double[m_TotalAttValues][m_TotalAttValues]; m_ClassAttAttCounts = new double[m_NumClasses][m_TotalAttValues][m_TotalAttValues]; m_Header = new Instances(instances, 0); // Calculate the counts for (int k = 0; k < m_NumInstances; k++) { int classVal = (int) instances.instance(k).classValue(); m_ClassCounts[classVal]++; int[] attIndex = new int[m_NumAttributes]; for (int i = 0; i < m_NumAttributes; i++) { if (i == m_ClassIndex) { attIndex[i] = -1; } else { attIndex[i] = m_StartAttIndex[i] + (int) instances.instance(k).value(i); m_AttCounts[attIndex[i]]++; } } for (int Att1 = 0; Att1 < m_NumAttributes; Att1++) { if (attIndex[Att1] == -1) continue; for (int Att2 = 0; Att2 < m_NumAttributes; Att2++) { if ((attIndex[Att2] != -1)) { m_AttAttCounts[attIndex[Att1]][attIndex[Att2]]++; m_ClassAttAttCounts[classVal][attIndex[Att1]][attIndex[Att2]]++; } } } } // compute mutual information between each attribute and class m_mutualInformation = new double[m_NumAttributes]; for (int att = 0; att < m_NumAttributes; att++) { if (att == m_ClassIndex) continue; m_mutualInformation[att] = mutualInfo(att); } }
/** * Creates split on numeric attribute. * * @exception Exception if something goes wrong */ private void handleNumericAttribute(Instances trainInstances) throws Exception { int firstMiss; int next = 1; int last = 0; int index = 0; int splitIndex = -1; double currentInfoGain; double defaultEnt; double minSplit; Instance instance; int i; // Current attribute is a numeric attribute. m_distribution = new Distribution(2, trainInstances.numClasses()); // Only Instances with known values are relevant. Enumeration enu = trainInstances.enumerateInstances(); i = 0; while (enu.hasMoreElements()) { instance = (Instance) enu.nextElement(); if (instance.isMissing(m_attIndex)) break; m_distribution.add(1, instance); i++; } firstMiss = i; // Compute minimum number of Instances required in each // subset. minSplit = 0.1 * (m_distribution.total()) / ((double) trainInstances.numClasses()); if (Utils.smOrEq(minSplit, m_minNoObj)) minSplit = m_minNoObj; else if (Utils.gr(minSplit, 25)) minSplit = 25; // Enough Instances with known values? if (Utils.sm((double) firstMiss, 2 * minSplit)) return; // Compute values of criteria for all possible split // indices. defaultEnt = m_infoGainCrit.oldEnt(m_distribution); while (next < firstMiss) { if (trainInstances.instance(next - 1).value(m_attIndex) + 1e-5 < trainInstances.instance(next).value(m_attIndex)) { // Move class values for all Instances up to next // possible split point. m_distribution.shiftRange(1, 0, trainInstances, last, next); // Check if enough Instances in each subset and compute // values for criteria. if (Utils.grOrEq(m_distribution.perBag(0), minSplit) && Utils.grOrEq(m_distribution.perBag(1), minSplit)) { currentInfoGain = m_infoGainCrit.splitCritValue(m_distribution, m_sumOfWeights, defaultEnt); if (Utils.gr(currentInfoGain, m_infoGain)) { m_infoGain = currentInfoGain; splitIndex = next - 1; } index++; } last = next; } next++; } // Was there any useful split? if (index == 0) return; // Compute modified information gain for best split. if (m_useMDLcorrection) { m_infoGain = m_infoGain - (Utils.log2(index) / m_sumOfWeights); } if (Utils.smOrEq(m_infoGain, 0)) return; // Set instance variables' values to values for // best split. m_numSubsets = 2; m_splitPoint = (trainInstances.instance(splitIndex + 1).value(m_attIndex) + trainInstances.instance(splitIndex).value(m_attIndex)) / 2; // In case we have a numerical precision problem we need to choose the // smaller value if (m_splitPoint == trainInstances.instance(splitIndex + 1).value(m_attIndex)) { m_splitPoint = trainInstances.instance(splitIndex).value(m_attIndex); } // Restore distributioN for best split. m_distribution = new Distribution(2, trainInstances.numClasses()); m_distribution.addRange(0, trainInstances, 0, splitIndex + 1); m_distribution.addRange(1, trainInstances, splitIndex + 1, firstMiss); // Compute modified gain ratio for best split. m_gainRatio = m_gainRatioCrit.splitCritValue(m_distribution, m_sumOfWeights, m_infoGain); }
protected void tokenizeInstance(Instance instance, boolean updateDictionary) { if (m_inputVector == null) { m_inputVector = new LinkedHashMap<String, Count>(); } else { m_inputVector.clear(); } if (m_useStopList && m_stopwords == null) { m_stopwords = new Stopwords(); try { if (getStopwords().exists() && !getStopwords().isDirectory()) { m_stopwords.read(getStopwords()); } } catch (Exception ex) { ex.printStackTrace(); } } for (int i = 0; i < instance.numAttributes(); i++) { if (instance.attribute(i).isString() && !instance.isMissing(i)) { m_tokenizer.tokenize(instance.stringValue(i)); while (m_tokenizer.hasMoreElements()) { String word = m_tokenizer.nextElement(); if (m_lowercaseTokens) { word = word.toLowerCase(); } word = m_stemmer.stem(word); if (m_useStopList) { if (m_stopwords.is(word)) { continue; } } Count docCount = m_inputVector.get(word); if (docCount == null) { m_inputVector.put(word, new Count(instance.weight())); } else { docCount.m_count += instance.weight(); } } } } if (updateDictionary) { int classValue = (int) instance.classValue(); LinkedHashMap<String, Count> dictForClass = m_probOfWordGivenClass.get(classValue); // document normalization double iNorm = 0; double fv = 0; if (m_normalize) { for (Count c : m_inputVector.values()) { // word counts or bag-of-words? fv = (m_wordFrequencies) ? c.m_count : 1.0; iNorm += Math.pow(Math.abs(fv), m_lnorm); } iNorm = Math.pow(iNorm, 1.0 / m_lnorm); } for (Map.Entry<String, Count> feature : m_inputVector.entrySet()) { String word = feature.getKey(); double freq = (m_wordFrequencies) ? feature.getValue().m_count : 1.0; // double freq = (feature.getValue().m_count / iNorm * m_norm); if (m_normalize) { freq /= (iNorm * m_norm); } // check all classes for (int i = 0; i < m_data.numClasses(); i++) { LinkedHashMap<String, Count> dict = m_probOfWordGivenClass.get(i); if (dict.get(word) == null) { dict.put(word, new Count(m_leplace)); m_wordsPerClass[i] += m_leplace; } } Count dictCount = dictForClass.get(word); /* * if (dictCount == null) { dictForClass.put(word, new Count(m_leplace + * freq)); m_wordsPerClass[classValue] += (m_leplace + freq); } else { */ dictCount.m_count += freq; m_wordsPerClass[classValue] += freq; // } } pruneDictionary(); } }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ @Override public void buildClassifier(Instances instances) throws Exception { int attIndex = 0; double sum; // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); m_Instances = new Instances(instances, 0); // Reserve space m_Counts = new double[instances.numClasses()][instances.numAttributes() - 1][0]; m_Means = new double[instances.numClasses()][instances.numAttributes() - 1]; m_Devs = new double[instances.numClasses()][instances.numAttributes() - 1]; m_Priors = new double[instances.numClasses()]; Enumeration<Attribute> enu = instances.enumerateAttributes(); while (enu.hasMoreElements()) { Attribute attribute = enu.nextElement(); if (attribute.isNominal()) { for (int j = 0; j < instances.numClasses(); j++) { m_Counts[j][attIndex] = new double[attribute.numValues()]; } } else { for (int j = 0; j < instances.numClasses(); j++) { m_Counts[j][attIndex] = new double[1]; } } attIndex++; } // Compute counts and sums Enumeration<Instance> enumInsts = instances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = enumInsts.nextElement(); if (!instance.classIsMissing()) { Enumeration<Attribute> enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (!instance.isMissing(attribute)) { if (attribute.isNominal()) { m_Counts[(int) instance.classValue()][attIndex][(int) instance.value(attribute)]++; } else { m_Means[(int) instance.classValue()][attIndex] += instance.value(attribute); m_Counts[(int) instance.classValue()][attIndex][0]++; } } attIndex++; } m_Priors[(int) instance.classValue()]++; } } // Compute means Enumeration<Attribute> enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (attribute.isNumeric()) { for (int j = 0; j < instances.numClasses(); j++) { if (m_Counts[j][attIndex][0] < 2) { throw new Exception( "attribute " + attribute.name() + ": less than two values for class " + instances.classAttribute().value(j)); } m_Means[j][attIndex] /= m_Counts[j][attIndex][0]; } } attIndex++; } // Compute standard deviations enumInsts = instances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = enumInsts.nextElement(); if (!instance.classIsMissing()) { enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (!instance.isMissing(attribute)) { if (attribute.isNumeric()) { m_Devs[(int) instance.classValue()][attIndex] += (m_Means[(int) instance.classValue()][attIndex] - instance.value(attribute)) * (m_Means[(int) instance.classValue()][attIndex] - instance.value(attribute)); } } attIndex++; } } } enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (attribute.isNumeric()) { for (int j = 0; j < instances.numClasses(); j++) { if (m_Devs[j][attIndex] <= 0) { throw new Exception( "attribute " + attribute.name() + ": standard deviation is 0 for class " + instances.classAttribute().value(j)); } else { m_Devs[j][attIndex] /= m_Counts[j][attIndex][0] - 1; m_Devs[j][attIndex] = Math.sqrt(m_Devs[j][attIndex]); } } } attIndex++; } // Normalize counts enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (attribute.isNominal()) { for (int j = 0; j < instances.numClasses(); j++) { sum = Utils.sum(m_Counts[j][attIndex]); for (int i = 0; i < attribute.numValues(); i++) { m_Counts[j][attIndex][i] = (m_Counts[j][attIndex][i] + 1) / (sum + attribute.numValues()); } } } attIndex++; } // Normalize priors sum = Utils.sum(m_Priors); for (int j = 0; j < instances.numClasses(); j++) { m_Priors[j] = (m_Priors[j] + 1) / (sum + instances.numClasses()); } }