Ejemplo n.º 1
0
  @Override
  public AnalysisCollector analyze(CollectionReader cr) throws AnalyzerFailureException {
    svm_problem svmProblem = loadData(cr);
    svm_problem sampled = null;
    if (findBestParameters) {
      if (sample < 1d) {
        logger.debug("Sampling.");
        sampled = do_sample(svmProblem);
      }
      logger.debug("Performing grid search.");
      do_find_best_parameters(sampled != null ? sampled : svmProblem);
    }
    svm_parameter svmParam = getDefaultSvmParameters();
    svmParam.probability = 1;
    svmParam.C = c;
    svmParam.gamma = gamma;
    setWeights(svmParam);

    logger.debug("Training with C=" + c + "  gamma=" + gamma);
    svm_model model = svm.svm_train(svmProblem, svmParam);

    logger.debug("Done!");
    return new SingletonAnalysisCollector(
        new LibSvmTrainerAnalysis(model, scaler, labelList, c, gamma));
  }
Ejemplo n.º 2
0
  public static void svmTrain(ArrayList<ArrayList<Double>> train) {
    svm_problem prob = new svm_problem();
    int dataCount = train.size();
    prob.y = new double[dataCount];
    prob.l = dataCount;
    prob.x = new svm_node[dataCount][];

    for (int i = 0; i < dataCount; i++) {
      ArrayList<Double> features = train.get(i);
      prob.x[i] = new svm_node[features.size() - 1];
      for (int j = 1; j < features.size(); j++) {
        svm_node node = new svm_node();
        node.index = j - 1;
        node.value = features.get(j);
        prob.x[i][j - 1] = node;
      }
      prob.y[i] = features.get(0);
    }

    svm_parameter param = new svm_parameter();
    param.probability = 1;
    param.gamma = 0.5;
    param.nu = 0.5;
    param.C = 1;
    param.svm_type = svm_parameter.NU_SVC;
    param.kernel_type = svm_parameter.LINEAR;
    param.cache_size = 20000;
    param.eps = 0.001;

    _model = svm.svm_train(prob, param);
  }
Ejemplo n.º 3
0
 private void setWeights(svm_parameter svmParam) {
   if (weights != null) {
     svmParam.nr_weight = weights.size();
     svmParam.weight_label = new int[weights.size()];
     svmParam.weight = new double[weights.size()];
     for (int i = 0; i < weights.size(); ++i) {
       svmParam.weight_label[i] = i;
       svmParam.weight[i] = weights.get(i);
     }
     logger.debug("Class weights: " + weights);
   }
 }
Ejemplo n.º 4
0
  public static void trainClassifier(
      List<TrainingSample<BxZoneLabel>> trainingElements, String output)
      throws AnalysisException, IOException {
    SVMHeaderLinesClassifier contentFilter = new SVMHeaderLinesClassifier();
    svm_parameter param = SVMZoneClassifier.getDefaultParam();
    param.gamma = BEST_GAMMA;
    param.C = BEST_C;
    param.kernel_type = svm_parameter.RBF;

    contentFilter.setParameter(param);
    contentFilter.buildClassifier(trainingElements);
    contentFilter.saveModel(output);
  }
Ejemplo n.º 5
0
  private void read_problem() throws IOException {
    BufferedReader fp = new BufferedReader(new FileReader(input_file_name));
    Vector<Double> vy = new Vector<Double>();
    Vector<svm_node[]> vx = new Vector<svm_node[]>();
    int max_index = 0;

    while (true) {
      String line = fp.readLine();
      if (line == null) break;

      StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:");

      vy.addElement(atof(st.nextToken()));
      int m = st.countTokens() / 2;
      svm_node[] x = new svm_node[m];
      for (int j = 0; j < m; j++) {
        x[j] = new svm_node();
        x[j].index = atoi(st.nextToken());
        x[j].value = atof(st.nextToken());
      }
      if (m > 0) max_index = Math.max(max_index, x[m - 1].index);
      vx.addElement(x);
    }

    prob = new svm_problem();
    prob.l = vy.size();
    prob.x = new svm_node[prob.l][];
    for (int i = 0; i < prob.l; i++) prob.x[i] = vx.elementAt(i);
    prob.y = new double[prob.l];
    for (int i = 0; i < prob.l; i++) prob.y[i] = vy.elementAt(i);

    if (param.gamma == 0 && max_index > 0) param.gamma = 1.0 / max_index;

    fp.close();
  }
Ejemplo n.º 6
0
  public static void train(svm_problem svmProblem) {
    defaultSVMParameter = new DefaultSVMParameter(LoadFeatureFile.f);
    svmParameter = defaultSVMParameter.svmParameter;
    filename += ".model";

    // Change the parameters with the desired C and Gamma
    svmParameter.C = 0.125;
    svmParameter.gamma = 0.5;
    // Train the SVM
    svmModel = svm.svm_train(svmProblem, svmParameter);

    try {
      svm.svm_save_model(filename, svmModel);
    } catch (IOException e) {
      e.printStackTrace();
    }
  }
