@Test
  public void train_setForestSize_randomForestHasCorrectSetting() {
    // Arrange

    // Act
    GeneralTrainer randomForestTrainer =
        createGeneralTrainer(
            "weka.classifiers.trees.RandomForest -I 800"); // -I for number or trees
    RandomForest randomForest = (RandomForest) randomForestTrainer.train(getTestInstances());

    // Assert
    assertEquals(800, randomForest.getNumTrees());
  }
示例#2
0
  public static Double runClassify(String trainFile, String testFile) {
    double predictOrder = 0.0;
    double trueOrder = 0.0;
    try {
      String trainWekaFileName = trainFile;
      String testWekaFileName = testFile;

      Instances train = DataSource.read(trainWekaFileName);
      Instances test = DataSource.read(testWekaFileName);

      train.setClassIndex(0);
      test.setClassIndex(0);

      train.deleteAttributeAt(8);
      test.deleteAttributeAt(8);
      train.deleteAttributeAt(6);
      test.deleteAttributeAt(6);
      train.deleteAttributeAt(5);
      test.deleteAttributeAt(5);
      train.deleteAttributeAt(4);
      test.deleteAttributeAt(4);

      // AdditiveRegression classifier = new AdditiveRegression();

      // NaiveBayes classifier = new NaiveBayes();

      RandomForest classifier = new RandomForest();
      // LibSVM classifier = new LibSVM();

      classifier.buildClassifier(train);
      Evaluation eval = new Evaluation(train);
      eval.evaluateModel(classifier, test);

      System.out.println(eval.toSummaryString("\nResults\n\n", true));
      // System.out.println(eval.toClassDetailsString());
      // System.out.println(eval.toMatrixString());
      int k = 892;
      for (int i = 0; i < test.numInstances(); i++) {
        predictOrder = classifier.classifyInstance(test.instance(i));
        trueOrder = test.instance(i).classValue();
        System.out.println((k++) + "," + (int) predictOrder);
      }

    } catch (Exception e) {
      e.printStackTrace();
    }
    return predictOrder;
  }
  public String classifyInstance(String newInst) {

    File f = null;
    String type = null;
    try {
      f = new File("/data/data/com.example.gpstracker/tmp.arff");
      f.createNewFile();

      FileWriter fw = new FileWriter(f);
      BufferedWriter bw = new BufferedWriter(fw);
      bw.write("@relation gps_tracking");
      bw.newLine();
      bw.newLine();
      bw.write("@attribute Longtitude numeric");
      bw.newLine();
      bw.write("@attribute Latitude numeric");
      bw.newLine();
      bw.write("@attribute CurrentSpeed numeric");
      bw.newLine();
      bw.write("@attribute Timestamp date \"yyyy-MM-dd HH:mm:ss\"");
      bw.newLine();
      bw.write("@attribute MoveType {Walking,Running,Biking,Driving,Metro,Bus,Motionless}");
      bw.newLine();
      bw.write("@attribute IsGpsFixed {yes,no}");
      bw.newLine();
      bw.newLine();
      bw.write("@data");
      bw.newLine();
      bw.write(newInst);
      bw.close();

      // load unlabeled data
      Instances unlabeled =
          new Instances(
              new BufferedReader(new FileReader("/data/data/com.example.gpstracker/tmp.arff")));
      // set class attribute
      unlabeled.setClassIndex(unlabeled.numAttributes() - 2);

      // label instances
      double clsLabel = classifier.classifyInstance(unlabeled.instance(0));
      type = unlabeled.classAttribute().value((int) clsLabel);
      boolean deleted = f.delete();

    } catch (FileNotFoundException e) {
      e.printStackTrace();
    } catch (IOException e) {
      e.printStackTrace();

    } catch (Exception e) {
      e.printStackTrace();
    }
    return type;
  }
  public void trainSystem() {

    BufferedReader reader = null;
    Instances data = null;
    try {
      reader =
          new BufferedReader(
              new InputStreamReader(getResources().getAssets().open("training_data.arff")));
      data = new Instances(reader);
      reader.close();
      data.setClassIndex(data.numAttributes() - 2);
      classifier = new RandomForest();
      classifier.buildClassifier(data);
    } catch (FileNotFoundException e) {
      Toast.makeText(getApplicationContext(), "File not Found.!", Toast.LENGTH_LONG).show();
      Log.e("GPSTracker", "fnfexception", e);
    } catch (IOException e) {
      Toast.makeText(getApplicationContext(), "IO Exception.!", Toast.LENGTH_LONG).show();
      Log.e("GPSTracker", "ioexception", e);
    } catch (Exception e) {
      Log.e("GPSTracker", "exception", e);
      Toast.makeText(getApplicationContext(), "Exception.!", Toast.LENGTH_LONG).show();
    }
  }