Exemplo n.º 1
0
  public static void main(String[] args) throws Exception {

    /*
     * First we load the test data from our ARFF file
     */
    ArffLoader testLoader = new ArffLoader();
    testLoader.setSource(new File("data/titanic/test.arff"));
    testLoader.setRetrieval(Loader.BATCH);
    Instances testDataSet = testLoader.getDataSet();

    /*
     * Now we tell the data set which attribute we want to classify, in our
     * case, we want to classify the first column: survived
     */
    Attribute testAttribute = testDataSet.attribute(0);
    testDataSet.setClass(testAttribute);
    testDataSet.deleteStringAttributes();

    /*
     * Now we read in the serialized model from disk
     */
    Classifier classifier = (Classifier) SerializationHelper.read("data/titanic/titanic.model");

    /*
     * This part may be a little confusing. We load up the test data again
     * so we have a prediction data set to populate. As we iterate over the
     * first data set we also iterate over the second data set. After an
     * instance is classified, we set the value of the prediction data set
     * to be the value of the classification
     */
    ArffLoader test1Loader = new ArffLoader();
    test1Loader.setSource(new File("data/titanic/test.arff"));
    Instances test1DataSet = test1Loader.getDataSet();
    Attribute test1Attribute = test1DataSet.attribute(0);
    test1DataSet.setClass(test1Attribute);

    /*
     * Now we iterate over the test data and classify each entry and set the
     * value of the 'survived' column to the result of the classification
     */
    Enumeration testInstances = testDataSet.enumerateInstances();
    Enumeration test1Instances = test1DataSet.enumerateInstances();
    while (testInstances.hasMoreElements()) {
      Instance instance = (Instance) testInstances.nextElement();
      Instance instance1 = (Instance) test1Instances.nextElement();
      double classification = classifier.classifyInstance(instance);
      instance1.setClassValue(classification);
    }

    /*
     * Now we want to write out our predictions. The resulting file is in a
     * format suitable to submit to Kaggle.
     */
    CSVSaver predictedCsvSaver = new CSVSaver();
    predictedCsvSaver.setFile(new File("data/titanic/predict.csv"));
    predictedCsvSaver.setInstances(test1DataSet);
    predictedCsvSaver.writeBatch();

    System.out.println("Prediciton saved to predict.csv");
  }
Exemplo n.º 2
0
 public static void writeFile(Instances dataSet, String path, String fileName, String extension) {
   /*
   path : "dataSet\\"
   filename : "accelerometer_instances"
   extension : "csv"
   */
   try {
     if (extension.equals("arff")) {
       ArffSaver arffSaver = new ArffSaver();
       arffSaver.setInstances(dataSet);
       arffSaver.setFile(new File(path + fileName + "." + extension));
       arffSaver.writeBatch();
     } else if (extension.equals("csv")) {
       CSVSaver csvSaver = new CSVSaver();
       csvSaver.setInstances(dataSet);
       csvSaver.setFile(new File(path + fileName + "." + extension));
       csvSaver.writeBatch();
     } else {
       System.out.println("arff or csv only!");
       System.exit(-1);
     }
   } catch (IOException e) {
     e.printStackTrace();
   }
 }
Exemplo n.º 3
0
  public static void main(String[] args) throws Exception {
    // NaiveBayesSimple nb = new NaiveBayesSimple();

    //		BufferedReader br_train = new BufferedReader(new FileReader("src/train.arff.txt"));
    //		String s = null;
    //		long st_time = System.currentTimeMillis();
    //		Instances inst_train = new Instances(br_train);
    //		System.out.println(inst_train.numAttributes());
    //		inst_train.setClassIndex(inst_train.numAttributes()-1);
    //		System.out.println("train time"+(System.currentTimeMillis()-st_time));
    // NaiveBayes nb1 = new NaiveBayes();
    // nb1.buildClassifier(inst_train);
    // br_train.close();
    long st_time = System.currentTimeMillis();
    st_time = System.currentTimeMillis();

    Classifier classifier = (Classifier) SerializationHelper.read("NaiveBayes.model");

    //		BufferedReader br_test = new BufferedReader(new FileReader("src/test.arff.txt"));
    //		Instances inst_test = new Instances(br_test);
    //		inst_test.setClassIndex(inst_test.numAttributes()-1);
    //		System.out.println("test time"+(System.currentTimeMillis()-st_time));
    //

    ArffLoader testLoader = new ArffLoader();
    testLoader.setSource(new File("src/test.arff"));
    testLoader.setRetrieval(Loader.BATCH);
    Instances testDataSet = testLoader.getDataSet();

    Attribute testAttribute = testDataSet.attribute("class");
    testDataSet.setClass(testAttribute);

    int correct = 0;
    int incorrect = 0;
    FastVector attInfo = new FastVector();
    attInfo.addElement(new Attribute("Id"));
    attInfo.addElement(new Attribute("Category"));

    Instances outputInstances = new Instances("predict", attInfo, testDataSet.numInstances());

    Enumeration testInstances = testDataSet.enumerateInstances();
    int index = 1;
    while (testInstances.hasMoreElements()) {
      Instance instance = (Instance) testInstances.nextElement();
      double classification = classifier.classifyInstance(instance);
      Instance predictInstance = new Instance(outputInstances.numAttributes());
      predictInstance.setValue(0, index++);
      predictInstance.setValue(1, (int) classification + 1);
      outputInstances.add(predictInstance);
    }

    System.out.println("Correct Instance: " + correct);
    System.out.println("IncCorrect Instance: " + incorrect);
    double accuracy = (double) (correct) / (double) (correct + incorrect);
    System.out.println("Accuracy: " + accuracy);
    CSVSaver predictedCsvSaver = new CSVSaver();
    predictedCsvSaver.setFile(new File("predict.csv"));
    predictedCsvSaver.setInstances(outputInstances);
    predictedCsvSaver.writeBatch();

    System.out.println("Prediciton saved to predict.csv");
  }