Ejemplo n.º 7
0
  public SVM train(InstanceList trainingList) {
    svm_problem problem = new svm_problem();
    problem.l = trainingList.size();
    problem.x = new svm_node[problem.l][];
    problem.y = new double[problem.l];

    for (int i = 0; i < trainingList.size(); i++) {
      Instance instance = trainingList.get(i);
      svm_node[] input = SVM.getSvmNodes(instance);
      if (input == null) {
        continue;
      }
      int labelIndex = ((Label) instance.getTarget()).getIndex();
      problem.x[i] = input;
      problem.y[i] = labelIndex;
    }

    int max_index = trainingList.getDataAlphabet().size();

    if (param.gamma == 0 && max_index > 0) {
      param.gamma = 1.0 / max_index;
    }

    // int numLabels = trainingList.getTargetAlphabet().size();
    // int[] weight_label = new int[numLabels];
    // double[] weight = trainingList.targetLabelDistribution().getValues();
    // double minValue = Double.MAX_VALUE;
    //
    // for (int i = 0; i < weight.length; i++) {
    // if (minValue > weight[i]) {
    // minValue = weight[i];
    // }
    // }
    //
    // for (int i = 0; i < weight.length; i++) {
    // weight_label[i] = i;
    // weight[i] = weight[i] / minValue;
    // }
    //
    // param.weight_label = weight_label;
    // param.weight = weight;

    String error_msg = svm.svm_check_parameter(problem, param);

    if (error_msg != null) {
      System.err.print("Error: " + error_msg + "\n");
      System.exit(1);
    }

    svm_model model = svm.svm_train(problem, param);

    classifier = new SVM(model, trainingList.getPipe());

    return classifier;
  }
  public svm_parameter setupSVM() {

    svm_parameter param = new svm_parameter();
    param.probability = 0;
    param.gamma = 0.5;
    // param.nu = 0.5;
    param.C = 0.5;
    param.svm_type = svm_parameter.EPSILON_SVR;
    param.kernel_type = svm_parameter.POLY;
    param.cache_size = 20000;
    param.eps = 0.001;

    return param;
  }
Ejemplo n.º 9
0
 public SVMTrainer() {
   svm_parameter param = new svm_parameter();
   param.svm_type = svm_parameter.C_SVC;
   param.kernel_type = svm_parameter.RBF;
   param.degree = 3;
   param.gamma = 0; // 1/num_features
   param.coef0 = 0;
   param.nu = 0.5;
   param.cache_size = 100;
   param.C = 1;
   param.eps = 1e-3;
   param.p = 0.1;
   param.shrinking = 0;
   param.probability = 1;
   param.nr_weight = 0;
   param.weight_label = new int[0];
   param.weight = new double[0];
   setParameter(param);
 }
  @Override
  public void train(List<Input> inputs) throws Exception {
    // What is the balance of the data set?
    Map<Integer, Integer> labelCounts = new HashMap<Integer, Integer>();
    for (Input input : inputs) {
      int thisLabel = (input.label == label ? 1 : -1);
      if (!labelCounts.containsKey(thisLabel)) {
        labelCounts.put(thisLabel, 1);
      } else {
        labelCounts.put(thisLabel, labelCounts.get(thisLabel) + 1);
      }
    }

    // How many models do we need to train?
    int modelCount =
        (int) Math.ceil(labelCounts.get(-1).intValue() / labelCounts.get(1).intValue());

    for (int i = 0; i < modelCount; i++) {
      svm_problem problem = new svm_problem();
      // l appears to be the number of training samples
      problem.l = labelCounts.get(1) * 2;
      // The features of each sample
      svm_node[][] svmTrainingSamples = new svm_node[problem.l][];
      double[] labels = new double[problem.l];
      int count = 0;
      // Get all of the 1-labeled inputs
      for (Input input : inputs) {
        if (input.label == label) {
          svmTrainingSamples[count] = input.getSVMFeatures();
          labels[count++] = 1;
        }
      }
      // Randomly get -1-labeled inputs
      for (int j = 0; j < labelCounts.get(1); j++) {
        Input input = inputs.get(MathUtils.RANDOM.nextInt(inputs.size()));
        if (input.label == label) {
          j--;
          continue;
        }
        svmTrainingSamples[count] = input.getSVMFeatures();
        labels[count++] = -1;
      }
      problem.x = svmTrainingSamples;
      // y is probably the labels to each sample
      problem.y = labels;

      // Set the training parameters
      svm_parameter param = new svm_parameter();
      // SVM Type
      //  0 -- C-SVC (classification)
      //  1 -- nu-SVC (classification)
      //  2 -- one-class SVM
      //  3 -- epsilon-SVR (regression)
      //  4 -- nu-SVR (regression)
      param.svm_type = svm_parameter.C_SVC;
      // Other C-SVC parameters:
      // param.weight[] : sets the parameter C of class i to weight*c (default 1)
      // param.C : cost, set the parameter C of C-SVC (default 1)
      // Type of kernel
      //  0 -- linear: u'*v
      //  1 -- polynomial: (gamma*u'*v + coef0)^degree
      //  2 -- radial basis function: exp(-gamma*|u-v|^2)
      //  3 -- sigmoid: tanh(gamma*u'*v + coef0)
      param.kernel_type = svm_parameter.RBF;
      // Kernel gamma
      param.gamma = 2e9;
      // Cost
      param.C = 2e9;
      // Degree for poly
      param.degree = 3;
      // Coefficient for poly/sigmoid
      param.coef0 = 0.1;
      // Stopping criteria
      param.eps = 0.05;
      // Make larger for faster training
      param.cache_size = 800;
      // Number of label weights
      param.nr_weight = 2;
      // Weight labels
      param.weight_label = new int[] {1, -1};
      param.weight = new double[] {1, 1};
      // Nu
      param.nu = 0.1;
      // p
      param.p = 0.1;
      // Shrinking heuristic
      param.shrinking = 1;
      // Estimate probabilities
      param.probability = 0;

      // Train the SVM classifier
      svm.svm_set_print_string_function(
          new svm_print_interface() {

            @Override
            public void print(String s) {}
          });
      svm_model model = svm.svm_train(problem, param);
      models.add(model);
    }
  }
