public static void wekaAlgorithms(Instances data) throws Exception {
    classifier = new FilteredClassifier(); // new instance of tree
    classifier.setClassifier(new NaiveBayes());
    //  classifier.setClassifier(new J48());
    // classifier.setClassifier(new RandomForest());

    //	classifier.setClassifier(new ZeroR());
    //  classifier.setClassifier(new NaiveBayes());
    //     classifier.setClassifier(new IBk());

    data.setClassIndex(data.numAttributes() - 1);
    Evaluation eval = new Evaluation(data);

    int folds = 10;
    eval.crossValidateModel(classifier, data, folds, new Random(1));

    System.out.println("===== Evaluating on filtered (training) dataset =====");
    System.out.println(eval.toSummaryString());
    System.out.println(eval.toClassDetailsString());
    double[][] mat = eval.confusionMatrix();
    System.out.println("========= Confusion Matrix =========");
    for (int i = 0; i < mat.length; i++) {
      for (int j = 0; j < mat.length; j++) {

        System.out.print(mat[i][j] + "  ");
      }
      System.out.println(" ");
    }
  }
示例#2
0
文件: WekaTest.java 项目: fsteeg/tm2
  /**
   * @param args
   * @throws Exception
   */
  public static void main(String[] args) throws Exception {
    Instances isTrainingSet = createSet(4);
    Instance instance1 = createInstance(new double[] {1, 0.7, 0.1, 0.7}, "S1", isTrainingSet);
    Instance instance2 = createInstance(new double[] {0.1, 0.2, 1, 0.3}, "S2", isTrainingSet);
    Instance instance22 = createInstance(new double[] {0, 0, 0, 0}, "S3", isTrainingSet);
    isTrainingSet.add(instance1);
    isTrainingSet.add(instance2);
    isTrainingSet.add(instance22);
    Instances isTestingSet = createSet(4);
    Instance instance3 = createInstance(new double[] {1, 0.7, 0.1, 0.7}, "S1", isTrainingSet);
    Instance instance4 = createInstance(new double[] {0.1, 0.2, 1, 0.3}, "S2", isTrainingSet);
    isTestingSet.add(instance3);
    isTestingSet.add(instance4);

    // Create a naïve bayes classifier
    Classifier cModel = (Classifier) new BayesNet(); // M5P
    cModel.buildClassifier(isTrainingSet);

    // Test the model
    Evaluation eTest = new Evaluation(isTrainingSet);
    eTest.evaluateModel(cModel, isTestingSet);

    // Print the result à la Weka explorer:
    String strSummary = eTest.toSummaryString();
    System.out.println(strSummary);

    // Get the likelihood of each classes
    // fDistribution[0] is the probability of being “positive”
    // fDistribution[1] is the probability of being “negative”
    double[] fDistribution = cModel.distributionForInstance(instance4);
    for (int i = 0; i < fDistribution.length; i++) {
      System.out.println(fDistribution[i]);
    }
  }
