Exemplo n.º 1
0
  public static void main(String[] args) throws Exception {

    /*
     * First we load the test data from our ARFF file
     */
    ArffLoader testLoader = new ArffLoader();
    testLoader.setSource(new File("data/titanic/test.arff"));
    testLoader.setRetrieval(Loader.BATCH);
    Instances testDataSet = testLoader.getDataSet();

    /*
     * Now we tell the data set which attribute we want to classify, in our
     * case, we want to classify the first column: survived
     */
    Attribute testAttribute = testDataSet.attribute(0);
    testDataSet.setClass(testAttribute);
    testDataSet.deleteStringAttributes();

    /*
     * Now we read in the serialized model from disk
     */
    Classifier classifier = (Classifier) SerializationHelper.read("data/titanic/titanic.model");

    /*
     * This part may be a little confusing. We load up the test data again
     * so we have a prediction data set to populate. As we iterate over the
     * first data set we also iterate over the second data set. After an
     * instance is classified, we set the value of the prediction data set
     * to be the value of the classification
     */
    ArffLoader test1Loader = new ArffLoader();
    test1Loader.setSource(new File("data/titanic/test.arff"));
    Instances test1DataSet = test1Loader.getDataSet();
    Attribute test1Attribute = test1DataSet.attribute(0);
    test1DataSet.setClass(test1Attribute);

    /*
     * Now we iterate over the test data and classify each entry and set the
     * value of the 'survived' column to the result of the classification
     */
    Enumeration testInstances = testDataSet.enumerateInstances();
    Enumeration test1Instances = test1DataSet.enumerateInstances();
    while (testInstances.hasMoreElements()) {
      Instance instance = (Instance) testInstances.nextElement();
      Instance instance1 = (Instance) test1Instances.nextElement();
      double classification = classifier.classifyInstance(instance);
      instance1.setClassValue(classification);
    }

    /*
     * Now we want to write out our predictions. The resulting file is in a
     * format suitable to submit to Kaggle.
     */
    CSVSaver predictedCsvSaver = new CSVSaver();
    predictedCsvSaver.setFile(new File("data/titanic/predict.csv"));
    predictedCsvSaver.setInstances(test1DataSet);
    predictedCsvSaver.writeBatch();

    System.out.println("Prediciton saved to predict.csv");
  }
