public static void main(String[] args) throws Exception { /* * First we load the test data from our ARFF file */ ArffLoader testLoader = new ArffLoader(); testLoader.setSource(new File("data/titanic/test.arff")); testLoader.setRetrieval(Loader.BATCH); Instances testDataSet = testLoader.getDataSet(); /* * Now we tell the data set which attribute we want to classify, in our * case, we want to classify the first column: survived */ Attribute testAttribute = testDataSet.attribute(0); testDataSet.setClass(testAttribute); testDataSet.deleteStringAttributes(); /* * Now we read in the serialized model from disk */ Classifier classifier = (Classifier) SerializationHelper.read("data/titanic/titanic.model"); /* * This part may be a little confusing. We load up the test data again * so we have a prediction data set to populate. As we iterate over the * first data set we also iterate over the second data set. After an * instance is classified, we set the value of the prediction data set * to be the value of the classification */ ArffLoader test1Loader = new ArffLoader(); test1Loader.setSource(new File("data/titanic/test.arff")); Instances test1DataSet = test1Loader.getDataSet(); Attribute test1Attribute = test1DataSet.attribute(0); test1DataSet.setClass(test1Attribute); /* * Now we iterate over the test data and classify each entry and set the * value of the 'survived' column to the result of the classification */ Enumeration testInstances = testDataSet.enumerateInstances(); Enumeration test1Instances = test1DataSet.enumerateInstances(); while (testInstances.hasMoreElements()) { Instance instance = (Instance) testInstances.nextElement(); Instance instance1 = (Instance) test1Instances.nextElement(); double classification = classifier.classifyInstance(instance); instance1.setClassValue(classification); } /* * Now we want to write out our predictions. The resulting file is in a * format suitable to submit to Kaggle. */ CSVSaver predictedCsvSaver = new CSVSaver(); predictedCsvSaver.setFile(new File("data/titanic/predict.csv")); predictedCsvSaver.setInstances(test1DataSet); predictedCsvSaver.writeBatch(); System.out.println("Prediciton saved to predict.csv"); }
/** * 用分类器测试 * * @param trainFileName * @param testFileName */ public static void classify(String trainFileName, String testFileName) { try { File inputFile = new File(fileName + trainFileName); // 训练语料文件 ArffLoader atf = new ArffLoader(); atf.setFile(inputFile); Instances instancesTrain = atf.getDataSet(); // 读入训练文件 // 设置类标签类 inputFile = new File(fileName + testFileName); // 测试语料文件 atf.setFile(inputFile); Instances instancesTest = atf.getDataSet(); // 读入测试文件 instancesTest.setClassIndex(instancesTest.numAttributes() - 1); instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1); classifier = (Classifier) Class.forName(CLASSIFIERNAME).newInstance(); classifier.buildClassifier(instancesTrain); Evaluation eval = new Evaluation(instancesTrain); // 第一个为一个训练过的分类器,第二个参数是在某个数据集上评价的数据集 eval.evaluateModel(classifier, instancesTest); System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); System.out.println("precision is :" + (1 - eval.errorRate())); } catch (Exception e) { e.printStackTrace(); } }
/** * Parses a given list of options. * * <p> * <!-- options-start --> * Valid options are: * * <p> * * <pre> -i <the input file> * The input file</pre> * * <pre> -o <the output file> * The output file</pre> * * <pre> -c <the class index> * The class index</pre> * * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String outputString = Utils.getOption('o', options); String inputString = Utils.getOption('i', options); String indexString = Utils.getOption('c', options); ArffLoader loader = new ArffLoader(); resetOptions(); // parse index int index = -1; if (indexString.length() != 0) { if (indexString.equals("first")) index = 0; else { if (indexString.equals("last")) index = -1; else index = Integer.parseInt(indexString); } } if (inputString.length() != 0) { try { File input = new File(inputString); loader.setFile(input); Instances inst = loader.getDataSet(); if (index == -1) inst.setClassIndex(inst.numAttributes() - 1); else inst.setClassIndex(index); setInstances(inst); } catch (Exception ex) { throw new IOException( "No data set loaded. Data set has to be arff format (Reason: " + ex.toString() + ")."); } } else throw new IOException("No data set to save."); if (outputString.length() != 0) { // add appropriate file extension if (!outputString.endsWith(getFileExtension())) { if (outputString.lastIndexOf('.') != -1) outputString = (outputString.substring(0, outputString.lastIndexOf('.'))) + getFileExtension(); else outputString = outputString + getFileExtension(); } try { File output = new File(outputString); setFile(output); } catch (Exception ex) { throw new IOException("Cannot create output file."); } } if (index == -1) index = getInstances().numAttributes() - 1; getInstances().setClassIndex(index); }
private void loadExistingData() { if (isExternalStorageAvailable()) { try { if (file.exists()) { ArffLoader loader = new ArffLoader(); loader.setFile(file); Instances existingData = loader.getDataSet(); addManyInstances(existingData); } } catch (IOException e) { e.printStackTrace(); } } }
public static void main(String args[]) throws Exception { ArffLoader trainLoader = new ArffLoader(); trainLoader.setSource(new File("src/train.arff")); trainLoader.setRetrieval(Loader.BATCH); Instances trainDataSet = trainLoader.getDataSet(); weka.core.Attribute trainAttribute = trainDataSet.attribute("class"); trainDataSet.setClass(trainAttribute); // trainDataSet.deleteStringAttributes(); NaiveBayes classifier = new NaiveBayes(); final double startTime = System.currentTimeMillis(); classifier.buildClassifier(trainDataSet); final double endTime = System.currentTimeMillis(); double executionTime = (endTime - startTime) / (1000.0); System.out.println("Total execution time: " + executionTime); SerializationHelper.write("NaiveBayes.model", classifier); System.out.println("Saved trained model to classifier.model"); }
public static void main(String[] args) throws Exception { // NaiveBayesSimple nb = new NaiveBayesSimple(); // BufferedReader br_train = new BufferedReader(new FileReader("src/train.arff.txt")); // String s = null; // long st_time = System.currentTimeMillis(); // Instances inst_train = new Instances(br_train); // System.out.println(inst_train.numAttributes()); // inst_train.setClassIndex(inst_train.numAttributes()-1); // System.out.println("train time"+(System.currentTimeMillis()-st_time)); // NaiveBayes nb1 = new NaiveBayes(); // nb1.buildClassifier(inst_train); // br_train.close(); long st_time = System.currentTimeMillis(); st_time = System.currentTimeMillis(); Classifier classifier = (Classifier) SerializationHelper.read("NaiveBayes.model"); // BufferedReader br_test = new BufferedReader(new FileReader("src/test.arff.txt")); // Instances inst_test = new Instances(br_test); // inst_test.setClassIndex(inst_test.numAttributes()-1); // System.out.println("test time"+(System.currentTimeMillis()-st_time)); // ArffLoader testLoader = new ArffLoader(); testLoader.setSource(new File("src/test.arff")); testLoader.setRetrieval(Loader.BATCH); Instances testDataSet = testLoader.getDataSet(); Attribute testAttribute = testDataSet.attribute("class"); testDataSet.setClass(testAttribute); int correct = 0; int incorrect = 0; FastVector attInfo = new FastVector(); attInfo.addElement(new Attribute("Id")); attInfo.addElement(new Attribute("Category")); Instances outputInstances = new Instances("predict", attInfo, testDataSet.numInstances()); Enumeration testInstances = testDataSet.enumerateInstances(); int index = 1; while (testInstances.hasMoreElements()) { Instance instance = (Instance) testInstances.nextElement(); double classification = classifier.classifyInstance(instance); Instance predictInstance = new Instance(outputInstances.numAttributes()); predictInstance.setValue(0, index++); predictInstance.setValue(1, (int) classification + 1); outputInstances.add(predictInstance); } System.out.println("Correct Instance: " + correct); System.out.println("IncCorrect Instance: " + incorrect); double accuracy = (double) (correct) / (double) (correct + incorrect); System.out.println("Accuracy: " + accuracy); CSVSaver predictedCsvSaver = new CSVSaver(); predictedCsvSaver.setFile(new File("predict.csv")); predictedCsvSaver.setInstances(outputInstances); predictedCsvSaver.writeBatch(); System.out.println("Prediciton saved to predict.csv"); }
public double getLiblinear(String path, String train, String test) { // 本次精确度 double accuracy = 0.0; try { LibLINEAR c1 = new LibLINEAR(); // * String[] options=weka.core.Utils.splitOptions( // * "-S 1 -C 1.0 -E 0.001 -B 0"); c1.setOptions(options); ArffLoader atf = new ArffLoader(); File TraininputFile = new File(train); atf.setFile(TraininputFile); // 训练语料文件 Instances instancesTrain = atf.getDataSet(); // 读入训练文件 instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1); File TestinputFile = new File(test); atf.setFile(TestinputFile); // 测试语料文件 Instances instancesTest = atf.getDataSet(); // 读入测试文件 // 设置分类属性所在行号(第一行为0号),instancesTest.numAttributes()可以取得属性总数 instancesTest.setClassIndex(instancesTest.numAttributes() - 1); c1.buildClassifier(instancesTrain); // 训练 Evaluation eval = new Evaluation(instancesTrain); eval.evaluateModel(c1, instancesTest); // eval.crossValidateModel(c1, instancesTrain, 10, new // Random(1)); File newfile = new File(path + "OutLiblinear_temp" + ".txt"); BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(newfile), "utf-8")); bufferedWriter.write(eval.toSummaryString() + "\r\n"); bufferedWriter.write(eval.toClassDetailsString() + "\r\n"); bufferedWriter.write(eval.toMatrixString() + "\r\n"); bufferedWriter.flush(); bufferedWriter.close(); BufferedReader bufferedReader = new BufferedReader(new FileReader(newfile)); String[] splitLineString = new String[5]; while (bufferedReader.ready()) { bufferedReader.readLine(); String lineString = bufferedReader.readLine(); splitLineString = lineString.split(" "); System.out.println(splitLineString[4]); break; } bufferedReader.close(); // 求分类准确度 String tempLine; BufferedReader tempBF = new BufferedReader(new FileReader(newfile)); while (tempBF.ready()) { tempLine = tempBF.readLine(); if (tempLine.contains("Correctly Classified Instances")) { tempLine = tempLine.substring(tempLine.lastIndexOf(".") - 2, tempLine.lastIndexOf(" ")); accuracy = Double.parseDouble(tempLine); break; } } tempBF.close(); } catch (Exception e) { System.out.println("Can't run linlinear of weka."); } return accuracy; }
public static void main(String[] args) throws Exception { /* * First we load our preditons from the CSV formatted file. */ CSVLoader predictCsvLoader = new CSVLoader(); predictCsvLoader.setSource(new File("predict.csv")); /* * Since we are not using the ARFF format here, we have to give the * loader a little bit of information about the data types. Columns * 3,8,10 need to be of type string and columns 1,4,11 are nominal * types. */ predictCsvLoader.setStringAttributes("3,8,10"); predictCsvLoader.setNominalAttributes("1,4,11"); Instances predictDataSet = predictCsvLoader.getDataSet(); /* * Here we set the attribute we want to test the predicitons with */ Attribute testAttribute = predictDataSet.attribute(0); predictDataSet.setClass(testAttribute); /* * We still have to remove all string attributes before we can test */ predictDataSet.deleteStringAttributes(); /* * Next we load the training data from our ARFF file */ ArffLoader trainLoader = new ArffLoader(); trainLoader.setSource(new File("train.arff")); trainLoader.setRetrieval(Loader.BATCH); Instances trainDataSet = trainLoader.getDataSet(); /* * Now we tell the data set which attribute we want to classify, in our * case, we want to classify the first column: survived */ Attribute trainAttribute = trainDataSet.attribute(0); trainDataSet.setClass(trainAttribute); /* * The RandomForest implementation cannot handle columns of type string, * so we remove them for now. */ trainDataSet.deleteStringAttributes(); /* * Now we read in the serialized model from disk */ Classifier classifier = (Classifier) SerializationHelper.read("titanic.model"); /* * Next we will use an Evaluation class to evaluate the performance of * our Classifier. */ Evaluation evaluation = new Evaluation(trainDataSet); evaluation.evaluateModel(classifier, predictDataSet, new Object[] {}); /* * After we evaluate the Classifier, we write out the summary * information to the screen. */ System.out.println(classifier); System.out.println(evaluation.toSummaryString()); }
public JSONArray Cluster(String wekaFilePath, int clusterNum) throws Exception { File inputFile = new File(wekaFilePath); ArffLoader arf = new ArffLoader(); arf.setFile(inputFile); Instances originIns = arf.getDataSet(); Instances insTest = new Instances(originIns); insTest.deleteStringAttributes(); int totalNum = insTest.numInstances(); // SimpleKMeans sm = new SimpleKMeans(); EM em = new EM(); em.setNumClusters(clusterNum); MakeDensityBasedClusterer sm = new MakeDensityBasedClusterer(); sm.setClusterer(em); sm.buildClusterer(insTest); System.out.println("totalNum:" + insTest.numInstances()); System.out.println("============================"); System.out.println(sm.toString()); Map<Integer, ArrayList<String>> result = new HashMap<Integer, ArrayList<String>>(); for (int i = 0; i < clusterNum; i++) { result.put(i, new ArrayList<String>()); } for (int i = 0; i < totalNum; i++) { Instance ins = originIns.instance(i); String word = ins.stringValue(0); Instance tempIns = new Instance(ins); tempIns.deleteAttributeAt(0); int cluster = sm.clusterInstance(tempIns); result.get(cluster).add(word); } // print the result ArrayList<String> words = new ArrayList<String>(); JSONArray keyWords = new JSONArray(); for (int k : result.keySet()) { words = result.get(k); PriorityQueue<MyTerm> clusterQueue = new PriorityQueue<MyTerm>(1, MyTermCompare); for (int i = 0; i < words.size(); i++) { String s = words.get(i); assert linkMap.containsKey(s); int freq = linkMap.get(s).totalFreq; clusterQueue.add(linkMap.get(s)); words.set(i, "(" + s + ":" + freq + ")"); } JSONArray clusterArray = new JSONArray(); int num = clusterQueue.size() / 10 + 1; // 5% int totalFreq = 0; int totalLength = 0; for (int i = 0; i < num && !clusterQueue.isEmpty(); ) { JSONObject mem = new JSONObject(); MyTerm myTerm = clusterQueue.poll(); String word = myTerm.originTrem.text(); if (word.length() == 1) { continue; } mem.put("text", word); mem.put("freq", myTerm.totalFreq); clusterArray.put(mem); i++; totalFreq += myTerm.totalFreq; totalLength += word.length(); } double averFreq = totalFreq * 1.0 / num; double averLength = totalLength * 1.0 / num; int count = 0; while (!clusterQueue.isEmpty() && count < num) { MyTerm myTerm = clusterQueue.poll(); String word = myTerm.originTrem.text(); int freq = myTerm.totalFreq; int times = (int) (word.length() / averFreq) + 1; if (freq > averFreq / times) { JSONObject mem = new JSONObject(); mem.put("text", word); mem.put("freq", freq); mem.put("extra", true); clusterArray.put(mem); } } keyWords.put(clusterArray); System.out.println( "cluster" + k + ":" + words.size() + ":\t" + (int) (words.size() * 1.0 / totalNum * 100)); if (result.get(k).size() < 100) { System.out.println(result.get(k)); } } // System.out.println("errorNum:"+errorNum); return keyWords; }