Esempio n. 1
0
  /**
   * 用分类器测试
   *
   * @param trainFileName
   * @param testFileName
   */
  public static void classify(String trainFileName, String testFileName) {
    try {
      File inputFile = new File(fileName + trainFileName); // 训练语料文件
      ArffLoader atf = new ArffLoader();
      atf.setFile(inputFile);
      Instances instancesTrain = atf.getDataSet(); // 读入训练文件

      // 设置类标签类
      inputFile = new File(fileName + testFileName); // 测试语料文件
      atf.setFile(inputFile);
      Instances instancesTest = atf.getDataSet(); // 读入测试文件

      instancesTest.setClassIndex(instancesTest.numAttributes() - 1);
      instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1);

      classifier = (Classifier) Class.forName(CLASSIFIERNAME).newInstance();
      classifier.buildClassifier(instancesTrain);

      Evaluation eval = new Evaluation(instancesTrain);
      //  第一个为一个训练过的分类器,第二个参数是在某个数据集上评价的数据集
      eval.evaluateModel(classifier, instancesTest);

      System.out.println(eval.toClassDetailsString());
      System.out.println(eval.toSummaryString());
      System.out.println(eval.toMatrixString());
      System.out.println("precision is :" + (1 - eval.errorRate()));

    } catch (Exception e) {
      e.printStackTrace();
    }
  }
Esempio n. 2
0
  /**
   * Parses a given list of options.
   *
   * <p>
   * <!-- options-start -->
   * Valid options are:
   *
   * <p>
   *
   * <pre> -i &lt;the input file&gt;
   * The input file</pre>
   *
   * <pre> -o &lt;the output file&gt;
   * The output file</pre>
   *
   * <pre> -c &lt;the class index&gt;
   * The class index</pre>
   *
   * <!-- options-end -->
   *
   * @param options the list of options as an array of strings
   * @throws Exception if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {

    String outputString = Utils.getOption('o', options);
    String inputString = Utils.getOption('i', options);
    String indexString = Utils.getOption('c', options);

    ArffLoader loader = new ArffLoader();

    resetOptions();

    // parse index
    int index = -1;
    if (indexString.length() != 0) {
      if (indexString.equals("first")) index = 0;
      else {
        if (indexString.equals("last")) index = -1;
        else index = Integer.parseInt(indexString);
      }
    }

    if (inputString.length() != 0) {
      try {
        File input = new File(inputString);
        loader.setFile(input);
        Instances inst = loader.getDataSet();
        if (index == -1) inst.setClassIndex(inst.numAttributes() - 1);
        else inst.setClassIndex(index);
        setInstances(inst);
      } catch (Exception ex) {
        throw new IOException(
            "No data set loaded. Data set has to be arff format (Reason: " + ex.toString() + ").");
      }
    } else throw new IOException("No data set to save.");

    if (outputString.length() != 0) {
      // add appropriate file extension
      if (!outputString.endsWith(getFileExtension())) {
        if (outputString.lastIndexOf('.') != -1)
          outputString =
              (outputString.substring(0, outputString.lastIndexOf('.'))) + getFileExtension();
        else outputString = outputString + getFileExtension();
      }
      try {
        File output = new File(outputString);
        setFile(output);
      } catch (Exception ex) {
        throw new IOException("Cannot create output file.");
      }
    }

    if (index == -1) index = getInstances().numAttributes() - 1;
    getInstances().setClassIndex(index);
  }
  private void loadExistingData() {
    if (isExternalStorageAvailable()) {
      try {
        if (file.exists()) {
          ArffLoader loader = new ArffLoader();
          loader.setFile(file);

          Instances existingData = loader.getDataSet();
          addManyInstances(existingData);
        }
      } catch (IOException e) {
        e.printStackTrace();
      }
    }
  }
Esempio n. 4
0
  public static void main(String[] args) throws Exception {

    /*
     * First we load the test data from our ARFF file
     */
    ArffLoader testLoader = new ArffLoader();
    testLoader.setSource(new File("data/titanic/test.arff"));
    testLoader.setRetrieval(Loader.BATCH);
    Instances testDataSet = testLoader.getDataSet();

    /*
     * Now we tell the data set which attribute we want to classify, in our
     * case, we want to classify the first column: survived
     */
    Attribute testAttribute = testDataSet.attribute(0);
    testDataSet.setClass(testAttribute);
    testDataSet.deleteStringAttributes();

    /*
     * Now we read in the serialized model from disk
     */
    Classifier classifier = (Classifier) SerializationHelper.read("data/titanic/titanic.model");

    /*
     * This part may be a little confusing. We load up the test data again
     * so we have a prediction data set to populate. As we iterate over the
     * first data set we also iterate over the second data set. After an
     * instance is classified, we set the value of the prediction data set
     * to be the value of the classification
     */
    ArffLoader test1Loader = new ArffLoader();
    test1Loader.setSource(new File("data/titanic/test.arff"));
    Instances test1DataSet = test1Loader.getDataSet();
    Attribute test1Attribute = test1DataSet.attribute(0);
    test1DataSet.setClass(test1Attribute);

    /*
     * Now we iterate over the test data and classify each entry and set the
     * value of the 'survived' column to the result of the classification
     */
    Enumeration testInstances = testDataSet.enumerateInstances();
    Enumeration test1Instances = test1DataSet.enumerateInstances();
    while (testInstances.hasMoreElements()) {
      Instance instance = (Instance) testInstances.nextElement();
      Instance instance1 = (Instance) test1Instances.nextElement();
      double classification = classifier.classifyInstance(instance);
      instance1.setClassValue(classification);
    }

    /*
     * Now we want to write out our predictions. The resulting file is in a
     * format suitable to submit to Kaggle.
     */
    CSVSaver predictedCsvSaver = new CSVSaver();
    predictedCsvSaver.setFile(new File("data/titanic/predict.csv"));
    predictedCsvSaver.setInstances(test1DataSet);
    predictedCsvSaver.writeBatch();

    System.out.println("Prediciton saved to predict.csv");
  }
