示例#1
0
  public static void svmTrain(ArrayList<ArrayList<Double>> train) {
    svm_problem prob = new svm_problem();
    int dataCount = train.size();
    prob.y = new double[dataCount];
    prob.l = dataCount;
    prob.x = new svm_node[dataCount][];

    for (int i = 0; i < dataCount; i++) {
      ArrayList<Double> features = train.get(i);
      prob.x[i] = new svm_node[features.size() - 1];
      for (int j = 1; j < features.size(); j++) {
        svm_node node = new svm_node();
        node.index = j - 1;
        node.value = features.get(j);
        prob.x[i][j - 1] = node;
      }
      prob.y[i] = features.get(0);
    }

    svm_parameter param = new svm_parameter();
    param.probability = 1;
    param.gamma = 0.5;
    param.nu = 0.5;
    param.C = 1;
    param.svm_type = svm_parameter.NU_SVC;
    param.kernel_type = svm_parameter.LINEAR;
    param.cache_size = 20000;
    param.eps = 0.001;

    _model = svm.svm_train(prob, param);
  }
示例#2
0
  @Override
  public AnalysisCollector analyze(CollectionReader cr) throws AnalyzerFailureException {
    svm_problem svmProblem = loadData(cr);
    svm_problem sampled = null;
    if (findBestParameters) {
      if (sample < 1d) {
        logger.debug("Sampling.");
        sampled = do_sample(svmProblem);
      }
      logger.debug("Performing grid search.");
      do_find_best_parameters(sampled != null ? sampled : svmProblem);
    }
    svm_parameter svmParam = getDefaultSvmParameters();
    svmParam.probability = 1;
    svmParam.C = c;
    svmParam.gamma = gamma;
    setWeights(svmParam);

    logger.debug("Training with C=" + c + "  gamma=" + gamma);
    svm_model model = svm.svm_train(svmProblem, svmParam);

    logger.debug("Done!");
    return new SingletonAnalysisCollector(
        new LibSvmTrainerAnalysis(model, scaler, labelList, c, gamma));
  }
  public static void trainClassifier(
      List<TrainingSample<BxZoneLabel>> trainingElements, String output)
      throws AnalysisException, IOException {
    SVMHeaderLinesClassifier contentFilter = new SVMHeaderLinesClassifier();
    svm_parameter param = SVMZoneClassifier.getDefaultParam();
    param.gamma = BEST_GAMMA;
    param.C = BEST_C;
    param.kernel_type = svm_parameter.RBF;

    contentFilter.setParameter(param);
    contentFilter.buildClassifier(trainingElements);
    contentFilter.saveModel(output);
  }
  public svm_parameter setupSVM() {

    svm_parameter param = new svm_parameter();
    param.probability = 0;
    param.gamma = 0.5;
    // param.nu = 0.5;
    param.C = 0.5;
    param.svm_type = svm_parameter.EPSILON_SVR;
    param.kernel_type = svm_parameter.POLY;
    param.cache_size = 20000;
    param.eps = 0.001;

    return param;
  }
  public static void train(svm_problem svmProblem) {
    defaultSVMParameter = new DefaultSVMParameter(LoadFeatureFile.f);
    svmParameter = defaultSVMParameter.svmParameter;
    filename += ".model";

    // Change the parameters with the desired C and Gamma
    svmParameter.C = 0.125;
    svmParameter.gamma = 0.5;
    // Train the SVM
    svmModel = svm.svm_train(svmProblem, svmParameter);

    try {
      svm.svm_save_model(filename, svmModel);
    } catch (IOException e) {
      e.printStackTrace();
    }
  }
示例#6
0
 public SVMTrainer() {
   svm_parameter param = new svm_parameter();
   param.svm_type = svm_parameter.C_SVC;
   param.kernel_type = svm_parameter.RBF;
   param.degree = 3;
   param.gamma = 0; // 1/num_features
   param.coef0 = 0;
   param.nu = 0.5;
   param.cache_size = 100;
   param.C = 1;
   param.eps = 1e-3;
   param.p = 0.1;
   param.shrinking = 0;
   param.probability = 1;
   param.nr_weight = 0;
   param.weight_label = new int[0];
   param.weight = new double[0];
   setParameter(param);
 }
