コード例 #1
0
  public static void run(String[] args) throws Exception {
    /**
     * *************************************************
     *
     * @param args[0]: train arff path
     * @param args[1]: test arff path
     */
    DataSource source = new DataSource(args[0]);
    Instances data = source.getDataSet();
    data.setClassIndex(data.numAttributes() - 1);
    NaiveBayes model = new NaiveBayes();
    model.buildClassifier(data);

    // Evaluation:
    Evaluation eval = new Evaluation(data);
    Instances testData = new DataSource(args[1]).getDataSet();
    testData.setClassIndex(testData.numAttributes() - 1);
    eval.evaluateModel(model, testData);
    System.out.println(model.toString());
    System.out.println(eval.toSummaryString("\nResults\n======\n", false));
    System.out.println("======\nConfusion Matrix:");
    double[][] confusionM = eval.confusionMatrix();
    for (int i = 0; i < confusionM.length; ++i) {
      for (int j = 0; j < confusionM[i].length; ++j) {
        System.out.format("%10s ", confusionM[i][j]);
      }
      System.out.print("\n");
    }
  }
コード例 #2
0
ファイル: train.java プロジェクト: kavyasrinet/CIFAR-10
  public static void main(String args[]) throws Exception {
    ArffLoader trainLoader = new ArffLoader();
    trainLoader.setSource(new File("src/train.arff"));
    trainLoader.setRetrieval(Loader.BATCH);
    Instances trainDataSet = trainLoader.getDataSet();
    weka.core.Attribute trainAttribute = trainDataSet.attribute("class");

    trainDataSet.setClass(trainAttribute);
    // trainDataSet.deleteStringAttributes();

    NaiveBayes classifier = new NaiveBayes();

    final double startTime = System.currentTimeMillis();
    classifier.buildClassifier(trainDataSet);
    final double endTime = System.currentTimeMillis();
    double executionTime = (endTime - startTime) / (1000.0);
    System.out.println("Total execution time: " + executionTime);

    SerializationHelper.write("NaiveBayes.model", classifier);
    System.out.println("Saved trained model to classifier.model");
  }
コード例 #3
0
  @Override
  protected Instances process(Instances instances) throws Exception {
    if (m_estimator == null) {
      m_estimator = new NaiveBayes();

      Instances trainingData = new Instances(instances);
      if (m_remove != null) {
        trainingData = Filter.useFilter(instances, m_remove);
      }
      m_estimator.buildClassifier(trainingData);
    }

    if (m_estimatorLookup == null) {
      m_estimatorLookup = new HashMap<String, Estimator[]>();
      Estimator[][] estimators = m_estimator.getConditionalEstimators();
      Instances header = m_estimator.getHeader();
      int index = 0;
      for (int i = 0; i < header.numAttributes(); i++) {
        if (i != header.classIndex()) {
          m_estimatorLookup.put(header.attribute(i).name(), estimators[index]);
          index++;
        }
      }
    }

    Instances result = new Instances(getOutputFormat(), instances.numInstances());
    for (int i = 0; i < instances.numInstances(); i++) {

      Instance current = instances.instance(i);
      Instance instNew = convertInstance(current);

      // add instance to output
      result.add(instNew);
    }

    return result;
  }
コード例 #4
0
    @Override
    public Void doInBackground() {
      BufferedReader reader;
      try {
        publish("Reading data...");
        reader = new BufferedReader(new FileReader("cross_validation_data.arff"));
        final Instances trainingdata = new Instances(reader);
        reader.close();
        // setting class attribute
        trainingdata.setClassIndex(13);
        trainingdata.randomize(new Random(1));
        long startTime = System.nanoTime();

        publish("Training Naive Bayes Classifier...");

        NaiveBayes nb = new NaiveBayes();
        startTime = System.nanoTime();
        nb.buildClassifier(trainingdata);
        double runningTimeNB = (System.nanoTime() - startTime) / 1000000;
        runningTimeNB /= 1000;
        // saving the naive bayes model
        weka.core.SerializationHelper.write("naivebayes.model", nb);
        System.out.println("running time" + runningTimeNB);
        publish("Done training NB.\nEvaluating NB using 10-fold cross-validation...");
        evalNB = new Evaluation(trainingdata);
        evalNB.crossValidateModel(nb, trainingdata, 10, new Random(1));
        publish("Done evaluating NB.");

        // System.out.println(evalNB.toSummaryString("\nResults for Naive Bayes\n======\n", false));

        MultilayerPerceptron mlp = new MultilayerPerceptron();
        mlp.setOptions(Utils.splitOptions("-L 0.3 -M 0.2 -N 500 -V 0 -S 0 -E 20 -H a"));
        publish("Training ANN...");
        startTime = System.nanoTime();
        mlp.buildClassifier(trainingdata);
        long runningTimeANN = (System.nanoTime() - startTime) / 1000000;
        runningTimeANN /= 1000;
        // saving the MLP model
        weka.core.SerializationHelper.write("mlp.model", mlp);

        publish("Done training ANN.\nEvaluating ANN using 10-fold cross-validation...");

        evalANN = new Evaluation(trainingdata);
        evalANN.evaluateModel(mlp, trainingdata);
        // evalMLP.crossValidateModel(mlp, trainingdata, 10, new Random(1));

        publish("Done evaluating ANN.");
        publish("Training SVM...");
        SMO svm = new SMO();

        startTime = System.nanoTime();
        svm.buildClassifier(trainingdata);
        long runningTimeSVM = (System.nanoTime() - startTime) / 1000000;
        runningTimeSVM /= 1000;
        weka.core.SerializationHelper.write("svm.model", svm);
        publish("Done training SVM.\nEvaluating SVM using 10-fold cross-validation...");
        evalSVM = new Evaluation(trainingdata);
        evalSVM.evaluateModel(svm, trainingdata);
        publish("Done evaluating SVM.");

        Platform.runLater(
            new Runnable() {
              @Override
              public void run() {
                bc.getData()
                    .get(0)
                    .getData()
                    .get(0)
                    .setYValue(evalANN.correct() / trainingdata.size() * 100);
                bc.getData()
                    .get(0)
                    .getData()
                    .get(1)
                    .setYValue(evalSVM.correct() / trainingdata.size() * 100);
                bc.getData()
                    .get(0)
                    .getData()
                    .get(2)
                    .setYValue(evalNB.correct() / trainingdata.size() * 100);

                for (int i = 0; i < NUM_CLASSES; i++) {
                  lineChart.getData().get(0).getData().get(i).setYValue(evalANN.recall(i) * 100);
                  lineChart.getData().get(1).getData().get(i).setYValue(evalSVM.recall(i) * 100);
                  lineChart.getData().get(2).getData().get(i).setYValue(evalNB.recall(i) * 100);
                }
              }
            });

        panel.fillConfTable(evalSVM.confusionMatrix());

        summaryTable.setValueAt(evalANN.correct() / trainingdata.size() * 100., 0, 1);
        summaryTable.setValueAt(evalSVM.correct() / trainingdata.size() * 100, 0, 2);
        summaryTable.setValueAt(evalNB.correct() / trainingdata.size() * 100, 0, 3);

        summaryTable.setValueAt(runningTimeANN, 1, 1);
        summaryTable.setValueAt(runningTimeSVM, 1, 2);
        summaryTable.setValueAt(runningTimeNB, 1, 3);

      } catch (Exception e1) {
        // TODO Auto-generated catch block
        e1.printStackTrace();
      }
      return null;
    }