Esempio n. 5
0
  public static void main(String args[]) throws Exception {
    ArffLoader trainLoader = new ArffLoader();
    trainLoader.setSource(new File("src/train.arff"));
    trainLoader.setRetrieval(Loader.BATCH);
    Instances trainDataSet = trainLoader.getDataSet();
    weka.core.Attribute trainAttribute = trainDataSet.attribute("class");

    trainDataSet.setClass(trainAttribute);
    // trainDataSet.deleteStringAttributes();

    NaiveBayes classifier = new NaiveBayes();

    final double startTime = System.currentTimeMillis();
    classifier.buildClassifier(trainDataSet);
    final double endTime = System.currentTimeMillis();
    double executionTime = (endTime - startTime) / (1000.0);
    System.out.println("Total execution time: " + executionTime);

    SerializationHelper.write("NaiveBayes.model", classifier);
    System.out.println("Saved trained model to classifier.model");
  }
Esempio n. 6
0
 /**
  * Main method.
  *
  * @param args should contain the name of an input file.
  */
 public static void main(String[] args) {
   if (args.length > 0) {
     File inputfile;
     inputfile = new File(args[0]);
     try {
       ArffLoader atf = new ArffLoader();
       atf.setSource(inputfile);
       System.out.println(atf.getStructure());
       Instance temp;
       do {
         temp = atf.getNextInstance();
         if (temp != null) {
           System.out.println(temp);
         }
       } while (temp != null);
     } catch (Exception ex) {
       ex.printStackTrace();
     }
   } else {
     System.err.println("Usage:\n\tArffLoader <file.arff>\n");
   }
 }
Esempio n. 7
0
  public static void main(String[] args) throws Exception {
    // NaiveBayesSimple nb = new NaiveBayesSimple();

    //		BufferedReader br_train = new BufferedReader(new FileReader("src/train.arff.txt"));
    //		String s = null;
    //		long st_time = System.currentTimeMillis();
    //		Instances inst_train = new Instances(br_train);
    //		System.out.println(inst_train.numAttributes());
    //		inst_train.setClassIndex(inst_train.numAttributes()-1);
    //		System.out.println("train time"+(System.currentTimeMillis()-st_time));
    // NaiveBayes nb1 = new NaiveBayes();
    // nb1.buildClassifier(inst_train);
    // br_train.close();
    long st_time = System.currentTimeMillis();
    st_time = System.currentTimeMillis();

    Classifier classifier = (Classifier) SerializationHelper.read("NaiveBayes.model");

    //		BufferedReader br_test = new BufferedReader(new FileReader("src/test.arff.txt"));
    //		Instances inst_test = new Instances(br_test);
    //		inst_test.setClassIndex(inst_test.numAttributes()-1);
    //		System.out.println("test time"+(System.currentTimeMillis()-st_time));
    //

    ArffLoader testLoader = new ArffLoader();
    testLoader.setSource(new File("src/test.arff"));
    testLoader.setRetrieval(Loader.BATCH);
    Instances testDataSet = testLoader.getDataSet();

    Attribute testAttribute = testDataSet.attribute("class");
    testDataSet.setClass(testAttribute);

    int correct = 0;
    int incorrect = 0;
    FastVector attInfo = new FastVector();
    attInfo.addElement(new Attribute("Id"));
    attInfo.addElement(new Attribute("Category"));

    Instances outputInstances = new Instances("predict", attInfo, testDataSet.numInstances());

    Enumeration testInstances = testDataSet.enumerateInstances();
    int index = 1;
    while (testInstances.hasMoreElements()) {
      Instance instance = (Instance) testInstances.nextElement();
      double classification = classifier.classifyInstance(instance);
      Instance predictInstance = new Instance(outputInstances.numAttributes());
      predictInstance.setValue(0, index++);
      predictInstance.setValue(1, (int) classification + 1);
      outputInstances.add(predictInstance);
    }

    System.out.println("Correct Instance: " + correct);
    System.out.println("IncCorrect Instance: " + incorrect);
    double accuracy = (double) (correct) / (double) (correct + incorrect);
    System.out.println("Accuracy: " + accuracy);
    CSVSaver predictedCsvSaver = new CSVSaver();
    predictedCsvSaver.setFile(new File("predict.csv"));
    predictedCsvSaver.setInstances(outputInstances);
    predictedCsvSaver.writeBatch();

    System.out.println("Prediciton saved to predict.csv");
  }
