/** * Trains the classifier on the array of Signal objects. Implementations of this method should * also produce an ordered list of the class names which can be returned with the <code> * getClassNames</code> method. * * @param inputData the Signal array that the model should be trained on. * @throws noMetadataException Thrown if there is no class metadata to train the Gaussian model * with */ public void train(Signal[] inputData) { List classNamesList = new ArrayList(); for (int i = 0; i < inputData.length; i++) { try { String className = inputData[i].getStringMetadata(Signal.PROP_CLASS); if ((className != null) && (!classNamesList.contains(className))) { classNamesList.add(className); } } catch (noMetadataException ex) { throw new RuntimeException("No class metadata found to train model on!", ex); } } Collections.sort(classNamesList); classnames = (String[]) classNamesList.toArray(new String[classNamesList.size()]); FastVector classValues = new FastVector(classnames.length); for (int i = 0; i < classnames.length; i++) { classValues.addElement(classnames[i]); } classAttribute = new Attribute(Signal.PROP_CLASS, classValues); Instances trainingDataSet = new Instances(Signal2Instance.convert(inputData[0], classAttribute)); if (inputData.length > 1) { for (int i = 1; i < inputData.length; i++) { Instances aSignalInstance = Signal2Instance.convert(inputData[i], classAttribute); for (int j = 0; j < aSignalInstance.numInstances(); j++) trainingDataSet.add(aSignalInstance.instance(j)); } } trainingDataSet.setClass(classAttribute); inputData = null; theRule = new MISMO(); // parse options StringTokenizer stOption = new StringTokenizer(this.MISMOOptions, " "); String[] options = new String[stOption.countTokens()]; for (int i = 0; i < options.length; i++) { options[i] = stOption.nextToken(); } try { theRule.setOptions(options); } catch (Exception ex) { throw new RuntimeException("Failed to set MISMO classifier options!", ex); } try { theRule.buildClassifier(trainingDataSet); System.out.println("WEKA: outputting MISMO classifier; " + theRule.globalInfo()); } catch (Exception ex) { throw new RuntimeException("Failed to train classifier!", ex); } }
/** * Classify a single vector. A RuntimeException should be thrown if the classifier is untrained. * * @param input Vector to classify * @return Integer indicating the class. */ public int classifyVector(double[] input) { Instance in = new Instance(1.0, input); try { return (int) theRule.classifyInstance(in); } catch (Exception ex) { System.out.println("Exception occured when classifying a vector!\n" + ex.getMessage()); ex.printStackTrace(); return -1; } }
/** * Calculates the probability of class membership of a single data vector. * * @param input The data vector to calculate the probabilities of class membership for. This * probabilities should be ordered such that the indexes match the class names returned by * <code>getClassNames</code> and should sum to 1.0. * @return An array of the probabilities of class membership */ public double[] probabilities(double[] input) { Instance in = new Instance(1.0, input); try { return theRule.distributionForInstance(in); } catch (Exception ex) { System.out.println( "Exception occured when calculating distribution for a vector!\n" + ex.getMessage()); ex.printStackTrace(); return null; } }
/** * Calculates the probability of class membership of a Signal Object. * * @param inputSignal The Signal object to calculate the probabilities of class membership for. * This probabilities should be ordered such that the indexes match the class names returned * by <code>getClassNames</code>. * @return An array of the probabilities of class membership */ public double[] probabilities(Signal inputSignal) { Instances theData = Signal2Instance.convert(inputSignal, this.classAttribute); double[] probsAccum = new double[this.getNumClasses()]; for (int i = 0; i < theData.numInstances(); i++) { double[] temp; try { temp = theRule.distributionForInstance(theData.instance(i)); // if(this.verbose){ // System.out.print("distribution:" ); // for (int j=0;j<this.getNumClasses();j++) { // System.out.print(" " + temp[j]); // } // System.out.println(""); // } } catch (Exception ex) { System.out.println( "Exception occured when calculating distribution for a vector!\n" + ex.getMessage()); ex.printStackTrace(); return null; } for (int j = 0; j < this.getNumClasses(); j++) { probsAccum[j] += Math.log(temp[j]); } } // if(this.verbose){ // System.out.print("Log Probs:" ); // for (int j=0;j<this.getNumClasses();j++) { // System.out.print(" " + probsAccum[j]); // } // System.out.println(""); // } // normalise to range 0:1 double min = Double.POSITIVE_INFINITY; double max = Double.NEGATIVE_INFINITY; for (int j = 0; j < this.getNumClasses(); j++) { if (probsAccum[j] > max) { max = probsAccum[j]; } if (probsAccum[j] < min) { min = probsAccum[j]; } } for (int j = 0; j < this.getNumClasses(); j++) { probsAccum[j] -= min; probsAccum[j] /= (max - min); } // normalise to sum to 1 double total = 0.0; for (int j = 0; j < this.getNumClasses(); j++) { total += probsAccum[j]; } if (verbose) { DecimalFormat dec = new DecimalFormat(); dec.setMaximumFractionDigits(3); if (this.verbose) { System.out.print("Probabilities:"); } for (int j = 0; j < this.getNumClasses(); j++) { probsAccum[j] /= total; if (verbose) { System.out.print(" " + dec.format(probsAccum[j])); } } if (verbose) { try { System.out.println("\t" + inputSignal.getStringMetadata(Signal.PROP_FILE_LOCATION)); } catch (noMetadataException nme) { nme.printStackTrace(System.out); } } } return probsAccum; }