示例#3
0
  /**
   * Creates an evaluation overview of the built classifier.
   *
   * @return the panel to be displayed as result evaluation view for the current decision point
   */
  protected JPanel createEvaluationVisualization(Instances data) {
    // build text field to display evaluation statistics
    JTextPane statistic = new JTextPane();

    try {
      // build evaluation statistics
      Evaluation evaluation = new Evaluation(data);
      evaluation.evaluateModel(myClassifier, data);
      statistic.setText(
          evaluation.toSummaryString()
              + "\n\n"
              + evaluation.toClassDetailsString()
              + "\n\n"
              + evaluation.toMatrixString());

    } catch (Exception ex) {
      ex.printStackTrace();
      return createMessagePanel("Error while creating the decision tree evaluation view");
    }

    statistic.setFont(new Font("Courier", Font.PLAIN, 14));
    statistic.setEditable(false);
    statistic.setCaretPosition(0);

    JPanel resultViewPanel = new JPanel();
    resultViewPanel.setLayout(new BoxLayout(resultViewPanel, BoxLayout.PAGE_AXIS));
    resultViewPanel.add(new JScrollPane(statistic));

    return resultViewPanel;
  }
  public static void run(String[] args) throws Exception {
    /**
     * *************************************************
     *
     * @param args[0]: train arff path
     * @param args[1]: test arff path
     */
    DataSource source = new DataSource(args[0]);
    Instances data = source.getDataSet();
    data.setClassIndex(data.numAttributes() - 1);
    NaiveBayes model = new NaiveBayes();
    model.buildClassifier(data);

    // Evaluation:
    Evaluation eval = new Evaluation(data);
    Instances testData = new DataSource(args[1]).getDataSet();
    testData.setClassIndex(testData.numAttributes() - 1);
    eval.evaluateModel(model, testData);
    System.out.println(model.toString());
    System.out.println(eval.toSummaryString("\nResults\n======\n", false));
    System.out.println("======\nConfusion Matrix:");
    double[][] confusionM = eval.confusionMatrix();
    for (int i = 0; i < confusionM.length; ++i) {
      for (int j = 0; j < confusionM[i].length; ++j) {
        System.out.format("%10s ", confusionM[i][j]);
      }
      System.out.print("\n");
    }
  }
  /**
   * Finds the best parameter combination. (recursive for each parameter being optimised).
   *
   * @param depth the index of the parameter to be optimised at this level
   * @param trainData the data the search is based on
   * @param random a random number generator
   * @throws Exception if an error occurs
   */
  protected void findParamsByCrossValidation(int depth, Instances trainData, Random random)
      throws Exception {

    if (depth < m_CVParams.size()) {
      CVParameter cvParam = (CVParameter) m_CVParams.elementAt(depth);

      double upper;
      switch ((int) (cvParam.m_Lower - cvParam.m_Upper + 0.5)) {
        case 1:
          upper = m_NumAttributes;
          break;
        case 2:
          upper = m_TrainFoldSize;
          break;
        default:
          upper = cvParam.m_Upper;
          break;
      }
      double increment = (upper - cvParam.m_Lower) / (cvParam.m_Steps - 1);
      for (cvParam.m_ParamValue = cvParam.m_Lower;
          cvParam.m_ParamValue <= upper;
          cvParam.m_ParamValue += increment) {
        findParamsByCrossValidation(depth + 1, trainData, random);
      }
    } else {

      Evaluation evaluation = new Evaluation(trainData);

      // Set the classifier options
      String[] options = createOptions();
      if (m_Debug) {
        System.err.print("Setting options for " + m_Classifier.getClass().getName() + ":");
        for (int i = 0; i < options.length; i++) {
          System.err.print(" " + options[i]);
        }
        System.err.println("");
      }
      ((OptionHandler) m_Classifier).setOptions(options);
      for (int j = 0; j < m_NumFolds; j++) {

        // We want to randomize the data the same way for every
        // learning scheme.
        Instances train = trainData.trainCV(m_NumFolds, j, new Random(1));
        Instances test = trainData.testCV(m_NumFolds, j);
        m_Classifier.buildClassifier(train);
        evaluation.setPriors(train);
        evaluation.evaluateModel(m_Classifier, test);
      }
      double error = evaluation.errorRate();
      if (m_Debug) {
        System.err.println("Cross-validated error rate: " + Utils.doubleToString(error, 6, 4));
      }
      if ((m_BestPerformance == -99) || (error < m_BestPerformance)) {

        m_BestPerformance = error;
        m_BestClassifierOptions = createOptions();
      }
    }
  }
 /** evaluates the classifier */
 @Override
 public void evaluate() throws Exception {
   // evaluate classifier and print some statistics
   if (_test.classIndex() == -1) _test.setClassIndex(_test.numAttributes() - 1);
   Evaluation eval = new Evaluation(_train);
   eval.evaluateModel(_cl, _test);
   System.out.println(eval.toSummaryString("\nResults\n======\n", false));
   System.out.println(eval.toMatrixString());
 }
