public static void main(String[] args) throws Exception { /* * First we load the test data from our ARFF file */ ArffLoader testLoader = new ArffLoader(); testLoader.setSource(new File("data/titanic/test.arff")); testLoader.setRetrieval(Loader.BATCH); Instances testDataSet = testLoader.getDataSet(); /* * Now we tell the data set which attribute we want to classify, in our * case, we want to classify the first column: survived */ Attribute testAttribute = testDataSet.attribute(0); testDataSet.setClass(testAttribute); testDataSet.deleteStringAttributes(); /* * Now we read in the serialized model from disk */ Classifier classifier = (Classifier) SerializationHelper.read("data/titanic/titanic.model"); /* * This part may be a little confusing. We load up the test data again * so we have a prediction data set to populate. As we iterate over the * first data set we also iterate over the second data set. After an * instance is classified, we set the value of the prediction data set * to be the value of the classification */ ArffLoader test1Loader = new ArffLoader(); test1Loader.setSource(new File("data/titanic/test.arff")); Instances test1DataSet = test1Loader.getDataSet(); Attribute test1Attribute = test1DataSet.attribute(0); test1DataSet.setClass(test1Attribute); /* * Now we iterate over the test data and classify each entry and set the * value of the 'survived' column to the result of the classification */ Enumeration testInstances = testDataSet.enumerateInstances(); Enumeration test1Instances = test1DataSet.enumerateInstances(); while (testInstances.hasMoreElements()) { Instance instance = (Instance) testInstances.nextElement(); Instance instance1 = (Instance) test1Instances.nextElement(); double classification = classifier.classifyInstance(instance); instance1.setClassValue(classification); } /* * Now we want to write out our predictions. The resulting file is in a * format suitable to submit to Kaggle. */ CSVSaver predictedCsvSaver = new CSVSaver(); predictedCsvSaver.setFile(new File("data/titanic/predict.csv")); predictedCsvSaver.setInstances(test1DataSet); predictedCsvSaver.writeBatch(); System.out.println("Prediciton saved to predict.csv"); }
public void batchPredict() { // load all test set String modelFile = "data\\AcquireValueShopper\\decisionTable_bayes_trees.model".replace("\\", File.separator); String pathTest = "data/AcquireValueShopper/test_new.csv"; String pathPredict = "data/AcquireValueShopper/submission.csv"; Scanner scanner; String line = ""; String[] partsOfLine = null; String id = ""; PrintWriter output; Map<String, String> testSet = new HashMap<String, String>(); try { scanner = new Scanner(new File(pathTest)); while (scanner.hasNext()) { line = scanner.nextLine().trim(); partsOfLine = line.split(","); id = partsOfLine[0]; testSet.put(id, line); } scanner.close(); } catch (FileNotFoundException e1) { // TODO Auto-generated catch block e1.printStackTrace(); } double[] returnProb; double prob = 0.0; // predict try { // load model Classifier classifier = (Classifier) SerializationHelper.read(modelFile); output = new PrintWriter(pathPredict); output.append("id,repeatProbability" + "\n"); Iterator<String> idIterator = testSet.keySet().iterator(); while (idIterator.hasNext()) { id = idIterator.next(); line = testSet.get(id); Instances instances = buildInstance(line); Instance instance = instances.instance(0); returnProb = classifier.distributionForInstance(instance); prob = returnProb[1]; // prob = classifier.classifyInstance(instance); output.append(id + "," + prob + "\n"); } output.close(); } catch (FileNotFoundException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } }
public static void main(String[] args) throws Exception { // NaiveBayesSimple nb = new NaiveBayesSimple(); // BufferedReader br_train = new BufferedReader(new FileReader("src/train.arff.txt")); // String s = null; // long st_time = System.currentTimeMillis(); // Instances inst_train = new Instances(br_train); // System.out.println(inst_train.numAttributes()); // inst_train.setClassIndex(inst_train.numAttributes()-1); // System.out.println("train time"+(System.currentTimeMillis()-st_time)); // NaiveBayes nb1 = new NaiveBayes(); // nb1.buildClassifier(inst_train); // br_train.close(); long st_time = System.currentTimeMillis(); st_time = System.currentTimeMillis(); Classifier classifier = (Classifier) SerializationHelper.read("NaiveBayes.model"); // BufferedReader br_test = new BufferedReader(new FileReader("src/test.arff.txt")); // Instances inst_test = new Instances(br_test); // inst_test.setClassIndex(inst_test.numAttributes()-1); // System.out.println("test time"+(System.currentTimeMillis()-st_time)); // ArffLoader testLoader = new ArffLoader(); testLoader.setSource(new File("src/test.arff")); testLoader.setRetrieval(Loader.BATCH); Instances testDataSet = testLoader.getDataSet(); Attribute testAttribute = testDataSet.attribute("class"); testDataSet.setClass(testAttribute); int correct = 0; int incorrect = 0; FastVector attInfo = new FastVector(); attInfo.addElement(new Attribute("Id")); attInfo.addElement(new Attribute("Category")); Instances outputInstances = new Instances("predict", attInfo, testDataSet.numInstances()); Enumeration testInstances = testDataSet.enumerateInstances(); int index = 1; while (testInstances.hasMoreElements()) { Instance instance = (Instance) testInstances.nextElement(); double classification = classifier.classifyInstance(instance); Instance predictInstance = new Instance(outputInstances.numAttributes()); predictInstance.setValue(0, index++); predictInstance.setValue(1, (int) classification + 1); outputInstances.add(predictInstance); } System.out.println("Correct Instance: " + correct); System.out.println("IncCorrect Instance: " + incorrect); double accuracy = (double) (correct) / (double) (correct + incorrect); System.out.println("Accuracy: " + accuracy); CSVSaver predictedCsvSaver = new CSVSaver(); predictedCsvSaver.setFile(new File("predict.csv")); predictedCsvSaver.setInstances(outputInstances); predictedCsvSaver.writeBatch(); System.out.println("Prediciton saved to predict.csv"); }
// 输入问题,输出问题所属类型。 public double classifyByBayes(String question) throws Exception { double label = -1; List<Question> questionID = questionDAO.getQuestionIDLabeled(); // 定义数据格式 Attribute att1 = new Attribute("法律政策"); Attribute att2 = new Attribute("位置交通"); Attribute att3 = new Attribute("风水"); Attribute att4 = new Attribute("房价"); Attribute att5 = new Attribute("楼层"); Attribute att6 = new Attribute("户型"); Attribute att7 = new Attribute("小区配套"); Attribute att8 = new Attribute("贷款"); Attribute att9 = new Attribute("买房时机"); Attribute att10 = new Attribute("开发商"); FastVector labels = new FastVector(); labels.addElement("1"); labels.addElement("2"); labels.addElement("3"); labels.addElement("4"); labels.addElement("5"); labels.addElement("6"); labels.addElement("7"); labels.addElement("8"); labels.addElement("9"); labels.addElement("10"); Attribute att11 = new Attribute("类别", labels); FastVector attributes = new FastVector(); attributes.addElement(att1); attributes.addElement(att2); attributes.addElement(att3); attributes.addElement(att4); attributes.addElement(att5); attributes.addElement(att6); attributes.addElement(att7); attributes.addElement(att8); attributes.addElement(att9); attributes.addElement(att10); attributes.addElement(att11); Instances dataset = new Instances("Test-dataset", attributes, 0); dataset.setClassIndex(10); Classifier classifier = null; if (!new File("naivebayes.model").exists()) { // 添加数据 double[] values = new double[11]; for (int i = 0; i < questionID.size(); i++) { for (int m = 0; m < 11; m++) { values[m] = 0; } int whitewordcount = 0; whitewordcount = questionDAO.getHitWhiteWordNum(questionID.get(i).getId()); if (whitewordcount != 0) { List<QuestionWhiteWord> questionwhiteword = questionDAO.getHitQuestionWhiteWord(questionID.get(i).getId()); for (int j = 0; j < questionwhiteword.size(); j++) { values[getAttIndex(questionwhiteword.get(j).getWordId()) - 1]++; } for (int m = 0; m < 11; m++) { values[m] = values[m] / whitewordcount; } } values[10] = questionID.get(i).getType() - 1; Instance inst = new Instance(1.0, values); dataset.add(inst); } // 构造分类器 classifier = new NaiveBayes(); classifier.buildClassifier(dataset); SerializationHelper.write("naivebayes.model", classifier); } else { classifier = (Classifier) SerializationHelper.read("naivebayes.model"); } System.out.println("*************begin evaluation*******************"); Evaluation evaluation = new Evaluation(dataset); evaluation.evaluateModel(classifier, dataset); // 按道理说,这里应该使用另一份数据,而不是训练集data。 System.out.println(evaluation.toSummaryString()); // 分类 System.out.println("*************begin classification*******************"); Instance subject = new Instance(1.0, getQuestionVector(question)); subject.setDataset(dataset); label = classifier.classifyInstance(subject); System.out.println("label: " + label); // double dis[]=classifier.distributionForInstance(inst); // for(double i:dis){ // System.out.print(i+" "); // } System.out.println(questionID.size()); return label + 1; }
public void loadModel(String filename) throws Exception { classifier = (Classifier) SerializationHelper.read(filename); }
@Override public Void doInBackground() { BufferedReader reader; publish("Computing features..."); int testingSamples = p.getAllFeatures2(path, "testing_data"); try { publish("Reading data..."); reader = new BufferedReader(new FileReader("testing_data.arff")); final Instances testingdata = new Instances(reader); reader.close(); // setting class attribute testingdata.setClassIndex(13); testingdata.randomize(new Random(1)); long startTime = System.nanoTime(); Classifier ann = (Classifier) weka.core.SerializationHelper.read("mlp.model"); publish("Evaluating ANN..."); evalANN = new Evaluation(testingdata); startTime = System.nanoTime(); evalANN.evaluateModel(ann, testingdata); long runningTimeANN = (System.nanoTime() - startTime) / 1000000; // runningTimeANN /= 100; publish("Done evaluating ANN"); publish("Evaluating SVM..."); Classifier svm = (Classifier) weka.core.SerializationHelper.read("svm.model"); evalSVM = new Evaluation(testingdata); startTime = System.nanoTime(); evalSVM.evaluateModel(svm, testingdata); long runningTimeSVM = (System.nanoTime() - startTime) / 1000000; // runningTimeSVM /= 100; publish("Done evaluating SVM"); publish("Evaluating NB..."); Classifier nb = (Classifier) weka.core.SerializationHelper.read("naivebayes.model"); evalNB = new Evaluation(testingdata); startTime = System.nanoTime(); evalNB.evaluateModel(nb, testingdata); long runningTimeNB = (System.nanoTime() - startTime) / 1000000; // runningTimeNB /= 100; publish("Done evaluating ANN"); Platform.runLater( new Runnable() { @Override public void run() { bc.getData() .get(0) .getData() .get(0) .setYValue(evalANN.correct() / testingdata.size() * 100); bc.getData() .get(0) .getData() .get(1) .setYValue(evalSVM.correct() / testingdata.size() * 100); bc.getData() .get(0) .getData() .get(2) .setYValue(evalNB.correct() / testingdata.size() * 100); for (int i = 0; i < NUM_CLASSES; i++) { lineChart.getData().get(0).getData().get(i).setYValue(evalANN.recall(i) * 100); lineChart.getData().get(1).getData().get(i).setYValue(evalSVM.recall(i) * 100); lineChart.getData().get(2).getData().get(i).setYValue(evalNB.recall(i) * 100); } } }); panel.fillConfTable(evalSVM.confusionMatrix()); summaryTable.setValueAt(evalANN.correct() / testingdata.size() * 100., 0, 1); summaryTable.setValueAt(evalSVM.correct() / testingdata.size() * 100, 0, 2); summaryTable.setValueAt(evalNB.correct() / testingdata.size() * 100, 0, 3); summaryTable.setValueAt(runningTimeANN, 1, 1); summaryTable.setValueAt(runningTimeSVM, 1, 2); summaryTable.setValueAt(runningTimeNB, 1, 3); } catch (Exception e1) { // TODO Auto-generated catch block e1.printStackTrace(); } return null; }
public static void main(String[] args) throws Exception { /* * First we load our preditons from the CSV formatted file. */ CSVLoader predictCsvLoader = new CSVLoader(); predictCsvLoader.setSource(new File("predict.csv")); /* * Since we are not using the ARFF format here, we have to give the * loader a little bit of information about the data types. Columns * 3,8,10 need to be of type string and columns 1,4,11 are nominal * types. */ predictCsvLoader.setStringAttributes("3,8,10"); predictCsvLoader.setNominalAttributes("1,4,11"); Instances predictDataSet = predictCsvLoader.getDataSet(); /* * Here we set the attribute we want to test the predicitons with */ Attribute testAttribute = predictDataSet.attribute(0); predictDataSet.setClass(testAttribute); /* * We still have to remove all string attributes before we can test */ predictDataSet.deleteStringAttributes(); /* * Next we load the training data from our ARFF file */ ArffLoader trainLoader = new ArffLoader(); trainLoader.setSource(new File("train.arff")); trainLoader.setRetrieval(Loader.BATCH); Instances trainDataSet = trainLoader.getDataSet(); /* * Now we tell the data set which attribute we want to classify, in our * case, we want to classify the first column: survived */ Attribute trainAttribute = trainDataSet.attribute(0); trainDataSet.setClass(trainAttribute); /* * The RandomForest implementation cannot handle columns of type string, * so we remove them for now. */ trainDataSet.deleteStringAttributes(); /* * Now we read in the serialized model from disk */ Classifier classifier = (Classifier) SerializationHelper.read("titanic.model"); /* * Next we will use an Evaluation class to evaluate the performance of * our Classifier. */ Evaluation evaluation = new Evaluation(trainDataSet); evaluation.evaluateModel(classifier, predictDataSet, new Object[] {}); /* * After we evaluate the Classifier, we write out the summary * information to the screen. */ System.out.println(classifier); System.out.println(evaluation.toSummaryString()); }
public HashMap<String, String> process( Sentence sent, String dep, HashSet<String> terms, List<NamedEntity> entities, String author, String aidx) { try { // System.out.println("ML start!"); // System.out.println("List : " + terms); HashMap<String, String> ht = new HashMap<String, String>(); List<NamedEntity> newEntities = new ArrayList<NamedEntity>(); for (NamedEntity entity : entities) { // System.out.println("original: " + entity.entity); boolean check = false; for (NamedEntity temp : entities) { if (entity == temp) continue; if (entity.entity.contains(temp.entity)) { check = true; } } if (!check) newEntities.add(entity); } List<DependencyTriple> dtl = getDependencyTripleList(dep); List<NamedEntity> targetCands = new ArrayList<NamedEntity>(); HashMap<NamedEntity, String> tOpinTerm = new HashMap<NamedEntity, String>(); List<NamedEntity> holderCands = new ArrayList<NamedEntity>(); HashMap<NamedEntity, String> hOpinTerm = new HashMap<NamedEntity, String>(); BufferedWriter writer = new BufferedWriter(new FileWriter("weka_target.csv")); writer.write("A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,Class\n"); boolean check = false; List<NamedEntity> targetTmp = new ArrayList<NamedEntity>(); for (NamedEntity entity : newEntities) { // System.out.println("extracted: " + entity.entity); String temp = getTargetFeatures(entity, author, terms, dtl); // System.out.println(temp); if (temp.length() > 1) { check = true; writer.write(temp); String[] toks = temp.split("\n"); for (int i = 0; i < toks.length; i++) { targetTmp.add(entity); tOpinTerm.put(entity, toks[i].substring(0, toks[i].indexOf(","))); } } } writer.close(); if (check) { DataSource source = new DataSource("weka_target.csv"); Instances testdata = source.getDataSet(); testdata.setClassIndex(testdata.numAttributes() - 1); Classifier models = (Classifier) weka.core.SerializationHelper.read("target_smoreg.model"); if (testdata.numInstances() != targetTmp.size()) System.out.println("wrong number of instances"); for (int i = 0; i < testdata.numInstances(); i++) { double pred = models.classifyInstance(testdata.instance(i)); if (pred >= 1.0) { // System.out.println(pred + " , " + targetTmp.get(i).entity); targetCands.add(targetTmp.get(i)); } } } writer = new BufferedWriter(new FileWriter("weka_holder.csv")); writer.write("A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,A11,A12,A13,A14,A15,Class\n"); check = false; List<NamedEntity> holderTmp = new ArrayList<NamedEntity>(); for (NamedEntity entity : newEntities) { // System.out.println("extracted: " + entity.entity); String temp = getHolderFeatures(entity, author, terms, dtl); // System.out.println(temp); if (temp.length() > 1) { check = true; writer.write(temp); String[] toks = temp.split("\n"); for (int i = 0; i < toks.length; i++) { holderTmp.add(entity); hOpinTerm.put(entity, toks[i].substring(0, toks[i].indexOf(","))); } } } writer.close(); if (check) { DataSource source = new DataSource("weka_holder.csv"); Instances testdata = source.getDataSet(); testdata.setClassIndex(testdata.numAttributes() - 1); Classifier models = (Classifier) weka.core.SerializationHelper.read("holder_smoreg.model"); if (testdata.numInstances() != holderTmp.size()) System.out.println("wrong number of instances"); for (int i = 0; i < testdata.numInstances(); i++) { double pred = models.classifyInstance(testdata.instance(i)); if (pred >= 1.0) { // System.out.println(pred + " , " + holderTmp.get(i).entity); holderCands.add(holderTmp.get(i)); } } } if ((targetCands.size() == 0) || (holderCands.size() == 0)) return ht; List<NamedEntity> holderCandTmp = new ArrayList<NamedEntity>(); for (NamedEntity holderCand : holderCands) { boolean hasLonger = false; for (NamedEntity temp : holderCands) { if (temp.entity.compareTo(holderCand.entity) == 0) continue; if (temp.entity.contains(holderCand.entity)) { hasLonger = true; break; } } if (!hasLonger) holderCandTmp.add(holderCand); } List<NamedEntity> targetCandTmp = new ArrayList<NamedEntity>(); for (NamedEntity targetCand : targetCands) { boolean hasLonger = false; for (NamedEntity temp : targetCands) { if (temp.entity.compareTo(targetCand.entity) == 0) continue; if (temp.entity.contains(targetCand.entity)) { hasLonger = true; break; } } if (!hasLonger) targetCandTmp.add(targetCand); } for (NamedEntity targetCand : targetCandTmp) { if (targetCand.entity.compareTo(author) == 0) continue; for (NamedEntity holderCand : holderCandTmp) { if (targetCand.entity.compareTo(holderCand.entity) == 0) continue; String targetOpin = tOpinTerm.get(targetCand); String holderOpin = hOpinTerm.get(holderCand); // System.out.println(targetOpin + ", " + holderOpin); if (targetOpin.compareTo(holderOpin) != 0) continue; String opin = targetOpin .concat("\t") .concat( Integer.toString(sent.sent.indexOf(targetOpin) + sent.beg) .concat("-") .concat( Integer.toString( sent.sent.indexOf(targetOpin) + sent.beg + targetOpin.length()))); String holder = holderCand .entity .concat("\t") .concat( Integer.toString(holderCand.beg) .concat("-") .concat(Integer.toString(holderCand.end))); String target = targetCand .entity .concat("\t") .concat( Integer.toString(targetCand.beg) .concat("-") .concat(Integer.toString(targetCand.end))); ht.put(targetOpin, opin.concat("\t").concat(holder).concat("\t").concat(target)); } } return ht; } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); return null; } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); return null; } }