Exemplo n.º 2
0
  public void batchPredict() {
    // load all test set
    String modelFile =
        "data\\AcquireValueShopper\\decisionTable_bayes_trees.model".replace("\\", File.separator);
    String pathTest = "data/AcquireValueShopper/test_new.csv";
    String pathPredict = "data/AcquireValueShopper/submission.csv";

    Scanner scanner;
    String line = "";
    String[] partsOfLine = null;
    String id = "";
    PrintWriter output;
    Map<String, String> testSet = new HashMap<String, String>();
    try {
      scanner = new Scanner(new File(pathTest));
      while (scanner.hasNext()) {
        line = scanner.nextLine().trim();
        partsOfLine = line.split(",");
        id = partsOfLine[0];
        testSet.put(id, line);
      }
      scanner.close();
    } catch (FileNotFoundException e1) {
      // TODO Auto-generated catch block
      e1.printStackTrace();
    }
    double[] returnProb;
    double prob = 0.0;
    // predict
    try {
      // load model
      Classifier classifier = (Classifier) SerializationHelper.read(modelFile);

      output = new PrintWriter(pathPredict);
      output.append("id,repeatProbability" + "\n");
      Iterator<String> idIterator = testSet.keySet().iterator();
      while (idIterator.hasNext()) {
        id = idIterator.next();
        line = testSet.get(id);
        Instances instances = buildInstance(line);
        Instance instance = instances.instance(0);
        returnProb = classifier.distributionForInstance(instance);
        prob = returnProb[1];
        // prob = classifier.classifyInstance(instance);
        output.append(id + "," + prob + "\n");
      }
      output.close();
    } catch (FileNotFoundException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
  }
Exemplo n.º 3
0
  public static void main(String[] args) throws Exception {
    // NaiveBayesSimple nb = new NaiveBayesSimple();

    //		BufferedReader br_train = new BufferedReader(new FileReader("src/train.arff.txt"));
    //		String s = null;
    //		long st_time = System.currentTimeMillis();
    //		Instances inst_train = new Instances(br_train);
    //		System.out.println(inst_train.numAttributes());
    //		inst_train.setClassIndex(inst_train.numAttributes()-1);
    //		System.out.println("train time"+(System.currentTimeMillis()-st_time));
    // NaiveBayes nb1 = new NaiveBayes();
    // nb1.buildClassifier(inst_train);
    // br_train.close();
    long st_time = System.currentTimeMillis();
    st_time = System.currentTimeMillis();

    Classifier classifier = (Classifier) SerializationHelper.read("NaiveBayes.model");

    //		BufferedReader br_test = new BufferedReader(new FileReader("src/test.arff.txt"));
    //		Instances inst_test = new Instances(br_test);
    //		inst_test.setClassIndex(inst_test.numAttributes()-1);
    //		System.out.println("test time"+(System.currentTimeMillis()-st_time));
    //

    ArffLoader testLoader = new ArffLoader();
    testLoader.setSource(new File("src/test.arff"));
    testLoader.setRetrieval(Loader.BATCH);
    Instances testDataSet = testLoader.getDataSet();

    Attribute testAttribute = testDataSet.attribute("class");
    testDataSet.setClass(testAttribute);

    int correct = 0;
    int incorrect = 0;
    FastVector attInfo = new FastVector();
    attInfo.addElement(new Attribute("Id"));
    attInfo.addElement(new Attribute("Category"));

    Instances outputInstances = new Instances("predict", attInfo, testDataSet.numInstances());

    Enumeration testInstances = testDataSet.enumerateInstances();
    int index = 1;
    while (testInstances.hasMoreElements()) {
      Instance instance = (Instance) testInstances.nextElement();
      double classification = classifier.classifyInstance(instance);
      Instance predictInstance = new Instance(outputInstances.numAttributes());
      predictInstance.setValue(0, index++);
      predictInstance.setValue(1, (int) classification + 1);
      outputInstances.add(predictInstance);
    }

    System.out.println("Correct Instance: " + correct);
    System.out.println("IncCorrect Instance: " + incorrect);
    double accuracy = (double) (correct) / (double) (correct + incorrect);
    System.out.println("Accuracy: " + accuracy);
    CSVSaver predictedCsvSaver = new CSVSaver();
    predictedCsvSaver.setFile(new File("predict.csv"));
    predictedCsvSaver.setInstances(outputInstances);
    predictedCsvSaver.writeBatch();

    System.out.println("Prediciton saved to predict.csv");
  }
  // 输入问题,输出问题所属类型。
  public double classifyByBayes(String question) throws Exception {
    double label = -1;
    List<Question> questionID = questionDAO.getQuestionIDLabeled();

    // 定义数据格式
    Attribute att1 = new Attribute("法律政策");
    Attribute att2 = new Attribute("位置交通");
    Attribute att3 = new Attribute("风水");
    Attribute att4 = new Attribute("房价");
    Attribute att5 = new Attribute("楼层");
    Attribute att6 = new Attribute("户型");
    Attribute att7 = new Attribute("小区配套");
    Attribute att8 = new Attribute("贷款");
    Attribute att9 = new Attribute("买房时机");
    Attribute att10 = new Attribute("开发商");
    FastVector labels = new FastVector();
    labels.addElement("1");
    labels.addElement("2");
    labels.addElement("3");
    labels.addElement("4");
    labels.addElement("5");
    labels.addElement("6");
    labels.addElement("7");
    labels.addElement("8");
    labels.addElement("9");
    labels.addElement("10");
    Attribute att11 = new Attribute("类别", labels);

    FastVector attributes = new FastVector();
    attributes.addElement(att1);
    attributes.addElement(att2);
    attributes.addElement(att3);
    attributes.addElement(att4);
    attributes.addElement(att5);
    attributes.addElement(att6);
    attributes.addElement(att7);
    attributes.addElement(att8);
    attributes.addElement(att9);
    attributes.addElement(att10);
    attributes.addElement(att11);
    Instances dataset = new Instances("Test-dataset", attributes, 0);
    dataset.setClassIndex(10);

    Classifier classifier = null;
    if (!new File("naivebayes.model").exists()) {
      // 添加数据
      double[] values = new double[11];
      for (int i = 0; i < questionID.size(); i++) {
        for (int m = 0; m < 11; m++) {
          values[m] = 0;
        }
        int whitewordcount = 0;
        whitewordcount = questionDAO.getHitWhiteWordNum(questionID.get(i).getId());
        if (whitewordcount != 0) {
          List<QuestionWhiteWord> questionwhiteword =
              questionDAO.getHitQuestionWhiteWord(questionID.get(i).getId());
          for (int j = 0; j < questionwhiteword.size(); j++) {
            values[getAttIndex(questionwhiteword.get(j).getWordId()) - 1]++;
          }
          for (int m = 0; m < 11; m++) {
            values[m] = values[m] / whitewordcount;
          }
        }
        values[10] = questionID.get(i).getType() - 1;
        Instance inst = new Instance(1.0, values);
        dataset.add(inst);
      }
      // 构造分类器
      classifier = new NaiveBayes();
      classifier.buildClassifier(dataset);
      SerializationHelper.write("naivebayes.model", classifier);
    } else {
      classifier = (Classifier) SerializationHelper.read("naivebayes.model");
    }

    System.out.println("*************begin evaluation*******************");
    Evaluation evaluation = new Evaluation(dataset);
    evaluation.evaluateModel(classifier, dataset); // 按道理说,这里应该使用另一份数据,而不是训练集data。
    System.out.println(evaluation.toSummaryString());

    // 分类
    System.out.println("*************begin classification*******************");
    Instance subject = new Instance(1.0, getQuestionVector(question));
    subject.setDataset(dataset);
    label = classifier.classifyInstance(subject);
    System.out.println("label: " + label);

    //        double dis[]=classifier.distributionForInstance(inst);
    //        for(double i:dis){
    //            System.out.print(i+"    ");
    //        }

    System.out.println(questionID.size());
    return label + 1;
  }
Exemplo n.º 5
0
 public void loadModel(String filename) throws Exception {
   classifier = (Classifier) SerializationHelper.read(filename);
 }
    @Override
    public Void doInBackground() {
      BufferedReader reader;
      publish("Computing features...");
      int testingSamples = p.getAllFeatures2(path, "testing_data");

      try {
        publish("Reading data...");

        reader = new BufferedReader(new FileReader("testing_data.arff"));
        final Instances testingdata = new Instances(reader);
        reader.close();
        // setting class attribute
        testingdata.setClassIndex(13);
        testingdata.randomize(new Random(1));
        long startTime = System.nanoTime();
        Classifier ann = (Classifier) weka.core.SerializationHelper.read("mlp.model");
        publish("Evaluating ANN...");

        evalANN = new Evaluation(testingdata);
        startTime = System.nanoTime();
        evalANN.evaluateModel(ann, testingdata);
        long runningTimeANN = (System.nanoTime() - startTime) / 1000000;
        // runningTimeANN /= 100;

        publish("Done evaluating ANN");

        publish("Evaluating SVM...");
        Classifier svm = (Classifier) weka.core.SerializationHelper.read("svm.model");

        evalSVM = new Evaluation(testingdata);
        startTime = System.nanoTime();
        evalSVM.evaluateModel(svm, testingdata);
        long runningTimeSVM = (System.nanoTime() - startTime) / 1000000;
        // runningTimeSVM /= 100;
        publish("Done evaluating SVM");

        publish("Evaluating NB...");
        Classifier nb = (Classifier) weka.core.SerializationHelper.read("naivebayes.model");

        evalNB = new Evaluation(testingdata);
        startTime = System.nanoTime();
        evalNB.evaluateModel(nb, testingdata);
        long runningTimeNB = (System.nanoTime() - startTime) / 1000000;
        // runningTimeNB /= 100;
        publish("Done evaluating ANN");

        Platform.runLater(
            new Runnable() {
              @Override
              public void run() {
                bc.getData()
                    .get(0)
                    .getData()
                    .get(0)
                    .setYValue(evalANN.correct() / testingdata.size() * 100);
                bc.getData()
                    .get(0)
                    .getData()
                    .get(1)
                    .setYValue(evalSVM.correct() / testingdata.size() * 100);
                bc.getData()
                    .get(0)
                    .getData()
                    .get(2)
                    .setYValue(evalNB.correct() / testingdata.size() * 100);

                for (int i = 0; i < NUM_CLASSES; i++) {
                  lineChart.getData().get(0).getData().get(i).setYValue(evalANN.recall(i) * 100);
                  lineChart.getData().get(1).getData().get(i).setYValue(evalSVM.recall(i) * 100);
                  lineChart.getData().get(2).getData().get(i).setYValue(evalNB.recall(i) * 100);
                }
              }
            });

        panel.fillConfTable(evalSVM.confusionMatrix());

        summaryTable.setValueAt(evalANN.correct() / testingdata.size() * 100., 0, 1);
        summaryTable.setValueAt(evalSVM.correct() / testingdata.size() * 100, 0, 2);
        summaryTable.setValueAt(evalNB.correct() / testingdata.size() * 100, 0, 3);

        summaryTable.setValueAt(runningTimeANN, 1, 1);
        summaryTable.setValueAt(runningTimeSVM, 1, 2);
        summaryTable.setValueAt(runningTimeNB, 1, 3);

      } catch (Exception e1) {
        // TODO Auto-generated catch block
        e1.printStackTrace();
      }
      return null;
    }
Exemplo n.º 7
0
  public static void main(String[] args) throws Exception {

    /*
     * First we load our preditons from the CSV formatted file.
     */
    CSVLoader predictCsvLoader = new CSVLoader();
    predictCsvLoader.setSource(new File("predict.csv"));

    /*
     * Since we are not using the ARFF format here, we have to give the
     * loader a little bit of information about the data types. Columns
     * 3,8,10 need to be of type string and columns 1,4,11 are nominal
     * types.
     */
    predictCsvLoader.setStringAttributes("3,8,10");
    predictCsvLoader.setNominalAttributes("1,4,11");
    Instances predictDataSet = predictCsvLoader.getDataSet();

    /*
     * Here we set the attribute we want to test the predicitons with
     */
    Attribute testAttribute = predictDataSet.attribute(0);
    predictDataSet.setClass(testAttribute);

    /*
     * We still have to remove all string attributes before we can test
     */
    predictDataSet.deleteStringAttributes();

    /*
     * Next we load the training data from our ARFF file
     */
    ArffLoader trainLoader = new ArffLoader();
    trainLoader.setSource(new File("train.arff"));
    trainLoader.setRetrieval(Loader.BATCH);
    Instances trainDataSet = trainLoader.getDataSet();

    /*
     * Now we tell the data set which attribute we want to classify, in our
     * case, we want to classify the first column: survived
     */
    Attribute trainAttribute = trainDataSet.attribute(0);
    trainDataSet.setClass(trainAttribute);

    /*
     * The RandomForest implementation cannot handle columns of type string,
     * so we remove them for now.
     */
    trainDataSet.deleteStringAttributes();

    /*
     * Now we read in the serialized model from disk
     */
    Classifier classifier = (Classifier) SerializationHelper.read("titanic.model");

    /*
     * Next we will use an Evaluation class to evaluate the performance of
     * our Classifier.
     */
    Evaluation evaluation = new Evaluation(trainDataSet);
    evaluation.evaluateModel(classifier, predictDataSet, new Object[] {});

    /*
     * After we evaluate the Classifier, we write out the summary
     * information to the screen.
     */
    System.out.println(classifier);
    System.out.println(evaluation.toSummaryString());
  }
  public HashMap<String, String> process(
      Sentence sent,
      String dep,
      HashSet<String> terms,
      List<NamedEntity> entities,
      String author,
      String aidx) {
    try {
      // System.out.println("ML start!");
      // System.out.println("List : " + terms);
      HashMap<String, String> ht = new HashMap<String, String>();

      List<NamedEntity> newEntities = new ArrayList<NamedEntity>();
      for (NamedEntity entity : entities) {
        // System.out.println("original: " + entity.entity);
        boolean check = false;

        for (NamedEntity temp : entities) {
          if (entity == temp) continue;

          if (entity.entity.contains(temp.entity)) {
            check = true;
          }
        }

        if (!check) newEntities.add(entity);
      }

      List<DependencyTriple> dtl = getDependencyTripleList(dep);
      List<NamedEntity> targetCands = new ArrayList<NamedEntity>();
      HashMap<NamedEntity, String> tOpinTerm = new HashMap<NamedEntity, String>();
      List<NamedEntity> holderCands = new ArrayList<NamedEntity>();
      HashMap<NamedEntity, String> hOpinTerm = new HashMap<NamedEntity, String>();

      BufferedWriter writer = new BufferedWriter(new FileWriter("weka_target.csv"));
      writer.write("A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,Class\n");

      boolean check = false;
      List<NamedEntity> targetTmp = new ArrayList<NamedEntity>();
      for (NamedEntity entity : newEntities) {
        // System.out.println("extracted: " + entity.entity);
        String temp = getTargetFeatures(entity, author, terms, dtl);
        // System.out.println(temp);
        if (temp.length() > 1) {
          check = true;
          writer.write(temp);
          String[] toks = temp.split("\n");
          for (int i = 0; i < toks.length; i++) {
            targetTmp.add(entity);
            tOpinTerm.put(entity, toks[i].substring(0, toks[i].indexOf(",")));
          }
        }
      }

      writer.close();

      if (check) {
        DataSource source = new DataSource("weka_target.csv");
        Instances testdata = source.getDataSet();
        testdata.setClassIndex(testdata.numAttributes() - 1);

        Classifier models = (Classifier) weka.core.SerializationHelper.read("target_smoreg.model");

        if (testdata.numInstances() != targetTmp.size())
          System.out.println("wrong number of instances");

        for (int i = 0; i < testdata.numInstances(); i++) {
          double pred = models.classifyInstance(testdata.instance(i));
          if (pred >= 1.0) {
            // System.out.println(pred + " , " + targetTmp.get(i).entity);
            targetCands.add(targetTmp.get(i));
          }
        }
      }

      writer = new BufferedWriter(new FileWriter("weka_holder.csv"));
      writer.write("A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,Class\n");

      check = false;
      List<NamedEntity> holderTmp = new ArrayList<NamedEntity>();
      for (NamedEntity entity : newEntities) {
        // System.out.println("extracted: " + entity.entity);
        String temp = getHolderFeatures(entity, author, terms, dtl);
        // System.out.println(temp);
        if (temp.length() > 1) {
          check = true;
          writer.write(temp);
          String[] toks = temp.split("\n");
          for (int i = 0; i < toks.length; i++) {
            holderTmp.add(entity);
            hOpinTerm.put(entity, toks[i].substring(0, toks[i].indexOf(",")));
          }
        }
      }

      writer.close();

      if (check) {
        DataSource source = new DataSource("weka_holder.csv");
        Instances testdata = source.getDataSet();
        testdata.setClassIndex(testdata.numAttributes() - 1);

        Classifier models = (Classifier) weka.core.SerializationHelper.read("holder_smoreg.model");

        if (testdata.numInstances() != holderTmp.size())
          System.out.println("wrong number of instances");

        for (int i = 0; i < testdata.numInstances(); i++) {
          double pred = models.classifyInstance(testdata.instance(i));
          if (pred >= 1.0) {
            // System.out.println(pred + " , " + holderTmp.get(i).entity);
            holderCands.add(holderTmp.get(i));
          }
        }
      }

      if ((targetCands.size() == 0) || (holderCands.size() == 0)) return ht;

      List<NamedEntity> holderCandTmp = new ArrayList<NamedEntity>();
      for (NamedEntity holderCand : holderCands) {
        boolean hasLonger = false;
        for (NamedEntity temp : holderCands) {
          if (temp.entity.compareTo(holderCand.entity) == 0) continue;

          if (temp.entity.contains(holderCand.entity)) {
            hasLonger = true;
            break;
          }
        }

        if (!hasLonger) holderCandTmp.add(holderCand);
      }

      List<NamedEntity> targetCandTmp = new ArrayList<NamedEntity>();
      for (NamedEntity targetCand : targetCands) {
        boolean hasLonger = false;
        for (NamedEntity temp : targetCands) {
          if (temp.entity.compareTo(targetCand.entity) == 0) continue;

          if (temp.entity.contains(targetCand.entity)) {
            hasLonger = true;
            break;
          }
        }

        if (!hasLonger) targetCandTmp.add(targetCand);
      }

      for (NamedEntity targetCand : targetCandTmp) {
        if (targetCand.entity.compareTo(author) == 0) continue;

        for (NamedEntity holderCand : holderCandTmp) {
          if (targetCand.entity.compareTo(holderCand.entity) == 0) continue;

          String targetOpin = tOpinTerm.get(targetCand);
          String holderOpin = hOpinTerm.get(holderCand);

          // System.out.println(targetOpin + ", " + holderOpin);
          if (targetOpin.compareTo(holderOpin) != 0) continue;

          String opin =
              targetOpin
                  .concat("\t")
                  .concat(
                      Integer.toString(sent.sent.indexOf(targetOpin) + sent.beg)
                          .concat("-")
                          .concat(
                              Integer.toString(
                                  sent.sent.indexOf(targetOpin) + sent.beg + targetOpin.length())));

          String holder =
              holderCand
                  .entity
                  .concat("\t")
                  .concat(
                      Integer.toString(holderCand.beg)
                          .concat("-")
                          .concat(Integer.toString(holderCand.end)));
          String target =
              targetCand
                  .entity
                  .concat("\t")
                  .concat(
                      Integer.toString(targetCand.beg)
                          .concat("-")
                          .concat(Integer.toString(targetCand.end)));
          ht.put(targetOpin, opin.concat("\t").concat(holder).concat("\t").concat(target));
        }
      }

      return ht;
    } catch (IOException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
      return null;
    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
      return null;
    }
  }