示例#7
0
  /**
   * Buildclassifier selects a classifier from the set of classifiers by minimising error on the
   * training data.
   *
   * @param data the training data to be used for generating the boosted classifier.
   * @exception Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    if (m_Classifiers.length == 0) {
      throw new Exception("No base classifiers have been set!");
    }
    Instances newData = new Instances(data);
    newData.deleteWithMissingClass();
    newData.randomize(new Random(m_Seed));
    if (newData.classAttribute().isNominal() && (m_NumXValFolds > 1))
      newData.stratify(m_NumXValFolds);
    Instances train = newData; // train on all data by default
    Instances test = newData; // test on training data by default
    Classifier bestClassifier = null;
    int bestIndex = -1;
    double bestPerformance = Double.NaN;
    int numClassifiers = m_Classifiers.length;
    for (int i = 0; i < numClassifiers; i++) {
      Classifier currentClassifier = getClassifier(i);
      Evaluation evaluation;
      if (m_NumXValFolds > 1) {
        evaluation = new Evaluation(newData);
        for (int j = 0; j < m_NumXValFolds; j++) {
          train = newData.trainCV(m_NumXValFolds, j);
          test = newData.testCV(m_NumXValFolds, j);
          currentClassifier.buildClassifier(train);
          evaluation.setPriors(train);
          evaluation.evaluateModel(currentClassifier, test);
        }
      } else {
        currentClassifier.buildClassifier(train);
        evaluation = new Evaluation(train);
        evaluation.evaluateModel(currentClassifier, test);
      }

      double error = evaluation.errorRate();
      if (m_Debug) {
        System.err.println(
            "Error rate: "
                + Utils.doubleToString(error, 6, 4)
                + " for classifier "
                + currentClassifier.getClass().getName());
      }

      if ((i == 0) || (error < bestPerformance)) {
        bestClassifier = currentClassifier;
        bestPerformance = error;
        bestIndex = i;
      }
    }
    m_ClassifierIndex = bestIndex;
    m_Classifier = bestClassifier;
    if (m_NumXValFolds > 1) {
      m_Classifier.buildClassifier(newData);
    }
  }
 public Evaluation evaluateClassifier(Instances trainInstances, Instances testInstances) {
   try {
     Evaluation eval = new Evaluation(trainInstances);
     eval.evaluateModel(bayesNet, testInstances);
     return eval;
   } catch (Exception e) {
     System.err.println(e.getMessage());
     e.printStackTrace();
     return null;
   }
 }
 /** uses the meta-classifier */
 protected static void useClassifier(Instances data) throws Exception {
   System.out.println("\n1. Meta-classfier");
   AttributeSelectedClassifier classifier = new AttributeSelectedClassifier();
   CfsSubsetEval eval = new CfsSubsetEval();
   GreedyStepwise search = new GreedyStepwise();
   search.setSearchBackwards(true);
   J48 base = new J48();
   classifier.setClassifier(base);
   classifier.setEvaluator(eval);
   classifier.setSearch(search);
   Evaluation evaluation = new Evaluation(data);
   evaluation.crossValidateModel(classifier, data, 10, new Random(1));
   System.out.println(evaluation.toSummaryString());
 }
示例#10
0
  public static Double runClassify(String trainFile, String testFile) {
    double predictOrder = 0.0;
    double trueOrder = 0.0;
    try {
      String trainWekaFileName = trainFile;
      String testWekaFileName = testFile;

      Instances train = DataSource.read(trainWekaFileName);
      Instances test = DataSource.read(testWekaFileName);

      train.setClassIndex(0);
      test.setClassIndex(0);

      train.deleteAttributeAt(8);
      test.deleteAttributeAt(8);
      train.deleteAttributeAt(6);
      test.deleteAttributeAt(6);
      train.deleteAttributeAt(5);
      test.deleteAttributeAt(5);
      train.deleteAttributeAt(4);
      test.deleteAttributeAt(4);

      // AdditiveRegression classifier = new AdditiveRegression();

      // NaiveBayes classifier = new NaiveBayes();

      RandomForest classifier = new RandomForest();
      // LibSVM classifier = new LibSVM();

      classifier.buildClassifier(train);
      Evaluation eval = new Evaluation(train);
      eval.evaluateModel(classifier, test);

      System.out.println(eval.toSummaryString("\nResults\n\n", true));
      // System.out.println(eval.toClassDetailsString());
      // System.out.println(eval.toMatrixString());
      int k = 892;
      for (int i = 0; i < test.numInstances(); i++) {
        predictOrder = classifier.classifyInstance(test.instance(i));
        trueOrder = test.instance(i).classValue();
        System.out.println((k++) + "," + (int) predictOrder);
      }

    } catch (Exception e) {
      e.printStackTrace();
    }
    return predictOrder;
  }