Ejemplo n.º 11
0
 public svm_parameter getDefaultSvmParameters() {
   svm_parameter param = new svm_parameter();
   // default values
   param.svm_type = svm_parameter.C_SVC;
   param.kernel_type = svm_parameter.RBF;
   param.degree = 3;
   param.gamma = 0; // 1/num_features
   param.coef0 = 0;
   param.nu = 0.5;
   param.cache_size = 100;
   param.C = 1;
   param.eps = 1e-3;
   param.p = 0.1;
   param.shrinking = 1;
   param.probability = 0;
   param.nr_weight = 0;
   param.weight_label = new int[0];
   param.weight = new double[0];
   return param;
 }
Ejemplo n.º 12
0
  private void do_find_best_parameters(svm_problem svmProblem) {
    svm_parameter svmParam = getDefaultSvmParameters();
    setWeights(svmParam);

    int maxIter =
        ((int) Math.ceil(Math.abs((log2cEnd - log2cBegin) / log2cStep)) + 1)
            * ((int) Math.ceil(Math.abs((log2gEnd - log2gBegin) / log2gStep)) + 1);

    // Run the grid search in separate CV threads
    ExecutorService executorService = Executors.newFixedThreadPool(numberOfThreads);

    List<CvParams> cvParamsList = new ArrayList<CvParams>();

    for (double log2c = log2cBegin;
        (log2cBegin < log2cEnd && log2c <= log2cEnd)
            || (log2cBegin >= log2cEnd && log2c >= log2cEnd);
        log2c += log2cStep) {

      double c1 = Math.pow(2, log2c);

      for (double log2g = log2gBegin;
          (log2gBegin < log2gEnd && log2g <= log2gEnd)
              || (log2gBegin >= log2gEnd && log2g >= log2gEnd);
          log2g += log2gStep) {

        double gamma1 = Math.pow(2, log2g);

        svm_parameter svmParam1 = (svm_parameter) svmParam.clone();
        svmParam1.C = c1;
        svmParam1.gamma = gamma1;

        executorService.execute(
            new RunnableSvmCrossValidator(svmProblem, svmParam1, nrFold, cvParamsList));
      }
    }

    // now wait for all threads to complete by calling shutdown
    // note that this will NOT terminate the currently running threads, it just signals the thread
    // pool to closeWriter
    // once all work is completed
    executorService.shutdown();

    while (!executorService.isTerminated()) {
      try {
        Thread.sleep(1000);
      } catch (InterruptedException e) {
        // don't care if we get interrupted
      }

      // every second, report statistics
      logger.debug(
          String.format("%% complete: %5.2f", cvParamsList.size() / (double) maxIter * 100));
      CvParams best = getBestCvParams(cvParamsList);
      CvParams worst = getWorstcvParams(cvParamsList);
      if (best != null) {
        logger.debug("Best accuracy: " + best.accuracy);
        logger.debug("Best C:        " + best.c);
        logger.debug("Best Gamma:    " + best.gamma);
      }
      if (worst != null) {
        logger.debug("Worst accuracy: " + worst.accuracy);
      }
    }

    CvParams best = getBestCvParams(cvParamsList);
    CvParams worst = getWorstcvParams(cvParamsList);
    if (best != null) {
      logger.debug("Best accuracy: " + best.accuracy);
      logger.debug("Best C:        " + best.c);
      logger.debug("Best Gamma:    " + best.gamma);

      c = best.c;
      gamma = best.gamma;
    } else {
      logger.error("Best CV parameters is null.");
    }
    if (worst != null) {
      logger.debug("Worst accuracy: " + worst.accuracy);
    }
  }