예제 #1
0
  public static void main(String[] args) throws Exception {

    /*
     * First we load the test data from our ARFF file
     */
    ArffLoader testLoader = new ArffLoader();
    testLoader.setSource(new File("data/titanic/test.arff"));
    testLoader.setRetrieval(Loader.BATCH);
    Instances testDataSet = testLoader.getDataSet();

    /*
     * Now we tell the data set which attribute we want to classify, in our
     * case, we want to classify the first column: survived
     */
    Attribute testAttribute = testDataSet.attribute(0);
    testDataSet.setClass(testAttribute);
    testDataSet.deleteStringAttributes();

    /*
     * Now we read in the serialized model from disk
     */
    Classifier classifier = (Classifier) SerializationHelper.read("data/titanic/titanic.model");

    /*
     * This part may be a little confusing. We load up the test data again
     * so we have a prediction data set to populate. As we iterate over the
     * first data set we also iterate over the second data set. After an
     * instance is classified, we set the value of the prediction data set
     * to be the value of the classification
     */
    ArffLoader test1Loader = new ArffLoader();
    test1Loader.setSource(new File("data/titanic/test.arff"));
    Instances test1DataSet = test1Loader.getDataSet();
    Attribute test1Attribute = test1DataSet.attribute(0);
    test1DataSet.setClass(test1Attribute);

    /*
     * Now we iterate over the test data and classify each entry and set the
     * value of the 'survived' column to the result of the classification
     */
    Enumeration testInstances = testDataSet.enumerateInstances();
    Enumeration test1Instances = test1DataSet.enumerateInstances();
    while (testInstances.hasMoreElements()) {
      Instance instance = (Instance) testInstances.nextElement();
      Instance instance1 = (Instance) test1Instances.nextElement();
      double classification = classifier.classifyInstance(instance);
      instance1.setClassValue(classification);
    }

    /*
     * Now we want to write out our predictions. The resulting file is in a
     * format suitable to submit to Kaggle.
     */
    CSVSaver predictedCsvSaver = new CSVSaver();
    predictedCsvSaver.setFile(new File("data/titanic/predict.csv"));
    predictedCsvSaver.setInstances(test1DataSet);
    predictedCsvSaver.writeBatch();

    System.out.println("Prediciton saved to predict.csv");
  }
예제 #2
0
  public static void main(String[] args) throws Exception {

    /*
     * First we load our preditons from the CSV formatted file.
     */
    CSVLoader predictCsvLoader = new CSVLoader();
    predictCsvLoader.setSource(new File("predict.csv"));

    /*
     * Since we are not using the ARFF format here, we have to give the
     * loader a little bit of information about the data types. Columns
     * 3,8,10 need to be of type string and columns 1,4,11 are nominal
     * types.
     */
    predictCsvLoader.setStringAttributes("3,8,10");
    predictCsvLoader.setNominalAttributes("1,4,11");
    Instances predictDataSet = predictCsvLoader.getDataSet();

    /*
     * Here we set the attribute we want to test the predicitons with
     */
    Attribute testAttribute = predictDataSet.attribute(0);
    predictDataSet.setClass(testAttribute);

    /*
     * We still have to remove all string attributes before we can test
     */
    predictDataSet.deleteStringAttributes();

    /*
     * Next we load the training data from our ARFF file
     */
    ArffLoader trainLoader = new ArffLoader();
    trainLoader.setSource(new File("train.arff"));
    trainLoader.setRetrieval(Loader.BATCH);
    Instances trainDataSet = trainLoader.getDataSet();

    /*
     * Now we tell the data set which attribute we want to classify, in our
     * case, we want to classify the first column: survived
     */
    Attribute trainAttribute = trainDataSet.attribute(0);
    trainDataSet.setClass(trainAttribute);

    /*
     * The RandomForest implementation cannot handle columns of type string,
     * so we remove them for now.
     */
    trainDataSet.deleteStringAttributes();

    /*
     * Now we read in the serialized model from disk
     */
    Classifier classifier = (Classifier) SerializationHelper.read("titanic.model");

    /*
     * Next we will use an Evaluation class to evaluate the performance of
     * our Classifier.
     */
    Evaluation evaluation = new Evaluation(trainDataSet);
    evaluation.evaluateModel(classifier, predictDataSet, new Object[] {});

    /*
     * After we evaluate the Classifier, we write out the summary
     * information to the screen.
     */
    System.out.println(classifier);
    System.out.println(evaluation.toSummaryString());
  }