示例#11
0
  public static Instances getKnowledgeBase() {
    if (knowledgeBase == null) {
      try {
        // load knowledgebase from file
        CreateAppInsertIntoVm.knowledgeBase =
            Action.loadKnowledge(Configuration.getInstance().getKBCreateAppInsertIntoVm());

        // prediction is also performed therefore the classifier and the evaluator must be
        // instantiated
        if (!isOnlyLearning()) {
          System.out.println("Classify data CreateAppInsertInto");
          if (knowledgeBase.numInstances() > 0) {
            classifier = new MultilayerPerceptron();
            classifier.buildClassifier(knowledgeBase);
            evaluation = new Evaluation(knowledgeBase);
            evaluation.crossValidateModel(
                classifier,
                knowledgeBase,
                10,
                knowledgeBase.getRandomNumberGenerator(randomData.nextLong(1, 1000)));
            System.out.println("Classified data CreateAppInsertInto");
          } else {
            System.out.println("No Instancedata for classifier CreateAppInsertIntoVm");
          }
        }
      } catch (Exception e) {
        e.printStackTrace();
      }
    }
    return knowledgeBase;
  }
  /**
   * Process a classifier's prediction for an instance and update a set of plotting instances and
   * additional plotting info. m_PlotShape for nominal class datasets holds shape types (actual data
   * points have automatic shape type assignment; classifier error data points have box shape type).
   * For numeric class datasets, the actual data points are stored in m_PlotInstances and m_PlotSize
   * stores the error (which is later converted to shape size values).
   *
   * @param toPredict the actual data point
   * @param classifier the classifier
   * @param eval the evaluation object to use for evaluating the classifier on the instance to
   *     predict
   * @see #m_PlotShapes
   * @see #m_PlotSizes
   * @see #m_PlotInstances
   */
  public void process(Instance toPredict, Classifier classifier, Evaluation eval) {
    double pred;
    double[] values;
    int i;

    try {
      pred = eval.evaluateModelOnceAndRecordPrediction(classifier, toPredict);

      if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) {
        toPredict =
            ((weka.classifiers.misc.InputMappedClassifier) classifier)
                .constructMappedInstance(toPredict);
      }

      if (!m_SaveForVisualization) return;

      if (m_PlotInstances != null) {
        values = new double[m_PlotInstances.numAttributes()];
        for (i = 0; i < m_PlotInstances.numAttributes(); i++) {
          if (i < toPredict.classIndex()) {
            values[i] = toPredict.value(i);
          } else if (i == toPredict.classIndex()) {
            values[i] = pred;
            values[i + 1] = toPredict.value(i);
            i++;
          } else {
            values[i] = toPredict.value(i - 1);
          }
        }

        m_PlotInstances.add(new DenseInstance(1.0, values));

        if (toPredict.classAttribute().isNominal()) {
          if (toPredict.isMissing(toPredict.classIndex()) || Utils.isMissingValue(pred)) {
            m_PlotShapes.addElement(new Integer(Plot2D.MISSING_SHAPE));
          } else if (pred != toPredict.classValue()) {
            // set to default error point shape
            m_PlotShapes.addElement(new Integer(Plot2D.ERROR_SHAPE));
          } else {
            // otherwise set to constant (automatically assigned) point shape
            m_PlotShapes.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE));
          }
          m_PlotSizes.addElement(new Integer(Plot2D.DEFAULT_SHAPE_SIZE));
        } else {
          // store the error (to be converted to a point size later)
          Double errd = null;
          if (!toPredict.isMissing(toPredict.classIndex()) && !Utils.isMissingValue(pred)) {
            errd = new Double(pred - toPredict.classValue());
            m_PlotShapes.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE));
          } else {
            // missing shape if actual class not present or prediction is missing
            m_PlotShapes.addElement(new Integer(Plot2D.MISSING_SHAPE));
          }
          m_PlotSizes.addElement(errd);
        }
      }
    } catch (Exception ex) {
      ex.printStackTrace();
    }
  }
