/**
   * Parses a given list of options.
   *
   * <p>
   * <!-- options-start -->
   * Valid options are:
   *
   * <p>
   *
   * <pre> -folds &lt;folds&gt;
   *  The number of folds for splitting the training set into
   *  train and test set. The first fold is always the training
   *  set. With '-V' you can invert this, i.e., instead of 20/80
   *  for 5 folds you'll get 80/20.
   *  (default 5)</pre>
   *
   * <pre> -V
   *  Inverts the fold selection, i.e., instead of using the first
   *  fold for the training set it is used for test set and the
   *  remaining folds for training.</pre>
   *
   * <pre> -verbose
   *  Whether to print some more information during building the
   *  classifier.
   *  (default is off)</pre>
   *
   * <pre> -insight
   *  Whether to use the labels of the original test set for more
   *  statistics (not used for learning!).
   *  (default is off)</pre>
   *
   * <pre> -S &lt;num&gt;
   *  Random number seed.
   *  (default 1)</pre>
   *
   * <pre> -D
   *  If set, classifier is run in debug mode and
   *  may output additional info to the console</pre>
   *
   * <pre> -naive
   *  Uses a sorted list (ordered according to distance) instead of the
   *  KDTree for finding the neighbors.
   *  (default is KDTree)</pre>
   *
   * <pre> -I
   *  Weight neighbours by the inverse of their distance
   *  (use when k &gt; 1)</pre>
   *
   * <pre> -F
   *  Weight neighbours by 1 - their distance
   *  (use when k &gt; 1)</pre>
   *
   * <pre> -K &lt;number of neighbors&gt;
   *  Number of nearest neighbours (k) used in classification.
   *  (Default = 1)</pre>
   *
   * <pre> -A
   *  The nearest neighbour search algorithm to use (default: LinearNN).
   * </pre>
   *
   * <!-- options-end -->
   *
   * @param options the list of options as an array of strings
   * @throws Exception if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {
    super.setOptions(options);

    setUseNaiveSearch(Utils.getFlag("naive", options));

    m_Classifier.setOptions(options);
    m_KNN = m_Classifier.getKNN(); // backup KNN
    m_Classifier.setCrossValidate(true); // always on!
    m_Classifier.setWindowSize(0); // always off!
    m_Classifier.setMeanSquared(false); // always off!
  }
  /**
   * Returns an enumeration describing the available options.
   *
   * @return an enumeration of all the available options.
   */
  public Enumeration listOptions() {
    Vector result;
    Enumeration en;

    result = new Vector();

    // ancestor
    en = super.listOptions();
    while (en.hasMoreElements()) result.addElement(en.nextElement());

    result.addElement(
        new Option(
            "\tUses a sorted list (ordered according to distance) instead of the\n"
                + "\tKDTree for finding the neighbors.\n"
                + "\t(default is KDTree)",
            "naive",
            0,
            "-naive"));

    // IBk
    en = m_Classifier.listOptions();
    while (en.hasMoreElements()) {
      Option o = (Option) en.nextElement();
      // remove -X, -W and -E
      if (!o.name().equals("X") && !o.name().equals("W") && !o.name().equals("E"))
        result.addElement(o);
    }

    return result.elements();
  }
  /**
   * determines the "K" for the neighbors from the training set, initializes the labels of the test
   * set to "missing" and generates the neighbors for all instances in the test set
   *
   * @throws Exception if initialization fails
   */
  protected void initialize() throws Exception {
    int i;
    double timeStart;
    double timeEnd;
    Instances trainingNew;
    Instances testNew;

    // determine K
    if (getVerbose()) System.out.println("\nOriginal KNN = " + m_KNN);
    ((IBk) m_Classifier).setKNN(m_KNN);
    ((IBk) m_Classifier).setCrossValidate(true);
    m_Classifier.buildClassifier(m_TrainsetNew);
    m_Classifier.toString(); // necessary to crossvalidate IBk!
    ((IBk) m_Classifier).setCrossValidate(false);
    m_KNNdetermined = ((IBk) m_Classifier).getKNN();
    if (getVerbose()) System.out.println("Determined KNN = " + m_KNNdetermined);

    // set class labels in test set to "missing"
    for (i = 0; i < m_TestsetNew.numInstances(); i++) m_TestsetNew.instance(i).setClassMissing();

    // copy data
    trainingNew = new Instances(m_TrainsetNew);
    testNew = new Instances(m_TestsetNew);

    // filter data
    m_Missing.setInputFormat(trainingNew);
    trainingNew = Filter.useFilter(trainingNew, m_Missing);
    testNew = Filter.useFilter(testNew, m_Missing);

    // create the list of neighbors for the instances in the test set
    m_NeighborsTestset = new Neighbors[m_TestsetNew.numInstances()];
    timeStart = System.currentTimeMillis();
    for (i = 0; i < testNew.numInstances(); i++) {
      m_NeighborsTestset[i] =
          new Neighbors(
              testNew.instance(i), m_TestsetNew.instance(i), m_KNNdetermined, trainingNew, testNew);
      m_NeighborsTestset[i].setVerbose(getVerbose() || getDebug());
      m_NeighborsTestset[i].setUseNaiveSearch(getUseNaiveSearch());
      m_NeighborsTestset[i].find();
    }
    timeEnd = System.currentTimeMillis();

    if (getVerbose())
      System.out.println(
          "Time for finding neighbors: " + Utils.doubleToString((timeEnd - timeStart) / 1000.0, 3));
  }
  /**
   * returns the best model as string representation. derived classes have to add additional
   * information here, like printing the classifier etc.
   *
   * @return the string representation of the best model
   */
  protected String toStringModel() {
    StringBuffer text;

    text = new StringBuffer();
    text.append(super.toStringModel());
    text.append("\n");
    text.append(m_Classifier.toString());

    return text.toString();
  }
  /** performs initialization of members */
  protected void initializeMembers() {
    super.initializeMembers();

    m_KNNdetermined = -1;
    m_NeighborsTestset = null;
    m_TrainsetNew = null;
    m_TestsetNew = null;
    m_UseNaiveSearch = false;
    m_LabeledTestset = null;
    m_Missing = new ReplaceMissingValues();

    m_Classifier = new IBk();
    m_Classifier.setKNN(10);
    m_Classifier.setCrossValidate(true);
    m_Classifier.setWindowSize(0);
    m_Classifier.setMeanSquared(false);

    m_KNN = m_Classifier.getKNN();

    m_AdditionalMeasures.add("measureDeterminedKNN");
  }
