public static void main(String[] args) throws Exception { /* * First we load the test data from our ARFF file */ ArffLoader testLoader = new ArffLoader(); testLoader.setSource(new File("data/titanic/test.arff")); testLoader.setRetrieval(Loader.BATCH); Instances testDataSet = testLoader.getDataSet(); /* * Now we tell the data set which attribute we want to classify, in our * case, we want to classify the first column: survived */ Attribute testAttribute = testDataSet.attribute(0); testDataSet.setClass(testAttribute); testDataSet.deleteStringAttributes(); /* * Now we read in the serialized model from disk */ Classifier classifier = (Classifier) SerializationHelper.read("data/titanic/titanic.model"); /* * This part may be a little confusing. We load up the test data again * so we have a prediction data set to populate. As we iterate over the * first data set we also iterate over the second data set. After an * instance is classified, we set the value of the prediction data set * to be the value of the classification */ ArffLoader test1Loader = new ArffLoader(); test1Loader.setSource(new File("data/titanic/test.arff")); Instances test1DataSet = test1Loader.getDataSet(); Attribute test1Attribute = test1DataSet.attribute(0); test1DataSet.setClass(test1Attribute); /* * Now we iterate over the test data and classify each entry and set the * value of the 'survived' column to the result of the classification */ Enumeration testInstances = testDataSet.enumerateInstances(); Enumeration test1Instances = test1DataSet.enumerateInstances(); while (testInstances.hasMoreElements()) { Instance instance = (Instance) testInstances.nextElement(); Instance instance1 = (Instance) test1Instances.nextElement(); double classification = classifier.classifyInstance(instance); instance1.setClassValue(classification); } /* * Now we want to write out our predictions. The resulting file is in a * format suitable to submit to Kaggle. */ CSVSaver predictedCsvSaver = new CSVSaver(); predictedCsvSaver.setFile(new File("data/titanic/predict.csv")); predictedCsvSaver.setInstances(test1DataSet); predictedCsvSaver.writeBatch(); System.out.println("Prediciton saved to predict.csv"); }
private Instances buildInstance(String data) { Instances instance; double[] values; // 3. fill with data // data = "15073302,9,0,1,0,0,5,f"; String[] s_parts = data.split(","); // first instance // 2. create Instances object String id = s_parts[0]; instance = new Instances(id, attributes, 0); values = new double[instance.numAttributes()]; for (int j = 1; j < s_parts.length; j++) { String text = s_parts[j]; if (j == s_parts.length - 1) { values[j - 1] = attributeReturn.indexOf(text); } else { values[j - 1] = Double.valueOf(text); } } // System.out.println(values[6]); // add instance.add(new DenseInstance(1.0, values)); instance.setClass(instance.attribute(s_parts.length - 2)); return instance; }
/** * Trains the classifier on the array of Signal objects. Implementations of this method should * also produce an ordered list of the class names which can be returned with the <code> * getClassNames</code> method. * * @param inputData the Signal array that the model should be trained on. * @throws noMetadataException Thrown if there is no class metadata to train the Gaussian model * with */ public void train(Signal[] inputData) { List classNamesList = new ArrayList(); for (int i = 0; i < inputData.length; i++) { try { String className = inputData[i].getStringMetadata(Signal.PROP_CLASS); if ((className != null) && (!classNamesList.contains(className))) { classNamesList.add(className); } } catch (noMetadataException ex) { throw new RuntimeException("No class metadata found to train model on!", ex); } } Collections.sort(classNamesList); classnames = (String[]) classNamesList.toArray(new String[classNamesList.size()]); FastVector classValues = new FastVector(classnames.length); for (int i = 0; i < classnames.length; i++) { classValues.addElement(classnames[i]); } classAttribute = new Attribute(Signal.PROP_CLASS, classValues); Instances trainingDataSet = new Instances(Signal2Instance.convert(inputData[0], classAttribute)); if (inputData.length > 1) { for (int i = 1; i < inputData.length; i++) { Instances aSignalInstance = Signal2Instance.convert(inputData[i], classAttribute); for (int j = 0; j < aSignalInstance.numInstances(); j++) trainingDataSet.add(aSignalInstance.instance(j)); } } trainingDataSet.setClass(classAttribute); inputData = null; theRule = new MISMO(); // parse options StringTokenizer stOption = new StringTokenizer(this.MISMOOptions, " "); String[] options = new String[stOption.countTokens()]; for (int i = 0; i < options.length; i++) { options[i] = stOption.nextToken(); } try { theRule.setOptions(options); } catch (Exception ex) { throw new RuntimeException("Failed to set MISMO classifier options!", ex); } try { theRule.buildClassifier(trainingDataSet); System.out.println("WEKA: outputting MISMO classifier; " + theRule.globalInfo()); } catch (Exception ex) { throw new RuntimeException("Failed to train classifier!", ex); } }
public Instances initializeInstances() { FastVector wekaAttributes = buildCosineAttributes(); Attribute label = (Attribute) wekaAttributes.lastElement(); Instances data = new Instances("semantic-space", wekaAttributes, 1000); data.setClass(label); return data; }
public Instances getInstances(List<ImageData> data) { CSVLoader loader = new CSVLoader(); Instances instances; try { // Create a temp csv file tempFile = new File("tmp"); PrintWriter pw = null; try { pw = new PrintWriter(tempFile); } catch (FileNotFoundException e) { throw new Error(e); } // Load the data into the csv file for (int i = 0; i < Reader.featureSize; i++) { pw.print(i + ","); } pw.println("class"); for (int i = 0; i < data.size(); i++) { List<Double> features = data.get(i).getFeatures(); for (int j = 0; j < features.size(); j++) { pw.print(features.get(j) + ","); } pw.println(data.get(i).getClassType()); pw.flush(); } // Load the instances from the temp csv file loader.setSource(new File("tmp")); instances = loader.getDataSet(); instances.setClass(instances.attribute("class")); return instances; } catch (IOException e) { throw new Error(e); } finally { if (tempFile != null) { tempFile.delete(); tempFile = null; } } }
private static Instances initializeAttributes() { String nameOfDataset = "Badges"; Instances instances; FastVector attributes = new FastVector(9); for (String featureName : features) { attributes.addElement(new Attribute(featureName, zeroOne)); } Attribute classLabel = new Attribute("Class", labels); // labels is a FastVector of '+' and '-' attributes.addElement(classLabel); instances = new Instances(nameOfDataset, attributes, 0); instances.setClass(classLabel); return instances; }
public static void main(String args[]) throws Exception { ArffLoader trainLoader = new ArffLoader(); trainLoader.setSource(new File("src/train.arff")); trainLoader.setRetrieval(Loader.BATCH); Instances trainDataSet = trainLoader.getDataSet(); weka.core.Attribute trainAttribute = trainDataSet.attribute("class"); trainDataSet.setClass(trainAttribute); // trainDataSet.deleteStringAttributes(); NaiveBayes classifier = new NaiveBayes(); final double startTime = System.currentTimeMillis(); classifier.buildClassifier(trainDataSet); final double endTime = System.currentTimeMillis(); double executionTime = (endTime - startTime) / (1000.0); System.out.println("Total execution time: " + executionTime); SerializationHelper.write("NaiveBayes.model", classifier); System.out.println("Saved trained model to classifier.model"); }
public Instance buildWekaInstance(QAPair pair) { double[] query = projector.transform(pair.getQueryList()); double[] answer = projector.transform(pair.getAnswerList()); double[] cosine = {projector.computeCosignSimilarity(query, answer), 0.0}; FastVector attributes = buildCosineAttributes(); Attribute label = (Attribute) attributes.lastElement(); Instances testInstances = new Instances("test", attributes, 1); testInstances.setClass(label); Instance example = new Instance(1, cosine); testInstances.add(example); example.setDataset(testInstances); if (!pair.getLabel().equals("-1")) { example.setClassValue(pair.getLabel()); } else { example.setClassMissing(); } return example; }
public static void analyze_accuracy_NHBS(int rng_seed) throws Exception { HashMap<String, Object> population_params = load_defaults(null); RawLoader rl = new RawLoader(population_params, true, false, rng_seed); List<DrugUser> learningData = rl.getLearningData(); Instances nhbs_data = new Instances("learning_instances", DrugUser.getAttInfo(), learningData.size()); for (DrugUser du : learningData) { nhbs_data.add(du.getInstance()); } System.out.println(nhbs_data.toSummaryString()); nhbs_data.setClass(DrugUser.getAttribMap().get("hcv_state")); // wishlist: remove infrequent values // weka.filters.unsupervised.instance.RemoveFrequentValues() Filter f1 = new RemoveUseless(); f1.setInputFormat(nhbs_data); nhbs_data = Filter.useFilter(nhbs_data, f1); System.out.println("NHBS IDU 2009 Dataset"); System.out.println("Summary of input:"); // System.out.printlnnhbs_data.toSummaryString()); System.out.println(" Num of classes: " + nhbs_data.numClasses()); System.out.println(" Num of attributes: " + nhbs_data.numAttributes()); for (int idx = 0; idx < nhbs_data.numAttributes(); ++idx) { Attribute attr = nhbs_data.attribute(idx); System.out.println("" + idx + ": " + attr.toString()); System.out.println(" distinct values:" + nhbs_data.numDistinctValues(idx)); // System.out.println("" + attr.enumerateValues()); } ArrayList<String> options = new ArrayList<String>(); options.add("-Q"); options.add("" + rng_seed); // System.exit(0); // nhbs_data.deleteAttributeAt(0); //response ID // nhbs_data.deleteAttributeAt(16); //zip // Classifier classifier = new NNge(); //best nearest-neighbor classifier: 40.00 // ROC=0.60 // Classifier classifier = new MINND(); // Classifier classifier = new CitationKNN(); // Classifier classifier = new LibSVM(); //requires LibSVM classes. only gets 37.7% // Classifier classifier = new SMOreg(); Classifier classifier = new Logistic(); // ROC=0.686 // Classifier classifier = new LinearNNSearch(); // LinearRegression: Cannot handle multi-valued nominal class! // Classifier classifier = new LinearRegression(); // Classifier classifier = new RandomForest(); // String[] options = {"-I", "100", "-K", "4"}; //-I trees, -K features per tree. generally, // might want to optimize (or not // https://cwiki.apache.org/confluence/display/MAHOUT/Random+Forests) // options.add("-I"); options.add("100"); options.add("-K"); options.add("4"); // ROC=0.673 // KStar classifier = new KStar(); // classifier.setGlobalBlend(20); //the amount of not greedy, in percent // ROC=0.633 // Classifier classifier = new AdaBoostM1(); // ROC=0.66 // Classifier classifier = new MultiBoostAB(); // ROC=0.67 // Classifier classifier = new Stacking(); // ROC=0.495 // J48 classifier = new J48(); // new instance of tree //building a C45 tree classifier // ROC=0.585 // String[] options = new String[1]; // options[0] = "-U"; // unpruned tree // classifier.setOptions(options); // set the options classifier.setOptions((String[]) options.toArray(new String[0])); // not needed before CV: http://weka.wikispaces.com/Use+WEKA+in+your+Java+code // classifier.buildClassifier(nhbs_data); // build classifier // evaluation Evaluation eval = new Evaluation(nhbs_data); eval.crossValidateModel(classifier, nhbs_data, 10, new Random(1)); // 10-fold cross validation System.out.println(eval.toSummaryString("\nResults\n\n", false)); System.out.println(eval.toClassDetailsString()); // System.out.println(eval.toCumulativeMarginDistributionString()); }
public static void main(String[] args) throws Exception { // NaiveBayesSimple nb = new NaiveBayesSimple(); // BufferedReader br_train = new BufferedReader(new FileReader("src/train.arff.txt")); // String s = null; // long st_time = System.currentTimeMillis(); // Instances inst_train = new Instances(br_train); // System.out.println(inst_train.numAttributes()); // inst_train.setClassIndex(inst_train.numAttributes()-1); // System.out.println("train time"+(System.currentTimeMillis()-st_time)); // NaiveBayes nb1 = new NaiveBayes(); // nb1.buildClassifier(inst_train); // br_train.close(); long st_time = System.currentTimeMillis(); st_time = System.currentTimeMillis(); Classifier classifier = (Classifier) SerializationHelper.read("NaiveBayes.model"); // BufferedReader br_test = new BufferedReader(new FileReader("src/test.arff.txt")); // Instances inst_test = new Instances(br_test); // inst_test.setClassIndex(inst_test.numAttributes()-1); // System.out.println("test time"+(System.currentTimeMillis()-st_time)); // ArffLoader testLoader = new ArffLoader(); testLoader.setSource(new File("src/test.arff")); testLoader.setRetrieval(Loader.BATCH); Instances testDataSet = testLoader.getDataSet(); Attribute testAttribute = testDataSet.attribute("class"); testDataSet.setClass(testAttribute); int correct = 0; int incorrect = 0; FastVector attInfo = new FastVector(); attInfo.addElement(new Attribute("Id")); attInfo.addElement(new Attribute("Category")); Instances outputInstances = new Instances("predict", attInfo, testDataSet.numInstances()); Enumeration testInstances = testDataSet.enumerateInstances(); int index = 1; while (testInstances.hasMoreElements()) { Instance instance = (Instance) testInstances.nextElement(); double classification = classifier.classifyInstance(instance); Instance predictInstance = new Instance(outputInstances.numAttributes()); predictInstance.setValue(0, index++); predictInstance.setValue(1, (int) classification + 1); outputInstances.add(predictInstance); } System.out.println("Correct Instance: " + correct); System.out.println("IncCorrect Instance: " + incorrect); double accuracy = (double) (correct) / (double) (correct + incorrect); System.out.println("Accuracy: " + accuracy); CSVSaver predictedCsvSaver = new CSVSaver(); predictedCsvSaver.setFile(new File("predict.csv")); predictedCsvSaver.setInstances(outputInstances); predictedCsvSaver.writeBatch(); System.out.println("Prediciton saved to predict.csv"); }
public SensitivityAnalysis(Instances d) { data = d; data.setClass(data.attribute(data.numAttributes() - 1)); ; System.out.println(data.classIndex()); }
public MappingInfo(Instances dataSet, MiningSchema miningSchema, Logger log) throws Exception { m_log = log; // miningSchema.convertStringAttsToNominal(); Instances fieldsI = miningSchema.getMiningSchemaAsInstances(); m_fieldsMap = new int[fieldsI.numAttributes()]; m_nominalValueMaps = new int[fieldsI.numAttributes()][]; for (int i = 0; i < fieldsI.numAttributes(); i++) { String schemaAttName = fieldsI.attribute(i).name(); boolean found = false; for (int j = 0; j < dataSet.numAttributes(); j++) { if (dataSet.attribute(j).name().equals(schemaAttName)) { Attribute miningSchemaAtt = fieldsI.attribute(i); Attribute incomingAtt = dataSet.attribute(j); // check type match if (miningSchemaAtt.type() != incomingAtt.type()) { throw new Exception( "[MappingInfo] type mismatch for field " + schemaAttName + ". Mining schema type " + miningSchemaAtt.toString() + ". Incoming type " + incomingAtt.toString() + "."); } // check nominal values (number, names...) if (miningSchemaAtt.numValues() != incomingAtt.numValues()) { String warningString = "[MappingInfo] WARNING: incoming nominal attribute " + incomingAtt.name() + " does not have the same " + "number of values as the corresponding mining " + "schema attribute."; if (m_log != null) { m_log.logMessage(warningString); } else { System.err.println(warningString); } } if (miningSchemaAtt.isNominal() || miningSchemaAtt.isString()) { int[] valuesMap = new int[incomingAtt.numValues()]; for (int k = 0; k < incomingAtt.numValues(); k++) { String incomingNomVal = incomingAtt.value(k); int indexInSchema = miningSchemaAtt.indexOfValue(incomingNomVal); if (indexInSchema < 0) { String warningString = "[MappingInfo] WARNING: incoming nominal attribute " + incomingAtt.name() + " has value " + incomingNomVal + " that doesn't occur in the mining schema."; if (m_log != null) { m_log.logMessage(warningString); } else { System.err.println(warningString); } valuesMap[k] = UNKNOWN_NOMINAL_VALUE; } else { valuesMap[k] = indexInSchema; } } m_nominalValueMaps[i] = valuesMap; } /*if (miningSchemaAtt.isNominal()) { for (int k = 0; k < miningSchemaAtt.numValues(); k++) { if (!miningSchemaAtt.value(k).equals(incomingAtt.value(k))) { throw new Exception("[PMMLUtils] value " + k + " (" + miningSchemaAtt.value(k) + ") does not match " + "incoming value (" + incomingAtt.value(k) + ") for attribute " + miningSchemaAtt.name() + "."); } } }*/ found = true; m_fieldsMap[i] = j; } } if (!found) { throw new Exception( "[MappingInfo] Unable to find a match for mining schema " + "attribute " + schemaAttName + " in the " + "incoming instances!"); } } // check class attribute (if set) if (fieldsI.classIndex() >= 0) { if (dataSet.classIndex() < 0) { // first see if we can find a matching class String className = fieldsI.classAttribute().name(); Attribute classMatch = dataSet.attribute(className); if (classMatch == null) { throw new Exception( "[MappingInfo] Can't find match for target field " + className + "in incoming instances!"); } dataSet.setClass(classMatch); } else if (!fieldsI.classAttribute().name().equals(dataSet.classAttribute().name())) { throw new Exception( "[MappingInfo] class attribute in mining schema does not match " + "class attribute in incoming instances!"); } } // Set up the textual description of the mapping fieldsMappingString(fieldsI, dataSet); }
public static void main(String[] args) throws Exception { /* * First we load our preditons from the CSV formatted file. */ CSVLoader predictCsvLoader = new CSVLoader(); predictCsvLoader.setSource(new File("predict.csv")); /* * Since we are not using the ARFF format here, we have to give the * loader a little bit of information about the data types. Columns * 3,8,10 need to be of type string and columns 1,4,11 are nominal * types. */ predictCsvLoader.setStringAttributes("3,8,10"); predictCsvLoader.setNominalAttributes("1,4,11"); Instances predictDataSet = predictCsvLoader.getDataSet(); /* * Here we set the attribute we want to test the predicitons with */ Attribute testAttribute = predictDataSet.attribute(0); predictDataSet.setClass(testAttribute); /* * We still have to remove all string attributes before we can test */ predictDataSet.deleteStringAttributes(); /* * Next we load the training data from our ARFF file */ ArffLoader trainLoader = new ArffLoader(); trainLoader.setSource(new File("train.arff")); trainLoader.setRetrieval(Loader.BATCH); Instances trainDataSet = trainLoader.getDataSet(); /* * Now we tell the data set which attribute we want to classify, in our * case, we want to classify the first column: survived */ Attribute trainAttribute = trainDataSet.attribute(0); trainDataSet.setClass(trainAttribute); /* * The RandomForest implementation cannot handle columns of type string, * so we remove them for now. */ trainDataSet.deleteStringAttributes(); /* * Now we read in the serialized model from disk */ Classifier classifier = (Classifier) SerializationHelper.read("titanic.model"); /* * Next we will use an Evaluation class to evaluate the performance of * our Classifier. */ Evaluation evaluation = new Evaluation(trainDataSet); evaluation.evaluateModel(classifier, predictDataSet, new Object[] {}); /* * After we evaluate the Classifier, we write out the summary * information to the screen. */ System.out.println(classifier); System.out.println(evaluation.toSummaryString()); }