public static void main(String[] args) throws Exception { ARClassifier tc = new ARClassifier(); tc.train(); Linear cl = Linear.loadFrom(modelFile); int i = 0; int j = 0; double ij = 0.0; int kk = 0; int jj = 0; int nn = 0; int n = 0; InstanceSet test = new InstanceSet(cl.getPipe(), cl.getAlphabetFactory()); SimpleFileReader sfr = new SimpleFileReader("../tmp/ar-train.txt", true); ArrayList<Instance> list1 = new ArrayList<Instance>(); while (sfr.hasNext()) { list1.add(sfr.next()); } List<String>[] str1 = new List[list1.size()]; String[] str2 = new String[list1.size()]; Iterator it = list1.iterator(); while (it.hasNext()) { Instance in = (Instance) it.next(); str1[i] = (List<String>) in.getData(); str2[i] = (String) in.getTarget(); i++; } for (int k = 0; k < str2.length; k++) { if (str2[k].equals("1")) kk++; } String ss = null; test.loadThruPipes(new ListReader(str1)); for (int ii = 0; ii < str1.length; ii++) { ss = cl.getStringLabel(test.getInstance(ii)); if (ss.equals("1")) j++; if (ss.equals("1") && ss.equals(str2[ii])) jj++; if (ss.equals("0") && ss.equals(str2[ii])) n++; if (ss.equals(str2[ii])) nn++; } ij = (nn + 0.0) / str2.length; System.out.print("整体正确率:" + ij); System.out.print('\n'); ij = (jj + 0.0) / kk; System.out.print("判断为指代关系的正确率:" + ij); System.out.print('\n'); ij = (n + 0.0) / (str2.length - kk); System.out.print("判断为非指代关系的正确率:" + ij); System.out.print('\n'); System.gc(); }
/** * 训练 * * @throws Exception */ public void train() throws Exception { // 建立字典管理器 Pipe lpipe = new Target2Label(al); Pipe fpipe = new StringArray2IndexArray(factory, true); // 构造转换器组 SeriesPipes pipe = new SeriesPipes(new Pipe[] {lpipe, fpipe}); InstanceSet instset = new InstanceSet(pipe, factory); instset.loadThruStagePipes(new SimpleFileReader(trainFile, " ", true, Type.LabelData)); Generator gen = new SFGenerator(); ZeroOneLoss l = new ZeroOneLoss(); Inferencer ms = new LinearMax(gen, factory.getLabelSize()); Update update = new LinearMaxPAUpdate(l); OnlineTrainer trainer = new OnlineTrainer(ms, update, l, factory, 50, 0.005f); Linear pclassifier = trainer.train(instset, instset); pipe.removeTargetPipe(); pclassifier.setPipe(pipe); factory.setStopIncrement(true); pclassifier.saveTo(modelFile); }
public static void main(String[] args) throws Exception { // 建立字典管理器 AlphabetFactory af = AlphabetFactory.buildFactory(); // 使用n元特征 Pipe ngrampp = new NGram(new int[] {2, 3}); // 将字符特征转换成字典索引 Pipe indexpp = new StringArray2IndexArray(af); // 将目标值对应的索引号作为类别 Pipe targetpp = new Target2Label(af.DefaultLabelAlphabet()); // 建立pipe组合 SeriesPipes pp = new SeriesPipes(new Pipe[] {ngrampp, targetpp, indexpp}); InstanceSet instset = new InstanceSet(pp, af); // 用不同的Reader读取相应格式的文件 Reader reader = new FileReader(trainDataPath, "UTF-8", ".data"); // 读入数据,并进行数据处理 instset.loadThruStagePipes(reader); float percent = 0.8f; // 将数据集分为训练是和测试集 InstanceSet[] splitsets = instset.split(percent); InstanceSet trainset = splitsets[0]; InstanceSet testset = splitsets[1]; /** 建立分类器 */ OnlineTrainer trainer = new OnlineTrainer(af); Linear pclassifier = trainer.train(trainset); pp.removeTargetPipe(); pclassifier.setPipe(pp); af.setStopIncrement(true); // 将分类器保存到模型文件 pclassifier.saveTo(modelFile); pclassifier = null; // 从模型文件读入分类器 Linear cl = Linear.loadFrom(modelFile); // 性能评测 Evaluation eval = new Evaluation(testset); eval.eval(cl, 1); /** 测试 */ System.out.println("类别 : 文本内容"); System.out.println("==================="); for (int i = 0; i < testset.size(); i++) { Instance data = testset.getInstance(i); Integer gold = (Integer) data.getTarget(); String pred_label = cl.getStringLabel(data); String gold_label = cl.getLabel(gold); if (pred_label.equals(gold_label)) System.out.println(pred_label + " : " + testset.getInstance(i).getSource()); else System.err.println( gold_label + "->" + pred_label + " : " + testset.getInstance(i).getSource()); } /** 分类器使用 */ String str = "韦德:不拿冠军就是失败 詹皇:没拿也不意味失败"; System.out.println("============\n分类:" + str); Pipe p = cl.getPipe(); Instance inst = new Instance(str); try { // 特征转换 p.addThruPipe(inst); } catch (Exception e) { e.printStackTrace(); } String res = cl.getStringLabel(inst); System.out.println("类别:" + res); // 清除模型文件 (new File(modelFile)).deleteOnExit(); // System.exit(0); }