예제 #3
0
  public JSONArray Cluster(String wekaFilePath, int clusterNum) throws Exception {
    File inputFile = new File(wekaFilePath);
    ArffLoader arf = new ArffLoader();
    arf.setFile(inputFile);
    Instances originIns = arf.getDataSet();
    Instances insTest = new Instances(originIns);
    insTest.deleteStringAttributes();
    int totalNum = insTest.numInstances();

    // SimpleKMeans sm = new SimpleKMeans();
    EM em = new EM();
    em.setNumClusters(clusterNum);
    MakeDensityBasedClusterer sm = new MakeDensityBasedClusterer();
    sm.setClusterer(em);
    sm.buildClusterer(insTest);

    System.out.println("totalNum:" + insTest.numInstances());
    System.out.println("============================");
    System.out.println(sm.toString());
    Map<Integer, ArrayList<String>> result = new HashMap<Integer, ArrayList<String>>();
    for (int i = 0; i < clusterNum; i++) {
      result.put(i, new ArrayList<String>());
    }

    for (int i = 0; i < totalNum; i++) {
      Instance ins = originIns.instance(i);
      String word = ins.stringValue(0);
      Instance tempIns = new Instance(ins);
      tempIns.deleteAttributeAt(0);
      int cluster = sm.clusterInstance(tempIns);
      result.get(cluster).add(word);
    }

    // print the result
    ArrayList<String> words = new ArrayList<String>();
    JSONArray keyWords = new JSONArray();
    for (int k : result.keySet()) {
      words = result.get(k);
      PriorityQueue<MyTerm> clusterQueue = new PriorityQueue<MyTerm>(1, MyTermCompare);
      for (int i = 0; i < words.size(); i++) {
        String s = words.get(i);
        assert linkMap.containsKey(s);
        int freq = linkMap.get(s).totalFreq;
        clusterQueue.add(linkMap.get(s));
        words.set(i, "(" + s + ":" + freq + ")");
      }

      JSONArray clusterArray = new JSONArray();
      int num = clusterQueue.size() / 10 + 1; // 5%
      int totalFreq = 0;
      int totalLength = 0;
      for (int i = 0; i < num && !clusterQueue.isEmpty(); ) {
        JSONObject mem = new JSONObject();
        MyTerm myTerm = clusterQueue.poll();
        String word = myTerm.originTrem.text();
        if (word.length() == 1) {
          continue;
        }
        mem.put("text", word);
        mem.put("freq", myTerm.totalFreq);
        clusterArray.put(mem);
        i++;
        totalFreq += myTerm.totalFreq;
        totalLength += word.length();
      }

      double averFreq = totalFreq * 1.0 / num;
      double averLength = totalLength * 1.0 / num;
      int count = 0;
      while (!clusterQueue.isEmpty() && count < num) {
        MyTerm myTerm = clusterQueue.poll();
        String word = myTerm.originTrem.text();
        int freq = myTerm.totalFreq;
        int times = (int) (word.length() / averFreq) + 1;
        if (freq > averFreq / times) {
          JSONObject mem = new JSONObject();
          mem.put("text", word);
          mem.put("freq", freq);
          mem.put("extra", true);
          clusterArray.put(mem);
        }
      }

      keyWords.put(clusterArray);
      System.out.println(
          "cluster" + k + ":" + words.size() + ":\t" + (int) (words.size() * 1.0 / totalNum * 100));
      if (result.get(k).size() < 100) {
        System.out.println(result.get(k));
      }
    }
    // System.out.println("errorNum:"+errorNum);
    return keyWords;
  }