public weka.classifiers.Classifier getClassifier() throws Exception { StringToWordVector stwv = new StringToWordVector(); stwv.setTFTransform(hasParam(Constant.RUNTIME_PARAMS.USE_TFIDF)); stwv.setIDFTransform(hasParam(Constant.RUNTIME_PARAMS.USE_TFIDF)); stwv.setLowerCaseTokens(hasParam(Constant.RUNTIME_PARAMS.CONV_LOWERCASE)); stwv.setUseStoplist(hasParam(Constant.RUNTIME_PARAMS.REM_STOP_WORDS)); stwv.setOutputWordCounts(hasParam(Constant.RUNTIME_PARAMS.USE_WORD_FREQ)); if (hasParam(Constant.RUNTIME_PARAMS.TRAIN_AND_TEST)) stwv.setInputFormat(getTrainData()); if (hasParam(Constant.RUNTIME_PARAMS.USE_BIGRAM)) { NGramTokenizer tokenizer = new NGramTokenizer(); tokenizer.setNGramMinSize(2); stwv.setTokenizer(tokenizer); } else if (hasParam(Constant.RUNTIME_PARAMS.USE_TRIGRAM)) { NGramTokenizer tokenizer = new NGramTokenizer(); tokenizer.setNGramMinSize(3); stwv.setTokenizer(tokenizer); } if (hasParam(Constant.RUNTIME_PARAMS.USE_STEMMER)) { SnowballStemmer stemmer = new SnowballStemmer("porter"); stwv.setStemmer(stemmer); } Logistic l = new Logistic(); FilteredClassifier cls = new FilteredClassifier(); cls.setClassifier(l); cls.setFilter(stwv); if (hasParam(Constant.RUNTIME_PARAMS.TRAIN_AND_TEST)) cls.buildClassifier(getTrainData()); return cls; }
/** 分类过程 */ public double classifyMessage(String message) throws Exception { filter.input(makeInstance(message, instances.stringFreeStructure())); Instance filteredInstance = filter.output(); // 必须使用原来的filter double predicted = classifier.classifyInstance(filteredInstance); // (int)predicted是类标索引 // System.out.println("Message classified as : " // + instances.classAttribute().value((int) predicted)); return predicted; }
public void filterData() throws Exception { Instances data = source.getDataSet(); StringToWordVector stv = new StringToWordVector(); stv.setOptions( weka.core.Utils.splitOptions( "-R first-last -W 1000 " + "-prune-rate -1.0 -N 0 " + "-stemmer weka.core.stemmers.NullStemmer -M 1 " + "-tokenizer \"weka.core.tokenizers.WordTokenizer -delimiters \\\" \\r\\n\\t.,;:\\\'\\\"()?!\"")); stv.setInputFormat(data); Instances newdata = Filter.useFilter(data, stv); this.inst = newdata; this.inst.setClassIndex(0); }
@Override public void crossValidation(String traindata) throws Exception { DataSource ds = new DataSource(traindata); Instances instances = ds.getDataSet(); StringToWordVector stv = new StringToWordVector(); stv.setOptions( weka.core.Utils.splitOptions( "-R first-last -W 1000 " + "-prune-rate -1.0 -N 0 " + "-stemmer weka.core.stemmers.NullStemmer -M 1 " + "-tokenizer \"weka.core.tokenizers.WordTokenizer -delimiters \\\" \\r\\n\\t.,;:\\\'\\\"()?!\"")); stv.setInputFormat(instances); instances = Filter.useFilter(instances, stv); instances.setClassIndex(0); Evaluation eval = new Evaluation(instances); eval.crossValidateModel(this.classifier, instances, 10, new Random(1)); System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); }
/** * Make data sets and train and test model * * @param filePathTrain * @param filePathTest * @param gram */ public static void makeDataSet(String filePathTrain, String filePathTest, int gram) { TextDirectoryLoader loader = new TextDirectoryLoader(); try { loader.setDirectory(new File(filePathTrain)); Instances dataRawTrain = loader.getDataSet(); loader.setDirectory(new File(filePathTest)); Instances dataRawTest = loader.getDataSet(); StringToWordVector filter = new StringToWordVector(); NGramTokenizer tokeniser = new NGramTokenizer(); tokeniser.setNGramMinSize(gram); tokeniser.setNGramMaxSize(gram); filter.setTokenizer(tokeniser); filter.setInputFormat(dataRawTrain); Instances train = Filter.useFilter(dataRawTrain, filter); // filter.setInputFormat(dataRawTest); Instances test = Filter.useFilter(dataRawTest, filter); /** * * * * <p>Replace this function each time to change models */ trainModelNaiveBayes(train, test); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } }
public static void main(String[] args) { if (args.length < 1) { System.out.println("usage: C4_5TweetTopicCategorization <root_path>"); System.exit(-1); } String rootPath = args[0]; File dataFolder = new File(rootPath + "/data"); String resultFolderPath = rootPath + "/results/C4_5/"; CrisisMailer crisisMailer = CrisisMailer.getCrisisMailer(); Logger logger = Logger.getLogger(C4_5TweetTopicCategorization.class); PropertyConfigurator.configure(Constants.LOG4J_PROPERTIES_FILE_PATH); File resultFolder = new File(resultFolderPath); if (!resultFolder.exists()) resultFolder.mkdir(); CSVLoader csvLoader = new CSVLoader(); try { for (File dataSetName : dataFolder.listFiles()) { Instances data = null; try { csvLoader.setSource(dataSetName); csvLoader.setStringAttributes("2"); data = csvLoader.getDataSet(); } catch (IOException ioe) { logger.error(ioe); crisisMailer.sendEmailAlert(ioe); System.exit(-1); } data.setClassIndex(data.numAttributes() - 1); data.deleteWithMissingClass(); Instances vectorizedData = null; StringToWordVector stringToWordVectorFilter = new StringToWordVector(); try { stringToWordVectorFilter.setInputFormat(data); stringToWordVectorFilter.setAttributeIndices("2"); stringToWordVectorFilter.setIDFTransform(true); stringToWordVectorFilter.setLowerCaseTokens(true); stringToWordVectorFilter.setOutputWordCounts(false); stringToWordVectorFilter.setUseStoplist(true); vectorizedData = Filter.useFilter(data, stringToWordVectorFilter); vectorizedData.deleteAttributeAt(0); // System.out.println(vectorizedData); } catch (Exception exception) { logger.error(exception); crisisMailer.sendEmailAlert(exception); System.exit(-1); } J48 j48Classifier = new J48(); /* FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(stringToWordVectorFilter); filteredClassifier.setClassifier(j48Classifier); */ try { Evaluation eval = new Evaluation(vectorizedData); eval.crossValidateModel( j48Classifier, vectorizedData, 5, new Random(System.currentTimeMillis())); FileOutputStream resultOutputStream = new FileOutputStream(new File(resultFolderPath + dataSetName.getName())); resultOutputStream.write(eval.toSummaryString("=== Summary ===", false).getBytes()); resultOutputStream.write(eval.toMatrixString().getBytes()); resultOutputStream.write(eval.toClassDetailsString().getBytes()); resultOutputStream.close(); } catch (Exception exception) { logger.error(exception); crisisMailer.sendEmailAlert(exception); System.exit(-1); } } } catch (Exception exception) { logger.error(exception); crisisMailer.sendEmailAlert(exception); System.out.println(-1); } }
private double[] classify(String test) { String[] lab = { "I.2", "I.3", "I.5", "I.6", "I.2.1", "I.2.6", "I.2.8", "I.3.5", "I.3.6", "I.3.7", "I.5.1", "I.5.2", "I.5.4", "I.6.3", "I.6.5", "I.6.8", }; int NSel = 1000; // Number of selection Filter[] filters = new Filter[2]; double[] x = new double[16]; double[] prd = new double[16]; double clsLabel; Ranker rank = new Ranker(); Evaluation eval = null; StringToWordVector stwv = new StringToWordVector(); weka.filters.supervised.attribute.AttributeSelection featSel = new weka.filters.supervised.attribute.AttributeSelection(); WordTokenizer wtok = new WordTokenizer(); String delim = " \r\n\t.,;:'\"()?!$*-&[]+/|\\"; InfoGainAttributeEval ig = new InfoGainAttributeEval(); String[] stwvOpts; wtok.setDelimiters(delim); Instances[] dataRaw = new Instances[10000]; DataSource[] source = new DataSource[16]; String str; Instances testset = null; DataSource testsrc = null; try { testsrc = new DataSource(test); testset = testsrc.getDataSet(); } catch (Exception e1) { // TODO Auto-generated catch block e1.printStackTrace(); } for (int j = 0; j < 16; j++) // 16 element 0-15 { try { str = lab[j]; source[j] = new DataSource( "D:/Users/nma1g11/workspace2/WebScraperFlatNew/dataPernode/new/" + str + ".arff"); dataRaw[j] = source[j].getDataSet(); } catch (Exception e) { e.printStackTrace(); } System.out.println(lab[j]); if (dataRaw[j].classIndex() == -1) dataRaw[j].setClassIndex(dataRaw[j].numAttributes() - 1); } if (testset.classIndex() == -1) testset.setClassIndex(testset.numAttributes() - 1); try { stwvOpts = weka.core.Utils.splitOptions( "-R first-last -W 1000000 -prune-rate -1.0 -C -T -I -N 1 -L -S -stemmer weka.core.stemmers.LovinsStemmer -M 2 "); stwv.setOptions(stwvOpts); stwv.setTokenizer(wtok); rank.setOptions(weka.core.Utils.splitOptions("-T -1.7976931348623157E308 -N 100")); rank.setNumToSelect(NSel); featSel.setEvaluator(ig); featSel.setSearch(rank); } catch (Exception e) { e.printStackTrace(); } filters[0] = stwv; filters[1] = featSel; System.out.println("Loading is Done!"); MultiFilter mfilter = new MultiFilter(); mfilter.setFilters(filters); FilteredClassifier classify = new FilteredClassifier(); classify.setClassifier( new NaiveBayesMultinomial()); ///////// Algorithm of The Classification ///////// classify.setFilter(mfilter); String ss2 = ""; try { Classifier[] clsArr = new Classifier[16]; clsArr = Classifier.makeCopies(classify, 16); String strcls = ""; List<String> clsList = new ArrayList<String>(); String s = null; String newcls = null; String lb = ""; String prev = ""; boolean flag = false; String Ocls = null; int q = 0; for (int i = 0; i < 16; i++) { for (int k = 0; k < testset.numInstances(); k++) { flag = false; s = testset.instance(k).stringValue(1); clsList.add(s); if (lab[i].equals(s)) { flag = true; newcls = s; } } clsArr[i].buildClassifier(dataRaw[i]); eval = new Evaluation(dataRaw[i]); for (int j = 0; j < testset.numInstances(); j++) { Ocls = testset.instance(j).stringValue(1); if (flag && !s.equals(null)) testset.instance(j).setClassValue(lab[i]); // ----------------------------------------- strcls = testset.instance(j).stringValue(1); if (i < 4) { if (strcls.substring(0, 3).equals(lab[i])) testset.instance(j).setClassValue(lab[i]); } else if (lab[i].substring(0, 3).equals(strcls)) testset.instance(j).setClassValue(lab[i]); // ------------------------------------------------ System.out.println( dataRaw[i].classAttribute().value(i) + " --- > Correct%:" + eval.pctCorrect() + " F-measure:" + eval.fMeasure(i)); if (!prev.equals(testset.instance(j).stringValue(0)) || !lab[i].equals(lb)) { clsLabel = clsArr[i].classifyInstance(testset.instance(j)); x = clsArr[i].distributionForInstance(testset.instance(j)); prd[i] = x[i]; System.out.println(" --- > prob: " + clsLabel); System.out.println(" --- > x :" + x[i]); System.out.println(clsLabel + " --> " + testset.classAttribute().value((int) clsLabel)); } testset.instance(j).setClassValue(Ocls); prev = testset.instance(j).stringValue(0); lb = lab[i]; } System.out.println("Done with " + lab[i].replace("99", "") + " !!!!!!!!!!!"); } System.out.println(eval.correct()); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } return prd; }
/** * 文本分类要特别一点,因为在使用StringToWordVector对象计算文本中词项(attribute)权重的时候需要用到全局变量,比如DF,所以这里需要批量处理 * 在weka中要注意有些机器学习算法是批处理有些不是 */ public void finishBatch() throws Exception { filter.setIDFTransform(true); filter.setInputFormat(instances); Instances filteredData = Filter.useFilter(instances, filter); // 这才真正产生符合weka算法输入格式的数据集 classifier.buildClassifier(filteredData); // 真正的训练分类器 }