示例#13
0
  // calculate if an App fits to a pm
  // TODO: gst: use WEKA to calc fit factor!!
  private int calculateFit(App app2, VirtualMachine vm) {
    int output = 0;
    if (Action.isOnlyLearning() == false && CreateAppInsertIntoVm.evaluation != null) {
      // is free space available in the VM
      if (app2.getCpu() + vm.getCurrentCpuUsage() < vm.getCurrentCpuAllocation()
          && app2.getMemory() + vm.getCurrentMemoryUsage() < vm.getCurrentMemoryAllocation()
          && app2.getStorage() + vm.getCurrentStorageUsage() < vm.getCurrentCpuAllocation()) {

        Instance instance = createInstance(Instance.missingValue(), vm);
        instance.setDataset(CreateAppInsertIntoVm.getKnowledgeBase());

        try {
          output = (int) (evaluation.evaluateModelOnce(classifier, instance) * 100);
        } catch (Exception e) {
          e.printStackTrace();
        }
      }
    } else {
      if (app2.getCpu() + vm.getCurrentCpuUsage() < vm.getCurrentCpuAllocation()
          && app2.getMemory() + vm.getCurrentMemoryUsage() < vm.getCurrentMemoryAllocation()
          && app2.getStorage() + vm.getCurrentStorageUsage() < vm.getCurrentCpuAllocation()) {
        output = randomData.nextInt(1, 100);
      }
    }
    return output;
  }
  /**
   * Adds the prediction intervals as additional attributes at the end. Since classifiers can
   * returns varying number of intervals per instance, the dataset is filled with missing values for
   * non-existing intervals.
   */
  protected void addPredictionIntervals() {
    int maxNum;
    int num;
    int i;
    int n;
    FastVector preds;
    FastVector atts;
    Instances data;
    Instance inst;
    Instance newInst;
    double[] values;
    double[][] predInt;

    // determine the maximum number of intervals
    maxNum = 0;
    preds = m_Evaluation.predictions();
    for (i = 0; i < preds.size(); i++) {
      num = ((NumericPrediction) preds.elementAt(i)).predictionIntervals().length;
      if (num > maxNum) maxNum = num;
    }

    // create new header
    atts = new FastVector();
    for (i = 0; i < m_PlotInstances.numAttributes(); i++)
      atts.addElement(m_PlotInstances.attribute(i));
    for (i = 0; i < maxNum; i++) {
      atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-lowerBoundary"));
      atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-upperBoundary"));
      atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-width"));
    }
    data = new Instances(m_PlotInstances.relationName(), atts, m_PlotInstances.numInstances());
    data.setClassIndex(m_PlotInstances.classIndex());

    // update data
    for (i = 0; i < m_PlotInstances.numInstances(); i++) {
      inst = m_PlotInstances.instance(i);
      // copy old values
      values = new double[data.numAttributes()];
      System.arraycopy(inst.toDoubleArray(), 0, values, 0, inst.numAttributes());
      // add interval data
      predInt = ((NumericPrediction) preds.elementAt(i)).predictionIntervals();
      for (n = 0; n < maxNum; n++) {
        if (n < predInt.length) {
          values[m_PlotInstances.numAttributes() + n * 3 + 0] = predInt[n][0];
          values[m_PlotInstances.numAttributes() + n * 3 + 1] = predInt[n][1];
          values[m_PlotInstances.numAttributes() + n * 3 + 2] = predInt[n][1] - predInt[n][0];
        } else {
          values[m_PlotInstances.numAttributes() + n * 3 + 0] = Utils.missingValue();
          values[m_PlotInstances.numAttributes() + n * 3 + 1] = Utils.missingValue();
          values[m_PlotInstances.numAttributes() + n * 3 + 2] = Utils.missingValue();
        }
      }
      // create new Instance
      newInst = new DenseInstance(inst.weight(), values);
      data.add(newInst);
    }

    m_PlotInstances = data;
  }
  public static void trainModel(Instances dataTrain, Instances dataTest) {
    try {
      LibLINEAR classifier = new LibLINEAR();
      classifier.setBias(10);

      classifier.buildClassifier(dataTrain);

      Evaluation eval = new Evaluation(dataTrain);
      eval.evaluateModel(classifier, dataTest);

      System.out.println(eval.toSummaryString("\nResults\n======\n", false));

    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
  }
示例#16
0
  /**
   * Main method for testing this class.
   *
   * @param argv the options
   */
  public static void main(String[] argv) {

    try {
      System.out.println(Evaluation.evaluateModel(new Decorate(), argv));
    } catch (Exception e) {
      System.err.println(e.getMessage());
    }
  }
示例#17
0
  @Override
  protected pikater.ontology.messages.Evaluation evaluateCA() {
    Evaluation eval = test();

    pikater.ontology.messages.Evaluation result = new pikater.ontology.messages.Evaluation();
    result.setError_rate((float) eval.errorRate());

    try {
      result.setKappa_statistic((float) eval.kappa());
    } catch (Exception e) {
      result.setKappa_statistic(-1);
    }

    result.setMean_absolute_error((float) eval.meanAbsoluteError());

    try {
      result.setRelative_absolute_error((float) eval.relativeAbsoluteError());
    } catch (Exception e) {
      result.setRelative_absolute_error(-1);
    }

    result.setRoot_mean_squared_error((float) eval.rootMeanSquaredError());
    result.setRoot_relative_squared_error((float) eval.rootRelativeSquaredError());

    return result;
  }
  /**
   * Main method for testing this class
   *
   * @param argv options
   */
  public static void main(String[] argv) {

    try {
      System.out.println(Evaluation.evaluateModel(new UnivariateLinearRegression(), argv));
    } catch (Exception e) {
      System.out.println(e.getMessage());
      e.printStackTrace();
    }
  }
示例#19
0
 /**
  * Utility method for fast 5-fold cross validation of a naive bayes model
  *
  * @param fullModel a <code>NaiveBayesUpdateable</code> value
  * @param trainingSet an <code>Instances</code> value
  * @param r a <code>Random</code> value
  * @return a <code>double</code> value
  * @exception Exception if an error occurs
  */
 public static double crossValidate(
     NaiveBayesUpdateable fullModel, Instances trainingSet, Random r) throws Exception {
   // make some copies for fast evaluation of 5-fold xval
   Classifier[] copies = AbstractClassifier.makeCopies(fullModel, 5);
   Evaluation eval = new Evaluation(trainingSet);
   // make some splits
   for (int j = 0; j < 5; j++) {
     Instances test = trainingSet.testCV(5, j);
     // unlearn these test instances
     for (int k = 0; k < test.numInstances(); k++) {
       test.instance(k).setWeight(-test.instance(k).weight());
       ((NaiveBayesUpdateable) copies[j]).updateClassifier(test.instance(k));
       // reset the weight back to its original value
       test.instance(k).setWeight(-test.instance(k).weight());
     }
     eval.evaluateModel(copies[j], test);
   }
   return eval.incorrect();
 }
示例#20
0
  @Override
  protected Evaluation test() {
    working = true;
    System.out.println("Agent " + getLocalName() + ": Testing...");

    // evaluate classifier and print some statistics
    Evaluation eval = null;
    try {
      eval = new Evaluation(train);
      eval.evaluateModel(cls, test);
      System.out.println(
          eval.toSummaryString(getLocalName() + " agent: " + "\nResults\n=======\n", false));

    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
    working = false;
    return eval;
  } // end test
  @Override
  public void crossValidation(String traindata) throws Exception {
    DataSource ds = new DataSource(traindata);
    Instances instances = ds.getDataSet();
    StringToWordVector stv = new StringToWordVector();
    stv.setOptions(
        weka.core.Utils.splitOptions(
            "-R first-last -W 1000 "
                + "-prune-rate -1.0 -N 0 "
                + "-stemmer weka.core.stemmers.NullStemmer -M 1 "
                + "-tokenizer \"weka.core.tokenizers.WordTokenizer -delimiters  \\\" \\r\\n\\t.,;:\\\'\\\"()?!\""));

    stv.setInputFormat(instances);
    instances = Filter.useFilter(instances, stv);
    instances.setClassIndex(0);
    Evaluation eval = new Evaluation(instances);
    eval.crossValidateModel(this.classifier, instances, 10, new Random(1));
    System.out.println(eval.toSummaryString());
    System.out.println(eval.toMatrixString());
  }
示例#22
0
  private static void evaluateClassifier(Classifier c, Instances trainData, Instances testData)
      throws Exception {
    System.err.println(
        "INFO: Starting split validation to predict '"
            + trainData.classAttribute().name()
            + "' using '"
            + c.getClass().getCanonicalName()
            + ":"
            + Arrays.toString(c.getOptions())
            + "' (#train="
            + trainData.numInstances()
            + ",#test="
            + testData.numInstances()
            + ") ...");

    if (trainData.classIndex() < 0) throw new IllegalStateException("class attribute not set");

    c.buildClassifier(trainData);
    Evaluation eval = new Evaluation(testData);
    eval.useNoPriors();
    double[] predictions = eval.evaluateModel(c, testData);

    System.out.println(eval.toClassDetailsString());
    System.out.println(eval.toSummaryString("\nResults\n======\n", false));

    // write predictions to file
    {
      System.err.println("INFO: Writing predictions to file ...");
      Writer out = new FileWriter("prediction.trec");
      writePredictionsTrecEval(predictions, testData, 0, trainData.classIndex(), out);
      out.close();
    }

    // write predicted distributions to CSV
    {
      System.err.println("INFO: Writing predicted distributions to CSV ...");
      Writer out = new FileWriter("predicted_distribution.csv");
      writePredictedDistributions(c, testData, 0, out);
      out.close();
    }
  }
  protected static Evaluation adjustForInputMappedClassifier(
      Evaluation eval,
      weka.classifiers.Classifier classifier,
      Instances inst,
      ClassifierErrorsPlotInstances plotInstances)
      throws Exception {

    if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) {
      Instances mappedClassifierHeader =
          ((weka.classifiers.misc.InputMappedClassifier) classifier)
              .getModelHeader(new Instances(inst, 0));

      eval = new Evaluation(new Instances(mappedClassifierHeader, 0));

      if (!eval.getHeader().equalHeaders(inst)) {
        // When the InputMappedClassifier is loading a model,
        // we need to make a new dataset that maps the test instances to
        // the structure expected by the mapped classifier - this is only
        // to ensure that the ClassifierPlotInstances object is configured
        // in accordance with what the embeded classifier was trained with
        Instances mappedClassifierDataset =
            ((weka.classifiers.misc.InputMappedClassifier) classifier)
                .getModelHeader(new Instances(mappedClassifierHeader, 0));
        for (int zz = 0; zz < inst.numInstances(); zz++) {
          Instance mapped =
              ((weka.classifiers.misc.InputMappedClassifier) classifier)
                  .constructMappedInstance(inst.instance(zz));
          mappedClassifierDataset.add(mapped);
        }

        eval.setPriors(mappedClassifierDataset);
        plotInstances.setInstances(mappedClassifierDataset);
        plotInstances.setClassifier(classifier);
        plotInstances.setClassIndex(mappedClassifierDataset.classIndex());
        plotInstances.setEvaluation(eval);
      }
    }

    return eval;
  }
示例#24
0
  /**
   * 用分类器测试
   *
   * @param trainFileName
   * @param testFileName
   */
  public static void classify(String trainFileName, String testFileName) {
    try {
      File inputFile = new File(fileName + trainFileName); // 训练语料文件
      ArffLoader atf = new ArffLoader();
      atf.setFile(inputFile);
      Instances instancesTrain = atf.getDataSet(); // 读入训练文件

      // 设置类标签类
      inputFile = new File(fileName + testFileName); // 测试语料文件
      atf.setFile(inputFile);
      Instances instancesTest = atf.getDataSet(); // 读入测试文件

      instancesTest.setClassIndex(instancesTest.numAttributes() - 1);
      instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1);

      classifier = (Classifier) Class.forName(CLASSIFIERNAME).newInstance();
      classifier.buildClassifier(instancesTrain);

      Evaluation eval = new Evaluation(instancesTrain);
      //  第一个为一个训练过的分类器,第二个参数是在某个数据集上评价的数据集
      eval.evaluateModel(classifier, instancesTest);

      System.out.println(eval.toClassDetailsString());
      System.out.println(eval.toSummaryString());
      System.out.println(eval.toMatrixString());
      System.out.println("precision is :" + (1 - eval.errorRate()));

    } catch (Exception e) {
      e.printStackTrace();
    }
  }