示例#6
0
  /**
   * Calculates the class membership probabilities for the given test instance.
   *
   * @param instance the instance to be classified
   * @return predicted class probability distribution
   * @throws Exception if distribution can't be computed
   */
  public double[] distributionForInstance(Instance instance) throws Exception {

    DecisionTableHashKey thekey;
    double[] tempDist;
    double[] normDist;

    m_disTransform.input(instance);
    m_disTransform.batchFinished();
    instance = m_disTransform.output();

    m_delTransform.input(instance);
    m_delTransform.batchFinished();
    instance = m_delTransform.output();

    thekey = new DecisionTableHashKey(instance, instance.numAttributes(), false);

    // if this one is not in the table
    if ((tempDist = (double[]) m_entries.get(thekey)) == null) {
      if (m_useIBk) {
        tempDist = m_ibk.distributionForInstance(instance);
      } else {
        if (!m_classIsNominal) {
          tempDist = new double[1];
          tempDist[0] = m_majority;
        } else {
          tempDist = m_classPriors.clone();
          /*tempDist = new double [m_theInstances.classAttribute().numValues()];
          tempDist[(int)m_majority] = 1.0; */
        }
      }
    } else {
      if (!m_classIsNominal) {
        normDist = new double[1];
        normDist[0] = (tempDist[0] / tempDist[1]);
        tempDist = normDist;
      } else {

        // normalise distribution
        normDist = new double[tempDist.length];
        System.arraycopy(tempDist, 0, normDist, 0, tempDist.length);
        Utils.normalize(normDist);
        tempDist = normDist;
      }
    }
    return tempDist;
  }
  /**
   * Gets the current settings of the classifier.
   *
   * @return an array of strings suitable for passing to setOptions
   */
  public String[] getOptions() {
    Vector result;
    String[] options;
    int i;

    result = new Vector();

    options = super.getOptions();
    for (i = 0; i < options.length; i++) result.add(options[i]);

    options = m_Classifier.getOptions();
    for (i = 0; i < options.length; i++) result.add(options[i]);

    if (getUseNaiveSearch()) result.add("-naive");

    return (String[]) result.toArray(new String[result.size()]);
  }
 /**
  * Sets the nearestNeighbourSearch algorithm to be used for finding nearest neighbour(s).
  *
  * @param value The NearestNeighbourSearch class.
  */
 public void setNearestNeighbourSearchAlgorithm(NearestNeighbourSearch value) {
   m_Classifier.setNearestNeighbourSearchAlgorithm(value);
 }
 /**
  * Returns the current nearestNeighbourSearch algorithm in use.
  *
  * @return the NearestNeighbourSearch algorithm currently in use.
  */
 public NearestNeighbourSearch getNearestNeighbourSearchAlgorithm() {
   return m_Classifier.getNearestNeighbourSearchAlgorithm();
 }
 /**
  * Returns the tip text for this property
  *
  * @return tip text for this property suitable for displaying in the explorer/experimenter gui
  */
 public String nearestNeighbourSearchAlgorithmTipText() {
   return m_Classifier.nearestNeighbourSearchAlgorithmTipText();
 }
 /**
  * Returns the tip text for this property
  *
  * @return tip text for this property suitable for displaying in the explorer/experimenter gui
  */
 public String distanceWeightingTipText() {
   return m_Classifier.distanceWeightingTipText();
 }
 /**
  * Gets the number of neighbours the learner will use.
  *
  * @return the number of neighbours.
  */
 public int getKNN() {
   return m_Classifier.getKNN();
 }
 /**
  * Set the number of neighbours the learner is to use.
  *
  * @param k the number of neighbours.
  */
 public void setKNN(int k) {
   m_Classifier.setKNN(k);
 }
 /**
  * Returns the tip text for this property
  *
  * @return tip text for this property suitable for displaying in the explorer/experimenter gui
  */
 public String KNNTipText() {
   return m_Classifier.KNNTipText();
 }
  // 输入问题,输出问题所属类型。
  public double classifyByKnn(String question) throws Exception {
    double label = -1;
    List<Question> questionID = questionDAO.getQuestionIDLabeled();

    // 定义数据格式
    Attribute att1 = new Attribute("法律政策");
    Attribute att2 = new Attribute("位置交通");
    Attribute att3 = new Attribute("风水");
    Attribute att4 = new Attribute("房价");
    Attribute att5 = new Attribute("楼层");
    Attribute att6 = new Attribute("户型");
    Attribute att7 = new Attribute("小区配套");
    Attribute att8 = new Attribute("贷款");
    Attribute att9 = new Attribute("买房时机");
    Attribute att10 = new Attribute("开发商");
    FastVector labels = new FastVector();
    labels.addElement("1");
    labels.addElement("2");
    labels.addElement("3");
    labels.addElement("4");
    labels.addElement("5");
    labels.addElement("6");
    labels.addElement("7");
    labels.addElement("8");
    labels.addElement("9");
    labels.addElement("10");
    Attribute att11 = new Attribute("类别", labels);

    FastVector attributes = new FastVector();
    attributes.addElement(att1);
    attributes.addElement(att2);
    attributes.addElement(att3);
    attributes.addElement(att4);
    attributes.addElement(att5);
    attributes.addElement(att6);
    attributes.addElement(att7);
    attributes.addElement(att8);
    attributes.addElement(att9);
    attributes.addElement(att10);
    attributes.addElement(att11);
    Instances dataset = new Instances("Test-dataset", attributes, 0);
    dataset.setClassIndex(10);

    Classifier classifier = null;
    if (!new File("knn.model").exists()) {
      // 添加数据

      for (int i = 0; i < questionID.size(); i++) {

        double[] values = new double[11];
        for (int m = 0; m < 11; m++) {
          values[m] = 0;
        }
        int whitewordcount = 0;
        whitewordcount = questionDAO.getHitWhiteWordNum(questionID.get(i).getId());
        if (whitewordcount != 0) {
          List<QuestionWhiteWord> questionwhiteword =
              questionDAO.getHitQuestionWhiteWord(questionID.get(i).getId());
          for (int j = 0; j < questionwhiteword.size(); j++) {
            values[getAttIndex(questionwhiteword.get(j).getWordId()) - 1]++;
          }
          for (int m = 0; m < 11; m++) {
            values[m] = values[m] / whitewordcount;
            System.out.println(m + "<>" + values[m]);
          }
        }
        System.out.println("第" + i + "个问题。");
        System.out.println(questionID.get(i).getQuestion());
        values[10] = questionID.get(i).getType() - 1;
        Instance inst = new Instance(1.0, values);
        dataset.add(inst);
      }
      // 构造分类器
      IBk ibk = new IBk();
      ibk.setKNN(3);
      classifier = ibk;
      classifier.buildClassifier(dataset);
      SerializationHelper.write("knn.model", classifier);
    } else {
      classifier = (Classifier) SerializationHelper.read("knn.model");
      System.out.println("串行化解析。");
    }

    System.out.println("*************begin evaluation*******************");
    Evaluation evaluation = new Evaluation(dataset);
    evaluation.evaluateModel(classifier, dataset); // 按道理说,这里应该使用另一份数据,而不是训练集data。
    System.out.println(evaluation.toSummaryString());

    // 分类
    System.out.println("*************begin classification*******************");
    Instance subject = new Instance(1.0, getQuestionVector(question));
    subject.setDataset(dataset);
    label = classifier.classifyInstance(subject);
    System.out.println("label: " + label);

    //        double dis[]=classifier.distributionForInstance(inst);
    //        for(double i:dis){
    //            System.out.print(i+"    ");
    //        }

    System.out.println(questionID.size());
    return label + 1;
  }