コード例 #5
0
  /**
   * Calculates the class membership probabilities for the given test instance.
   *
   * @param instance the instance to be classified
   * @return predicted class probability distribution
   * @throws Exception if an error occurred during the prediction
   */
  public double[] distributionForInstance(Instance instance) throws Exception {

    NaiveBayes nb = new NaiveBayes();

    // System.out.println("number of instances		"+m_Train.numInstances());

    if (m_Train.numInstances() == 0) {
      // throw new Exception("No training instances!");
      return m_defaultModel.distributionForInstance(instance);
    }
    if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {
      m_kNNValid = false;
      boolean deletedInstance = false;
      while (m_Train.numInstances() > m_WindowSize) {
        m_Train.delete(0);
      }
      // rebuild datastructure KDTree currently can't delete
      if (deletedInstance == true) m_NNSearch.setInstances(m_Train);
    }

    // Select k by cross validation
    if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) {
      crossValidate();
    }

    m_NNSearch.addInstanceInfo(instance);
    m_kNN = 1000;
    Instances neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN);
    double[] distances = m_NNSearch.getDistances();

    // System.out.println("--------------classify instance--------- ");
    // System.out.println("neighbours.numInstances"+neighbours.numInstances());
    // System.out.println("distances.length"+distances.length);
    // System.out.println("--------------classify instance--------- ");

    /*	for (int k = 0; k < distances.length; k++) {
    	//System.out.println("-------");
    	//System.out.println("distance of "+k+"	"+distances[k]);
    	//System.out.println("instance of "+k+"	"+neighbours.instance(k));
    	//distances[k] = distances[k]+0.1;
    	//System.out.println("------- after add 0.1");
    	//System.out.println("distance of "+k+"	"+distances[k]);
    }
    */
    Instances instances = new Instances(m_Train);
    // int attrnum = instances.numAttributes();
    instances.deleteWithMissingClass();

    Instances newm_Train = new Instances(instances, 0, instances.numInstances());

    for (int k = 0; k < neighbours.numInstances(); k++) {
      // System.out.println("-------");
      // Instance in = new Instance();
      Instance insk = neighbours.instance(k);
      // System.out.println("instance "+k+"	"+neighbours.instance(k));
      // System.out.println("-------");
      double dis = distances[k] + 0.1;
      // System.out.println("dis		"+dis);
      dis = (1 / dis) * 10;
      // System.out.println("1/dis		"+dis);
      int weightnum = (int) dis;
      // System.out.println("weightnum		"+weightnum);

      for (int s = 0; s < weightnum; s++) {

        newm_Train.add(insk);
      }
    }

    // System.out.println("number of instances		"+newm_Train.numInstances());

    /*  for (int k = 0; k < newm_Train.numInstances(); k++) {
    		System.out.println("-------");
    		System.out.println("instance "+k+"	"+newm_Train.instance(k));
    		System.out.println("-------");
    	}

    /*
    	for (int k = 0; k < distances.length; k++) {
    		System.out.println("-------");
    		System.out.println("distance of "+k+"	"+distances[k]);
    		System.out.println("-------");
    	}*/

    nb.buildClassifier(newm_Train);
    double[] dis = nb.distributionForInstance(instance);
    // double[] distribution = makeDistribution(neighbours, distances);
    return dis;
    // return distribution;
  }