示例#7
0
 private void set_params() {
   param = new svm_parameter();
   // default values
   param.svm_type = svm_parameter.C_SVC;
   param.kernel_type = svm_parameter.RBF;
   param.degree = 3;
   param.gamma = 0; // 1/num_features
   param.coef0 = 0;
   param.nu = 0.5;
   param.cache_size = 100;
   param.C = 1;
   param.eps = 1e-3;
   param.p = 0.1;
   param.shrinking = 1;
   param.probability = 0;
   param.nr_weight = 0;
   param.weight_label = new int[0];
   param.weight = new double[0];
 }
  @Override
  public void train(List<Input> inputs) throws Exception {
    // What is the balance of the data set?
    Map<Integer, Integer> labelCounts = new HashMap<Integer, Integer>();
    for (Input input : inputs) {
      int thisLabel = (input.label == label ? 1 : -1);
      if (!labelCounts.containsKey(thisLabel)) {
        labelCounts.put(thisLabel, 1);
      } else {
        labelCounts.put(thisLabel, labelCounts.get(thisLabel) + 1);
      }
    }

    // How many models do we need to train?
    int modelCount =
        (int) Math.ceil(labelCounts.get(-1).intValue() / labelCounts.get(1).intValue());

    for (int i = 0; i < modelCount; i++) {
      svm_problem problem = new svm_problem();
      // l appears to be the number of training samples
      problem.l = labelCounts.get(1) * 2;
      // The features of each sample
      svm_node[][] svmTrainingSamples = new svm_node[problem.l][];
      double[] labels = new double[problem.l];
      int count = 0;
      // Get all of the 1-labeled inputs
      for (Input input : inputs) {
        if (input.label == label) {
          svmTrainingSamples[count] = input.getSVMFeatures();
          labels[count++] = 1;
        }
      }
      // Randomly get -1-labeled inputs
      for (int j = 0; j < labelCounts.get(1); j++) {
        Input input = inputs.get(MathUtils.RANDOM.nextInt(inputs.size()));
        if (input.label == label) {
          j--;
          continue;
        }
        svmTrainingSamples[count] = input.getSVMFeatures();
        labels[count++] = -1;
      }
      problem.x = svmTrainingSamples;
      // y is probably the labels to each sample
      problem.y = labels;

      // Set the training parameters
      svm_parameter param = new svm_parameter();
      // SVM Type
      //  0 -- C-SVC (classification)
      //  1 -- nu-SVC (classification)
      //  2 -- one-class SVM
      //  3 -- epsilon-SVR (regression)
      //  4 -- nu-SVR (regression)
      param.svm_type = svm_parameter.C_SVC;
      // Other C-SVC parameters:
      // param.weight[] : sets the parameter C of class i to weight*c (default 1)
      // param.C : cost, set the parameter C of C-SVC (default 1)
      // Type of kernel
      //  0 -- linear: u'*v
      //  1 -- polynomial: (gamma*u'*v + coef0)^degree
      //  2 -- radial basis function: exp(-gamma*|u-v|^2)
      //  3 -- sigmoid: tanh(gamma*u'*v + coef0)
      param.kernel_type = svm_parameter.RBF;
      // Kernel gamma
      param.gamma = 2e9;
      // Cost
      param.C = 2e9;
      // Degree for poly
      param.degree = 3;
      // Coefficient for poly/sigmoid
      param.coef0 = 0.1;
      // Stopping criteria
      param.eps = 0.05;
      // Make larger for faster training
      param.cache_size = 800;
      // Number of label weights
      param.nr_weight = 2;
      // Weight labels
      param.weight_label = new int[] {1, -1};
      param.weight = new double[] {1, 1};
      // Nu
      param.nu = 0.1;
      // p
      param.p = 0.1;
      // Shrinking heuristic
      param.shrinking = 1;
      // Estimate probabilities
      param.probability = 0;

      // Train the SVM classifier
      svm.svm_set_print_string_function(
          new svm_print_interface() {

            @Override
            public void print(String s) {}
          });
      svm_model model = svm.svm_train(problem, param);
      models.add(model);
    }
  }
示例#9
0
  private void do_find_best_parameters(svm_problem svmProblem) {
    svm_parameter svmParam = getDefaultSvmParameters();
    setWeights(svmParam);

    int maxIter =
        ((int) Math.ceil(Math.abs((log2cEnd - log2cBegin) / log2cStep)) + 1)
            * ((int) Math.ceil(Math.abs((log2gEnd - log2gBegin) / log2gStep)) + 1);

    // Run the grid search in separate CV threads
    ExecutorService executorService = Executors.newFixedThreadPool(numberOfThreads);

    List<CvParams> cvParamsList = new ArrayList<CvParams>();

    for (double log2c = log2cBegin;
        (log2cBegin < log2cEnd && log2c <= log2cEnd)
            || (log2cBegin >= log2cEnd && log2c >= log2cEnd);
        log2c += log2cStep) {

      double c1 = Math.pow(2, log2c);

      for (double log2g = log2gBegin;
          (log2gBegin < log2gEnd && log2g <= log2gEnd)
              || (log2gBegin >= log2gEnd && log2g >= log2gEnd);
          log2g += log2gStep) {

        double gamma1 = Math.pow(2, log2g);

        svm_parameter svmParam1 = (svm_parameter) svmParam.clone();
        svmParam1.C = c1;
        svmParam1.gamma = gamma1;

        executorService.execute(
            new RunnableSvmCrossValidator(svmProblem, svmParam1, nrFold, cvParamsList));
      }
    }

    // now wait for all threads to complete by calling shutdown
    // note that this will NOT terminate the currently running threads, it just signals the thread
    // pool to closeWriter
    // once all work is completed
    executorService.shutdown();

    while (!executorService.isTerminated()) {
      try {
        Thread.sleep(1000);
      } catch (InterruptedException e) {
        // don't care if we get interrupted
      }

      // every second, report statistics
      logger.debug(
          String.format("%% complete: %5.2f", cvParamsList.size() / (double) maxIter * 100));
      CvParams best = getBestCvParams(cvParamsList);
      CvParams worst = getWorstcvParams(cvParamsList);
      if (best != null) {
        logger.debug("Best accuracy: " + best.accuracy);
        logger.debug("Best C:        " + best.c);
        logger.debug("Best Gamma:    " + best.gamma);
      }
      if (worst != null) {
        logger.debug("Worst accuracy: " + worst.accuracy);
      }
    }

    CvParams best = getBestCvParams(cvParamsList);
    CvParams worst = getWorstcvParams(cvParamsList);
    if (best != null) {
      logger.debug("Best accuracy: " + best.accuracy);
      logger.debug("Best C:        " + best.c);
      logger.debug("Best Gamma:    " + best.gamma);

      c = best.c;
      gamma = best.gamma;
    } else {
      logger.error("Best CV parameters is null.");
    }
    if (worst != null) {
      logger.debug("Worst accuracy: " + worst.accuracy);
    }
  }