示例#16
0
  /**
   * @param args the command line arguments
   * @throws Exception
   */
  public static void main(String[] args) throws Exception {
    PreProcessor p = new PreProcessor("census-income.data", "census-income-preprocessed.arff");

    p.smote();

    PreProcessor p_test =
        new PreProcessor("census-income.test", "census-income-test-preprocessed.arff");

    p_test.run();

    BufferedReader traindata =
        new BufferedReader(new FileReader("census-income-preprocessed.arff"));
    BufferedReader testdata =
        new BufferedReader(new FileReader("census-income-test-preprocessed.arff"));
    Instances traininstance = new Instances(traindata);
    Instances testinstance = new Instances(testdata);

    traindata.close();
    testdata.close();
    traininstance.setClassIndex(traininstance.numAttributes() - 1);
    testinstance.setClassIndex(testinstance.numAttributes() - 1);
    int numOfAttributes = testinstance.numAttributes();
    int numOfInstances = testinstance.numInstances();

    NaiveBayesClassifier nb = new NaiveBayesClassifier("census-income-preprocessed.arff");
    Classifier cnaive = nb.NBClassify();

    DecisionTree dt = new DecisionTree("census-income-preprocessed.arff");
    Classifier cls = dt.DTClassify();

    AdaBoost ab = new AdaBoost("census-income-preprocessed.arff");
    AdaBoostM1 m1 = ab.AdaBoostDTClassify();

    BaggingMethod b = new BaggingMethod("census-income-preprocessed.arff");
    Bagging bag = b.BaggingDTClassify();

    SVM s = new SVM("census-income-preprocessed.arff");
    SMO svm = s.SMOClassifier();

    knn knnclass = new knn("census-income-preprocessed.arff");
    IBk knnc = knnclass.knnclassifier();

    Logistic log = new Logistic();
    log.buildClassifier(traininstance);

    int match = 0;
    int error = 0;
    int greater = 0;
    int less = 0;

    for (int i = 0; i < numOfInstances; i++) {
      String predicted = "";
      greater = 0;
      less = 0;
      double predictions[] = new double[8];

      double pred = cls.classifyInstance(testinstance.instance(i));
      predictions[0] = pred;

      double abpred = m1.classifyInstance(testinstance.instance(i));
      predictions[1] = abpred;

      double naivepred = cnaive.classifyInstance(testinstance.instance(i));
      predictions[2] = naivepred;

      double bagpred = bag.classifyInstance(testinstance.instance(i));
      predictions[3] = bagpred;

      double smopred = svm.classifyInstance(testinstance.instance(i));
      predictions[4] = smopred;

      double knnpred = knnc.classifyInstance(testinstance.instance(i));
      predictions[5] = knnpred;

      for (int j = 0; j < 6; j++) {
        if ((testinstance.instance(i).classAttribute().value((int) predictions[j]))
                .compareTo(">50K")
            == 0) greater++;
        else less++;
      }

      if (greater > less) predicted = ">50K";
      else predicted = "<=50K";

      if ((testinstance.instance(i).stringValue(numOfAttributes - 1)).compareTo(predicted) == 0)
        match++;
      else error++;
    }

    System.out.println("Correctly classified Instances: " + match);
    System.out.println("Misclassified Instances: " + error);

    double accuracy = (double) match / (double) numOfInstances * 100;
    double error_percent = 100 - accuracy;
    System.out.println("Accuracy: " + accuracy + "%");
    System.out.println("Error: " + error_percent + "%");
  }
 /**
  * Gets the distance weighting method used. Will be one of WEIGHT_NONE, WEIGHT_INVERSE, or
  * WEIGHT_SIMILARITY
  *
  * @return the distance weighting method used.
  * @see IBk#WEIGHT_NONE
  * @see IBk#WEIGHT_INVERSE
  * @see IBk#WEIGHT_SIMILARITY
  */
 public SelectedTag getDistanceWeighting() {
   return m_Classifier.getDistanceWeighting();
 }
 /**
  * Sets the distance weighting method used. Values other than WEIGHT_NONE, WEIGHT_INVERSE, or
  * WEIGHT_SIMILARITY will be ignored.
  *
  * @param newMethod the distance weighting method to use
  * @see IBk#WEIGHT_NONE
  * @see IBk#WEIGHT_INVERSE
  * @see IBk#WEIGHT_SIMILARITY
  */
 public void setDistanceWeighting(SelectedTag newMethod) {
   m_Classifier.setDistanceWeighting(newMethod);
 }
