public static void main(String[] args) throws Exception { // 建立字典管理器 AlphabetFactory af = AlphabetFactory.buildFactory(); // 使用n元特征 Pipe ngrampp = new NGram(new int[] {2, 3}); // 将字符特征转换成字典索引 Pipe indexpp = new StringArray2IndexArray(af); // 将目标值对应的索引号作为类别 Pipe targetpp = new Target2Label(af.DefaultLabelAlphabet()); // 建立pipe组合 SeriesPipes pp = new SeriesPipes(new Pipe[] {ngrampp, targetpp, indexpp}); InstanceSet instset = new InstanceSet(pp, af); // 用不同的Reader读取相应格式的文件 Reader reader = new FileReader(trainDataPath, "UTF-8", ".data"); // 读入数据,并进行数据处理 instset.loadThruStagePipes(reader); float percent = 0.8f; // 将数据集分为训练是和测试集 InstanceSet[] splitsets = instset.split(percent); InstanceSet trainset = splitsets[0]; InstanceSet testset = splitsets[1]; /** 建立分类器 */ OnlineTrainer trainer = new OnlineTrainer(af); Linear pclassifier = trainer.train(trainset); pp.removeTargetPipe(); pclassifier.setPipe(pp); af.setStopIncrement(true); // 将分类器保存到模型文件 pclassifier.saveTo(modelFile); pclassifier = null; // 从模型文件读入分类器 Linear cl = Linear.loadFrom(modelFile); // 性能评测 Evaluation eval = new Evaluation(testset); eval.eval(cl, 1); /** 测试 */ System.out.println("类别 : 文本内容"); System.out.println("==================="); for (int i = 0; i < testset.size(); i++) { Instance data = testset.getInstance(i); Integer gold = (Integer) data.getTarget(); String pred_label = cl.getStringLabel(data); String gold_label = cl.getLabel(gold); if (pred_label.equals(gold_label)) System.out.println(pred_label + " : " + testset.getInstance(i).getSource()); else System.err.println( gold_label + "->" + pred_label + " : " + testset.getInstance(i).getSource()); } /** 分类器使用 */ String str = "韦德:不拿冠军就是失败 詹皇:没拿也不意味失败"; System.out.println("============\n分类:" + str); Pipe p = cl.getPipe(); Instance inst = new Instance(str); try { // 特征转换 p.addThruPipe(inst); } catch (Exception e) { e.printStackTrace(); } String res = cl.getStringLabel(inst); System.out.println("类别:" + res); // 清除模型文件 (new File(modelFile)).deleteOnExit(); // System.exit(0); }
public static void main(String[] args) throws Exception { // 分词 Pipe removepp = new RemoveWords(); CWSTagger tag = new CWSTagger("../models/seg.m"); Pipe segpp = new CNPipe(tag); Pipe s2spp = new Strings2StringArray(); /** Bayes */ // 建立字典管理器 AlphabetFactory af = AlphabetFactory.buildFactory(); // 使用n元特征 Pipe ngrampp = new NGram(new int[] {2, 3}); // 将字符特征转换成字典索引; Pipe sparsepp = new StringArray2SV(af); // 将目标值对应的索引号作为类别 Pipe targetpp = new Target2Label(af.DefaultLabelAlphabet()); // 建立pipe组合 SeriesPipes pp = new SeriesPipes(new Pipe[] {removepp, segpp, s2spp, targetpp, sparsepp}); System.out.print("\nReading data......\n"); InstanceSet instset = new InstanceSet(pp, af); Reader reader = new MyDocumentReader(trainDataPath, "gbk"); instset.loadThruStagePipes(reader); System.out.print("..Reading data complete\n"); // 将数据集分为训练是和测试集 System.out.print("Sspliting...."); float percent = 0.8f; InstanceSet[] splitsets = instset.split(percent); InstanceSet trainset = splitsets[0]; InstanceSet testset = splitsets[1]; System.out.print("..Spliting complete!\n"); System.out.print("Training...\n"); af.setStopIncrement(true); BayesTrainer trainer = new BayesTrainer(); BayesClassifier classifier = (BayesClassifier) trainer.train(trainset); System.out.print("..Training complete!\n"); System.out.print("Saving model...\n"); classifier.saveTo(bayesModelFile); classifier = null; System.out.print("..Saving model complete!\n"); /** 测试 */ System.out.print("Loading model...\n"); BayesClassifier bayes; bayes = BayesClassifier.loadFrom(bayesModelFile); System.out.print("..Loading model complete!\n"); System.out.println("Testing Bayes..."); int flag = 0; float[] percents_cs = new float[] {1.0f, 0.9f, 0.8f, 0.7f, 0.5f, 0.3f, 0.2f, 0.1f}; int[] counts_cs = new int[10]; for (int test = 0; test < percents_cs.length; test++) { System.out.println("Testing Bayes" + percents_cs[test] + "..."); if (test != 0) bayes.fS_CS(percents_cs[test]); int count = 0; for (int i = 0; i < testset.size(); i++) { Instance data = testset.getInstance(i); Integer gold = (Integer) data.getTarget(); Predict<String> pres = bayes.classify(data, Type.STRING, 3); String pred_label = pres.getLabel(); String gold_label = bayes.getLabel(gold); if (pred_label.equals(gold_label)) { count++; } else { flag = i; // System.err.println(gold_label+"->"+pred_label+" : // "+testset.getInstance(i).getTempData()); // for(int j=0;j<3;j++) // System.out.println(pres.getLabel(j)+":"+pres.getScore(j)); } } counts_cs[test] = count; System.out.println( "Bayes Precision(" + percents_cs[test] + "):" + ((float) count / testset.size()) + "(" + count + "/" + testset.size() + ")"); } bayes.noFeatureSelection(); float[] percents_csmax = new float[] {1.0f, 0.9f, 0.8f, 0.7f, 0.5f, 0.3f, 0.2f, 0.1f}; int[] counts_csmax = new int[10]; for (int test = 0; test < percents_csmax.length; test++) { System.out.println("Testing Bayes" + percents_csmax[test] + "..."); if (test != 0) bayes.fS_CS_Max(percents_csmax[test]); int count = 0; for (int i = 0; i < testset.size(); i++) { Instance data = testset.getInstance(i); Integer gold = (Integer) data.getTarget(); Predict<String> pres = bayes.classify(data, Type.STRING, 3); String pred_label = pres.getLabel(); String gold_label = bayes.getLabel(gold); if (pred_label.equals(gold_label)) { count++; } else { // System.err.println(gold_label+"->"+pred_label+" : // "+testset.getInstance(i).getTempData()); // for(int j=0;j<3;j++) // System.out.println(pres.getLabel(j)+":"+pres.getScore(j)); } } counts_csmax[test] = count; System.out.println( "Bayes Precision(" + percents_csmax[test] + "):" + ((float) count / testset.size()) + "(" + count + "/" + testset.size() + ")"); } bayes.noFeatureSelection(); float[] percents_ig = new float[] {1.0f, 0.9f, 0.8f, 0.7f, 0.5f, 0.3f, 0.2f, 0.1f}; int[] counts_ig = new int[10]; for (int test = 0; test < percents_ig.length; test++) { System.out.println("Testing Bayes" + percents_ig[test] + "..."); if (test != 0) bayes.fS_IG(percents_ig[test]); int count = 0; for (int i = 0; i < testset.size(); i++) { Instance data = testset.getInstance(i); Integer gold = (Integer) data.getTarget(); Predict<String> pres = bayes.classify(data, Type.STRING, 3); String pred_label = pres.getLabel(); String gold_label = bayes.getLabel(gold); if (pred_label.equals(gold_label)) { count++; } else { // System.err.println(gold_label+"->"+pred_label+" : // "+testset.getInstance(i).getTempData()); // for(int j=0;j<3;j++) // System.out.println(pres.getLabel(j)+":"+pres.getScore(j)); } } counts_ig[test] = count; System.out.println( "Bayes Precision(" + percents_csmax[test] + "):" + ((float) count / testset.size()) + "(" + count + "/" + testset.size() + ")"); } System.out.println("..Testing Bayes complete!"); for (int i = 0; i < percents_cs.length; i++) System.out.println( "Bayes Precision CS(" + percents_cs[i] + "):" + ((float) counts_cs[i] / testset.size()) + "(" + counts_cs[i] + "/" + testset.size() + ")"); for (int i = 0; i < percents_csmax.length; i++) System.out.println( "Bayes Precision CS_Max(" + percents_csmax[i] + "):" + ((float) counts_csmax[i] / testset.size()) + "(" + counts_csmax[i] + "/" + testset.size() + ")"); for (int i = 0; i < percents_ig.length; i++) System.out.println( "Bayes Precision IG(" + percents_ig[i] + "):" + ((float) counts_ig[i] / testset.size()) + "(" + counts_ig[i] + "/" + testset.size() + ")"); }