/** * Parses a given list of options. * * <p> * <!-- options-start --> * Valid options are: * * <p> * * <pre> -folds <folds> * The number of folds for splitting the training set into * train and test set. The first fold is always the training * set. With '-V' you can invert this, i.e., instead of 20/80 * for 5 folds you'll get 80/20. * (default 5)</pre> * * <pre> -V * Inverts the fold selection, i.e., instead of using the first * fold for the training set it is used for test set and the * remaining folds for training.</pre> * * <pre> -verbose * Whether to print some more information during building the * classifier. * (default is off)</pre> * * <pre> -insight * Whether to use the labels of the original test set for more * statistics (not used for learning!). * (default is off)</pre> * * <pre> -S <num> * Random number seed. * (default 1)</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -naive * Uses a sorted list (ordered according to distance) instead of the * KDTree for finding the neighbors. * (default is KDTree)</pre> * * <pre> -I * Weight neighbours by the inverse of their distance * (use when k > 1)</pre> * * <pre> -F * Weight neighbours by 1 - their distance * (use when k > 1)</pre> * * <pre> -K <number of neighbors> * Number of nearest neighbours (k) used in classification. * (Default = 1)</pre> * * <pre> -A * The nearest neighbour search algorithm to use (default: LinearNN). * </pre> * * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { super.setOptions(options); setUseNaiveSearch(Utils.getFlag("naive", options)); m_Classifier.setOptions(options); m_KNN = m_Classifier.getKNN(); // backup KNN m_Classifier.setCrossValidate(true); // always on! m_Classifier.setWindowSize(0); // always off! m_Classifier.setMeanSquared(false); // always off! }
/** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector result; Enumeration en; result = new Vector(); // ancestor en = super.listOptions(); while (en.hasMoreElements()) result.addElement(en.nextElement()); result.addElement( new Option( "\tUses a sorted list (ordered according to distance) instead of the\n" + "\tKDTree for finding the neighbors.\n" + "\t(default is KDTree)", "naive", 0, "-naive")); // IBk en = m_Classifier.listOptions(); while (en.hasMoreElements()) { Option o = (Option) en.nextElement(); // remove -X, -W and -E if (!o.name().equals("X") && !o.name().equals("W") && !o.name().equals("E")) result.addElement(o); } return result.elements(); }
/** * determines the "K" for the neighbors from the training set, initializes the labels of the test * set to "missing" and generates the neighbors for all instances in the test set * * @throws Exception if initialization fails */ protected void initialize() throws Exception { int i; double timeStart; double timeEnd; Instances trainingNew; Instances testNew; // determine K if (getVerbose()) System.out.println("\nOriginal KNN = " + m_KNN); ((IBk) m_Classifier).setKNN(m_KNN); ((IBk) m_Classifier).setCrossValidate(true); m_Classifier.buildClassifier(m_TrainsetNew); m_Classifier.toString(); // necessary to crossvalidate IBk! ((IBk) m_Classifier).setCrossValidate(false); m_KNNdetermined = ((IBk) m_Classifier).getKNN(); if (getVerbose()) System.out.println("Determined KNN = " + m_KNNdetermined); // set class labels in test set to "missing" for (i = 0; i < m_TestsetNew.numInstances(); i++) m_TestsetNew.instance(i).setClassMissing(); // copy data trainingNew = new Instances(m_TrainsetNew); testNew = new Instances(m_TestsetNew); // filter data m_Missing.setInputFormat(trainingNew); trainingNew = Filter.useFilter(trainingNew, m_Missing); testNew = Filter.useFilter(testNew, m_Missing); // create the list of neighbors for the instances in the test set m_NeighborsTestset = new Neighbors[m_TestsetNew.numInstances()]; timeStart = System.currentTimeMillis(); for (i = 0; i < testNew.numInstances(); i++) { m_NeighborsTestset[i] = new Neighbors( testNew.instance(i), m_TestsetNew.instance(i), m_KNNdetermined, trainingNew, testNew); m_NeighborsTestset[i].setVerbose(getVerbose() || getDebug()); m_NeighborsTestset[i].setUseNaiveSearch(getUseNaiveSearch()); m_NeighborsTestset[i].find(); } timeEnd = System.currentTimeMillis(); if (getVerbose()) System.out.println( "Time for finding neighbors: " + Utils.doubleToString((timeEnd - timeStart) / 1000.0, 3)); }
/** * returns the best model as string representation. derived classes have to add additional * information here, like printing the classifier etc. * * @return the string representation of the best model */ protected String toStringModel() { StringBuffer text; text = new StringBuffer(); text.append(super.toStringModel()); text.append("\n"); text.append(m_Classifier.toString()); return text.toString(); }
/** performs initialization of members */ protected void initializeMembers() { super.initializeMembers(); m_KNNdetermined = -1; m_NeighborsTestset = null; m_TrainsetNew = null; m_TestsetNew = null; m_UseNaiveSearch = false; m_LabeledTestset = null; m_Missing = new ReplaceMissingValues(); m_Classifier = new IBk(); m_Classifier.setKNN(10); m_Classifier.setCrossValidate(true); m_Classifier.setWindowSize(0); m_Classifier.setMeanSquared(false); m_KNN = m_Classifier.getKNN(); m_AdditionalMeasures.add("measureDeterminedKNN"); }
/** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if distribution can't be computed */ public double[] distributionForInstance(Instance instance) throws Exception { DecisionTableHashKey thekey; double[] tempDist; double[] normDist; m_disTransform.input(instance); m_disTransform.batchFinished(); instance = m_disTransform.output(); m_delTransform.input(instance); m_delTransform.batchFinished(); instance = m_delTransform.output(); thekey = new DecisionTableHashKey(instance, instance.numAttributes(), false); // if this one is not in the table if ((tempDist = (double[]) m_entries.get(thekey)) == null) { if (m_useIBk) { tempDist = m_ibk.distributionForInstance(instance); } else { if (!m_classIsNominal) { tempDist = new double[1]; tempDist[0] = m_majority; } else { tempDist = m_classPriors.clone(); /*tempDist = new double [m_theInstances.classAttribute().numValues()]; tempDist[(int)m_majority] = 1.0; */ } } } else { if (!m_classIsNominal) { normDist = new double[1]; normDist[0] = (tempDist[0] / tempDist[1]); tempDist = normDist; } else { // normalise distribution normDist = new double[tempDist.length]; System.arraycopy(tempDist, 0, normDist, 0, tempDist.length); Utils.normalize(normDist); tempDist = normDist; } } return tempDist; }
/** * Gets the current settings of the classifier. * * @return an array of strings suitable for passing to setOptions */ public String[] getOptions() { Vector result; String[] options; int i; result = new Vector(); options = super.getOptions(); for (i = 0; i < options.length; i++) result.add(options[i]); options = m_Classifier.getOptions(); for (i = 0; i < options.length; i++) result.add(options[i]); if (getUseNaiveSearch()) result.add("-naive"); return (String[]) result.toArray(new String[result.size()]); }
/** * Sets the nearestNeighbourSearch algorithm to be used for finding nearest neighbour(s). * * @param value The NearestNeighbourSearch class. */ public void setNearestNeighbourSearchAlgorithm(NearestNeighbourSearch value) { m_Classifier.setNearestNeighbourSearchAlgorithm(value); }
/** * Returns the current nearestNeighbourSearch algorithm in use. * * @return the NearestNeighbourSearch algorithm currently in use. */ public NearestNeighbourSearch getNearestNeighbourSearchAlgorithm() { return m_Classifier.getNearestNeighbourSearchAlgorithm(); }
/** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the explorer/experimenter gui */ public String nearestNeighbourSearchAlgorithmTipText() { return m_Classifier.nearestNeighbourSearchAlgorithmTipText(); }
/** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the explorer/experimenter gui */ public String distanceWeightingTipText() { return m_Classifier.distanceWeightingTipText(); }
/** * Gets the number of neighbours the learner will use. * * @return the number of neighbours. */ public int getKNN() { return m_Classifier.getKNN(); }
/** * Set the number of neighbours the learner is to use. * * @param k the number of neighbours. */ public void setKNN(int k) { m_Classifier.setKNN(k); }
/** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the explorer/experimenter gui */ public String KNNTipText() { return m_Classifier.KNNTipText(); }
// 输入问题,输出问题所属类型。 public double classifyByKnn(String question) throws Exception { double label = -1; List<Question> questionID = questionDAO.getQuestionIDLabeled(); // 定义数据格式 Attribute att1 = new Attribute("法律政策"); Attribute att2 = new Attribute("位置交通"); Attribute att3 = new Attribute("风水"); Attribute att4 = new Attribute("房价"); Attribute att5 = new Attribute("楼层"); Attribute att6 = new Attribute("户型"); Attribute att7 = new Attribute("小区配套"); Attribute att8 = new Attribute("贷款"); Attribute att9 = new Attribute("买房时机"); Attribute att10 = new Attribute("开发商"); FastVector labels = new FastVector(); labels.addElement("1"); labels.addElement("2"); labels.addElement("3"); labels.addElement("4"); labels.addElement("5"); labels.addElement("6"); labels.addElement("7"); labels.addElement("8"); labels.addElement("9"); labels.addElement("10"); Attribute att11 = new Attribute("类别", labels); FastVector attributes = new FastVector(); attributes.addElement(att1); attributes.addElement(att2); attributes.addElement(att3); attributes.addElement(att4); attributes.addElement(att5); attributes.addElement(att6); attributes.addElement(att7); attributes.addElement(att8); attributes.addElement(att9); attributes.addElement(att10); attributes.addElement(att11); Instances dataset = new Instances("Test-dataset", attributes, 0); dataset.setClassIndex(10); Classifier classifier = null; if (!new File("knn.model").exists()) { // 添加数据 for (int i = 0; i < questionID.size(); i++) { double[] values = new double[11]; for (int m = 0; m < 11; m++) { values[m] = 0; } int whitewordcount = 0; whitewordcount = questionDAO.getHitWhiteWordNum(questionID.get(i).getId()); if (whitewordcount != 0) { List<QuestionWhiteWord> questionwhiteword = questionDAO.getHitQuestionWhiteWord(questionID.get(i).getId()); for (int j = 0; j < questionwhiteword.size(); j++) { values[getAttIndex(questionwhiteword.get(j).getWordId()) - 1]++; } for (int m = 0; m < 11; m++) { values[m] = values[m] / whitewordcount; System.out.println(m + "<>" + values[m]); } } System.out.println("第" + i + "个问题。"); System.out.println(questionID.get(i).getQuestion()); values[10] = questionID.get(i).getType() - 1; Instance inst = new Instance(1.0, values); dataset.add(inst); } // 构造分类器 IBk ibk = new IBk(); ibk.setKNN(3); classifier = ibk; classifier.buildClassifier(dataset); SerializationHelper.write("knn.model", classifier); } else { classifier = (Classifier) SerializationHelper.read("knn.model"); System.out.println("串行化解析。"); } System.out.println("*************begin evaluation*******************"); Evaluation evaluation = new Evaluation(dataset); evaluation.evaluateModel(classifier, dataset); // 按道理说,这里应该使用另一份数据,而不是训练集data。 System.out.println(evaluation.toSummaryString()); // 分类 System.out.println("*************begin classification*******************"); Instance subject = new Instance(1.0, getQuestionVector(question)); subject.setDataset(dataset); label = classifier.classifyInstance(subject); System.out.println("label: " + label); // double dis[]=classifier.distributionForInstance(inst); // for(double i:dis){ // System.out.print(i+" "); // } System.out.println(questionID.size()); return label + 1; }
/** * @param args the command line arguments * @throws Exception */ public static void main(String[] args) throws Exception { PreProcessor p = new PreProcessor("census-income.data", "census-income-preprocessed.arff"); p.smote(); PreProcessor p_test = new PreProcessor("census-income.test", "census-income-test-preprocessed.arff"); p_test.run(); BufferedReader traindata = new BufferedReader(new FileReader("census-income-preprocessed.arff")); BufferedReader testdata = new BufferedReader(new FileReader("census-income-test-preprocessed.arff")); Instances traininstance = new Instances(traindata); Instances testinstance = new Instances(testdata); traindata.close(); testdata.close(); traininstance.setClassIndex(traininstance.numAttributes() - 1); testinstance.setClassIndex(testinstance.numAttributes() - 1); int numOfAttributes = testinstance.numAttributes(); int numOfInstances = testinstance.numInstances(); NaiveBayesClassifier nb = new NaiveBayesClassifier("census-income-preprocessed.arff"); Classifier cnaive = nb.NBClassify(); DecisionTree dt = new DecisionTree("census-income-preprocessed.arff"); Classifier cls = dt.DTClassify(); AdaBoost ab = new AdaBoost("census-income-preprocessed.arff"); AdaBoostM1 m1 = ab.AdaBoostDTClassify(); BaggingMethod b = new BaggingMethod("census-income-preprocessed.arff"); Bagging bag = b.BaggingDTClassify(); SVM s = new SVM("census-income-preprocessed.arff"); SMO svm = s.SMOClassifier(); knn knnclass = new knn("census-income-preprocessed.arff"); IBk knnc = knnclass.knnclassifier(); Logistic log = new Logistic(); log.buildClassifier(traininstance); int match = 0; int error = 0; int greater = 0; int less = 0; for (int i = 0; i < numOfInstances; i++) { String predicted = ""; greater = 0; less = 0; double predictions[] = new double[8]; double pred = cls.classifyInstance(testinstance.instance(i)); predictions[0] = pred; double abpred = m1.classifyInstance(testinstance.instance(i)); predictions[1] = abpred; double naivepred = cnaive.classifyInstance(testinstance.instance(i)); predictions[2] = naivepred; double bagpred = bag.classifyInstance(testinstance.instance(i)); predictions[3] = bagpred; double smopred = svm.classifyInstance(testinstance.instance(i)); predictions[4] = smopred; double knnpred = knnc.classifyInstance(testinstance.instance(i)); predictions[5] = knnpred; for (int j = 0; j < 6; j++) { if ((testinstance.instance(i).classAttribute().value((int) predictions[j])) .compareTo(">50K") == 0) greater++; else less++; } if (greater > less) predicted = ">50K"; else predicted = "<=50K"; if ((testinstance.instance(i).stringValue(numOfAttributes - 1)).compareTo(predicted) == 0) match++; else error++; } System.out.println("Correctly classified Instances: " + match); System.out.println("Misclassified Instances: " + error); double accuracy = (double) match / (double) numOfInstances * 100; double error_percent = 100 - accuracy; System.out.println("Accuracy: " + accuracy + "%"); System.out.println("Error: " + error_percent + "%"); }
/** * Gets the distance weighting method used. Will be one of WEIGHT_NONE, WEIGHT_INVERSE, or * WEIGHT_SIMILARITY * * @return the distance weighting method used. * @see IBk#WEIGHT_NONE * @see IBk#WEIGHT_INVERSE * @see IBk#WEIGHT_SIMILARITY */ public SelectedTag getDistanceWeighting() { return m_Classifier.getDistanceWeighting(); }
/** * Sets the distance weighting method used. Values other than WEIGHT_NONE, WEIGHT_INVERSE, or * WEIGHT_SIMILARITY will be ignored. * * @param newMethod the distance weighting method to use * @see IBk#WEIGHT_NONE * @see IBk#WEIGHT_INVERSE * @see IBk#WEIGHT_SIMILARITY */ public void setDistanceWeighting(SelectedTag newMethod) { m_Classifier.setDistanceWeighting(newMethod); }
private static ClassifierContext createContext(String[] args) throws ParseException { Options options = new Options() .addOption(new Option(help, "help", false, "show help")) .addOption(new Option(file, "file", true, "file path containing the data set")) .addOption( new Option( trainPercentage, "train-percentage", true, "training set in percentage, rest is test set; type double; default value training set: 0.7, test set 0.3")) .addOption( new Option( crossValidation, "cross-validation", true, "use cross-validation or not;type boolean; default value false")) .addOption( new Option( kFolds, "kfolds", true, "#folds used in cross-validation; type integer, default 5")) .addOption( new Option( classifier, "classifier", true, "classifier to use <dt, knn, boost, nn, svm> : Decision Tree(dt), Nearest Neighbours(knn), Boosting(boost), Neural Networks(nn), Support Vector Machine(svm)")) .addOption( new Option( dtPruning, "pruning", true, "dt: use pruning; type boolean, default value true>")) .addOption( new Option( dtCf, "confidence_factor", true, "dt: set confidence-factor, used for pruning; type double, default value 0.25")) .addOption( new Option( boostC, "boost_classifier", true, "boost: set the classifier learner; one of <stump|dt>, default stump")) .addOption( new Option( boostNrIterations, "boost_nr_iterations", true, "boost: set the nr of bagging iterations; type integer, default value 10>")) .addOption( new Option( boostDtPruning, "boost_dt_pruning", true, "boost: use pruning for dt; type boolean, default value true>")) .addOption( new Option( boostDtCf, "boost_dt_cf", true, "boost: use dt confidence-factor, used for pruning; type double, default value 0.25")) .addOption( new Option( knnK, "knn_k", true, "knn: specify the #nearest neighbours; type integer, default value 1")) .addOption( new Option( knnWeightDistance, "knn_weight_distance", true, "knn: specify a weight distance; type integer <1=None,2=Inverse,3=Similarity>, default value 1")) .addOption( new Option( nnLearningRate, "nn_learning_rate", true, "nn: backpropagation learning rate; type double, default value 0.3")) .addOption( new Option( nnMomentum, "nn_momentum", true, "nn: backpropagation momentum rate; type double, default value 0.2")) .addOption( new Option( nnHiddentUnits, "nn_hidden_units", true, "nn: comma-separated string for #hidden layers and nodes per layer. e.g. \"a,3,4\"; see weka for more details")) .addOption( new Option( svmKernelFunction, "svm_kernel_function", true, "svm: one of: <poly,radial>, default value poly")) .addOption( new Option( svmPolyExp, "svm_poly_exponent", true, "svm: the exponent value of a polynomial kernel; type double, default value 1.0")) .addOption( new Option( svmRadialGamma, "svm_radial_gamma", true, "svm: the gamma parameter value of the radial kernel; type double, default value 0.01")); String f = null; Double trainP = 0.7; Boolean cv = false; Integer kfolds = 5; ClassifierTypes classifierType = null; Classifier cls = null; CommandLine commandLine = new DefaultParser().parse(options, args); if (commandLine.hasOption(help)) { new HelpFormatter().printHelp("Run Classifiers", options); return null; } if (commandLine.hasOption(file)) { f = commandLine.getOptionValue(file); } else { System.out.println("Please provide data set file. See help for more details"); return null; } if (commandLine.hasOption(trainPercentage)) { String trainp = commandLine.getOptionValue(trainPercentage); trainP = getDouble(trainp); if (trainP == null || trainP >= 1 || trainP <= 0) { System.out.println( "Please provide an double between (0,1) range for training set pct. See help for more details"); return null; } } if (commandLine.hasOption(crossValidation)) { cv = Boolean.valueOf(commandLine.getOptionValue(crossValidation)); if (cv) { if (commandLine.hasOption(kFolds)) { kfolds = getInt(commandLine.getOptionValue(kFolds)); if (kfolds == null) { System.out.println( "Please provide an integer for kfolds e.g. 10. See help for more details"); return null; } } } } if (commandLine.hasOption(classifier)) { String c = commandLine.getOptionValue(classifier); classifierType = ClassifierTypes.toEnum(c); switch (classifierType) { case DECISION_TREE: J48 dt = new J48(); if (commandLine.hasOption(dtPruning)) { Boolean pruning = Boolean.valueOf(commandLine.getOptionValue(dtPruning)); dt.setUnpruned(!pruning); } if (commandLine.hasOption(dtCf)) { String cfStr = commandLine.getOptionValue(dtCf); Float cf = getFloat(cfStr); if (cf == null) { System.out.println( "Please provide a floating point number for dt_cf e.g. 0.25. See help for more details"); return null; } dt.setConfidenceFactor(cf); } cls = dt; break; case BOOSTING: AdaBoostM1 boost = new AdaBoostM1(); if (commandLine.hasOption(boostC)) { String boostCls = commandLine.getOptionValue(boostC); if ("stump".equalsIgnoreCase(boostCls)) { // nothing to, default value } else if ("dt".equalsIgnoreCase(boostCls)) { J48 boostDt = new J48(); if (commandLine.hasOption(boostDtPruning)) { Boolean boostPruning = Boolean.valueOf(commandLine.getOptionValue(boostDtPruning)); boostDt.setUnpruned(!boostPruning); } if (commandLine.hasOption(boostDtCf)) { String cfStr = commandLine.getOptionValue(boostDtCf); Float cf = getFloat(cfStr); if (cf == null) { System.out.println( "Please provide a floating point number for boost_dt_cf e.g. 0.25. See help for more details"); return null; } boostDt.setConfidenceFactor(cf); } boost.setClassifier(boostDt); } else { System.out.println( "boost_c can use one of <stump, dt> as values. See help for more details"); return null; } } if (commandLine.hasOption(boostNrIterations)) { Integer nrIt = getInt(commandLine.getOptionValue(boostNrIterations)); if (nrIt == null) { System.out.println( "Please provide an integer for boost_nr_it. See help for more details"); return null; } boost.setNumIterations(nrIt); } cls = boost; break; case KNN: IBk ibk = new IBk(); if (commandLine.hasOption(knnK)) { Integer k = getInt(commandLine.getOptionValue(knnK)); if (k == null) { System.out.println("Please provide an integer for knn_n. For more details see help"); return null; } ibk.setKNN(k); } if (commandLine.hasOption(knnWeightDistance)) { Integer knnwd = getInt(commandLine.getOptionValue(knnWeightDistance)); if (knnwd == null) { System.out.println( "Please provide one of 1,2,4 for knn_w_d. For more details see help"); return null; } if (1 != knnwd || 2 != knnwd || 4 != knnwd) { System.out.println( "Please provide one of 1,2,4 for knn_w_d. See help for more details"); return null; } ibk.setDistanceWeighting(new SelectedTag(knnwd, IBk.TAGS_WEIGHTING)); } cls = ibk; break; case NN: MultilayerPerceptron nn = new MultilayerPerceptron(); if (commandLine.hasOption(nnLearningRate)) { Double nnLR = getDouble(commandLine.getOptionValue(nnLearningRate)); if (nnLR == null) { System.out.println( "Please provide a double for NN learning rate. See help for more details"); return null; } nn.setLearningRate(nnLR); } if (commandLine.hasOption(nnMomentum)) { Double nnMR = getDouble(commandLine.getOptionValue(nnMomentum)); if (nnMR == null) { System.out.println( "Please provide a double for NN momentum rate. See help for more details"); return null; } nn.setMomentum(nnMR); } if (commandLine.hasOption(nnHiddentUnits)) { String nnHU = commandLine.getOptionValue(nnHiddentUnits); nn.setHiddenLayers(nnHU); } cls = nn; break; case SVM: SMO svm = new SMO(); if (commandLine.hasOption(svmKernelFunction)) { String svmkf = commandLine.getOptionValue(svmKernelFunction); Kernel kernel = null; if ("poly".equalsIgnoreCase(svmkf)) { PolyKernel pk = new PolyKernel(); if (commandLine.hasOption(svmPolyExp)) { Double expValue = getDouble(commandLine.getOptionValue(svmPolyExp)); if (expValue == null) { System.out.println( "Please provide a double value for svm_poly_exp. See help for more details"); return null; } pk.setExponent(expValue); } kernel = pk; } else if ("radial".equalsIgnoreCase(svmkf)) { RBFKernel rbfk = new RBFKernel(); if (commandLine.hasOption(svmRadialGamma)) { Double gamma = getDouble(commandLine.getOptionValue(svmRadialGamma)); if (gamma == null) { System.out.println( "Please provide a double value for svm_radial_gamma. See help for more details"); return null; } rbfk.setGamma(gamma); } kernel = rbfk; } else { System.out.println( "Please provide one of <poly, radial> for svm_kernel_fct. See help for more details"); return null; } svm.setKernel(kernel); } else { if (commandLine.hasOption(svmPolyExp)) { PolyKernel polyKernel = new PolyKernel(); Double expValue = getDouble(commandLine.getOptionValue(svmPolyExp)); if (expValue == null) { System.out.println( "Please provide a double value for svm_poly_exp. For more details, see help"); return null; } polyKernel.setExponent(expValue); svm.setKernel(polyKernel); } } cls = svm; break; } } else { System.out.println("Please provide a classifier. See help for more details"); return null; } return new ClassifierContext(f, trainP, cls, cv, kfolds); }
/** * Generates the classifier. * * @param data set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class m_theInstances = new Instances(data); m_theInstances.deleteWithMissingClass(); m_rr = new Random(1); if (m_theInstances.classAttribute().isNominal()) { // Set up class priors m_classPriorCounts = new double[data.classAttribute().numValues()]; Arrays.fill(m_classPriorCounts, 1.0); for (int i = 0; i < data.numInstances(); i++) { Instance curr = data.instance(i); m_classPriorCounts[(int) curr.classValue()] += curr.weight(); } m_classPriors = m_classPriorCounts.clone(); Utils.normalize(m_classPriors); } setUpEvaluator(); if (m_theInstances.classAttribute().isNumeric()) { m_disTransform = new weka.filters.unsupervised.attribute.Discretize(); m_classIsNominal = false; // use binned discretisation if the class is numeric ((weka.filters.unsupervised.attribute.Discretize) m_disTransform).setBins(10); ((weka.filters.unsupervised.attribute.Discretize) m_disTransform).setInvertSelection(true); // Discretize all attributes EXCEPT the class String rangeList = ""; rangeList += (m_theInstances.classIndex() + 1); // System.out.println("The class col: "+m_theInstances.classIndex()); ((weka.filters.unsupervised.attribute.Discretize) m_disTransform) .setAttributeIndices(rangeList); } else { m_disTransform = new weka.filters.supervised.attribute.Discretize(); ((weka.filters.supervised.attribute.Discretize) m_disTransform).setUseBetterEncoding(true); m_classIsNominal = true; } m_disTransform.setInputFormat(m_theInstances); m_theInstances = Filter.useFilter(m_theInstances, m_disTransform); m_numAttributes = m_theInstances.numAttributes(); m_numInstances = m_theInstances.numInstances(); m_majority = m_theInstances.meanOrMode(m_theInstances.classAttribute()); // Perform the search int[] selected = m_search.search(m_evaluator, m_theInstances); m_decisionFeatures = new int[selected.length + 1]; System.arraycopy(selected, 0, m_decisionFeatures, 0, selected.length); m_decisionFeatures[m_decisionFeatures.length - 1] = m_theInstances.classIndex(); // reduce instances to selected features m_delTransform = new Remove(); m_delTransform.setInvertSelection(true); // set features to keep m_delTransform.setAttributeIndicesArray(m_decisionFeatures); m_delTransform.setInputFormat(m_theInstances); m_dtInstances = Filter.useFilter(m_theInstances, m_delTransform); // reset the number of attributes m_numAttributes = m_dtInstances.numAttributes(); // create hash table m_entries = new Hashtable((int) (m_dtInstances.numInstances() * 1.5)); // insert instances into the hash table for (int i = 0; i < m_numInstances; i++) { Instance inst = m_dtInstances.instance(i); insertIntoTable(inst, null); } // Replace the global table majority with nearest neighbour? if (m_useIBk) { m_ibk = new IBk(); m_ibk.buildClassifier(m_theInstances); } // Save memory if (m_saveMemory) { m_theInstances = new Instances(m_theInstances, 0); m_dtInstances = new Instances(m_dtInstances, 0); } m_evaluation = null; }