Esempio n. 8
0
  public double getLiblinear(String path, String train, String test) {
    // 本次精确度
    double accuracy = 0.0;

    try {
      LibLINEAR c1 = new LibLINEAR();

      // * String[] options=weka.core.Utils.splitOptions(
      // * "-S 1 -C 1.0 -E 0.001 -B 0"); c1.setOptions(options);

      ArffLoader atf = new ArffLoader();
      File TraininputFile = new File(train);
      atf.setFile(TraininputFile); // 训练语料文件
      Instances instancesTrain = atf.getDataSet(); // 读入训练文件
      instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1);

      File TestinputFile = new File(test);
      atf.setFile(TestinputFile); // 测试语料文件
      Instances instancesTest = atf.getDataSet(); // 读入测试文件
      // 设置分类属性所在行号(第一行为0号),instancesTest.numAttributes()可以取得属性总数
      instancesTest.setClassIndex(instancesTest.numAttributes() - 1);

      c1.buildClassifier(instancesTrain); // 训练

      Evaluation eval = new Evaluation(instancesTrain);
      eval.evaluateModel(c1, instancesTest);
      // eval.crossValidateModel(c1, instancesTrain, 10, new
      // Random(1));
      File newfile = new File(path + "OutLiblinear_temp" + ".txt");

      BufferedWriter bufferedWriter =
          new BufferedWriter(new OutputStreamWriter(new FileOutputStream(newfile), "utf-8"));

      bufferedWriter.write(eval.toSummaryString() + "\r\n");
      bufferedWriter.write(eval.toClassDetailsString() + "\r\n");
      bufferedWriter.write(eval.toMatrixString() + "\r\n");

      bufferedWriter.flush();
      bufferedWriter.close();

      BufferedReader bufferedReader = new BufferedReader(new FileReader(newfile));
      String[] splitLineString = new String[5];
      while (bufferedReader.ready()) {
        bufferedReader.readLine();
        String lineString = bufferedReader.readLine();
        splitLineString = lineString.split(" ");
        System.out.println(splitLineString[4]);
        break;
      }
      bufferedReader.close();

      // 求分类准确度
      String tempLine;
      BufferedReader tempBF = new BufferedReader(new FileReader(newfile));
      while (tempBF.ready()) {
        tempLine = tempBF.readLine();
        if (tempLine.contains("Correctly Classified Instances")) {
          tempLine = tempLine.substring(tempLine.lastIndexOf(".") - 2, tempLine.lastIndexOf(" "));
          accuracy = Double.parseDouble(tempLine);
          break;
        }
      }

      tempBF.close();

    } catch (Exception e) {
      System.out.println("Can't run linlinear of weka.");
    }

    return accuracy;
  }
