Ejemplo n.º 1
0
  public static void main(String[] args) throws Exception {

    ARClassifier tc = new ARClassifier();
    tc.train();
    Linear cl = Linear.loadFrom(modelFile);
    int i = 0;
    int j = 0;
    double ij = 0.0;
    int kk = 0;
    int jj = 0;
    int nn = 0;
    int n = 0;
    InstanceSet test = new InstanceSet(cl.getPipe(), cl.getAlphabetFactory());
    SimpleFileReader sfr = new SimpleFileReader("../tmp/ar-train.txt", true);

    ArrayList<Instance> list1 = new ArrayList<Instance>();
    while (sfr.hasNext()) {
      list1.add(sfr.next());
    }
    List<String>[] str1 = new List[list1.size()];
    String[] str2 = new String[list1.size()];
    Iterator it = list1.iterator();
    while (it.hasNext()) {
      Instance in = (Instance) it.next();
      str1[i] = (List<String>) in.getData();
      str2[i] = (String) in.getTarget();
      i++;
    }
    for (int k = 0; k < str2.length; k++) {
      if (str2[k].equals("1")) kk++;
    }
    String ss = null;
    test.loadThruPipes(new ListReader(str1));

    for (int ii = 0; ii < str1.length; ii++) {
      ss = cl.getStringLabel(test.getInstance(ii));
      if (ss.equals("1")) j++;

      if (ss.equals("1") && ss.equals(str2[ii])) jj++;
      if (ss.equals("0") && ss.equals(str2[ii])) n++;
      if (ss.equals(str2[ii])) nn++;
    }

    ij = (nn + 0.0) / str2.length;
    System.out.print("整体正确率:" + ij);
    System.out.print('\n');
    ij = (jj + 0.0) / kk;
    System.out.print("判断为指代关系的正确率:" + ij);
    System.out.print('\n');
    ij = (n + 0.0) / (str2.length - kk);
    System.out.print("判断为非指代关系的正确率:" + ij);
    System.out.print('\n');

    System.gc();
  }
Ejemplo n.º 2
0
  /**
   * 训练
   *
   * @throws Exception
   */
  public void train() throws Exception {

    // 建立字典管理器

    Pipe lpipe = new Target2Label(al);
    Pipe fpipe = new StringArray2IndexArray(factory, true);
    // 构造转换器组
    SeriesPipes pipe = new SeriesPipes(new Pipe[] {lpipe, fpipe});

    InstanceSet instset = new InstanceSet(pipe, factory);
    instset.loadThruStagePipes(new SimpleFileReader(trainFile, " ", true, Type.LabelData));
    Generator gen = new SFGenerator();
    ZeroOneLoss l = new ZeroOneLoss();
    Inferencer ms = new LinearMax(gen, factory.getLabelSize());
    Update update = new LinearMaxPAUpdate(l);
    OnlineTrainer trainer = new OnlineTrainer(ms, update, l, factory, 50, 0.005f);
    Linear pclassifier = trainer.train(instset, instset);
    pipe.removeTargetPipe();
    pclassifier.setPipe(pipe);
    factory.setStopIncrement(true);
    pclassifier.saveTo(modelFile);
  }
Ejemplo n.º 3
0
  public static void main(String[] args) throws Exception {

    // 建立字典管理器
    AlphabetFactory af = AlphabetFactory.buildFactory();

    // 使用n元特征
    Pipe ngrampp = new NGram(new int[] {2, 3});
    // 将字符特征转换成字典索引
    Pipe indexpp = new StringArray2IndexArray(af);
    // 将目标值对应的索引号作为类别
    Pipe targetpp = new Target2Label(af.DefaultLabelAlphabet());

    // 建立pipe组合
    SeriesPipes pp = new SeriesPipes(new Pipe[] {ngrampp, targetpp, indexpp});

    InstanceSet instset = new InstanceSet(pp, af);

    // 用不同的Reader读取相应格式的文件
    Reader reader = new FileReader(trainDataPath, "UTF-8", ".data");

    // 读入数据,并进行数据处理
    instset.loadThruStagePipes(reader);

    float percent = 0.8f;

    // 将数据集分为训练是和测试集
    InstanceSet[] splitsets = instset.split(percent);

    InstanceSet trainset = splitsets[0];
    InstanceSet testset = splitsets[1];

    /** 建立分类器 */
    OnlineTrainer trainer = new OnlineTrainer(af);
    Linear pclassifier = trainer.train(trainset);
    pp.removeTargetPipe();
    pclassifier.setPipe(pp);
    af.setStopIncrement(true);

    // 将分类器保存到模型文件
    pclassifier.saveTo(modelFile);
    pclassifier = null;

    // 从模型文件读入分类器
    Linear cl = Linear.loadFrom(modelFile);

    // 性能评测
    Evaluation eval = new Evaluation(testset);
    eval.eval(cl, 1);

    /** 测试 */
    System.out.println("类别 : 文本内容");
    System.out.println("===================");
    for (int i = 0; i < testset.size(); i++) {
      Instance data = testset.getInstance(i);

      Integer gold = (Integer) data.getTarget();
      String pred_label = cl.getStringLabel(data);
      String gold_label = cl.getLabel(gold);

      if (pred_label.equals(gold_label))
        System.out.println(pred_label + " : " + testset.getInstance(i).getSource());
      else
        System.err.println(
            gold_label + "->" + pred_label + " : " + testset.getInstance(i).getSource());
    }

    /** 分类器使用 */
    String str = "韦德:不拿冠军就是失败 詹皇:没拿也不意味失败";
    System.out.println("============\n分类:" + str);
    Pipe p = cl.getPipe();
    Instance inst = new Instance(str);
    try {
      // 特征转换
      p.addThruPipe(inst);
    } catch (Exception e) {
      e.printStackTrace();
    }
    String res = cl.getStringLabel(inst);
    System.out.println("类别:" + res);
    // 清除模型文件
    (new File(modelFile)).deleteOnExit();
    // System.exit(0);
  }