@Override
 public Classifier getWekaClassifer() {
   try {
     SMO smo = new SMO();
     smo.setC(c);
     smo.setBuildLogisticModels(buildLogisticModels);
     smo.setRidge(ridge);
     Kernel kernel = this.kernel.getClass().newInstance();
     if (kernel instanceof PolyKernel) ((PolyKernel) kernel).setExponent(exp);
     else if (kernel instanceof RBFKernel) ((RBFKernel) kernel).setGamma(gamma);
     smo.setKernel(kernel);
     return smo;
   } catch (Exception e) {
     throw new RuntimeException(e);
   }
 }
Ejemplo n.º 2
0
  private static ClassifierContext createContext(String[] args) throws ParseException {
    Options options =
        new Options()
            .addOption(new Option(help, "help", false, "show help"))
            .addOption(new Option(file, "file", true, "file path containing the data set"))
            .addOption(
                new Option(
                    trainPercentage,
                    "train-percentage",
                    true,
                    "training set in percentage, rest is test set; type double; default value training set: 0.7, test set 0.3"))
            .addOption(
                new Option(
                    crossValidation,
                    "cross-validation",
                    true,
                    "use cross-validation or not;type boolean; default value false"))
            .addOption(
                new Option(
                    kFolds,
                    "kfolds",
                    true,
                    "#folds used in cross-validation; type integer, default 5"))
            .addOption(
                new Option(
                    classifier,
                    "classifier",
                    true,
                    "classifier to use <dt, knn, boost, nn, svm> : Decision Tree(dt), Nearest Neighbours(knn), Boosting(boost), Neural Networks(nn), Support Vector Machine(svm)"))
            .addOption(
                new Option(
                    dtPruning,
                    "pruning",
                    true,
                    "dt: use pruning; type boolean, default value true>"))
            .addOption(
                new Option(
                    dtCf,
                    "confidence_factor",
                    true,
                    "dt: set confidence-factor, used for pruning; type double, default value 0.25"))
            .addOption(
                new Option(
                    boostC,
                    "boost_classifier",
                    true,
                    "boost: set the classifier learner; one of <stump|dt>, default stump"))
            .addOption(
                new Option(
                    boostNrIterations,
                    "boost_nr_iterations",
                    true,
                    "boost: set the nr of bagging iterations; type integer, default value 10>"))
            .addOption(
                new Option(
                    boostDtPruning,
                    "boost_dt_pruning",
                    true,
                    "boost: use pruning for dt; type boolean, default value true>"))
            .addOption(
                new Option(
                    boostDtCf,
                    "boost_dt_cf",
                    true,
                    "boost: use dt confidence-factor, used for pruning; type double, default value 0.25"))
            .addOption(
                new Option(
                    knnK,
                    "knn_k",
                    true,
                    "knn: specify the #nearest neighbours; type integer, default value 1"))
            .addOption(
                new Option(
                    knnWeightDistance,
                    "knn_weight_distance",
                    true,
                    "knn: specify a weight distance; type integer <1=None,2=Inverse,3=Similarity>, default value 1"))
            .addOption(
                new Option(
                    nnLearningRate,
                    "nn_learning_rate",
                    true,
                    "nn: backpropagation learning rate; type double, default value 0.3"))
            .addOption(
                new Option(
                    nnMomentum,
                    "nn_momentum",
                    true,
                    "nn: backpropagation momentum rate; type double, default value 0.2"))
            .addOption(
                new Option(
                    nnHiddentUnits,
                    "nn_hidden_units",
                    true,
                    "nn: comma-separated string for #hidden layers and nodes per layer. e.g. \"a,3,4\"; see weka for more details"))
            .addOption(
                new Option(
                    svmKernelFunction,
                    "svm_kernel_function",
                    true,
                    "svm: one of: <poly,radial>, default value poly"))
            .addOption(
                new Option(
                    svmPolyExp,
                    "svm_poly_exponent",
                    true,
                    "svm: the exponent value of a polynomial kernel; type double, default value 1.0"))
            .addOption(
                new Option(
                    svmRadialGamma,
                    "svm_radial_gamma",
                    true,
                    "svm: the gamma parameter value of the radial kernel; type double, default value 0.01"));

    String f = null;
    Double trainP = 0.7;
    Boolean cv = false;
    Integer kfolds = 5;
    ClassifierTypes classifierType = null;
    Classifier cls = null;

    CommandLine commandLine = new DefaultParser().parse(options, args);
    if (commandLine.hasOption(help)) {
      new HelpFormatter().printHelp("Run Classifiers", options);
      return null;
    }

    if (commandLine.hasOption(file)) {
      f = commandLine.getOptionValue(file);
    } else {
      System.out.println("Please provide data set file. See help for more details");
      return null;
    }

    if (commandLine.hasOption(trainPercentage)) {
      String trainp = commandLine.getOptionValue(trainPercentage);
      trainP = getDouble(trainp);
      if (trainP == null || trainP >= 1 || trainP <= 0) {
        System.out.println(
            "Please provide an double between (0,1) range for training set pct. See help for more details");
        return null;
      }
    }

    if (commandLine.hasOption(crossValidation)) {
      cv = Boolean.valueOf(commandLine.getOptionValue(crossValidation));
      if (cv) {
        if (commandLine.hasOption(kFolds)) {
          kfolds = getInt(commandLine.getOptionValue(kFolds));
          if (kfolds == null) {
            System.out.println(
                "Please provide an integer for kfolds e.g. 10. See help for more details");
            return null;
          }
        }
      }
    }

    if (commandLine.hasOption(classifier)) {
      String c = commandLine.getOptionValue(classifier);
      classifierType = ClassifierTypes.toEnum(c);
      switch (classifierType) {
        case DECISION_TREE:
          J48 dt = new J48();
          if (commandLine.hasOption(dtPruning)) {
            Boolean pruning = Boolean.valueOf(commandLine.getOptionValue(dtPruning));
            dt.setUnpruned(!pruning);
          }
          if (commandLine.hasOption(dtCf)) {
            String cfStr = commandLine.getOptionValue(dtCf);
            Float cf = getFloat(cfStr);
            if (cf == null) {
              System.out.println(
                  "Please provide a floating point number for dt_cf e.g. 0.25. See help for more details");
              return null;
            }
            dt.setConfidenceFactor(cf);
          }
          cls = dt;
          break;
        case BOOSTING:
          AdaBoostM1 boost = new AdaBoostM1();
          if (commandLine.hasOption(boostC)) {
            String boostCls = commandLine.getOptionValue(boostC);
            if ("stump".equalsIgnoreCase(boostCls)) {
              // nothing to, default value
            } else if ("dt".equalsIgnoreCase(boostCls)) {
              J48 boostDt = new J48();
              if (commandLine.hasOption(boostDtPruning)) {
                Boolean boostPruning = Boolean.valueOf(commandLine.getOptionValue(boostDtPruning));
                boostDt.setUnpruned(!boostPruning);
              }
              if (commandLine.hasOption(boostDtCf)) {
                String cfStr = commandLine.getOptionValue(boostDtCf);
                Float cf = getFloat(cfStr);
                if (cf == null) {
                  System.out.println(
                      "Please provide a floating point number for boost_dt_cf e.g. 0.25. See help for more details");
                  return null;
                }
                boostDt.setConfidenceFactor(cf);
              }
              boost.setClassifier(boostDt);
            } else {
              System.out.println(
                  "boost_c can use one of <stump, dt> as values. See help for more details");
              return null;
            }
          }
          if (commandLine.hasOption(boostNrIterations)) {
            Integer nrIt = getInt(commandLine.getOptionValue(boostNrIterations));
            if (nrIt == null) {
              System.out.println(
                  "Please provide an integer for boost_nr_it. See help for more details");
              return null;
            }
            boost.setNumIterations(nrIt);
          }
          cls = boost;
          break;
        case KNN:
          IBk ibk = new IBk();
          if (commandLine.hasOption(knnK)) {
            Integer k = getInt(commandLine.getOptionValue(knnK));
            if (k == null) {
              System.out.println("Please provide an integer for knn_n. For more details see help");
              return null;
            }
            ibk.setKNN(k);
          }
          if (commandLine.hasOption(knnWeightDistance)) {
            Integer knnwd = getInt(commandLine.getOptionValue(knnWeightDistance));
            if (knnwd == null) {
              System.out.println(
                  "Please provide one of 1,2,4 for knn_w_d. For more details see help");
              return null;
            }
            if (1 != knnwd || 2 != knnwd || 4 != knnwd) {
              System.out.println(
                  "Please provide one of 1,2,4 for knn_w_d. See help for more details");
              return null;
            }
            ibk.setDistanceWeighting(new SelectedTag(knnwd, IBk.TAGS_WEIGHTING));
          }
          cls = ibk;
          break;
        case NN:
          MultilayerPerceptron nn = new MultilayerPerceptron();
          if (commandLine.hasOption(nnLearningRate)) {
            Double nnLR = getDouble(commandLine.getOptionValue(nnLearningRate));
            if (nnLR == null) {
              System.out.println(
                  "Please provide a double for NN learning rate. See help for more details");
              return null;
            }
            nn.setLearningRate(nnLR);
          }
          if (commandLine.hasOption(nnMomentum)) {
            Double nnMR = getDouble(commandLine.getOptionValue(nnMomentum));
            if (nnMR == null) {
              System.out.println(
                  "Please provide a double for NN momentum rate. See help for more details");
              return null;
            }
            nn.setMomentum(nnMR);
          }
          if (commandLine.hasOption(nnHiddentUnits)) {
            String nnHU = commandLine.getOptionValue(nnHiddentUnits);
            nn.setHiddenLayers(nnHU);
          }

          cls = nn;
          break;
        case SVM:
          SMO svm = new SMO();
          if (commandLine.hasOption(svmKernelFunction)) {
            String svmkf = commandLine.getOptionValue(svmKernelFunction);
            Kernel kernel = null;
            if ("poly".equalsIgnoreCase(svmkf)) {
              PolyKernel pk = new PolyKernel();
              if (commandLine.hasOption(svmPolyExp)) {
                Double expValue = getDouble(commandLine.getOptionValue(svmPolyExp));
                if (expValue == null) {
                  System.out.println(
                      "Please provide a double value for svm_poly_exp. See help for more details");
                  return null;
                }
                pk.setExponent(expValue);
              }
              kernel = pk;
            } else if ("radial".equalsIgnoreCase(svmkf)) {
              RBFKernel rbfk = new RBFKernel();
              if (commandLine.hasOption(svmRadialGamma)) {
                Double gamma = getDouble(commandLine.getOptionValue(svmRadialGamma));
                if (gamma == null) {
                  System.out.println(
                      "Please provide a double value for svm_radial_gamma. See help for more details");
                  return null;
                }
                rbfk.setGamma(gamma);
              }
              kernel = rbfk;
            } else {
              System.out.println(
                  "Please provide one of <poly, radial> for svm_kernel_fct. See help for more details");
              return null;
            }
            svm.setKernel(kernel);
          } else {
            if (commandLine.hasOption(svmPolyExp)) {
              PolyKernel polyKernel = new PolyKernel();
              Double expValue = getDouble(commandLine.getOptionValue(svmPolyExp));
              if (expValue == null) {
                System.out.println(
                    "Please provide a double value for svm_poly_exp. For more details, see help");
                return null;
              }
              polyKernel.setExponent(expValue);
              svm.setKernel(polyKernel);
            }
          }
          cls = svm;
          break;
      }
    } else {
      System.out.println("Please provide a classifier. See help for more details");
      return null;
    }

    return new ClassifierContext(f, trainP, cls, cv, kfolds);
  }
    @Override
    public Void doInBackground() {
      BufferedReader reader;
      try {
        publish("Reading data...");
        reader = new BufferedReader(new FileReader("cross_validation_data.arff"));
        final Instances trainingdata = new Instances(reader);
        reader.close();
        // setting class attribute
        trainingdata.setClassIndex(13);
        trainingdata.randomize(new Random(1));
        long startTime = System.nanoTime();

        publish("Training Naive Bayes Classifier...");

        NaiveBayes nb = new NaiveBayes();
        startTime = System.nanoTime();
        nb.buildClassifier(trainingdata);
        double runningTimeNB = (System.nanoTime() - startTime) / 1000000;
        runningTimeNB /= 1000;
        // saving the naive bayes model
        weka.core.SerializationHelper.write("naivebayes.model", nb);
        System.out.println("running time" + runningTimeNB);
        publish("Done training NB.\nEvaluating NB using 10-fold cross-validation...");
        evalNB = new Evaluation(trainingdata);
        evalNB.crossValidateModel(nb, trainingdata, 10, new Random(1));
        publish("Done evaluating NB.");

        // System.out.println(evalNB.toSummaryString("\nResults for Naive Bayes\n======\n", false));

        MultilayerPerceptron mlp = new MultilayerPerceptron();
        mlp.setOptions(Utils.splitOptions("-L 0.3 -M 0.2 -N 500 -V 0 -S 0 -E 20 -H a"));
        publish("Training ANN...");
        startTime = System.nanoTime();
        mlp.buildClassifier(trainingdata);
        long runningTimeANN = (System.nanoTime() - startTime) / 1000000;
        runningTimeANN /= 1000;
        // saving the MLP model
        weka.core.SerializationHelper.write("mlp.model", mlp);

        publish("Done training ANN.\nEvaluating ANN using 10-fold cross-validation...");

        evalANN = new Evaluation(trainingdata);
        evalANN.evaluateModel(mlp, trainingdata);
        // evalMLP.crossValidateModel(mlp, trainingdata, 10, new Random(1));

        publish("Done evaluating ANN.");
        publish("Training SVM...");
        SMO svm = new SMO();

        startTime = System.nanoTime();
        svm.buildClassifier(trainingdata);
        long runningTimeSVM = (System.nanoTime() - startTime) / 1000000;
        runningTimeSVM /= 1000;
        weka.core.SerializationHelper.write("svm.model", svm);
        publish("Done training SVM.\nEvaluating SVM using 10-fold cross-validation...");
        evalSVM = new Evaluation(trainingdata);
        evalSVM.evaluateModel(svm, trainingdata);
        publish("Done evaluating SVM.");

        Platform.runLater(
            new Runnable() {
              @Override
              public void run() {
                bc.getData()
                    .get(0)
                    .getData()
                    .get(0)
                    .setYValue(evalANN.correct() / trainingdata.size() * 100);
                bc.getData()
                    .get(0)
                    .getData()
                    .get(1)
                    .setYValue(evalSVM.correct() / trainingdata.size() * 100);
                bc.getData()
                    .get(0)
                    .getData()
                    .get(2)
                    .setYValue(evalNB.correct() / trainingdata.size() * 100);

                for (int i = 0; i < NUM_CLASSES; i++) {
                  lineChart.getData().get(0).getData().get(i).setYValue(evalANN.recall(i) * 100);
                  lineChart.getData().get(1).getData().get(i).setYValue(evalSVM.recall(i) * 100);
                  lineChart.getData().get(2).getData().get(i).setYValue(evalNB.recall(i) * 100);
                }
              }
            });

        panel.fillConfTable(evalSVM.confusionMatrix());

        summaryTable.setValueAt(evalANN.correct() / trainingdata.size() * 100., 0, 1);
        summaryTable.setValueAt(evalSVM.correct() / trainingdata.size() * 100, 0, 2);
        summaryTable.setValueAt(evalNB.correct() / trainingdata.size() * 100, 0, 3);

        summaryTable.setValueAt(runningTimeANN, 1, 1);
        summaryTable.setValueAt(runningTimeSVM, 1, 2);
        summaryTable.setValueAt(runningTimeNB, 1, 3);

      } catch (Exception e1) {
        // TODO Auto-generated catch block
        e1.printStackTrace();
      }
      return null;
    }
