public static void svmTrain(ArrayList<ArrayList<Double>> train) { svm_problem prob = new svm_problem(); int dataCount = train.size(); prob.y = new double[dataCount]; prob.l = dataCount; prob.x = new svm_node[dataCount][]; for (int i = 0; i < dataCount; i++) { ArrayList<Double> features = train.get(i); prob.x[i] = new svm_node[features.size() - 1]; for (int j = 1; j < features.size(); j++) { svm_node node = new svm_node(); node.index = j - 1; node.value = features.get(j); prob.x[i][j - 1] = node; } prob.y[i] = features.get(0); } svm_parameter param = new svm_parameter(); param.probability = 1; param.gamma = 0.5; param.nu = 0.5; param.C = 1; param.svm_type = svm_parameter.NU_SVC; param.kernel_type = svm_parameter.LINEAR; param.cache_size = 20000; param.eps = 0.001; _model = svm.svm_train(prob, param); }
public SVMTrainer() { svm_parameter param = new svm_parameter(); param.svm_type = svm_parameter.C_SVC; param.kernel_type = svm_parameter.RBF; param.degree = 3; param.gamma = 0; // 1/num_features param.coef0 = 0; param.nu = 0.5; param.cache_size = 100; param.C = 1; param.eps = 1e-3; param.p = 0.1; param.shrinking = 0; param.probability = 1; param.nr_weight = 0; param.weight_label = new int[0]; param.weight = new double[0]; setParameter(param); }
private void set_params() { param = new svm_parameter(); // default values param.svm_type = svm_parameter.C_SVC; param.kernel_type = svm_parameter.RBF; param.degree = 3; param.gamma = 0; // 1/num_features param.coef0 = 0; param.nu = 0.5; param.cache_size = 100; param.C = 1; param.eps = 1e-3; param.p = 0.1; param.shrinking = 1; param.probability = 0; param.nr_weight = 0; param.weight_label = new int[0]; param.weight = new double[0]; }
@Override public void train(List<Input> inputs) throws Exception { // What is the balance of the data set? Map<Integer, Integer> labelCounts = new HashMap<Integer, Integer>(); for (Input input : inputs) { int thisLabel = (input.label == label ? 1 : -1); if (!labelCounts.containsKey(thisLabel)) { labelCounts.put(thisLabel, 1); } else { labelCounts.put(thisLabel, labelCounts.get(thisLabel) + 1); } } // How many models do we need to train? int modelCount = (int) Math.ceil(labelCounts.get(-1).intValue() / labelCounts.get(1).intValue()); for (int i = 0; i < modelCount; i++) { svm_problem problem = new svm_problem(); // l appears to be the number of training samples problem.l = labelCounts.get(1) * 2; // The features of each sample svm_node[][] svmTrainingSamples = new svm_node[problem.l][]; double[] labels = new double[problem.l]; int count = 0; // Get all of the 1-labeled inputs for (Input input : inputs) { if (input.label == label) { svmTrainingSamples[count] = input.getSVMFeatures(); labels[count++] = 1; } } // Randomly get -1-labeled inputs for (int j = 0; j < labelCounts.get(1); j++) { Input input = inputs.get(MathUtils.RANDOM.nextInt(inputs.size())); if (input.label == label) { j--; continue; } svmTrainingSamples[count] = input.getSVMFeatures(); labels[count++] = -1; } problem.x = svmTrainingSamples; // y is probably the labels to each sample problem.y = labels; // Set the training parameters svm_parameter param = new svm_parameter(); // SVM Type // 0 -- C-SVC (classification) // 1 -- nu-SVC (classification) // 2 -- one-class SVM // 3 -- epsilon-SVR (regression) // 4 -- nu-SVR (regression) param.svm_type = svm_parameter.C_SVC; // Other C-SVC parameters: // param.weight[] : sets the parameter C of class i to weight*c (default 1) // param.C : cost, set the parameter C of C-SVC (default 1) // Type of kernel // 0 -- linear: u'*v // 1 -- polynomial: (gamma*u'*v + coef0)^degree // 2 -- radial basis function: exp(-gamma*|u-v|^2) // 3 -- sigmoid: tanh(gamma*u'*v + coef0) param.kernel_type = svm_parameter.RBF; // Kernel gamma param.gamma = 2e9; // Cost param.C = 2e9; // Degree for poly param.degree = 3; // Coefficient for poly/sigmoid param.coef0 = 0.1; // Stopping criteria param.eps = 0.05; // Make larger for faster training param.cache_size = 800; // Number of label weights param.nr_weight = 2; // Weight labels param.weight_label = new int[] {1, -1}; param.weight = new double[] {1, 1}; // Nu param.nu = 0.1; // p param.p = 0.1; // Shrinking heuristic param.shrinking = 1; // Estimate probabilities param.probability = 0; // Train the SVM classifier svm.svm_set_print_string_function( new svm_print_interface() { @Override public void print(String s) {} }); svm_model model = svm.svm_train(problem, param); models.add(model); } }