Example #1
0
  /**
   * 训练
   *
   * @throws Exception
   */
  public void train() throws Exception {

    // 建立字典管理器

    Pipe lpipe = new Target2Label(al);
    Pipe fpipe = new StringArray2IndexArray(factory, true);
    // 构造转换器组
    SeriesPipes pipe = new SeriesPipes(new Pipe[] {lpipe, fpipe});

    InstanceSet instset = new InstanceSet(pipe, factory);
    instset.loadThruStagePipes(new SimpleFileReader(trainFile, " ", true, Type.LabelData));
    Generator gen = new SFGenerator();
    ZeroOneLoss l = new ZeroOneLoss();
    Inferencer ms = new LinearMax(gen, factory.getLabelSize());
    Update update = new LinearMaxPAUpdate(l);
    OnlineTrainer trainer = new OnlineTrainer(ms, update, l, factory, 50, 0.005f);
    Linear pclassifier = trainer.train(instset, instset);
    pipe.removeTargetPipe();
    pclassifier.setPipe(pipe);
    factory.setStopIncrement(true);
    pclassifier.saveTo(modelFile);
  }
Example #2
0
/**
 * 训练分类器
 *
 * @author jszhao
 * @version 1.0
 * @since FudanNLP 1.5
 */
public class ARClassifier {

  static InstanceSet train;
  static InstanceSet test;
  static AlphabetFactory factory = AlphabetFactory.buildFactory();
  static LabelAlphabet al = factory.DefaultLabelAlphabet();
  static String path = null;
  static Pipe pipe;
  /** 训练文件 */
  private String trainFile = "../tmp/ar-train.txt";

  /** 模型文件 */
  private static String modelFile = "../models/ar.m";

  public static void main(String[] args) throws Exception {

    ARClassifier tc = new ARClassifier();
    tc.train();
    Linear cl = Linear.loadFrom(modelFile);
    int i = 0;
    int j = 0;
    double ij = 0.0;
    int kk = 0;
    int jj = 0;
    int nn = 0;
    int n = 0;
    InstanceSet test = new InstanceSet(cl.getPipe(), cl.getAlphabetFactory());
    SimpleFileReader sfr = new SimpleFileReader("../tmp/ar-train.txt", true);

    ArrayList<Instance> list1 = new ArrayList<Instance>();
    while (sfr.hasNext()) {
      list1.add(sfr.next());
    }
    List<String>[] str1 = new List[list1.size()];
    String[] str2 = new String[list1.size()];
    Iterator it = list1.iterator();
    while (it.hasNext()) {
      Instance in = (Instance) it.next();
      str1[i] = (List<String>) in.getData();
      str2[i] = (String) in.getTarget();
      i++;
    }
    for (int k = 0; k < str2.length; k++) {
      if (str2[k].equals("1")) kk++;
    }
    String ss = null;
    test.loadThruPipes(new ListReader(str1));

    for (int ii = 0; ii < str1.length; ii++) {
      ss = cl.getStringLabel(test.getInstance(ii));
      if (ss.equals("1")) j++;

      if (ss.equals("1") && ss.equals(str2[ii])) jj++;
      if (ss.equals("0") && ss.equals(str2[ii])) n++;
      if (ss.equals(str2[ii])) nn++;
    }

    ij = (nn + 0.0) / str2.length;
    System.out.print("整体正确率:" + ij);
    System.out.print('\n');
    ij = (jj + 0.0) / kk;
    System.out.print("判断为指代关系的正确率:" + ij);
    System.out.print('\n');
    ij = (n + 0.0) / (str2.length - kk);
    System.out.print("判断为非指代关系的正确率:" + ij);
    System.out.print('\n');

    System.gc();
  }

