Пример #1
0
  public void train(int[] instanceLengths, int[] ignore, String trainfile, File train_forest)
      throws IOException {

    int i = 0;
    for (i = 0; i < options.numIters; i++) {

      System.out.print(" Iteration " + i);
      System.out.print("[");

      long start = System.currentTimeMillis();

      trainingIter(instanceLengths, ignore, trainfile, train_forest, i + 1);

      long end = System.currentTimeMillis();
      // System.out.println("Training iter took: " + (end-start));
      System.out.println("|Time:" + (end - start) + "]");
    }
    params.averageParams(i * countActualInstances(ignore));
    //	 afm 06-04-08
    if (options.separateLab) {
      LabelClassifier oc =
          new LabelClassifier(
              options, instanceLengths, ignore, trainfile, train_forest, this, pipe);
      try {
        classifier = oc.trainClassifier(100);
      } catch (Exception e) {
        e.printStackTrace();
      }
    }
  }
Пример #2
0
  public void loadModel(String file) throws Exception {
    ObjectInputStream in = new ObjectInputStream(new FileInputStream(file));
    params.parameters = (double[]) in.readObject();
    pipe.dataAlphabet = (Alphabet) in.readObject();
    pipe.typeAlphabet = (Alphabet) in.readObject();

    // afm 06-04-08
    if (options.separateLab) {
      classifier = (Classifier) in.readObject();
    }

    in.close();
    pipe.closeAlphabets();
  }
Пример #3
0
  // Note: Change this to pass -1 for indices in instanceLengths[] that you
  // don't want to use on training (need to be careful because i is being used
  // in the for loop; need new index)
  private void trainingIter(
      int[] instanceLengths, int ignore[], String trainfile, File train_forest, int iter)
      throws IOException {

    int numUpd = 0;
    ObjectInputStream in = new ObjectInputStream(new FileInputStream(train_forest));
    boolean evaluateI = true;

    int numInstances = instanceLengths.length;

    // afm -- Count the real number of instances to be considered
    int numActualInstances = countActualInstances(ignore);

    int j = 0;
    for (int i = 0; i < numInstances; i++) {
      if ((i + 1) % 500 == 0) {
        System.out.print((i + 1) + ",");
        // System.out.println("  "+(i+1)+" instances");
      }

      int length = instanceLengths[i];

      // Get production crap.
      FeatureVector[][][] fvs = new FeatureVector[length][length][2];
      double[][][] probs = new double[length][length][2];
      FeatureVector[][][][] nt_fvs = new FeatureVector[length][pipe.types.length][2][2];
      double[][][][] nt_probs = new double[length][pipe.types.length][2][2];
      FeatureVector[][][] fvs_trips = new FeatureVector[length][length][length];
      double[][][] probs_trips = new double[length][length][length];
      FeatureVector[][][] fvs_sibs = new FeatureVector[length][length][2];
      double[][][] probs_sibs = new double[length][length][2];

      DependencyInstance inst;

      if (options.secondOrder) {
        inst =
            ((DependencyPipe2O) pipe)
                .readInstance(
                    in,
                    length,
                    fvs,
                    probs,
                    fvs_trips,
                    probs_trips,
                    fvs_sibs,
                    probs_sibs,
                    nt_fvs,
                    nt_probs,
                    params);
      } else inst = pipe.readInstance(in, length, fvs, probs, nt_fvs, nt_probs, params);

      // afm 03-06-08
      if (ignore[i] != 0) // This sentence is to be ignored
      continue;

      double upd =
          (double)
              (options.numIters * numActualInstances
                  - (numActualInstances * (iter - 1) + (j + 1))
                  + 1);
      int K = options.trainK;
      Object[][] d = null;
      if (options.decodeType.equals("proj")) {
        if (options.secondOrder)
          d =
              ((DependencyDecoder2O) decoder)
                  .decodeProjective(
                      inst,
                      fvs,
                      probs,
                      fvs_trips,
                      probs_trips,
                      fvs_sibs,
                      probs_sibs,
                      nt_fvs,
                      nt_probs,
                      K);
        else d = decoder.decodeProjective(inst, fvs, probs, nt_fvs, nt_probs, K);
      }
      if (options.decodeType.equals("non-proj")) {
        if (options.secondOrder)
          d =
              ((DependencyDecoder2O) decoder)
                  .decodeNonProjective(
                      inst,
                      fvs,
                      probs,
                      fvs_trips,
                      probs_trips,
                      fvs_sibs,
                      probs_sibs,
                      nt_fvs,
                      nt_probs,
                      K);
        else d = decoder.decodeNonProjective(inst, fvs, probs, nt_fvs, nt_probs, K);
      }
      params.updateParamsMIRA(inst, d, upd);
      j++;
    }

    // System.out.println("");
    // System.out.println("  "+numInstances+" instances");

    System.out.print(numActualInstances);

    in.close();
  }