Пример #1
0
  public DocumentStream(long ts, String rootDir, FeedSettings settings) {
    this.ts = ts;
    this.rootDir = rootDir + "/" + ts + "/";
    // this.maxWordsPerTopic =maxWordsPerTopic;
    this.settings = settings;

    list = LdaModel.createInstanceList(this.ts);
    testing = new InstanceList(list.getPipe());
    lda = new LdaModel();
  }
Пример #2
0
  public SVM train(InstanceList trainingList) {
    svm_problem problem = new svm_problem();
    problem.l = trainingList.size();
    problem.x = new svm_node[problem.l][];
    problem.y = new double[problem.l];

    for (int i = 0; i < trainingList.size(); i++) {
      Instance instance = trainingList.get(i);
      svm_node[] input = SVM.getSvmNodes(instance);
      if (input == null) {
        continue;
      }
      int labelIndex = ((Label) instance.getTarget()).getIndex();
      problem.x[i] = input;
      problem.y[i] = labelIndex;
    }

    int max_index = trainingList.getDataAlphabet().size();

    if (param.gamma == 0 && max_index > 0) {
      param.gamma = 1.0 / max_index;
    }

    // int numLabels = trainingList.getTargetAlphabet().size();
    // int[] weight_label = new int[numLabels];
    // double[] weight = trainingList.targetLabelDistribution().getValues();
    // double minValue = Double.MAX_VALUE;
    //
    // for (int i = 0; i < weight.length; i++) {
    // if (minValue > weight[i]) {
    // minValue = weight[i];
    // }
    // }
    //
    // for (int i = 0; i < weight.length; i++) {
    // weight_label[i] = i;
    // weight[i] = weight[i] / minValue;
    // }
    //
    // param.weight_label = weight_label;
    // param.weight = weight;

    String error_msg = svm.svm_check_parameter(problem, param);

    if (error_msg != null) {
      System.err.print("Error: " + error_msg + "\n");
      System.exit(1);
    }

    svm_model model = svm.svm_train(problem, param);

    classifier = new SVM(model, trainingList.getPipe());

    return classifier;
  }
