public weka.classifiers.Classifier getClassifier() throws Exception { StringToWordVector stwv = new StringToWordVector(); stwv.setTFTransform(hasParam(Constant.RUNTIME_PARAMS.USE_TFIDF)); stwv.setIDFTransform(hasParam(Constant.RUNTIME_PARAMS.USE_TFIDF)); stwv.setLowerCaseTokens(hasParam(Constant.RUNTIME_PARAMS.CONV_LOWERCASE)); stwv.setUseStoplist(hasParam(Constant.RUNTIME_PARAMS.REM_STOP_WORDS)); stwv.setOutputWordCounts(hasParam(Constant.RUNTIME_PARAMS.USE_WORD_FREQ)); if (hasParam(Constant.RUNTIME_PARAMS.TRAIN_AND_TEST)) stwv.setInputFormat(getTrainData()); if (hasParam(Constant.RUNTIME_PARAMS.USE_BIGRAM)) { NGramTokenizer tokenizer = new NGramTokenizer(); tokenizer.setNGramMinSize(2); stwv.setTokenizer(tokenizer); } else if (hasParam(Constant.RUNTIME_PARAMS.USE_TRIGRAM)) { NGramTokenizer tokenizer = new NGramTokenizer(); tokenizer.setNGramMinSize(3); stwv.setTokenizer(tokenizer); } if (hasParam(Constant.RUNTIME_PARAMS.USE_STEMMER)) { SnowballStemmer stemmer = new SnowballStemmer("porter"); stwv.setStemmer(stemmer); } Logistic l = new Logistic(); FilteredClassifier cls = new FilteredClassifier(); cls.setClassifier(l); cls.setFilter(stwv); if (hasParam(Constant.RUNTIME_PARAMS.TRAIN_AND_TEST)) cls.buildClassifier(getTrainData()); return cls; }
public static void main(String[] args) { if (args.length < 1) { System.out.println("usage: C4_5TweetTopicCategorization <root_path>"); System.exit(-1); } String rootPath = args[0]; File dataFolder = new File(rootPath + "/data"); String resultFolderPath = rootPath + "/results/C4_5/"; CrisisMailer crisisMailer = CrisisMailer.getCrisisMailer(); Logger logger = Logger.getLogger(C4_5TweetTopicCategorization.class); PropertyConfigurator.configure(Constants.LOG4J_PROPERTIES_FILE_PATH); File resultFolder = new File(resultFolderPath); if (!resultFolder.exists()) resultFolder.mkdir(); CSVLoader csvLoader = new CSVLoader(); try { for (File dataSetName : dataFolder.listFiles()) { Instances data = null; try { csvLoader.setSource(dataSetName); csvLoader.setStringAttributes("2"); data = csvLoader.getDataSet(); } catch (IOException ioe) { logger.error(ioe); crisisMailer.sendEmailAlert(ioe); System.exit(-1); } data.setClassIndex(data.numAttributes() - 1); data.deleteWithMissingClass(); Instances vectorizedData = null; StringToWordVector stringToWordVectorFilter = new StringToWordVector(); try { stringToWordVectorFilter.setInputFormat(data); stringToWordVectorFilter.setAttributeIndices("2"); stringToWordVectorFilter.setIDFTransform(true); stringToWordVectorFilter.setLowerCaseTokens(true); stringToWordVectorFilter.setOutputWordCounts(false); stringToWordVectorFilter.setUseStoplist(true); vectorizedData = Filter.useFilter(data, stringToWordVectorFilter); vectorizedData.deleteAttributeAt(0); // System.out.println(vectorizedData); } catch (Exception exception) { logger.error(exception); crisisMailer.sendEmailAlert(exception); System.exit(-1); } J48 j48Classifier = new J48(); /* FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(stringToWordVectorFilter); filteredClassifier.setClassifier(j48Classifier); */ try { Evaluation eval = new Evaluation(vectorizedData); eval.crossValidateModel( j48Classifier, vectorizedData, 5, new Random(System.currentTimeMillis())); FileOutputStream resultOutputStream = new FileOutputStream(new File(resultFolderPath + dataSetName.getName())); resultOutputStream.write(eval.toSummaryString("=== Summary ===", false).getBytes()); resultOutputStream.write(eval.toMatrixString().getBytes()); resultOutputStream.write(eval.toClassDetailsString().getBytes()); resultOutputStream.close(); } catch (Exception exception) { logger.error(exception); crisisMailer.sendEmailAlert(exception); System.exit(-1); } } } catch (Exception exception) { logger.error(exception); crisisMailer.sendEmailAlert(exception); System.out.println(-1); } }