コード例 #1
0
  public void train(String[] trainSections) throws IOException {
    pipe = defaultPipe();
    InstanceList trainingInstanceList = prepareInstanceList(trainSections);
    /*NFoldEvaluator evaluator = new NFoldEvaluator();
    evaluator.evaluate(trainer, trainingInstanceList, 10); */
    /*InstanceList[] instanceLists =
            trainingInstanceList.splitInOrder(new double[]{0.9, 0.1});
    Classifier classifier = trainer.train(instanceLists[0]);
    showAccuracy(classifier, instanceLists[1]);*/

    // showNFoldAccuracy(trainingInstanceList, 10, 2633);
    trainer.train(trainingInstanceList);

    // showNFoldTypeSpecificAccuracy(trainingInstanceList, 10);
  }
コード例 #2
0
  public void train(String[] trainSections, String[] testSections) throws IOException {
    pipe = defaultPipe();
    InstanceList trainingInstanceList = prepareInstanceList(trainSections);
    InstanceList testingInstanceList = prepareInstanceList(testSections);

    // Classifier classifier = trainer.train(trainingInstanceList, testingInstanceList);

    Classifier classifier = trainer.train(trainingInstanceList);

    System.out.println("training size: " + trainingInstanceList.size());
    System.out.println("testing size: " + testingInstanceList.size());

    // showAccuracy(classifier, testingInstanceList);

    // getTypeSpecificAccuracy(trainingInstanceList, testingInstanceList, true);
    // showInterpolatedTCAccuracy(trainingInstanceList, testingInstanceList);
  }
コード例 #3
0
  private void showNFoldAccuracy(InstanceList instanceList, int n, int count) {
    InstanceList.CrossValidationIterator cvIt = instanceList.crossValidationIterator(n);
    double accuracies[] = new double[n];
    double accuracy = 0;
    int run = 0;
    double totalTP = 0;
    while (cvIt.hasNext()) {
      InstanceList[] nextSplit = cvIt.nextSplit();
      InstanceList trainingInstances = nextSplit[0];
      InstanceList testingInstances = nextSplit[1];

      trainer = new MyClassifierTrainer(new RankMaxEntTrainer());
      Classifier classifier = trainer.train(trainingInstances);
      accuracies[run] = getAccuracy(classifier, testingInstances);
      accuracy += accuracies[run];
      totalTP += accuracies[run] * testingInstances.size();
      run++;
    }
    System.out.println(n + "-Fold accuracy(avg): " + accuracy / n);
    System.out.println("Total tp:" + totalTP);
    System.out.println("Total count:" + count);
    System.out.println(n + "-Fold accuracy: " + totalTP / count);
  }
コード例 #4
0
 public void save(String fileName) {
   trainer.saveModel(fileName);
 }
コード例 #5
0
  void showInterpolatedTCAccuracy(
      InstanceList trainingInstanceList, InstanceList testingInstanceList) {
    trainer = new MyClassifierTrainer(new RankMaxEntTrainer());
    RankMaxEnt generalClassifier = (RankMaxEnt) trainer.train(trainingInstanceList);

    InstanceList[] trainingInstanceLists = new InstanceList[3];
    InstanceList[] testingInstanceLists = new InstanceList[3];

    for (int i = 0; i < 3; i++) {
      trainingInstanceLists[i] = new InstanceList(trainingInstanceList.getPipe());
      testingInstanceLists[i] = new InstanceList(testingInstanceList.getPipe());
    }

    for (Instance instance : trainingInstanceList) {
      Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
      Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line());
      String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase();
      String category = connAnalyzer.getCategory(conn);
      if (category == null) category = "Conj-adverbial";

      if (category.startsWith("Coord")) {
        trainingInstanceLists[0].add(instance);
      } else if (category.startsWith("Sub")) {
        trainingInstanceLists[1].add(instance);
      } else {
        trainingInstanceLists[2].add(instance);
      }
    }
    for (Instance instance : testingInstanceList) {
      Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
      Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line());
      String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase();
      String category = connAnalyzer.getCategory(conn);
      if (category == null) category = "Conj-adverbial";

      if (category.startsWith("Coord")) {
        testingInstanceLists[0].add(instance);
      } else if (category.startsWith("Sub")) {
        testingInstanceLists[1].add(instance);
      } else {
        testingInstanceLists[2].add(instance);
      }
    }

    MyClassifierTrainer trainers[] = new MyClassifierTrainer[3];
    RankMaxEnt classifiers[] = new RankMaxEnt[3];

    double total = 0;
    double correct = 0;
    for (int i = 0; i < 3; i++) {
      trainers[i] = new MyClassifierTrainer(new RankMaxEntTrainer());
      classifiers[i] = (RankMaxEnt) trainers[i].train(trainingInstanceLists[i]);
      total += testingInstanceLists[i].size();
      // correct += getAccuracy(classifiers[i], testingInstanceLists[i]) *
      // testingInstanceLists[i].size(); //accuracy * total
      for (Instance instance : testingInstanceLists[i]) {
        Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
        int trueIndex = rankInstance.trueArg1Candidate;
        double genScores[] = new double[((FeatureVectorSequence) instance.getData()).size()];
        generalClassifier.getClassificationScores(instance, genScores);
        double tcScores[] = new double[((FeatureVectorSequence) instance.getData()).size()];
        classifiers[i].getClassificationScores(instance, tcScores);
        double max = 0;
        int maxIndex = -1;

        for (int j = 0; j < genScores.length; j++) {
          double score = genScores[j] * 0.4 + tcScores[j] * 0.6;
          if (score > max) {
            max = score;
            maxIndex = j;
          }
        }
        if (maxIndex == trueIndex) {
          correct++;
        }
      }
    }

    System.out.println("Using interpolated model:");
    System.out.println("Total: " + total);
    System.out.println("Correct: " + correct);
    System.out.println("Accuracy: " + correct / total);
  }