예제 #1
0
  public static void main(String[] args) throws Exception {

    /*
     * First we load the test data from our ARFF file
     */
    ArffLoader testLoader = new ArffLoader();
    testLoader.setSource(new File("data/titanic/test.arff"));
    testLoader.setRetrieval(Loader.BATCH);
    Instances testDataSet = testLoader.getDataSet();

    /*
     * Now we tell the data set which attribute we want to classify, in our
     * case, we want to classify the first column: survived
     */
    Attribute testAttribute = testDataSet.attribute(0);
    testDataSet.setClass(testAttribute);
    testDataSet.deleteStringAttributes();

    /*
     * Now we read in the serialized model from disk
     */
    Classifier classifier = (Classifier) SerializationHelper.read("data/titanic/titanic.model");

    /*
     * This part may be a little confusing. We load up the test data again
     * so we have a prediction data set to populate. As we iterate over the
     * first data set we also iterate over the second data set. After an
     * instance is classified, we set the value of the prediction data set
     * to be the value of the classification
     */
    ArffLoader test1Loader = new ArffLoader();
    test1Loader.setSource(new File("data/titanic/test.arff"));
    Instances test1DataSet = test1Loader.getDataSet();
    Attribute test1Attribute = test1DataSet.attribute(0);
    test1DataSet.setClass(test1Attribute);

    /*
     * Now we iterate over the test data and classify each entry and set the
     * value of the 'survived' column to the result of the classification
     */
    Enumeration testInstances = testDataSet.enumerateInstances();
    Enumeration test1Instances = test1DataSet.enumerateInstances();
    while (testInstances.hasMoreElements()) {
      Instance instance = (Instance) testInstances.nextElement();
      Instance instance1 = (Instance) test1Instances.nextElement();
      double classification = classifier.classifyInstance(instance);
      instance1.setClassValue(classification);
    }

    /*
     * Now we want to write out our predictions. The resulting file is in a
     * format suitable to submit to Kaggle.
     */
    CSVSaver predictedCsvSaver = new CSVSaver();
    predictedCsvSaver.setFile(new File("data/titanic/predict.csv"));
    predictedCsvSaver.setInstances(test1DataSet);
    predictedCsvSaver.writeBatch();

    System.out.println("Prediciton saved to predict.csv");
  }
예제 #2
0
  public static void main(String args[]) throws Exception {
    ArffLoader trainLoader = new ArffLoader();
    trainLoader.setSource(new File("src/train.arff"));
    trainLoader.setRetrieval(Loader.BATCH);
    Instances trainDataSet = trainLoader.getDataSet();
    weka.core.Attribute trainAttribute = trainDataSet.attribute("class");

    trainDataSet.setClass(trainAttribute);
    // trainDataSet.deleteStringAttributes();

    NaiveBayes classifier = new NaiveBayes();

    final double startTime = System.currentTimeMillis();
    classifier.buildClassifier(trainDataSet);
    final double endTime = System.currentTimeMillis();
    double executionTime = (endTime - startTime) / (1000.0);
    System.out.println("Total execution time: " + executionTime);

    SerializationHelper.write("NaiveBayes.model", classifier);
    System.out.println("Saved trained model to classifier.model");
  }
