@Override public String classify(User user, Sample sample) { Instances trainingSet = new TrainingSetBuilder() .setAttributes(user.getBssids()) .setClassAttribute( "Location", user.getLocations().stream().map(Location::getName).collect(Collectors.toList())) .build("TrainingSet", 1); // Create instance Map<String, Integer> BSSIDLevelMap = getBSSIDLevelMap(sample); Instance instance = new Instance(trainingSet.numAttributes()); for (Enumeration e = trainingSet.enumerateAttributes(); e.hasMoreElements(); ) { Attribute attribute = (Attribute) e.nextElement(); String bssid = attribute.name(); int level = (BSSIDLevelMap.containsKey(bssid)) ? BSSIDLevelMap.get(bssid) : 0; instance.setValue(attribute, level); } if (sample.getLocation() != null) instance.setValue(trainingSet.classAttribute(), sample.getLocation()); instance.setDataset(trainingSet); trainingSet.add(instance); int predictedClass = classify(fromBase64(user.getClassifiers()), instance); return trainingSet.classAttribute().value(predictedClass); }
/** * Returns a string representation of the classifier. * * @return a string representation of the classifier */ public String toString() { StringBuffer result = new StringBuffer( "The independent probability of a class\n--------------------------------------\n"); for (int c = 0; c < m_numClasses; c++) result .append(m_headerInfo.classAttribute().value(c)) .append("\t") .append(Double.toString(m_probOfClass[c])) .append("\n"); result.append( "\nThe probability of a word given the class\n-----------------------------------------\n\t"); for (int c = 0; c < m_numClasses; c++) result.append(m_headerInfo.classAttribute().value(c)).append("\t"); result.append("\n"); for (int w = 0; w < m_numAttributes; w++) { result.append(m_headerInfo.attribute(w).name()).append("\t"); for (int c = 0; c < m_numClasses; c++) result.append(Double.toString(Math.exp(m_probOfWordGivenClass[c][w]))).append("\t"); result.append("\n"); } return result.toString(); }
@Override public void buildClassifier(Instances data) throws Exception { trainingData = data; Attribute classAttribute = data.classAttribute(); prototypes = new ArrayList<>(); classedData = new HashMap<String, ArrayList<Sequence>>(); indexClassedDataInFullData = new HashMap<String, ArrayList<Integer>>(); for (int c = 0; c < data.numClasses(); c++) { classedData.put(data.classAttribute().value(c), new ArrayList<Sequence>()); indexClassedDataInFullData.put(data.classAttribute().value(c), new ArrayList<Integer>()); } sequences = new Sequence[data.numInstances()]; classMap = new String[sequences.length]; for (int i = 0; i < sequences.length; i++) { Instance sample = data.instance(i); MonoDoubleItemSet[] sequence = new MonoDoubleItemSet[sample.numAttributes() - 1]; int shift = (sample.classIndex() == 0) ? 1 : 0; for (int t = 0; t < sequence.length; t++) { sequence[t] = new MonoDoubleItemSet(sample.value(t + shift)); } sequences[i] = new Sequence(sequence); String clas = sample.stringValue(classAttribute); classMap[i] = clas; classedData.get(clas).add(sequences[i]); indexClassedDataInFullData.get(clas).add(i); // System.out.println("Element "+i+" of train is classed "+clas+" and went to element // "+(indexClassedDataInFullData.get(clas).size()-1)); } buildSpecificClassifier(data); }
/** * Returns a textual description of this classifier. * * @return a textual description of this classifier. */ @Override public String toString() { if (m_probOfClass == null) { return "NaiveBayesMultinomialText: No model built yet.\n"; } StringBuffer result = new StringBuffer(); // build a master dictionary over all classes HashSet<String> master = new HashSet<String>(); for (int i = 0; i < m_data.numClasses(); i++) { LinkedHashMap<String, Count> classDict = m_probOfWordGivenClass.get(i); for (String key : classDict.keySet()) { master.add(key); } } result.append("Dictionary size: " + master.size()).append("\n\n"); result.append("The independent frequency of a class\n"); result.append("--------------------------------------\n"); for (int i = 0; i < m_data.numClasses(); i++) { result .append(m_data.classAttribute().value(i)) .append("\t") .append(Double.toString(m_probOfClass[i])) .append("\n"); } result.append("\nThe frequency of a word given the class\n"); result.append("-----------------------------------------\n"); for (int i = 0; i < m_data.numClasses(); i++) { result.append(Utils.padLeft(m_data.classAttribute().value(i), 11)).append("\t"); } result.append("\n"); Iterator<String> masterIter = master.iterator(); while (masterIter.hasNext()) { String word = masterIter.next(); for (int i = 0; i < m_data.numClasses(); i++) { LinkedHashMap<String, Count> classDict = m_probOfWordGivenClass.get(i); Count c = classDict.get(word); if (c == null) { result.append("<laplace=1>\t"); } else { result.append(Utils.padLeft(Double.toString(c.m_count), 11)).append("\t"); } } result.append(word); result.append("\n"); } return result.toString(); }
/** * Inserts an instance into the hash table * * @param inst instance to be inserted * @param instA to create the hash key from * @throws Exception if the instance can't be inserted */ private void insertIntoTable(Instance inst, double[] instA) throws Exception { double[] tempClassDist2; double[] newDist; DecisionTableHashKey thekey; if (instA != null) { thekey = new DecisionTableHashKey(instA); } else { thekey = new DecisionTableHashKey(inst, inst.numAttributes(), false); } // see if this one is already in the table tempClassDist2 = (double[]) m_entries.get(thekey); if (tempClassDist2 == null) { if (m_classIsNominal) { newDist = new double[m_theInstances.classAttribute().numValues()]; // Leplace estimation for (int i = 0; i < m_theInstances.classAttribute().numValues(); i++) { newDist[i] = 1.0; } newDist[(int) inst.classValue()] = inst.weight(); // add to the table m_entries.put(thekey, newDist); } else { newDist = new double[2]; newDist[0] = inst.classValue() * inst.weight(); newDist[1] = inst.weight(); // add to the table m_entries.put(thekey, newDist); } } else { // update the distribution for this instance if (m_classIsNominal) { tempClassDist2[(int) inst.classValue()] += inst.weight(); // update the table m_entries.put(thekey, tempClassDist2); } else { tempClassDist2[0] += (inst.classValue() * inst.weight()); tempClassDist2[1] += inst.weight(); // update the table m_entries.put(thekey, tempClassDist2); } } }
/** * Stratify the given data into the given number of bags based on the class values. It differs * from the <code>Instances.stratify(int fold)</code> that before stratification it sorts the * instances according to the class order in the header file. It assumes no missing values in the * class. * * @param data the given data * @param folds the given number of folds * @param rand the random object used to randomize the instances * @return the stratified instances */ public static final Instances stratify(Instances data, int folds, Random rand) { if (!data.classAttribute().isNominal()) return data; Instances result = new Instances(data, 0); Instances[] bagsByClasses = new Instances[data.numClasses()]; for (int i = 0; i < bagsByClasses.length; i++) bagsByClasses[i] = new Instances(data, 0); // Sort by class for (int j = 0; j < data.numInstances(); j++) { Instance datum = data.instance(j); bagsByClasses[(int) datum.classValue()].add(datum); } // Randomize each class for (int j = 0; j < bagsByClasses.length; j++) bagsByClasses[j].randomize(rand); for (int k = 0; k < folds; k++) { int offset = k, bag = 0; oneFold: while (true) { while (offset >= bagsByClasses[bag].numInstances()) { offset -= bagsByClasses[bag].numInstances(); if (++bag >= bagsByClasses.length) // Next bag break oneFold; } result.add(bagsByClasses[bag].instance(offset)); offset += folds; } } return result; }
private static void writePredictedDistributions( Classifier c, Instances data, int idIndex, Writer out) throws Exception { // header out.write("id"); for (int i = 0; i < data.numClasses(); i++) { out.write(",\""); out.write(data.classAttribute().value(i).replaceAll("[\"\\\\]", "_")); out.write("\""); } out.write("\n"); // data for (int i = 0; i < data.numInstances(); i++) { final String id = data.instance(i).stringValue(idIndex); double[] distribution = c.distributionForInstance(data.instance(i)); // final String label = data.attribute(classIndex).value(); out.write(id); for (double probability : distribution) { out.write(","); out.write(String.valueOf(probability > 1e-5 ? (float) probability : 0f)); } out.write("\n"); } }
/** * Outputs the linear regression model as a string. * * @return the model as string */ public String toString() { if (m_TransformedData == null) { return "Linear Regression: No model built yet."; } try { StringBuffer text = new StringBuffer(); int column = 0; boolean first = true; text.append("\nLinear Regression Model\n\n"); text.append(m_TransformedData.classAttribute().name() + " =\n\n"); for (int i = 0; i < m_TransformedData.numAttributes(); i++) { if ((i != m_ClassIndex) && (m_SelectedAttributes[i])) { if (!first) text.append(" +\n"); else first = false; text.append(Utils.doubleToString(m_Coefficients[column], 12, 4) + " * "); text.append(m_TransformedData.attribute(i).name()); column++; } } text.append(" +\n" + Utils.doubleToString(m_Coefficients[column], 12, 4)); return text.toString(); } catch (Exception e) { return "Can't print Linear Regression!"; } }
protected void buildSpecificClassifier(Instances data) { if (distancesPerClass == null) { initDistances(); } ArrayList<String> classes = new ArrayList<String>(classedData.keySet()); for (String clas : classes) { // if the class is empty, continue if (classedData.get(clas).isEmpty()) continue; KMeansCachedSymbolicSequence kmeans = new KMeansCachedSymbolicSequence( nbPrototypesPerClass[trainingData.classAttribute().indexOfValue(clas)], classedData.get(clas), distancesPerClass.get(clas)); kmeans.cluster(); for (int i = 0; i < kmeans.centers.length; i++) { if (kmeans.centers[i] != null) { // ~ if empty cluster ClassedSequence s = new ClassedSequence(kmeans.centers[i], clas); prototypes.add(s); } } } }
@Override public List<Classifier> buildClassifiers(User user, List<Sample> validSamples) { Instances trainingSet = new TrainingSetBuilder() .setAttributes(user.getBssids()) .setClassAttribute( "Location", user.getLocations().stream().map(Location::getName).collect(Collectors.toList())) .build("TrainingSet", validSamples.size()); // Create instances validSamples.forEach( sample -> { Map<String, Integer> BSSIDLevelMap = getBSSIDLevelMap(sample); Instance instance = new Instance(trainingSet.numAttributes()); for (Enumeration e = trainingSet.enumerateAttributes(); e.hasMoreElements(); ) { Attribute attribute = (Attribute) e.nextElement(); String bssid = attribute.name(); int level = (BSSIDLevelMap.containsKey(bssid)) ? BSSIDLevelMap.get(bssid) : 0; instance.setValue(attribute, level); } instance.setValue(trainingSet.classAttribute(), sample.getLocation()); instance.setDataset(trainingSet); trainingSet.add(instance); }); // Build classifiers List<Classifier> classifiers = buildClassifiers(trainingSet); return classifiers; }
private static void analyse(Instances train, Instances datapredict) { String mpOptions = "-L 0.3 -M 0.2 -N 500 -V 0 -S 0 -E 20 -H a"; try { train.setClassIndex(train.numAttributes() - 1); train.deleteAttributeAt(0); int numClasses = train.numClasses(); for (int i = 0; i < numClasses; i++) { System.out.println("class value [" + i + "]=" + train.classAttribute().value(i) + ""); } // Instance of NN MultilayerPerceptron mlp = new MultilayerPerceptron(); mlp.setOptions(weka.core.Utils.splitOptions(mpOptions)); mlp.buildClassifier(train); datapredict.setClassIndex(datapredict.numAttributes() - 1); datapredict.deleteAttributeAt(0); // Instances predicteddata = new Instances(datapredict); for (int i = 0; i < datapredict.numInstances(); i++) { Instance newInst = datapredict.instance(i); double pred = mlp.classifyInstance(newInst); int predInt = (int) pred; // Math.round(pred); String predString = train.classAttribute().value(predInt); System.out.println( "cliente[" + i + "] pred[" + pred + "] predInt[" + predInt + "] desertor[" + predString + "]"); } } catch (Exception e) { e.printStackTrace(); } }
@Override public void buildClassifier(Instances instances) throws Exception { List<List<Object>> data = new LinkedList<List<Object>>(); int classAttribute = instances.classAttribute().index(); determineNumericAttributes(instances); for (Instance inst : instances) { data.add(convert(inst)); } classifier.learnModel(data, classAttribute); }
/** * Buildclassifier selects a classifier from the set of classifiers by minimising error on the * training data. * * @param data the training data to be used for generating the boosted classifier. * @exception Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { if (m_Classifiers.length == 0) { throw new Exception("No base classifiers have been set!"); } Instances newData = new Instances(data); newData.deleteWithMissingClass(); newData.randomize(new Random(m_Seed)); if (newData.classAttribute().isNominal() && (m_NumXValFolds > 1)) newData.stratify(m_NumXValFolds); Instances train = newData; // train on all data by default Instances test = newData; // test on training data by default Classifier bestClassifier = null; int bestIndex = -1; double bestPerformance = Double.NaN; int numClassifiers = m_Classifiers.length; for (int i = 0; i < numClassifiers; i++) { Classifier currentClassifier = getClassifier(i); Evaluation evaluation; if (m_NumXValFolds > 1) { evaluation = new Evaluation(newData); for (int j = 0; j < m_NumXValFolds; j++) { train = newData.trainCV(m_NumXValFolds, j); test = newData.testCV(m_NumXValFolds, j); currentClassifier.buildClassifier(train); evaluation.setPriors(train); evaluation.evaluateModel(currentClassifier, test); } } else { currentClassifier.buildClassifier(train); evaluation = new Evaluation(train); evaluation.evaluateModel(currentClassifier, test); } double error = evaluation.errorRate(); if (m_Debug) { System.err.println( "Error rate: " + Utils.doubleToString(error, 6, 4) + " for classifier " + currentClassifier.getClass().getName()); } if ((i == 0) || (error < bestPerformance)) { bestClassifier = currentClassifier; bestPerformance = error; bestIndex = i; } } m_ClassifierIndex = bestIndex; m_Classifier = bestClassifier; if (m_NumXValFolds > 1) { m_Classifier.buildClassifier(newData); } }
/** Initializes the m_Attributes of the class. */ private void init_m_Attributes() { try { m_NumInstances = m_Train.numInstances(); m_NumClasses = m_Train.numClasses(); m_NumAttributes = m_Train.numAttributes(); m_ClassType = m_Train.classAttribute().type(); m_InitFlag = ON; } catch (Exception e) { e.printStackTrace(); } }
/** * processes the instances using the HAAR algorithm * * @param instances the data to process * @return the modified data * @throws Exception in case the processing goes wrong */ protected Instances processHAAR(Instances instances) throws Exception { Instances result; int i; int n; int j; int clsIdx; double[] oldVal; double[] newVal; int level; int length; double[] clsVal; Attribute clsAtt; clsIdx = instances.classIndex(); clsVal = null; clsAtt = null; if (clsIdx > -1) { clsVal = instances.attributeToDoubleArray(clsIdx); clsAtt = (Attribute) instances.classAttribute().copy(); instances.setClassIndex(-1); instances.deleteAttributeAt(clsIdx); } result = new Instances(instances, 0); level = (int) StrictMath.ceil(StrictMath.log(instances.numAttributes()) / StrictMath.log(2.0)); for (i = 0; i < instances.numInstances(); i++) { oldVal = instances.instance(i).toDoubleArray(); newVal = new double[oldVal.length]; for (n = level; n > 0; n--) { length = (int) StrictMath.pow(2, n - 1); for (j = 0; j < length; j++) { newVal[j] = (oldVal[j * 2] + oldVal[j * 2 + 1]) / StrictMath.sqrt(2); newVal[j + length] = (oldVal[j * 2] - oldVal[j * 2 + 1]) / StrictMath.sqrt(2); } System.arraycopy(newVal, 0, oldVal, 0, newVal.length); } // add new transformed instance result.add(new DenseInstance(1, newVal)); } // add class again if (clsIdx > -1) { result.insertAttributeAt(clsAtt, clsIdx); result.setClassIndex(clsIdx); for (i = 0; i < clsVal.length; i++) result.instance(i).setClassValue(clsVal[i]); } return result; }
public String classifyInstance(String newInst) { File f = null; String type = null; try { f = new File("/data/data/com.example.gpstracker/tmp.arff"); f.createNewFile(); FileWriter fw = new FileWriter(f); BufferedWriter bw = new BufferedWriter(fw); bw.write("@relation gps_tracking"); bw.newLine(); bw.newLine(); bw.write("@attribute Longtitude numeric"); bw.newLine(); bw.write("@attribute Latitude numeric"); bw.newLine(); bw.write("@attribute CurrentSpeed numeric"); bw.newLine(); bw.write("@attribute Timestamp date \"yyyy-MM-dd HH:mm:ss\""); bw.newLine(); bw.write("@attribute MoveType {Walking,Running,Biking,Driving,Metro,Bus,Motionless}"); bw.newLine(); bw.write("@attribute IsGpsFixed {yes,no}"); bw.newLine(); bw.newLine(); bw.write("@data"); bw.newLine(); bw.write(newInst); bw.close(); // load unlabeled data Instances unlabeled = new Instances( new BufferedReader(new FileReader("/data/data/com.example.gpstracker/tmp.arff"))); // set class attribute unlabeled.setClassIndex(unlabeled.numAttributes() - 2); // label instances double clsLabel = classifier.classifyInstance(unlabeled.instance(0)); type = unlabeled.classAttribute().value((int) clsLabel); boolean deleted = f.delete(); } catch (FileNotFoundException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } catch (Exception e) { e.printStackTrace(); } return type; }
@Override public Instances labelData(String data) throws Exception { Instances unlabeled = new Instances(new BufferedReader(new FileReader(data))); // set class attribute unlabeled.setClassIndex(unlabeled.numAttributes() - 1); // create copy Instances labeled = new Instances(unlabeled); for (int i = 0; i < unlabeled.numInstances(); i++) { Instance ui = unlabeled.instance(i); double clsLabel = this.classifier.classifyInstance(ui); labeled.instance(i).setClassValue(clsLabel); System.out.println(ui.toString() + " -> " + unlabeled.classAttribute().value((int) clsLabel)); } return labeled; }
/** * Sets the format of the input instances. * * @param instanceInfo an Instances object containing the input instance structure (any instances * contained in the object are ignored - only the structure is required). * @return true if the outputFormat may be collected immediately * @throws Exception if the input format can't be set successfully */ @Override public boolean setInputFormat(Instances instanceInfo) throws Exception { super.setInputFormat(instanceInfo); if (instanceInfo.classIndex() < 0) { throw new UnassignedClassException("No class has been assigned to the instances"); } setOutputFormat(); m_Indices = null; if (instanceInfo.classAttribute().isNominal()) { return true; } else { return false; } }
/** * Add a rule to the ruleset and update the stats * * @param lastRule the rule to be added */ public void addAndUpdate(Rule lastRule) { if (m_Ruleset == null) m_Ruleset = new FastVector(); m_Ruleset.addElement(lastRule); Instances data = (m_Filtered == null) ? m_Data : ((Instances[]) m_Filtered.lastElement())[1]; double[] stats = new double[6]; double[] classCounts = new double[m_Data.classAttribute().numValues()]; Instances[] filtered = computeSimpleStats(m_Ruleset.size() - 1, data, stats, classCounts); if (m_Filtered == null) m_Filtered = new FastVector(); m_Filtered.addElement(filtered); if (m_SimpleStats == null) m_SimpleStats = new FastVector(); m_SimpleStats.addElement(stats); if (m_Distributions == null) m_Distributions = new FastVector(); m_Distributions.addElement(classCounts); }
/** * Returns a description of the classifier. * * @return a description of the classifier as a string. */ @Override public String toString() { if (m_Instances == null) { return "Naive Bayes (simple): No model built yet."; } try { StringBuffer text = new StringBuffer("Naive Bayes (simple)"); int attIndex; for (int i = 0; i < m_Instances.numClasses(); i++) { text.append( "\n\nClass " + m_Instances.classAttribute().value(i) + ": P(C) = " + Utils.doubleToString(m_Priors[i], 10, 8) + "\n\n"); Enumeration<Attribute> enumAtts = m_Instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); text.append("Attribute " + attribute.name() + "\n"); if (attribute.isNominal()) { for (int j = 0; j < attribute.numValues(); j++) { text.append(attribute.value(j) + "\t"); } text.append("\n"); for (int j = 0; j < attribute.numValues(); j++) { text.append(Utils.doubleToString(m_Counts[i][attIndex][j], 10, 8) + "\t"); } } else { text.append("Mean: " + Utils.doubleToString(m_Means[i][attIndex], 10, 8) + "\t"); text.append("Standard Deviation: " + Utils.doubleToString(m_Devs[i][attIndex], 10, 8)); } text.append("\n\n"); attIndex++; } } return text.toString(); } catch (Exception e) { return "Can't print Naive Bayes classifier!"; } }
/** * Filter the data according to the ruleset and compute the basic stats: coverage/uncoverage, * true/false positive/negatives of each rule */ public void countData() { if ((m_Filtered != null) || (m_Ruleset == null) || (m_Data == null)) return; int size = m_Ruleset.size(); m_Filtered = new FastVector(size); m_SimpleStats = new FastVector(size); m_Distributions = new FastVector(size); Instances data = new Instances(m_Data); for (int i = 0; i < size; i++) { double[] stats = new double[6]; // 6 statistics parameters double[] classCounts = new double[m_Data.classAttribute().numValues()]; Instances[] filtered = computeSimpleStats(i, data, stats, classCounts); m_Filtered.addElement(filtered); m_SimpleStats.addElement(stats); m_Distributions.addElement(classCounts); data = filtered[1]; // Data not covered } }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class Instances trainData = new Instances(instances); trainData.deleteWithMissingClass(); if (!(m_Classifier instanceof OptionHandler)) { throw new IllegalArgumentException("Base classifier should be OptionHandler."); } m_InitOptions = ((OptionHandler) m_Classifier).getOptions(); m_BestPerformance = -99; m_NumAttributes = trainData.numAttributes(); Random random = new Random(m_Seed); trainData.randomize(random); m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances(); // Check whether there are any parameters to optimize if (m_CVParams.size() == 0) { m_Classifier.buildClassifier(trainData); m_BestClassifierOptions = m_InitOptions; return; } if (trainData.classAttribute().isNominal()) { trainData.stratify(m_NumFolds); } m_BestClassifierOptions = null; // Set up m_ClassifierOptions -- take getOptions() and remove // those being optimised. m_ClassifierOptions = ((OptionHandler) m_Classifier).getOptions(); for (int i = 0; i < m_CVParams.size(); i++) { Utils.getOption(((CVParameter) m_CVParams.elementAt(i)).m_ParamChar, m_ClassifierOptions); } findParamsByCrossValidation(0, trainData, random); String[] options = (String[]) m_BestClassifierOptions.clone(); ((OptionHandler) m_Classifier).setOptions(options); m_Classifier.buildClassifier(trainData); }
/** * Method for building an Id3 tree. * * @param data the training data * @exception Exception if decision tree can't be built successfully */ private void makeTree(Instances data) throws Exception { // Check if no instances have reached this node. if (data.numInstances() == 0) { m_Attribute = null; m_ClassValue = Utils.missingValue(); m_Distribution = new double[data.numClasses()]; return; } // Compute attribute with maximum information gain. double[] infoGains = new double[data.numAttributes()]; Enumeration attEnum = data.enumerateAttributes(); while (attEnum.hasMoreElements()) { Attribute att = (Attribute) attEnum.nextElement(); infoGains[att.index()] = computeInfoGain(data, att); } m_Attribute = data.attribute(Utils.maxIndex(infoGains)); // Make leaf if information gain is zero. // Otherwise create successors. if (Utils.eq(infoGains[m_Attribute.index()], 0)) { m_Attribute = null; m_Distribution = new double[data.numClasses()]; Enumeration instEnum = data.enumerateInstances(); while (instEnum.hasMoreElements()) { Instance inst = (Instance) instEnum.nextElement(); m_Distribution[(int) inst.classValue()]++; } Utils.normalize(m_Distribution); m_ClassValue = Utils.maxIndex(m_Distribution); m_ClassAttribute = data.classAttribute(); } else { Instances[] splitData = splitData(data, m_Attribute); m_Successors = new Id3[m_Attribute.numValues()]; for (int j = 0; j < m_Attribute.numValues(); j++) { m_Successors[j] = new Id3(); m_Successors[j].makeTree(splitData[j]); } } }
private static void evaluateClassifier(Classifier c, Instances trainData, Instances testData) throws Exception { System.err.println( "INFO: Starting split validation to predict '" + trainData.classAttribute().name() + "' using '" + c.getClass().getCanonicalName() + ":" + Arrays.toString(c.getOptions()) + "' (#train=" + trainData.numInstances() + ",#test=" + testData.numInstances() + ") ..."); if (trainData.classIndex() < 0) throw new IllegalStateException("class attribute not set"); c.buildClassifier(trainData); Evaluation eval = new Evaluation(testData); eval.useNoPriors(); double[] predictions = eval.evaluateModel(c, testData); System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); // write predictions to file { System.err.println("INFO: Writing predictions to file ..."); Writer out = new FileWriter("prediction.trec"); writePredictionsTrecEval(predictions, testData, 0, trainData.classIndex(), out); out.close(); } // write predicted distributions to CSV { System.err.println("INFO: Writing predicted distributions to CSV ..."); Writer out = new FileWriter("predicted_distribution.csv"); writePredictedDistributions(c, testData, 0, out); out.close(); } }
/** * test on one sample * * @param sample * @return p(y|sample) forall y * @throws Exception */ public double classifyInstance(Instance sample) throws Exception { // transform instance to sequence MonoDoubleItemSet[] sequence = new MonoDoubleItemSet[sample.numAttributes() - 1]; int shift = (sample.classIndex() == 0) ? 1 : 0; for (int t = 0; t < sequence.length; t++) { sequence[t] = new MonoDoubleItemSet(sample.value(t + shift)); } Sequence seq = new Sequence(sequence); // for each class String classValue = null; double maxProb = 0.0; double[] pr = new double[classedData.keySet().size()]; for (String clas : classedData.keySet()) { int c = trainingData.classAttribute().indexOfValue(clas); double prob = 0.0; for (int k = 0; k < centroidsPerClass[c].length; k++) { // compute P(Q|k_c) if (sigmasPerClass[c][k] == Double.NaN || sigmasPerClass[c][k] == 0) { System.err.println("sigma=NAN||sigma=0"); continue; } double dist = seq.distanceEuc(centroidsPerClass[c][k]); double p = computeProbaForQueryAndCluster(sigmasPerClass[c][k], dist); prob += p / centroidsPerClass[c].length; // prob += p*prior[c][k]; if (p > maxProb) { maxProb = p; classValue = clas; } } // if (prob > maxProb) { // maxProb = prob; // classValue = clas; // } } // System.out.println(Arrays.toString(pr)); // System.out.println(classValue); return sample.classAttribute().indexOfValue(classValue); }
/** * use for training data * * @param instancesWithMeta * @param labelInstances * @return * @throws Exception */ public static Instances addNominalLabelsForClassificationToTrainingData( Instances instances, AttributeFilterMeta instancesWithMeta, Instances labelInstances) throws Exception { Instances finalCleaned = Instances.mergeInstances(instances, labelInstances); finalCleaned.setClassIndex(finalCleaned.numAttributes() - 1); Attribute classAt = finalCleaned.classAttribute(); int numOfAttValues = classAt.numValues(); String attValues = ""; for (int nai = 0; nai < numOfAttValues; nai++) { if (nai != 0) { attValues += ","; } attValues += classAt.value(nai); } instancesWithMeta.setClassAtrributeValues(attValues); instancesWithMeta.setInstances(finalCleaned); return finalCleaned; }
public void chooseClassifier() { int classIndex = 0; // number of attributes must be greater than 1 /** * We can use either a supervised or an un-supervised algorithm if a class attribute already * exists in the dataset (meaning some labeled instances exists), depending on the size of the * training set, the decision is taken. */ classIndex = traindata.numAttributes() - 1; traindata.setClassIndex(classIndex); if (classIndex == traindata.numAttributes() - 1 || traindata.attribute("class") != null || traindata.attribute("Class") != null && traindata.size() >= testdata.size()) { System.out.println("class attribute found...."); System.out.println("Initial training set is larger than the test set...." + traindata.size()); // Go ahead to generate folds, then call classifier try { ce.generateFolds(traindata); } catch (Exception ex) { Logger.getLogger(FileTypeEnablerAndProcessor.class.getName()).log(Level.SEVERE, null, ex); } } /** * When there is no class attribute to show labeled instances exists then use an un-supervised * algorithm straight; no need for the cross-validation folds. */ else { try { System.out.println("class attribute not found"); classIndex = traindata.numAttributes() - 1; traindata.setClassIndex(classIndex); System.out.println("Class to predict is = " + traindata.classAttribute() + "\n"); uc.autoProbClass(traindata); } catch (Exception ex) { Logger.getLogger(FileTypeEnablerAndProcessor.class.getName()).log(Level.SEVERE, null, ex); } } }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); m_NumClasses = instances.numClasses(); m_ClassType = instances.classAttribute().type(); m_Train = new Instances(instances, 0, instances.numInstances()); // Throw away initial instances until within the specified window size if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) { m_Train = new Instances(m_Train, m_Train.numInstances() - m_WindowSize, m_WindowSize); } m_NumAttributesUsed = 0.0; for (int i = 0; i < m_Train.numAttributes(); i++) { if ((i != m_Train.classIndex()) && (m_Train.attribute(i).isNominal() || m_Train.attribute(i).isNumeric())) { m_NumAttributesUsed += 1.0; } } m_NNSearch.setInstances(m_Train); // Invalidate any currently cross-validation selected k m_kNNValid = false; m_defaultModel = new ZeroR(); m_defaultModel.buildClassifier(instances); m_defaultModel.setOptions(getOptions()); // System.out.println("hello world"); }
static void evaluateClassifier(Classifier c, Instances data, int folds) throws Exception { System.err.println( "INFO: Starting crossvalidation to predict '" + data.classAttribute().name() + "' using '" + c.getClass().getCanonicalName() + ":" + Arrays.toString(c.getOptions()) + "' ..."); StringBuffer sb = new StringBuffer(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(c, data, folds, new Random(1), sb, new Range("first"), Boolean.FALSE); // write predictions to file { Writer out = new FileWriter("cv.log"); out.write(sb.toString()); out.close(); } System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); }
@Override public NaiveBayesMultinomialText aggregate(NaiveBayesMultinomialText toAggregate) throws Exception { if (m_numModels == Integer.MIN_VALUE) { throw new Exception( "Can't aggregate further - model has already been " + "aggregated and finalized"); } if (m_probOfClass == null) { throw new Exception("No model built yet, can't aggregate"); } // just check the class attribute for compatibility as we will be // merging dictionaries if (!m_data.classAttribute().equals(toAggregate.m_data.classAttribute())) { throw new Exception( "Can't aggregate - class attribute in data headers " + "does not match: " + m_data.classAttribute().equalsMsg(toAggregate.m_data.classAttribute())); } for (int i = 0; i < m_probOfClass.length; i++) { m_probOfClass[i] += toAggregate.m_probOfClass[i]; } Map<Integer, LinkedHashMap<String, Count>> dicts = toAggregate.m_probOfWordGivenClass; Iterator<Map.Entry<Integer, LinkedHashMap<String, Count>>> perClass = dicts.entrySet().iterator(); while (perClass.hasNext()) { Map.Entry<Integer, LinkedHashMap<String, Count>> currentClassDict = perClass.next(); LinkedHashMap<String, Count> masterDict = m_probOfWordGivenClass.get(currentClassDict.getKey()); if (masterDict == null) { // we haven't seen this class during our training masterDict = new LinkedHashMap<String, Count>(); m_probOfWordGivenClass.put(currentClassDict.getKey(), masterDict); } // now process words seen for this class Iterator<Map.Entry<String, Count>> perClassEntries = currentClassDict.getValue().entrySet().iterator(); while (perClassEntries.hasNext()) { Map.Entry<String, Count> entry = perClassEntries.next(); Count masterCount = masterDict.get(entry.getKey()); if (masterCount == null) { // we haven't seen this entry (or its been pruned) masterCount = new Count(entry.getValue().m_count); masterDict.put(entry.getKey(), masterCount); } else { // add up masterCount.m_count += entry.getValue().m_count; } } } m_numModels++; return this; }