Ejemplo n.º 4
0
  /**
   * @param args the command line arguments
   * @throws Exception
   */
  public static void main(String[] args) throws Exception {
    PreProcessor p = new PreProcessor("census-income.data", "census-income-preprocessed.arff");

    p.smote();

    PreProcessor p_test =
        new PreProcessor("census-income.test", "census-income-test-preprocessed.arff");

    p_test.run();

    BufferedReader traindata =
        new BufferedReader(new FileReader("census-income-preprocessed.arff"));
    BufferedReader testdata =
        new BufferedReader(new FileReader("census-income-test-preprocessed.arff"));
    Instances traininstance = new Instances(traindata);
    Instances testinstance = new Instances(testdata);

    traindata.close();
    testdata.close();
    traininstance.setClassIndex(traininstance.numAttributes() - 1);
    testinstance.setClassIndex(testinstance.numAttributes() - 1);
    int numOfAttributes = testinstance.numAttributes();
    int numOfInstances = testinstance.numInstances();

    NaiveBayesClassifier nb = new NaiveBayesClassifier("census-income-preprocessed.arff");
    Classifier cnaive = nb.NBClassify();

    DecisionTree dt = new DecisionTree("census-income-preprocessed.arff");
    Classifier cls = dt.DTClassify();

    AdaBoost ab = new AdaBoost("census-income-preprocessed.arff");
    AdaBoostM1 m1 = ab.AdaBoostDTClassify();

    BaggingMethod b = new BaggingMethod("census-income-preprocessed.arff");
    Bagging bag = b.BaggingDTClassify();

    SVM s = new SVM("census-income-preprocessed.arff");
    SMO svm = s.SMOClassifier();

    knn knnclass = new knn("census-income-preprocessed.arff");
    IBk knnc = knnclass.knnclassifier();

    Logistic log = new Logistic();
    log.buildClassifier(traininstance);

    int match = 0;
    int error = 0;
    int greater = 0;
    int less = 0;

    for (int i = 0; i < numOfInstances; i++) {
      String predicted = "";
      greater = 0;
      less = 0;
      double predictions[] = new double[8];

      double pred = cls.classifyInstance(testinstance.instance(i));
      predictions[0] = pred;

      double abpred = m1.classifyInstance(testinstance.instance(i));
      predictions[1] = abpred;

      double naivepred = cnaive.classifyInstance(testinstance.instance(i));
      predictions[2] = naivepred;

      double bagpred = bag.classifyInstance(testinstance.instance(i));
      predictions[3] = bagpred;

      double smopred = svm.classifyInstance(testinstance.instance(i));
      predictions[4] = smopred;

      double knnpred = knnc.classifyInstance(testinstance.instance(i));
      predictions[5] = knnpred;

      for (int j = 0; j < 6; j++) {
        if ((testinstance.instance(i).classAttribute().value((int) predictions[j]))
                .compareTo(">50K")
            == 0) greater++;
        else less++;
      }

      if (greater > less) predicted = ">50K";
      else predicted = "<=50K";

      if ((testinstance.instance(i).stringValue(numOfAttributes - 1)).compareTo(predicted) == 0)
        match++;
      else error++;
    }

    System.out.println("Correctly classified Instances: " + match);
    System.out.println("Misclassified Instances: " + error);

    double accuracy = (double) match / (double) numOfInstances * 100;
    double error_percent = 100 - accuracy;
    System.out.println("Accuracy: " + accuracy + "%");
    System.out.println("Error: " + error_percent + "%");
  }