Пример #3
0
 /**
  * Create and train a CRF model from the given training data, optionally testing it on the given
  * test data.
  *
  * @param training training data
  * @param testing test data (possibly <code>null</code>)
  * @param eval accuracy evaluator (possibly <code>null</code>)
  * @param orders label Markov orders (main and backoff)
  * @param defaultLabel default label
  * @param forbidden regular expression specifying impossible label transitions <em>current</em>
  *     <code>,</code><em>next</em> (<code>null</code> indicates no forbidden transitions)
  * @param allowed regular expression specifying allowed label transitions (<code>null</code>
  *     indicates everything is allowed that is not forbidden)
  * @param connected whether to include even transitions not occurring in the training data.
  * @param iterations number of training iterations
  * @param var Gaussian prior variance
  * @return the trained model
  */
 public static CRF train(
     InstanceList training,
     InstanceList testing,
     TransducerEvaluator eval,
     int[] orders,
     String defaultLabel,
     String forbidden,
     String allowed,
     boolean connected,
     int iterations,
     double var,
     CRF crf) {
   Pattern forbiddenPat = Pattern.compile(forbidden);
   Pattern allowedPat = Pattern.compile(allowed);
   if (crf == null) {
     crf = new CRF(training.getPipe(), (Pipe) null);
     String startName =
         crf.addOrderNStates(
             training, orders, null, defaultLabel, forbiddenPat, allowedPat, connected);
     CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf);
     crft.setGaussianPriorVariance(var);
     for (int i = 0; i < crf.numStates(); i++)
       crf.getState(i).setInitialWeight(Transducer.IMPOSSIBLE_WEIGHT);
     crf.getState(startName).setInitialWeight(0.0);
   }
   logger.info("Training on " + training.size() + " instances");
   if (testing != null) logger.info("Testing on " + testing.size() + " instances");
   CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf);
   if (featureInductionOption.value) {
     crft.trainWithFeatureInduction(
         training, null, testing, eval, iterations, 10, 20, 500, 0.5, false, null);
   } else {
     boolean converged;
     for (int i = 1; i <= iterations; i++) {
       converged = crft.train(training, 1);
       if (i % 1 == 0 && eval != null) // Change the 1 to higher integer to evaluate less often
       eval.evaluate(crft);
       if (viterbiOutputOption.value && i % 10 == 0)
         new ViterbiWriter(
                 "", new InstanceList[] {training, testing}, new String[] {"training", "testing"})
             .evaluate(crft);
       if (converged) break;
     }
   }
   return crf;
 }
  void showInterpolatedTCAccuracy(
      InstanceList trainingInstanceList, InstanceList testingInstanceList) {
    trainer = new MyClassifierTrainer(new RankMaxEntTrainer());
    RankMaxEnt generalClassifier = (RankMaxEnt) trainer.train(trainingInstanceList);

    InstanceList[] trainingInstanceLists = new InstanceList[3];
    InstanceList[] testingInstanceLists = new InstanceList[3];

    for (int i = 0; i < 3; i++) {
      trainingInstanceLists[i] = new InstanceList(trainingInstanceList.getPipe());
      testingInstanceLists[i] = new InstanceList(testingInstanceList.getPipe());
    }

    for (Instance instance : trainingInstanceList) {
      Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
      Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line());
      String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase();
      String category = connAnalyzer.getCategory(conn);
      if (category == null) category = "Conj-adverbial";

      if (category.startsWith("Coord")) {
        trainingInstanceLists[0].add(instance);
      } else if (category.startsWith("Sub")) {
        trainingInstanceLists[1].add(instance);
      } else {
        trainingInstanceLists[2].add(instance);
      }
    }
    for (Instance instance : testingInstanceList) {
      Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
      Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line());
      String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase();
      String category = connAnalyzer.getCategory(conn);
      if (category == null) category = "Conj-adverbial";

      if (category.startsWith("Coord")) {
        testingInstanceLists[0].add(instance);
      } else if (category.startsWith("Sub")) {
        testingInstanceLists[1].add(instance);
      } else {
        testingInstanceLists[2].add(instance);
      }
    }

    MyClassifierTrainer trainers[] = new MyClassifierTrainer[3];
    RankMaxEnt classifiers[] = new RankMaxEnt[3];

    double total = 0;
    double correct = 0;
    for (int i = 0; i < 3; i++) {
      trainers[i] = new MyClassifierTrainer(new RankMaxEntTrainer());
      classifiers[i] = (RankMaxEnt) trainers[i].train(trainingInstanceLists[i]);
      total += testingInstanceLists[i].size();
      // correct += getAccuracy(classifiers[i], testingInstanceLists[i]) *
      // testingInstanceLists[i].size(); //accuracy * total
      for (Instance instance : testingInstanceLists[i]) {
        Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
        int trueIndex = rankInstance.trueArg1Candidate;
        double genScores[] = new double[((FeatureVectorSequence) instance.getData()).size()];
        generalClassifier.getClassificationScores(instance, genScores);
        double tcScores[] = new double[((FeatureVectorSequence) instance.getData()).size()];
        classifiers[i].getClassificationScores(instance, tcScores);
        double max = 0;
        int maxIndex = -1;

        for (int j = 0; j < genScores.length; j++) {
          double score = genScores[j] * 0.4 + tcScores[j] * 0.6;
          if (score > max) {
            max = score;
            maxIndex = j;
          }
        }
        if (maxIndex == trueIndex) {
          correct++;
        }
      }
    }

    System.out.println("Using interpolated model:");
    System.out.println("Total: " + total);
    System.out.println("Correct: " + correct);
    System.out.println("Accuracy: " + correct / total);
  }
  double[] getTypeSpecificAccuracy(
      InstanceList trainingInstanceList, InstanceList testingInstanceList, boolean show) {
    InstanceList[] trainingInstanceLists = new InstanceList[3];
    InstanceList[] testingInstanceLists = new InstanceList[3];

    for (int i = 0; i < 3; i++) {
      trainingInstanceLists[i] = new InstanceList(trainingInstanceList.getPipe());
      testingInstanceLists[i] = new InstanceList(testingInstanceList.getPipe());
    }

    for (Instance instance : trainingInstanceList) {
      Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
      Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line());
      String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase();
      String category = connAnalyzer.getCategory(conn);
      if (category == null) category = "Conj-adverbial";

      if (category.startsWith("Coord")) {
        trainingInstanceLists[0].add(instance);
      } else if (category.startsWith("Sub")) {
        trainingInstanceLists[1].add(instance);
      } else {
        trainingInstanceLists[2].add(instance);
      }
    }
    for (Instance instance : testingInstanceList) {
      Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
      Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line());
      String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase();
      String category = connAnalyzer.getCategory(conn);
      if (category == null) category = "Conj-adverbial";

      if (category.startsWith("Coord")) {
        testingInstanceLists[0].add(instance);
      } else if (category.startsWith("Sub")) {
        testingInstanceLists[1].add(instance);
      } else {
        testingInstanceLists[2].add(instance);
      }
    }

    MyClassifierTrainer trainers[] = new MyClassifierTrainer[3];
    Classifier classifiers[] = new Classifier[3];

    double total = 0;
    double correct = 0;
    for (int i = 0; i < 3; i++) {
      trainers[i] = new MyClassifierTrainer(new RankMaxEntTrainer());
      classifiers[i] = trainers[i].train(trainingInstanceLists[i]);
      total += testingInstanceLists[i].size();
      correct +=
          getAccuracy(classifiers[i], testingInstanceLists[i])
              * testingInstanceLists[i].size(); // accuracy * total
    }
    if (show) {
      System.out.println("Using type specific models:");
      System.out.println("Total: " + total);
      System.out.println("Correct: " + correct);
      System.out.println("Accuracy: " + correct / total);
    }
    return new double[] {total, correct, 1.0 * correct / total};
  }
Пример #6
0
 public InstanceList getInstances() {
   InstanceList ret = new InstanceList(m_ilist.getPipe());
   for (int ii = 0; ii < m_instIndices.length; ii++) ret.add(m_ilist.get(m_instIndices[ii]));
   return ret;
 }
Пример #7
0
  public void cleanup() {

    list = new InstanceList(list.getPipe());
  }