Esempio n. 9
0
  public static void main(String[] args) throws Exception {

    /*
     * First we load our preditons from the CSV formatted file.
     */
    CSVLoader predictCsvLoader = new CSVLoader();
    predictCsvLoader.setSource(new File("predict.csv"));

    /*
     * Since we are not using the ARFF format here, we have to give the
     * loader a little bit of information about the data types. Columns
     * 3,8,10 need to be of type string and columns 1,4,11 are nominal
     * types.
     */
    predictCsvLoader.setStringAttributes("3,8,10");
    predictCsvLoader.setNominalAttributes("1,4,11");
    Instances predictDataSet = predictCsvLoader.getDataSet();

    /*
     * Here we set the attribute we want to test the predicitons with
     */
    Attribute testAttribute = predictDataSet.attribute(0);
    predictDataSet.setClass(testAttribute);

    /*
     * We still have to remove all string attributes before we can test
     */
    predictDataSet.deleteStringAttributes();

    /*
     * Next we load the training data from our ARFF file
     */
    ArffLoader trainLoader = new ArffLoader();
    trainLoader.setSource(new File("train.arff"));
    trainLoader.setRetrieval(Loader.BATCH);
    Instances trainDataSet = trainLoader.getDataSet();

    /*
     * Now we tell the data set which attribute we want to classify, in our
     * case, we want to classify the first column: survived
     */
    Attribute trainAttribute = trainDataSet.attribute(0);
    trainDataSet.setClass(trainAttribute);

    /*
     * The RandomForest implementation cannot handle columns of type string,
     * so we remove them for now.
     */
    trainDataSet.deleteStringAttributes();

    /*
     * Now we read in the serialized model from disk
     */
    Classifier classifier = (Classifier) SerializationHelper.read("titanic.model");

    /*
     * Next we will use an Evaluation class to evaluate the performance of
     * our Classifier.
     */
    Evaluation evaluation = new Evaluation(trainDataSet);
    evaluation.evaluateModel(classifier, predictDataSet, new Object[] {});

    /*
     * After we evaluate the Classifier, we write out the summary
     * information to the screen.
     */
    System.out.println(classifier);
    System.out.println(evaluation.toSummaryString());
  }
Esempio n. 10
0
  public JSONArray Cluster(String wekaFilePath, int clusterNum) throws Exception {
    File inputFile = new File(wekaFilePath);
    ArffLoader arf = new ArffLoader();
    arf.setFile(inputFile);
    Instances originIns = arf.getDataSet();
    Instances insTest = new Instances(originIns);
    insTest.deleteStringAttributes();
    int totalNum = insTest.numInstances();

    // SimpleKMeans sm = new SimpleKMeans();
    EM em = new EM();
    em.setNumClusters(clusterNum);
    MakeDensityBasedClusterer sm = new MakeDensityBasedClusterer();
    sm.setClusterer(em);
    sm.buildClusterer(insTest);

    System.out.println("totalNum:" + insTest.numInstances());
    System.out.println("============================");
    System.out.println(sm.toString());
    Map<Integer, ArrayList<String>> result = new HashMap<Integer, ArrayList<String>>();
    for (int i = 0; i < clusterNum; i++) {
      result.put(i, new ArrayList<String>());
    }

    for (int i = 0; i < totalNum; i++) {
      Instance ins = originIns.instance(i);
      String word = ins.stringValue(0);
      Instance tempIns = new Instance(ins);
      tempIns.deleteAttributeAt(0);
      int cluster = sm.clusterInstance(tempIns);
      result.get(cluster).add(word);
    }

    // print the result
    ArrayList<String> words = new ArrayList<String>();
    JSONArray keyWords = new JSONArray();
    for (int k : result.keySet()) {
      words = result.get(k);
      PriorityQueue<MyTerm> clusterQueue = new PriorityQueue<MyTerm>(1, MyTermCompare);
      for (int i = 0; i < words.size(); i++) {
        String s = words.get(i);
        assert linkMap.containsKey(s);
        int freq = linkMap.get(s).totalFreq;
        clusterQueue.add(linkMap.get(s));
        words.set(i, "(" + s + ":" + freq + ")");
      }

      JSONArray clusterArray = new JSONArray();
      int num = clusterQueue.size() / 10 + 1; // 5%
      int totalFreq = 0;
      int totalLength = 0;
      for (int i = 0; i < num && !clusterQueue.isEmpty(); ) {
        JSONObject mem = new JSONObject();
        MyTerm myTerm = clusterQueue.poll();
        String word = myTerm.originTrem.text();
        if (word.length() == 1) {
          continue;
        }
        mem.put("text", word);
        mem.put("freq", myTerm.totalFreq);
        clusterArray.put(mem);
        i++;
        totalFreq += myTerm.totalFreq;
        totalLength += word.length();
      }

      double averFreq = totalFreq * 1.0 / num;
      double averLength = totalLength * 1.0 / num;
      int count = 0;
      while (!clusterQueue.isEmpty() && count < num) {
        MyTerm myTerm = clusterQueue.poll();
        String word = myTerm.originTrem.text();
        int freq = myTerm.totalFreq;
        int times = (int) (word.length() / averFreq) + 1;
        if (freq > averFreq / times) {
          JSONObject mem = new JSONObject();
          mem.put("text", word);
          mem.put("freq", freq);
          mem.put("extra", true);
          clusterArray.put(mem);
        }
      }

      keyWords.put(clusterArray);
      System.out.println(
          "cluster" + k + ":" + words.size() + ":\t" + (int) (words.size() * 1.0 / totalNum * 100));
      if (result.get(k).size() < 100) {
        System.out.println(result.get(k));
      }
    }
    // System.out.println("errorNum:"+errorNum);
    return keyWords;
  }