예제 #3
0
  public static void main(String[] args) throws Exception {
    // NaiveBayesSimple nb = new NaiveBayesSimple();

    //		BufferedReader br_train = new BufferedReader(new FileReader("src/train.arff.txt"));
    //		String s = null;
    //		long st_time = System.currentTimeMillis();
    //		Instances inst_train = new Instances(br_train);
    //		System.out.println(inst_train.numAttributes());
    //		inst_train.setClassIndex(inst_train.numAttributes()-1);
    //		System.out.println("train time"+(System.currentTimeMillis()-st_time));
    // NaiveBayes nb1 = new NaiveBayes();
    // nb1.buildClassifier(inst_train);
    // br_train.close();
    long st_time = System.currentTimeMillis();
    st_time = System.currentTimeMillis();

    Classifier classifier = (Classifier) SerializationHelper.read("NaiveBayes.model");

    //		BufferedReader br_test = new BufferedReader(new FileReader("src/test.arff.txt"));
    //		Instances inst_test = new Instances(br_test);
    //		inst_test.setClassIndex(inst_test.numAttributes()-1);
    //		System.out.println("test time"+(System.currentTimeMillis()-st_time));
    //

    ArffLoader testLoader = new ArffLoader();
    testLoader.setSource(new File("src/test.arff"));
    testLoader.setRetrieval(Loader.BATCH);
    Instances testDataSet = testLoader.getDataSet();

    Attribute testAttribute = testDataSet.attribute("class");
    testDataSet.setClass(testAttribute);

    int correct = 0;
    int incorrect = 0;
    FastVector attInfo = new FastVector();
    attInfo.addElement(new Attribute("Id"));
    attInfo.addElement(new Attribute("Category"));

    Instances outputInstances = new Instances("predict", attInfo, testDataSet.numInstances());

    Enumeration testInstances = testDataSet.enumerateInstances();
    int index = 1;
    while (testInstances.hasMoreElements()) {
      Instance instance = (Instance) testInstances.nextElement();
      double classification = classifier.classifyInstance(instance);
      Instance predictInstance = new Instance(outputInstances.numAttributes());
      predictInstance.setValue(0, index++);
      predictInstance.setValue(1, (int) classification + 1);
      outputInstances.add(predictInstance);
    }

    System.out.println("Correct Instance: " + correct);
    System.out.println("IncCorrect Instance: " + incorrect);
    double accuracy = (double) (correct) / (double) (correct + incorrect);
    System.out.println("Accuracy: " + accuracy);
    CSVSaver predictedCsvSaver = new CSVSaver();
    predictedCsvSaver.setFile(new File("predict.csv"));
    predictedCsvSaver.setInstances(outputInstances);
    predictedCsvSaver.writeBatch();

    System.out.println("Prediciton saved to predict.csv");
  }
예제 #4
0
  public static void main(String[] args) throws Exception {

    /*
     * First we load our preditons from the CSV formatted file.
     */
    CSVLoader predictCsvLoader = new CSVLoader();
    predictCsvLoader.setSource(new File("predict.csv"));

    /*
     * Since we are not using the ARFF format here, we have to give the
     * loader a little bit of information about the data types. Columns
     * 3,8,10 need to be of type string and columns 1,4,11 are nominal
     * types.
     */
    predictCsvLoader.setStringAttributes("3,8,10");
    predictCsvLoader.setNominalAttributes("1,4,11");
    Instances predictDataSet = predictCsvLoader.getDataSet();

    /*
     * Here we set the attribute we want to test the predicitons with
     */
    Attribute testAttribute = predictDataSet.attribute(0);
    predictDataSet.setClass(testAttribute);

    /*
     * We still have to remove all string attributes before we can test
     */
    predictDataSet.deleteStringAttributes();

    /*
     * Next we load the training data from our ARFF file
     */
    ArffLoader trainLoader = new ArffLoader();
    trainLoader.setSource(new File("train.arff"));
    trainLoader.setRetrieval(Loader.BATCH);
    Instances trainDataSet = trainLoader.getDataSet();

    /*
     * Now we tell the data set which attribute we want to classify, in our
     * case, we want to classify the first column: survived
     */
    Attribute trainAttribute = trainDataSet.attribute(0);
    trainDataSet.setClass(trainAttribute);

    /*
     * The RandomForest implementation cannot handle columns of type string,
     * so we remove them for now.
     */
    trainDataSet.deleteStringAttributes();

    /*
     * Now we read in the serialized model from disk
     */
    Classifier classifier = (Classifier) SerializationHelper.read("titanic.model");

    /*
     * Next we will use an Evaluation class to evaluate the performance of
     * our Classifier.
     */
    Evaluation evaluation = new Evaluation(trainDataSet);
    evaluation.evaluateModel(classifier, predictDataSet, new Object[] {});

    /*
     * After we evaluate the Classifier, we write out the summary
     * information to the screen.
     */
    System.out.println(classifier);
    System.out.println(evaluation.toSummaryString());
  }