示例#19
0
  private static ClassifierContext createContext(String[] args) throws ParseException {
    Options options =
        new Options()
            .addOption(new Option(help, "help", false, "show help"))
            .addOption(new Option(file, "file", true, "file path containing the data set"))
            .addOption(
                new Option(
                    trainPercentage,
                    "train-percentage",
                    true,
                    "training set in percentage, rest is test set; type double; default value training set: 0.7, test set 0.3"))
            .addOption(
                new Option(
                    crossValidation,
                    "cross-validation",
                    true,
                    "use cross-validation or not;type boolean; default value false"))
            .addOption(
                new Option(
                    kFolds,
                    "kfolds",
                    true,
                    "#folds used in cross-validation; type integer, default 5"))
            .addOption(
                new Option(
                    classifier,
                    "classifier",
                    true,
                    "classifier to use <dt, knn, boost, nn, svm> : Decision Tree(dt), Nearest Neighbours(knn), Boosting(boost), Neural Networks(nn), Support Vector Machine(svm)"))
            .addOption(
                new Option(
                    dtPruning,
                    "pruning",
                    true,
                    "dt: use pruning; type boolean, default value true>"))
            .addOption(
                new Option(
                    dtCf,
                    "confidence_factor",
                    true,
                    "dt: set confidence-factor, used for pruning; type double, default value 0.25"))
            .addOption(
                new Option(
                    boostC,
                    "boost_classifier",
                    true,
                    "boost: set the classifier learner; one of <stump|dt>, default stump"))
            .addOption(
                new Option(
                    boostNrIterations,
                    "boost_nr_iterations",
                    true,
                    "boost: set the nr of bagging iterations; type integer, default value 10>"))
            .addOption(
                new Option(
                    boostDtPruning,
                    "boost_dt_pruning",
                    true,
                    "boost: use pruning for dt; type boolean, default value true>"))
            .addOption(
                new Option(
                    boostDtCf,
                    "boost_dt_cf",
                    true,
                    "boost: use dt confidence-factor, used for pruning; type double, default value 0.25"))
            .addOption(
                new Option(
                    knnK,
                    "knn_k",
                    true,
                    "knn: specify the #nearest neighbours; type integer, default value 1"))
            .addOption(
                new Option(
                    knnWeightDistance,
                    "knn_weight_distance",
                    true,
                    "knn: specify a weight distance; type integer <1=None,2=Inverse,3=Similarity>, default value 1"))
            .addOption(
                new Option(
                    nnLearningRate,
                    "nn_learning_rate",
                    true,
                    "nn: backpropagation learning rate; type double, default value 0.3"))
            .addOption(
                new Option(
                    nnMomentum,
                    "nn_momentum",
                    true,
                    "nn: backpropagation momentum rate; type double, default value 0.2"))
            .addOption(
                new Option(
                    nnHiddentUnits,
                    "nn_hidden_units",
                    true,
                    "nn: comma-separated string for #hidden layers and nodes per layer. e.g. \"a,3,4\"; see weka for more details"))
            .addOption(
                new Option(
                    svmKernelFunction,
                    "svm_kernel_function",
                    true,
                    "svm: one of: <poly,radial>, default value poly"))
            .addOption(
                new Option(
                    svmPolyExp,
                    "svm_poly_exponent",
                    true,
                    "svm: the exponent value of a polynomial kernel; type double, default value 1.0"))
            .addOption(
                new Option(
                    svmRadialGamma,
                    "svm_radial_gamma",
                    true,
                    "svm: the gamma parameter value of the radial kernel; type double, default value 0.01"));

    String f = null;
    Double trainP = 0.7;
    Boolean cv = false;
    Integer kfolds = 5;
    ClassifierTypes classifierType = null;
    Classifier cls = null;

    CommandLine commandLine = new DefaultParser().parse(options, args);
    if (commandLine.hasOption(help)) {
      new HelpFormatter().printHelp("Run Classifiers", options);
      return null;
    }

    if (commandLine.hasOption(file)) {
      f = commandLine.getOptionValue(file);
    } else {
      System.out.println("Please provide data set file. See help for more details");
      return null;
    }

    if (commandLine.hasOption(trainPercentage)) {
      String trainp = commandLine.getOptionValue(trainPercentage);
      trainP = getDouble(trainp);
      if (trainP == null || trainP >= 1 || trainP <= 0) {
        System.out.println(
            "Please provide an double between (0,1) range for training set pct. See help for more details");
        return null;
      }
    }

    if (commandLine.hasOption(crossValidation)) {
      cv = Boolean.valueOf(commandLine.getOptionValue(crossValidation));
      if (cv) {
        if (commandLine.hasOption(kFolds)) {
          kfolds = getInt(commandLine.getOptionValue(kFolds));
          if (kfolds == null) {
            System.out.println(
                "Please provide an integer for kfolds e.g. 10. See help for more details");
            return null;
          }
        }
      }
    }

    if (commandLine.hasOption(classifier)) {
      String c = commandLine.getOptionValue(classifier);
      classifierType = ClassifierTypes.toEnum(c);
      switch (classifierType) {
        case DECISION_TREE:
          J48 dt = new J48();
          if (commandLine.hasOption(dtPruning)) {
            Boolean pruning = Boolean.valueOf(commandLine.getOptionValue(dtPruning));
            dt.setUnpruned(!pruning);
          }
          if (commandLine.hasOption(dtCf)) {
            String cfStr = commandLine.getOptionValue(dtCf);
            Float cf = getFloat(cfStr);
            if (cf == null) {
              System.out.println(
                  "Please provide a floating point number for dt_cf e.g. 0.25. See help for more details");
              return null;
            }
            dt.setConfidenceFactor(cf);
          }
          cls = dt;
          break;
        case BOOSTING:
          AdaBoostM1 boost = new AdaBoostM1();
          if (commandLine.hasOption(boostC)) {
            String boostCls = commandLine.getOptionValue(boostC);
            if ("stump".equalsIgnoreCase(boostCls)) {
              // nothing to, default value
            } else if ("dt".equalsIgnoreCase(boostCls)) {
              J48 boostDt = new J48();
              if (commandLine.hasOption(boostDtPruning)) {
                Boolean boostPruning = Boolean.valueOf(commandLine.getOptionValue(boostDtPruning));
                boostDt.setUnpruned(!boostPruning);
              }
              if (commandLine.hasOption(boostDtCf)) {
                String cfStr = commandLine.getOptionValue(boostDtCf);
                Float cf = getFloat(cfStr);
                if (cf == null) {
                  System.out.println(
                      "Please provide a floating point number for boost_dt_cf e.g. 0.25. See help for more details");
                  return null;
                }
                boostDt.setConfidenceFactor(cf);
              }
              boost.setClassifier(boostDt);
            } else {
              System.out.println(
                  "boost_c can use one of <stump, dt> as values. See help for more details");
              return null;
            }
          }
          if (commandLine.hasOption(boostNrIterations)) {
            Integer nrIt = getInt(commandLine.getOptionValue(boostNrIterations));
            if (nrIt == null) {
              System.out.println(
                  "Please provide an integer for boost_nr_it. See help for more details");
              return null;
            }
            boost.setNumIterations(nrIt);
          }
          cls = boost;
          break;
        case KNN:
          IBk ibk = new IBk();
          if (commandLine.hasOption(knnK)) {
            Integer k = getInt(commandLine.getOptionValue(knnK));
            if (k == null) {
              System.out.println("Please provide an integer for knn_n. For more details see help");
              return null;
            }
            ibk.setKNN(k);
          }
          if (commandLine.hasOption(knnWeightDistance)) {
            Integer knnwd = getInt(commandLine.getOptionValue(knnWeightDistance));
            if (knnwd == null) {
              System.out.println(
                  "Please provide one of 1,2,4 for knn_w_d. For more details see help");
              return null;
            }
            if (1 != knnwd || 2 != knnwd || 4 != knnwd) {
              System.out.println(
                  "Please provide one of 1,2,4 for knn_w_d. See help for more details");
              return null;
            }
            ibk.setDistanceWeighting(new SelectedTag(knnwd, IBk.TAGS_WEIGHTING));
          }
          cls = ibk;
          break;
        case NN:
          MultilayerPerceptron nn = new MultilayerPerceptron();
          if (commandLine.hasOption(nnLearningRate)) {
            Double nnLR = getDouble(commandLine.getOptionValue(nnLearningRate));
            if (nnLR == null) {
              System.out.println(
                  "Please provide a double for NN learning rate. See help for more details");
              return null;
            }
            nn.setLearningRate(nnLR);
          }
          if (commandLine.hasOption(nnMomentum)) {
            Double nnMR = getDouble(commandLine.getOptionValue(nnMomentum));
            if (nnMR == null) {
              System.out.println(
                  "Please provide a double for NN momentum rate. See help for more details");
              return null;
            }
            nn.setMomentum(nnMR);
          }
          if (commandLine.hasOption(nnHiddentUnits)) {
            String nnHU = commandLine.getOptionValue(nnHiddentUnits);
            nn.setHiddenLayers(nnHU);
          }

          cls = nn;
          break;
        case SVM:
          SMO svm = new SMO();
          if (commandLine.hasOption(svmKernelFunction)) {
            String svmkf = commandLine.getOptionValue(svmKernelFunction);
            Kernel kernel = null;
            if ("poly".equalsIgnoreCase(svmkf)) {
              PolyKernel pk = new PolyKernel();
              if (commandLine.hasOption(svmPolyExp)) {
                Double expValue = getDouble(commandLine.getOptionValue(svmPolyExp));
                if (expValue == null) {
                  System.out.println(
                      "Please provide a double value for svm_poly_exp. See help for more details");
                  return null;
                }
                pk.setExponent(expValue);
              }
              kernel = pk;
            } else if ("radial".equalsIgnoreCase(svmkf)) {
              RBFKernel rbfk = new RBFKernel();
              if (commandLine.hasOption(svmRadialGamma)) {
                Double gamma = getDouble(commandLine.getOptionValue(svmRadialGamma));
                if (gamma == null) {
                  System.out.println(
                      "Please provide a double value for svm_radial_gamma. See help for more details");
                  return null;
                }
                rbfk.setGamma(gamma);
              }
              kernel = rbfk;
            } else {
              System.out.println(
                  "Please provide one of <poly, radial> for svm_kernel_fct. See help for more details");
              return null;
            }
            svm.setKernel(kernel);
          } else {
            if (commandLine.hasOption(svmPolyExp)) {
              PolyKernel polyKernel = new PolyKernel();
              Double expValue = getDouble(commandLine.getOptionValue(svmPolyExp));
              if (expValue == null) {
                System.out.println(
                    "Please provide a double value for svm_poly_exp. For more details, see help");
                return null;
              }
              polyKernel.setExponent(expValue);
              svm.setKernel(polyKernel);
            }
          }
          cls = svm;
          break;
      }
    } else {
      System.out.println("Please provide a classifier. See help for more details");
      return null;
    }

    return new ClassifierContext(f, trainP, cls, cv, kfolds);
  }