  /**
   * 训练
   *
   * @throws Exception
   */
  public void train() throws Exception {

    // 建立字典管理器

    Pipe lpipe = new Target2Label(al);
    Pipe fpipe = new StringArray2IndexArray(factory, true);
    // 构造转换器组
    SeriesPipes pipe = new SeriesPipes(new Pipe[] {lpipe, fpipe});

    InstanceSet instset = new InstanceSet(pipe, factory);
    instset.loadThruStagePipes(new SimpleFileReader(trainFile, " ", true, Type.LabelData));
    Generator gen = new SFGenerator();
    ZeroOneLoss l = new ZeroOneLoss();
    Inferencer ms = new LinearMax(gen, factory.getLabelSize());
    Update update = new LinearMaxPAUpdate(l);
    OnlineTrainer trainer = new OnlineTrainer(ms, update, l, factory, 50, 0.005f);
    Linear pclassifier = trainer.train(instset, instset);
    pipe.removeTargetPipe();
    pclassifier.setPipe(pipe);
    factory.setStopIncrement(true);
    pclassifier.saveTo(modelFile);
  }
}
Example #3
0
  public static void main(String[] args) throws Exception {

    // 建立字典管理器
    AlphabetFactory af = AlphabetFactory.buildFactory();

    // 使用n元特征
    Pipe ngrampp = new NGram(new int[] {2, 3});
    // 将字符特征转换成字典索引
    Pipe indexpp = new StringArray2IndexArray(af);
    // 将目标值对应的索引号作为类别
    Pipe targetpp = new Target2Label(af.DefaultLabelAlphabet());

    // 建立pipe组合
    SeriesPipes pp = new SeriesPipes(new Pipe[] {ngrampp, targetpp, indexpp});

    InstanceSet instset = new InstanceSet(pp, af);

    // 用不同的Reader读取相应格式的文件
    Reader reader = new FileReader(trainDataPath, "UTF-8", ".data");

    // 读入数据,并进行数据处理
    instset.loadThruStagePipes(reader);

    float percent = 0.8f;

    // 将数据集分为训练是和测试集
    InstanceSet[] splitsets = instset.split(percent);

    InstanceSet trainset = splitsets[0];
    InstanceSet testset = splitsets[1];

    /** 建立分类器 */
    OnlineTrainer trainer = new OnlineTrainer(af);
    Linear pclassifier = trainer.train(trainset);
    pp.removeTargetPipe();
    pclassifier.setPipe(pp);
    af.setStopIncrement(true);

    // 将分类器保存到模型文件
    pclassifier.saveTo(modelFile);
    pclassifier = null;

    // 从模型文件读入分类器
    Linear cl = Linear.loadFrom(modelFile);

    // 性能评测
    Evaluation eval = new Evaluation(testset);
    eval.eval(cl, 1);

    /** 测试 */
    System.out.println("类别 : 文本内容");
    System.out.println("===================");
    for (int i = 0; i < testset.size(); i++) {
      Instance data = testset.getInstance(i);

      Integer gold = (Integer) data.getTarget();
      String pred_label = cl.getStringLabel(data);
      String gold_label = cl.getLabel(gold);

      if (pred_label.equals(gold_label))
        System.out.println(pred_label + " : " + testset.getInstance(i).getSource());
      else
        System.err.println(
            gold_label + "->" + pred_label + " : " + testset.getInstance(i).getSource());
    }

    /** 分类器使用 */
    String str = "韦德:不拿冠军就是失败 詹皇:没拿也不意味失败";
    System.out.println("============\n分类:" + str);
    Pipe p = cl.getPipe();
    Instance inst = new Instance(str);
    try {
      // 特征转换
      p.addThruPipe(inst);
    } catch (Exception e) {
      e.printStackTrace();
    }
    String res = cl.getStringLabel(inst);
    System.out.println("类别:" + res);
    // 清除模型文件
    (new File(modelFile)).deleteOnExit();
    // System.exit(0);
  }
  public static void main(String[] args) throws Exception {
    // 分词
    Pipe removepp = new RemoveWords();
    CWSTagger tag = new CWSTagger("../models/seg.m");
    Pipe segpp = new CNPipe(tag);
    Pipe s2spp = new Strings2StringArray();
    /** Bayes */
    // 建立字典管理器
    AlphabetFactory af = AlphabetFactory.buildFactory();
    // 使用n元特征
    Pipe ngrampp = new NGram(new int[] {2, 3});
    // 将字符特征转换成字典索引;
    Pipe sparsepp = new StringArray2SV(af);
    // 将目标值对应的索引号作为类别
    Pipe targetpp = new Target2Label(af.DefaultLabelAlphabet());
    // 建立pipe组合
    SeriesPipes pp = new SeriesPipes(new Pipe[] {removepp, segpp, s2spp, targetpp, sparsepp});

    System.out.print("\nReading data......\n");
    InstanceSet instset = new InstanceSet(pp, af);
    Reader reader = new MyDocumentReader(trainDataPath, "gbk");
    instset.loadThruStagePipes(reader);
    System.out.print("..Reading data complete\n");

    // 将数据集分为训练是和测试集
    System.out.print("Sspliting....");
    float percent = 0.8f;
    InstanceSet[] splitsets = instset.split(percent);

    InstanceSet trainset = splitsets[0];
    InstanceSet testset = splitsets[1];
    System.out.print("..Spliting complete!\n");

    System.out.print("Training...\n");
    af.setStopIncrement(true);
    BayesTrainer trainer = new BayesTrainer();
    BayesClassifier classifier = (BayesClassifier) trainer.train(trainset);
    System.out.print("..Training complete!\n");
    System.out.print("Saving model...\n");
    classifier.saveTo(bayesModelFile);
    classifier = null;
    System.out.print("..Saving model complete!\n");
    /** 测试 */
    System.out.print("Loading model...\n");
    BayesClassifier bayes;
    bayes = BayesClassifier.loadFrom(bayesModelFile);
    System.out.print("..Loading model complete!\n");

    System.out.println("Testing Bayes...");
    int flag = 0;
    float[] percents_cs = new float[] {1.0f, 0.9f, 0.8f, 0.7f, 0.5f, 0.3f, 0.2f, 0.1f};
    int[] counts_cs = new int[10];
    for (int test = 0; test < percents_cs.length; test++) {
      System.out.println("Testing Bayes" + percents_cs[test] + "...");
      if (test != 0) bayes.fS_CS(percents_cs[test]);
      int count = 0;
      for (int i = 0; i < testset.size(); i++) {
        Instance data = testset.getInstance(i);
        Integer gold = (Integer) data.getTarget();
        Predict<String> pres = bayes.classify(data, Type.STRING, 3);
        String pred_label = pres.getLabel();
        String gold_label = bayes.getLabel(gold);

        if (pred_label.equals(gold_label)) {
          count++;
        } else {
          flag = i;
          //					System.err.println(gold_label+"->"+pred_label+" :
          // "+testset.getInstance(i).getTempData());
          //					for(int j=0;j<3;j++)
          //						System.out.println(pres.getLabel(j)+":"+pres.getScore(j));
        }
      }
      counts_cs[test] = count;
      System.out.println(
          "Bayes Precision("
              + percents_cs[test]
              + "):"
              + ((float) count / testset.size())
              + "("
              + count
              + "/"
              + testset.size()
              + ")");
    }
    bayes.noFeatureSelection();
    float[] percents_csmax = new float[] {1.0f, 0.9f, 0.8f, 0.7f, 0.5f, 0.3f, 0.2f, 0.1f};
    int[] counts_csmax = new int[10];
    for (int test = 0; test < percents_csmax.length; test++) {
      System.out.println("Testing Bayes" + percents_csmax[test] + "...");
      if (test != 0) bayes.fS_CS_Max(percents_csmax[test]);
      int count = 0;
      for (int i = 0; i < testset.size(); i++) {
        Instance data = testset.getInstance(i);
        Integer gold = (Integer) data.getTarget();
        Predict<String> pres = bayes.classify(data, Type.STRING, 3);
        String pred_label = pres.getLabel();
        String gold_label = bayes.getLabel(gold);

        if (pred_label.equals(gold_label)) {
          count++;
        } else {
          //					System.err.println(gold_label+"->"+pred_label+" :
          // "+testset.getInstance(i).getTempData());
          //					for(int j=0;j<3;j++)
          //						System.out.println(pres.getLabel(j)+":"+pres.getScore(j));
        }
      }
      counts_csmax[test] = count;
      System.out.println(
          "Bayes Precision("
              + percents_csmax[test]
              + "):"
              + ((float) count / testset.size())
              + "("
              + count
              + "/"
              + testset.size()
              + ")");
    }
    bayes.noFeatureSelection();
    float[] percents_ig = new float[] {1.0f, 0.9f, 0.8f, 0.7f, 0.5f, 0.3f, 0.2f, 0.1f};
    int[] counts_ig = new int[10];
    for (int test = 0; test < percents_ig.length; test++) {
      System.out.println("Testing Bayes" + percents_ig[test] + "...");
      if (test != 0) bayes.fS_IG(percents_ig[test]);
      int count = 0;
      for (int i = 0; i < testset.size(); i++) {
        Instance data = testset.getInstance(i);
        Integer gold = (Integer) data.getTarget();
        Predict<String> pres = bayes.classify(data, Type.STRING, 3);
        String pred_label = pres.getLabel();
        String gold_label = bayes.getLabel(gold);

        if (pred_label.equals(gold_label)) {
          count++;
        } else {
          //					System.err.println(gold_label+"->"+pred_label+" :
          // "+testset.getInstance(i).getTempData());
          //					for(int j=0;j<3;j++)
          //						System.out.println(pres.getLabel(j)+":"+pres.getScore(j));
        }
      }
      counts_ig[test] = count;
      System.out.println(
          "Bayes Precision("
              + percents_csmax[test]
              + "):"
              + ((float) count / testset.size())
              + "("
              + count
              + "/"
              + testset.size()
              + ")");
    }

    System.out.println("..Testing Bayes complete!");
    for (int i = 0; i < percents_cs.length; i++)
      System.out.println(
          "Bayes Precision CS("
              + percents_cs[i]
              + "):"
              + ((float) counts_cs[i] / testset.size())
              + "("
              + counts_cs[i]
              + "/"
              + testset.size()
              + ")");

    for (int i = 0; i < percents_csmax.length; i++)
      System.out.println(
          "Bayes Precision CS_Max("
              + percents_csmax[i]
              + "):"
              + ((float) counts_csmax[i] / testset.size())
              + "("
              + counts_csmax[i]
              + "/"
              + testset.size()
              + ")");

    for (int i = 0; i < percents_ig.length; i++)
      System.out.println(
          "Bayes Precision IG("
              + percents_ig[i]
              + "):"
              + ((float) counts_ig[i] / testset.size())
              + "("
              + counts_ig[i]
              + "/"
              + testset.size()
              + ")");
  }