示例#25
0
  /** outputs some data about the classifier */
  public String toString() {
    StringBuffer result;

    result = new StringBuffer();
    result.append("Weka - Demo\n===========\n\n");

    result.append(
        "Classifier...: "
            + m_Classifier.getClass().getName()
            + " "
            + Utils.joinOptions(m_Classifier.getOptions())
            + "\n");
    if (m_Filter instanceof OptionHandler)
      result.append(
          "Filter.......: "
              + m_Filter.getClass().getName()
              + " "
              + Utils.joinOptions(((OptionHandler) m_Filter).getOptions())
              + "\n");
    else result.append("Filter.......: " + m_Filter.getClass().getName() + "\n");
    result.append("Training file: " + m_TrainingFile + "\n");
    result.append("\n");

    result.append(m_Classifier.toString() + "\n");
    result.append(m_Evaluation.toSummaryString() + "\n");
    try {
      result.append(m_Evaluation.toMatrixString() + "\n");
    } catch (Exception e) {
      e.printStackTrace();
    }
    try {
      result.append(m_Evaluation.toClassDetailsString() + "\n");
    } catch (Exception e) {
      e.printStackTrace();
    }

    return result.toString();
  }
示例#26
0
  /** runs 10fold CV over the training file */
  public void execute() throws Exception {
    // run filter
    m_Filter.setInputFormat(m_Training);
    Instances filtered = Filter.useFilter(m_Training, m_Filter);

    // train classifier on complete file for tree
    m_Classifier.buildClassifier(filtered);

    // 10fold CV with seed=1
    m_Evaluation = new Evaluation(filtered);

    m_Evaluation.crossValidateModel(
        m_Classifier, filtered, 10, m_Training.getRandomNumberGenerator(1));
  }
  private static void run() throws Exception {
    DataSource source = new DataSource("src/files/powerpuffgirls.arff");

    int folds = 10;
    int runs = 30;

    HashMap<String, Classifier> hash = new HashMap<>();

    hash.put("J48", new J48());
    hash.put("NaiveBayes", new NaiveBayes());
    hash.put("IBk=1", new IBk(1));
    hash.put("IBk=3", new IBk(3));
    hash.put("MultilayerPerceptron", new MultilayerPerceptron());

    //        LibSVM svm = new LibSVM();
    //        svm.setOptions(new String[]{"-S 0 -K 2 -D 3 -G 0.0 -R 0.0 -N 0.5 -M 0.40 -C 1.0 -E
    // 0.001 -P 0.1"});

    //        hash.put("LibSVM", svm);

    Instances data = source.getDataSet();
    data.setClassIndex(4);

    System.out.println("#seed \t correctly instances \t percentage of corrects\n");

    for (Entry<String, Classifier> entry : hash.entrySet()) {
      System.out.println("\n Algorithm: " + entry.getKey() + "\n");

      for (int i = 1; i <= runs; i++) {
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(entry.getValue(), data, folds, new Random(i));

        System.out.println(summary(eval));
      }
    }
  }
 private double[] makePredictions(
     Classifier classifier, Instances validationSet, Evaluation evaluation) {
   double[] predictions = null;
   try {
     predictions = evaluation.evaluateModel(classifier, validationSet);
   } catch (ArrayIndexOutOfBoundsException e) {
     throw new ClassifierPredictionException(
         "Error applying the trained classifier to the train instances. The number of features of the instance exceeds the number of features the classifier was trained on.",
         e);
   } catch (Exception e) {
     throw new ClassifierPredictionException(
         "Error applying the trained classifier to the test instances.", e);
   }
   return predictions;
 }
