Пример #1
0
  /**
   * 用当前模型在测试集上进行测试 输出正确率
   *
   * @param testSet
   */
  public void test(InstanceSet testSet) {

    double err = 0;
    double errorAll = 0;
    int total = 0;
    for (int i = 0; i < testSet.size(); i++) {
      Instance inst = testSet.getInstance(i);
      total += ((int[]) inst.getTarget()).length;
      Results pred = (Results) msolver.getBest(inst, 1);
      double l = loss.calc(pred.getPredAt(0), inst.getTarget());
      if (l > 0) { // 预测错误
        errorAll += 1.0;
        err += l;
      }
    }
    if (!simpleOutput) {
      System.out.print("Test:\t");
      System.out.print(total - err);
      System.out.print('/');
      System.out.print(total);
      System.out.print("\tTag acc:");
    } else {
      System.out.print('\t');
    }
    System.out.print(1 - err / total);
    if (!simpleOutput) {
      System.out.print("\tSentence acc:");
      System.out.println(1 - errorAll / testSet.size());
    }
    System.out.println();
  }
Пример #2
0
  /** 训练 */
  public Linear train(InstanceSet trainingList, InstanceSet testList) {
    int numSamples = trainingList.size();
    count = 0;
    for (int ii = 0; ii < trainingList.size(); ii++) {
      Instance inst = trainingList.getInstance(ii);
      count += ((int[]) inst.getTarget()).length;
    }

    System.out.println("Chars Number: " + count);

    double oldErrorRate = Double.MAX_VALUE;

    // 开始循环
    long beginTime, endTime;
    long beginTimeIter, endTimeIter;
    beginTime = System.currentTimeMillis();
    double pE = 0;
    int iter = 0;
    int frac = numSamples / 10;
    while (iter++ < maxIter) {
      if (!simpleOutput) {
        System.out.print("iter:");
        System.out.print(iter + "\t");
      }
      double err = 0;
      double errorAll = 0;
      beginTimeIter = System.currentTimeMillis();
      int progress = frac;
      for (int ii = 0; ii < numSamples; ii++) {
        Instance inst = trainingList.getInstance(ii);
        Results pred = (Results) msolver.getBest(inst, 1);
        double l = loss.calc(pred.getPredAt(0), inst.getTarget());
        if (l > 0) { // 预测错误,更新权重
          errorAll += 1.0;
          err += l;
          update.update(inst, weights, pred.getPredAt(0), c);
        } else {
          if (pred.size() > 1) update.update(inst, weights, pred.getPredAt(1), c);
        }
        if (!simpleOutput && ii % progress == 0) { // 显示进度
          System.out.print('.');
          progress += frac;
        }
      }
      double errRate = err / count;

      endTimeIter = System.currentTimeMillis();
      if (!simpleOutput) {
        System.out.println("\ttime:" + (endTimeIter - beginTimeIter) / 1000.0 + "s");
        System.out.print("Train:");
        System.out.print("\tTag acc:");
      }
      System.out.print(1 - errRate);
      if (!simpleOutput) {
        System.out.print("\tSentence acc:");
        System.out.print(1 - errorAll / numSamples);
        System.out.println();
      }
      if (testList != null) {
        test(testList);
      }
      if (Math.abs(errRate - oldErrorRate) < eps) {
        System.out.println("Convergence!");
        break;
      }
      oldErrorRate = errRate;
      if (interim) {
        Linear p = new Linear(msolver, trainingList.getAlphabetFactory());
        try {
          p.saveTo("tmp.model");
        } catch (IOException e) {
          System.err.println("write model error!");
        }
      }
      if (isOptimized) { // 模型优化,去掉不显著的特征
        int[] idx = MyArrays.getTop(weights.clone(), threshold, false);
        System.out.print("Opt: weight numbers: " + MyArrays.countNoneZero(weights));
        MyArrays.set(weights, idx, 0.0);
        System.out.println(" -> " + MyArrays.countNoneZero(weights));
      }
    }
    endTime = System.currentTimeMillis();
    System.out.println("done!");
    System.out.println("time escape:" + (endTime - beginTime) / 1000.0 + "s");
    Linear p = new Linear(msolver, trainingList.getAlphabetFactory());
    return p;
  }