예제 #1
0
  /**
   * 训练
   *
   * @throws Exception
   */
  public void train() throws Exception {

    // 建立字典管理器

    Pipe lpipe = new Target2Label(al);
    Pipe fpipe = new StringArray2IndexArray(factory, true);
    // 构造转换器组
    SeriesPipes pipe = new SeriesPipes(new Pipe[] {lpipe, fpipe});

    InstanceSet instset = new InstanceSet(pipe, factory);
    instset.loadThruStagePipes(new SimpleFileReader(trainFile, " ", true, Type.LabelData));
    Generator gen = new SFGenerator();
    ZeroOneLoss l = new ZeroOneLoss();
    Inferencer ms = new LinearMax(gen, factory.getLabelSize());
    Update update = new LinearMaxPAUpdate(l);
    OnlineTrainer trainer = new OnlineTrainer(ms, update, l, factory, 50, 0.005f);
    Linear pclassifier = trainer.train(instset, instset);
    pipe.removeTargetPipe();
    pclassifier.setPipe(pipe);
    factory.setStopIncrement(true);
    pclassifier.saveTo(modelFile);
  }
예제 #2
0
  public static void main(String[] args) throws Exception {

    // 建立字典管理器
    AlphabetFactory af = AlphabetFactory.buildFactory();

    // 使用n元特征
    Pipe ngrampp = new NGram(new int[] {2, 3});
    // 将字符特征转换成字典索引
    Pipe indexpp = new StringArray2IndexArray(af);
    // 将目标值对应的索引号作为类别
    Pipe targetpp = new Target2Label(af.DefaultLabelAlphabet());

    // 建立pipe组合
    SeriesPipes pp = new SeriesPipes(new Pipe[] {ngrampp, targetpp, indexpp});

    InstanceSet instset = new InstanceSet(pp, af);

    // 用不同的Reader读取相应格式的文件
    Reader reader = new FileReader(trainDataPath, "UTF-8", ".data");

    // 读入数据,并进行数据处理
    instset.loadThruStagePipes(reader);

    float percent = 0.8f;

    // 将数据集分为训练是和测试集
    InstanceSet[] splitsets = instset.split(percent);

    InstanceSet trainset = splitsets[0];
    InstanceSet testset = splitsets[1];

    /** 建立分类器 */
    OnlineTrainer trainer = new OnlineTrainer(af);
    Linear pclassifier = trainer.train(trainset);
    pp.removeTargetPipe();
    pclassifier.setPipe(pp);
    af.setStopIncrement(true);

    // 将分类器保存到模型文件
    pclassifier.saveTo(modelFile);
    pclassifier = null;

    // 从模型文件读入分类器
    Linear cl = Linear.loadFrom(modelFile);

    // 性能评测
    Evaluation eval = new Evaluation(testset);
    eval.eval(cl, 1);

    /** 测试 */
    System.out.println("类别 : 文本内容");
    System.out.println("===================");
    for (int i = 0; i < testset.size(); i++) {
      Instance data = testset.getInstance(i);

      Integer gold = (Integer) data.getTarget();
      String pred_label = cl.getStringLabel(data);
      String gold_label = cl.getLabel(gold);

      if (pred_label.equals(gold_label))
        System.out.println(pred_label + " : " + testset.getInstance(i).getSource());
      else
        System.err.println(
            gold_label + "->" + pred_label + " : " + testset.getInstance(i).getSource());
    }

    /** 分类器使用 */
    String str = "韦德:不拿冠军就是失败 詹皇:没拿也不意味失败";
    System.out.println("============\n分类:" + str);
    Pipe p = cl.getPipe();
    Instance inst = new Instance(str);
    try {
      // 特征转换
      p.addThruPipe(inst);
    } catch (Exception e) {
      e.printStackTrace();
    }
    String res = cl.getStringLabel(inst);
    System.out.println("类别:" + res);
    // 清除模型文件
    (new File(modelFile)).deleteOnExit();
    // System.exit(0);
  }