コード例 #1
0
ファイル: LibSvmTrainer.java プロジェクト: rlxrlxrlx/nakala
  @Override
  public AnalysisCollector analyze(CollectionReader cr) throws AnalyzerFailureException {
    svm_problem svmProblem = loadData(cr);
    svm_problem sampled = null;
    if (findBestParameters) {
      if (sample < 1d) {
        logger.debug("Sampling.");
        sampled = do_sample(svmProblem);
      }
      logger.debug("Performing grid search.");
      do_find_best_parameters(sampled != null ? sampled : svmProblem);
    }
    svm_parameter svmParam = getDefaultSvmParameters();
    svmParam.probability = 1;
    svmParam.C = c;
    svmParam.gamma = gamma;
    setWeights(svmParam);

    logger.debug("Training with C=" + c + "  gamma=" + gamma);
    svm_model model = svm.svm_train(svmProblem, svmParam);

    logger.debug("Done!");
    return new SingletonAnalysisCollector(
        new LibSvmTrainerAnalysis(model, scaler, labelList, c, gamma));
  }
コード例 #2
0
ファイル: LibSvmTrainer.java プロジェクト: rlxrlxrlx/nakala
 public svm_parameter getDefaultSvmParameters() {
   svm_parameter param = new svm_parameter();
   // default values
   param.svm_type = svm_parameter.C_SVC;
   param.kernel_type = svm_parameter.RBF;
   param.degree = 3;
   param.gamma = 0; // 1/num_features
   param.coef0 = 0;
   param.nu = 0.5;
   param.cache_size = 100;
   param.C = 1;
   param.eps = 1e-3;
   param.p = 0.1;
   param.shrinking = 1;
   param.probability = 0;
   param.nr_weight = 0;
   param.weight_label = new int[0];
   param.weight = new double[0];
   return param;
 }
コード例 #3
0
ファイル: svmRegression.java プロジェクト: Navieclipse/KEEL
  /** Process the training and test files provided in the parameters file to the constructor. */
  public void process() {
    double[] outputs;
    double[] outputs2;
    Instance neighbor;
    double dist, mean;
    int actual;
    int[] N = new int[nneigh];
    double[] Ndist = new double[nneigh];
    boolean allNull;
    svm_problem SVMp = null;
    svm_parameter SVMparam = new svm_parameter();
    svm_model svr = null;
    svm_node SVMn[];
    double[] outputsCandidate = null;
    boolean same = true;
    Vector instancesSelected = new Vector();
    Vector instancesSelected2 = new Vector();

    // SVM PARAMETERS
    SVMparam.C = C;
    SVMparam.cache_size = 10; // 10MB of cache
    SVMparam.degree = degree;
    SVMparam.eps = eps;
    SVMparam.gamma = gamma;
    SVMparam.nr_weight = 0;
    SVMparam.nu = nu;
    SVMparam.p = p;
    SVMparam.shrinking = shrinking;
    SVMparam.probability = 0;
    if (kernelType.compareTo("LINEAR") == 0) {
      SVMparam.kernel_type = svm_parameter.LINEAR;
    } else if (kernelType.compareTo("POLY") == 0) {
      SVMparam.kernel_type = svm_parameter.POLY;
    } else if (kernelType.compareTo("RBF") == 0) {
      SVMparam.kernel_type = svm_parameter.RBF;
    } else if (kernelType.compareTo("SIGMOID") == 0) {
      SVMparam.kernel_type = svm_parameter.SIGMOID;
    }

    SVMparam.svm_type = svm_parameter.EPSILON_SVR;

    try {

      // Load in memory a dataset that contains a classification problem
      IS.readSet(input_train_name, true);
      int in = 0;
      int out = 0;

      ndatos = IS.getNumInstances();
      nvariables = Attributes.getNumAttributes();
      nentradas = Attributes.getInputNumAttributes();
      nsalidas = Attributes.getOutputNumAttributes();

      X = new String[ndatos][2]; // matrix with transformed data

      mostCommon = new String[nvariables];
      SVMp = new svm_problem();
      SVMp.l = ndatos;
      SVMp.y = new double[SVMp.l];
      SVMp.x = new svm_node[SVMp.l][nentradas + 1];
      for (int l = 0; l < SVMp.l; l++) {
        for (int n = 0; n < Attributes.getInputNumAttributes() + 1; n++) {
          SVMp.x[l][n] = new svm_node();
        }
      }

      for (int i = 0; i < ndatos; i++) {
        Instance inst = IS.getInstance(i);

        SVMp.y[i] = inst.getAllOutputValues()[0];
        for (int n = 0; n < Attributes.getInputNumAttributes(); n++) {
          SVMp.x[i][n].index = n;
          SVMp.x[i][n].value = inst.getAllInputValues()[n];
          SVMp.y[i] = inst.getAllOutputValues()[0];
        }
        // end of instance
        SVMp.x[i][nentradas].index = -1;
      }
      if (svm.svm_check_parameter(SVMp, SVMparam) != null) {
        System.out.println("SVM parameter error in training:");
        System.out.println(svm.svm_check_parameter(SVMp, SVMparam));
        System.exit(-1);
      }
      // train the SVM
      if (ndatos > 0) {
        svr = svm.svm_train(SVMp, SVMparam);
      }
      for (int i = 0; i < ndatos; i++) {
        Instance inst = IS.getInstance(i);
        X[i][0] = new String(String.valueOf(inst.getAllOutputValues()[0]));
        //			the values used for regression
        SVMn = new svm_node[Attributes.getInputNumAttributes() + 1];
        for (int n = 0; n < Attributes.getInputNumAttributes(); n++) {
          SVMn[n] = new svm_node();
          SVMn[n].index = n;
          SVMn[n].value = inst.getAllInputValues()[n];
        }
        SVMn[nentradas] = new svm_node();
        SVMn[nentradas].index = -1;
        // pedict the class
        X[i][1] = new String(String.valueOf((svm.svm_predict(svr, SVMn))));
      }
    } catch (Exception e) {
      System.out.println("Dataset exception = " + e);
      e.printStackTrace();
      System.exit(-1);
    }
    write_results(output_train_name);
    /** ************************************************************************************ */
    try {

      // Load in memory a dataset that contains a classification
      // problem
      IS.readSet(input_test_name, false);
      int in = 0;
      int out = 0;

      ndatos = IS.getNumInstances();
      nvariables = Attributes.getNumAttributes();
      nentradas = Attributes.getInputNumAttributes();
      nsalidas = Attributes.getOutputNumAttributes();

      X = new String[ndatos][2]; // matrix with transformed data
      // data

      mostCommon = new String[nvariables];

      for (int i = 0; i < ndatos; i++) {
        Instance inst = IS.getInstance(i);
        X[i][0] = new String(String.valueOf(inst.getAllOutputValues()[0]));

        SVMn = new svm_node[Attributes.getInputNumAttributes() + 1];
        for (int n = 0; n < Attributes.getInputNumAttributes(); n++) {
          SVMn[n] = new svm_node();
          SVMn[n].index = n;
          SVMn[n].value = inst.getAllInputValues()[n];
        }
        SVMn[nentradas] = new svm_node();
        SVMn[nentradas].index = -1;
        // pedict the class
        X[i][1] = new String(String.valueOf(svm.svm_predict(svr, SVMn)));
      }
    } catch (Exception e) {
      System.out.println("Dataset exception = " + e);
      e.printStackTrace();
      System.exit(-1);
    }
    System.out.println("escribiendo test");
    write_results(output_test_name);
  }