示例#20
0
  /**
   * Generates the classifier.
   *
   * @param data set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    m_theInstances = new Instances(data);
    m_theInstances.deleteWithMissingClass();

    m_rr = new Random(1);

    if (m_theInstances.classAttribute().isNominal()) { // 	 Set up class priors
      m_classPriorCounts = new double[data.classAttribute().numValues()];
      Arrays.fill(m_classPriorCounts, 1.0);
      for (int i = 0; i < data.numInstances(); i++) {
        Instance curr = data.instance(i);
        m_classPriorCounts[(int) curr.classValue()] += curr.weight();
      }
      m_classPriors = m_classPriorCounts.clone();
      Utils.normalize(m_classPriors);
    }

    setUpEvaluator();

    if (m_theInstances.classAttribute().isNumeric()) {
      m_disTransform = new weka.filters.unsupervised.attribute.Discretize();
      m_classIsNominal = false;

      // use binned discretisation if the class is numeric
      ((weka.filters.unsupervised.attribute.Discretize) m_disTransform).setBins(10);
      ((weka.filters.unsupervised.attribute.Discretize) m_disTransform).setInvertSelection(true);

      // Discretize all attributes EXCEPT the class
      String rangeList = "";
      rangeList += (m_theInstances.classIndex() + 1);
      // System.out.println("The class col: "+m_theInstances.classIndex());

      ((weka.filters.unsupervised.attribute.Discretize) m_disTransform)
          .setAttributeIndices(rangeList);
    } else {
      m_disTransform = new weka.filters.supervised.attribute.Discretize();
      ((weka.filters.supervised.attribute.Discretize) m_disTransform).setUseBetterEncoding(true);
      m_classIsNominal = true;
    }

    m_disTransform.setInputFormat(m_theInstances);
    m_theInstances = Filter.useFilter(m_theInstances, m_disTransform);

    m_numAttributes = m_theInstances.numAttributes();
    m_numInstances = m_theInstances.numInstances();
    m_majority = m_theInstances.meanOrMode(m_theInstances.classAttribute());

    // Perform the search
    int[] selected = m_search.search(m_evaluator, m_theInstances);

    m_decisionFeatures = new int[selected.length + 1];
    System.arraycopy(selected, 0, m_decisionFeatures, 0, selected.length);
    m_decisionFeatures[m_decisionFeatures.length - 1] = m_theInstances.classIndex();

    // reduce instances to selected features
    m_delTransform = new Remove();
    m_delTransform.setInvertSelection(true);

    // set features to keep
    m_delTransform.setAttributeIndicesArray(m_decisionFeatures);
    m_delTransform.setInputFormat(m_theInstances);
    m_dtInstances = Filter.useFilter(m_theInstances, m_delTransform);

    // reset the number of attributes
    m_numAttributes = m_dtInstances.numAttributes();

    // create hash table
    m_entries = new Hashtable((int) (m_dtInstances.numInstances() * 1.5));

    // insert instances into the hash table
    for (int i = 0; i < m_numInstances; i++) {
      Instance inst = m_dtInstances.instance(i);
      insertIntoTable(inst, null);
    }

    // Replace the global table majority with nearest neighbour?
    if (m_useIBk) {
      m_ibk = new IBk();
      m_ibk.buildClassifier(m_theInstances);
    }

    // Save memory
    if (m_saveMemory) {
      m_theInstances = new Instances(m_theInstances, 0);
      m_dtInstances = new Instances(m_dtInstances, 0);
    }
    m_evaluation = null;
  }