示例#29
0
  @Override
  protected DataInstances getPredictions(Instances test, DataInstances onto_test) {

    Evaluation eval = test();
    double pre[] = new double[test.numInstances()];
    for (int i = 0; i < test.numInstances(); i++) {
      try {
        pre[i] = eval.evaluateModelOnce((Classifier) getModelObject(), test.instance(i));
      } catch (Exception e) {
        pre[i] = Integer.MAX_VALUE;
      }
    }

    // copy results to the DataInstancs
    int i = 0;
    Iterator itr = onto_test.getInstances().iterator();
    while (itr.hasNext()) {
      Instance next_instance = (Instance) itr.next();
      next_instance.setPrediction(pre[i]);
      i++;
    }

    return onto_test;
  }
示例#30
0
  static void evaluateClassifier(Classifier c, Instances data, int folds) throws Exception {
    System.err.println(
        "INFO: Starting crossvalidation to predict '"
            + data.classAttribute().name()
            + "' using '"
            + c.getClass().getCanonicalName()
            + ":"
            + Arrays.toString(c.getOptions())
            + "' ...");

    StringBuffer sb = new StringBuffer();
    Evaluation eval = new Evaluation(data);
    eval.crossValidateModel(c, data, folds, new Random(1), sb, new Range("first"), Boolean.FALSE);

    // write predictions to file
    {
      Writer out = new FileWriter("cv.log");
      out.write(sb.toString());
      out.close();
    }

    System.out.println(eval.toClassDetailsString());
    System.out.println(eval.toSummaryString("\nResults\n======\n", false));
  }