예제 #1
0
  @Override
  public AnalysisCollector analyze(CollectionReader cr) throws AnalyzerFailureException {
    svm_problem svmProblem = loadData(cr);
    svm_problem sampled = null;
    if (findBestParameters) {
      if (sample < 1d) {
        logger.debug("Sampling.");
        sampled = do_sample(svmProblem);
      }
      logger.debug("Performing grid search.");
      do_find_best_parameters(sampled != null ? sampled : svmProblem);
    }
    svm_parameter svmParam = getDefaultSvmParameters();
    svmParam.probability = 1;
    svmParam.C = c;
    svmParam.gamma = gamma;
    setWeights(svmParam);

    logger.debug("Training with C=" + c + "  gamma=" + gamma);
    svm_model model = svm.svm_train(svmProblem, svmParam);

    logger.debug("Done!");
    return new SingletonAnalysisCollector(
        new LibSvmTrainerAnalysis(model, scaler, labelList, c, gamma));
  }
예제 #2
0
 private void setWeights(svm_parameter svmParam) {
   if (weights != null) {
     svmParam.nr_weight = weights.size();
     svmParam.weight_label = new int[weights.size()];
     svmParam.weight = new double[weights.size()];
     for (int i = 0; i < weights.size(); ++i) {
       svmParam.weight_label[i] = i;
       svmParam.weight[i] = weights.get(i);
     }
     logger.debug("Class weights: " + weights);
   }
 }
예제 #3
0
  /** Process the training and test files provided in the parameters file to the constructor. */
  public void process() {
    double[] outputs;
    double[] outputs2;
    Instance neighbor;
    double dist, mean;
    int actual;
    int[] N = new int[nneigh];
    double[] Ndist = new double[nneigh];
    boolean allNull;
    svm_problem SVMp = null;
    svm_parameter SVMparam = new svm_parameter();
    svm_model svr = null;
    svm_node SVMn[];
    double[] outputsCandidate = null;
    boolean same = true;
    Vector instancesSelected = new Vector();
    Vector instancesSelected2 = new Vector();

    // SVM PARAMETERS
    SVMparam.C = C;
    SVMparam.cache_size = 10; // 10MB of cache
    SVMparam.degree = degree;
    SVMparam.eps = eps;
    SVMparam.gamma = gamma;
    SVMparam.nr_weight = 0;
    SVMparam.nu = nu;
    SVMparam.p = p;
    SVMparam.shrinking = shrinking;
    SVMparam.probability = 0;
    if (kernelType.compareTo("LINEAR") == 0) {
      SVMparam.kernel_type = svm_parameter.LINEAR;
    } else if (kernelType.compareTo("POLY") == 0) {
      SVMparam.kernel_type = svm_parameter.POLY;
    } else if (kernelType.compareTo("RBF") == 0) {
      SVMparam.kernel_type = svm_parameter.RBF;
    } else if (kernelType.compareTo("SIGMOID") == 0) {
      SVMparam.kernel_type = svm_parameter.SIGMOID;
    }

    SVMparam.svm_type = svm_parameter.EPSILON_SVR;

    try {

      // Load in memory a dataset that contains a classification problem
      IS.readSet(input_train_name, true);
      int in = 0;
      int out = 0;

      ndatos = IS.getNumInstances();
      nvariables = Attributes.getNumAttributes();
      nentradas = Attributes.getInputNumAttributes();
      nsalidas = Attributes.getOutputNumAttributes();

      X = new String[ndatos][2]; // matrix with transformed data

      mostCommon = new String[nvariables];
      SVMp = new svm_problem();
      SVMp.l = ndatos;
      SVMp.y = new double[SVMp.l];
      SVMp.x = new svm_node[SVMp.l][nentradas + 1];
      for (int l = 0; l < SVMp.l; l++) {
        for (int n = 0; n < Attributes.getInputNumAttributes() + 1; n++) {
          SVMp.x[l][n] = new svm_node();
        }
      }

      for (int i = 0; i < ndatos; i++) {
        Instance inst = IS.getInstance(i);

        SVMp.y[i] = inst.getAllOutputValues()[0];
        for (int n = 0; n < Attributes.getInputNumAttributes(); n++) {
          SVMp.x[i][n].index = n;
          SVMp.x[i][n].value = inst.getAllInputValues()[n];
          SVMp.y[i] = inst.getAllOutputValues()[0];
        }
        // end of instance
        SVMp.x[i][nentradas].index = -1;
      }
      if (svm.svm_check_parameter(SVMp, SVMparam) != null) {
        System.out.println("SVM parameter error in training:");
        System.out.println(svm.svm_check_parameter(SVMp, SVMparam));
        System.exit(-1);
      }
      // train the SVM
      if (ndatos > 0) {
        svr = svm.svm_train(SVMp, SVMparam);
      }
      for (int i = 0; i < ndatos; i++) {
        Instance inst = IS.getInstance(i);
        X[i][0] = new String(String.valueOf(inst.getAllOutputValues()[0]));
        //			the values used for regression
        SVMn = new svm_node[Attributes.getInputNumAttributes() + 1];
        for (int n = 0; n < Attributes.getInputNumAttributes(); n++) {
          SVMn[n] = new svm_node();
          SVMn[n].index = n;
          SVMn[n].value = inst.getAllInputValues()[n];
        }
        SVMn[nentradas] = new svm_node();
        SVMn[nentradas].index = -1;
        // pedict the class
        X[i][1] = new String(String.valueOf((svm.svm_predict(svr, SVMn))));
      }
    } catch (Exception e) {
      System.out.println("Dataset exception = " + e);
      e.printStackTrace();
      System.exit(-1);
    }
    write_results(output_train_name);
    /** ************************************************************************************ */
    try {

      // Load in memory a dataset that contains a classification
      // problem
      IS.readSet(input_test_name, false);
      int in = 0;
      int out = 0;

      ndatos = IS.getNumInstances();
      nvariables = Attributes.getNumAttributes();
      nentradas = Attributes.getInputNumAttributes();
      nsalidas = Attributes.getOutputNumAttributes();

      X = new String[ndatos][2]; // matrix with transformed data
      // data

      mostCommon = new String[nvariables];

      for (int i = 0; i < ndatos; i++) {
        Instance inst = IS.getInstance(i);
        X[i][0] = new String(String.valueOf(inst.getAllOutputValues()[0]));

        SVMn = new svm_node[Attributes.getInputNumAttributes() + 1];
        for (int n = 0; n < Attributes.getInputNumAttributes(); n++) {
          SVMn[n] = new svm_node();
          SVMn[n].index = n;
          SVMn[n].value = inst.getAllInputValues()[n];
        }
        SVMn[nentradas] = new svm_node();
        SVMn[nentradas].index = -1;
        // pedict the class
        X[i][1] = new String(String.valueOf(svm.svm_predict(svr, SVMn)));
      }
    } catch (Exception e) {
      System.out.println("Dataset exception = " + e);
      e.printStackTrace();
      System.exit(-1);
    }
    System.out.println("escribiendo test");
    write_results(output_test_name);
  }
예제 #4
0
파일: svm.java 프로젝트: CPernet/CanlabCore
	public static svm_model svm_load_model(String model_file_name) throws IOException
	{
		BufferedReader fp = new BufferedReader(new FileReader(model_file_name));

		// read parameters

		svm_model model = new svm_model();
		svm_parameter param = new svm_parameter();
		model.param = param;
		model.label = null;
		model.nSV = null;

		while(true)
		{
			String cmd = fp.readLine();
			String arg = cmd.substring(cmd.indexOf(' ')+1);

			if(cmd.startsWith("svm_type"))
			{
				int i;
				for(i=0;i<svm_type_table.length;i++)
				{
					if(arg.indexOf(svm_type_table[i])!=-1)
					{
						param.svm_type=i;
						break;
					}
				}
				if(i == svm_type_table.length)
				{
					System.err.print("unknown svm type.\n");
					System.exit(1);
				}
			}
			else if(cmd.startsWith("kernel_type"))
			{
				int i;
				for(i=0;i<kernel_type_table.length;i++)
				{
					if(arg.indexOf(kernel_type_table[i])!=-1)
					{
						param.kernel_type=i;
						break;
					}
				}
				if(i == kernel_type_table.length)
				{
					System.err.print("unknown kernel function.\n");
					System.exit(1);
				}
			}
			else if(cmd.startsWith("degree"))
				param.degree = atof(arg);
			else if(cmd.startsWith("gamma"))
				param.gamma = atof(arg);
			else if(cmd.startsWith("coef0"))
				param.coef0 = atof(arg);
			else if(cmd.startsWith("nr_class"))
				model.nr_class = atoi(arg);
			else if(cmd.startsWith("total_sv"))
				model.l = atoi(arg);
			else if(cmd.startsWith("rho"))
			{
				int n = model.nr_class * (model.nr_class-1)/2;
				model.rho = new double[n];
				StringTokenizer st = new StringTokenizer(arg);
				for(int i=0;i<n;i++)
					model.rho[i] = atof(st.nextToken());
			}
			else if(cmd.startsWith("label"))
			{
				int n = model.nr_class;
				model.label = new int[n];
				StringTokenizer st = new StringTokenizer(arg);
				for(int i=0;i<n;i++)
					model.label[i] = atoi(st.nextToken());					
			}
			else if(cmd.startsWith("nr_sv"))
			{
				int n = model.nr_class;
				model.nSV = new int[n];
				StringTokenizer st = new StringTokenizer(arg);
				for(int i=0;i<n;i++)
					model.nSV[i] = atoi(st.nextToken());
			}
			else if(cmd.startsWith("SV"))
			{
				break;
			}
			else
			{
				System.err.print("unknown text in model file\n");
				System.exit(1);
			}
		}

		// read sv_coef and SV

		int m = model.nr_class - 1;
		int l = model.l;
		model.sv_coef = new double[m][l];
		model.SV = new svm_node[l][];

		for(int i=0;i<l;i++)
		{
			String line = fp.readLine();
			StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");

			for(int k=0;k<m;k++)
				model.sv_coef[k][i] = atof(st.nextToken());
			int n = st.countTokens()/2;
			model.SV[i] = new svm_node[n];
			for(int j=0;j<n;j++)
			{
				model.SV[i][j] = new svm_node();
				model.SV[i][j].index = atoi(st.nextToken());
				model.SV[i][j].value = atof(st.nextToken());
			}
		}

		fp.close();
		return model;
	}
예제 #5
0
 public svm_parameter getDefaultSvmParameters() {
   svm_parameter param = new svm_parameter();
   // default values
   param.svm_type = svm_parameter.C_SVC;
   param.kernel_type = svm_parameter.RBF;
   param.degree = 3;
   param.gamma = 0; // 1/num_features
   param.coef0 = 0;
   param.nu = 0.5;
   param.cache_size = 100;
   param.C = 1;
   param.eps = 1e-3;
   param.p = 0.1;
   param.shrinking = 1;
   param.probability = 0;
   param.nr_weight = 0;
   param.weight_label = new int[0];
   param.weight = new double[0];
   return param;
 }
예제 #6
0
  private void do_find_best_parameters(svm_problem svmProblem) {
    svm_parameter svmParam = getDefaultSvmParameters();
    setWeights(svmParam);

    int maxIter =
        ((int) Math.ceil(Math.abs((log2cEnd - log2cBegin) / log2cStep)) + 1)
            * ((int) Math.ceil(Math.abs((log2gEnd - log2gBegin) / log2gStep)) + 1);

    // Run the grid search in separate CV threads
    ExecutorService executorService = Executors.newFixedThreadPool(numberOfThreads);

    List<CvParams> cvParamsList = new ArrayList<CvParams>();

    for (double log2c = log2cBegin;
        (log2cBegin < log2cEnd && log2c <= log2cEnd)
            || (log2cBegin >= log2cEnd && log2c >= log2cEnd);
        log2c += log2cStep) {

      double c1 = Math.pow(2, log2c);

      for (double log2g = log2gBegin;
          (log2gBegin < log2gEnd && log2g <= log2gEnd)
              || (log2gBegin >= log2gEnd && log2g >= log2gEnd);
          log2g += log2gStep) {

        double gamma1 = Math.pow(2, log2g);

        svm_parameter svmParam1 = (svm_parameter) svmParam.clone();
        svmParam1.C = c1;
        svmParam1.gamma = gamma1;

        executorService.execute(
            new RunnableSvmCrossValidator(svmProblem, svmParam1, nrFold, cvParamsList));
      }
    }

    // now wait for all threads to complete by calling shutdown
    // note that this will NOT terminate the currently running threads, it just signals the thread
    // pool to closeWriter
    // once all work is completed
    executorService.shutdown();

    while (!executorService.isTerminated()) {
      try {
        Thread.sleep(1000);
      } catch (InterruptedException e) {
        // don't care if we get interrupted
      }

      // every second, report statistics
      logger.debug(
          String.format("%% complete: %5.2f", cvParamsList.size() / (double) maxIter * 100));
      CvParams best = getBestCvParams(cvParamsList);
      CvParams worst = getWorstcvParams(cvParamsList);
      if (best != null) {
        logger.debug("Best accuracy: " + best.accuracy);
        logger.debug("Best C:        " + best.c);
        logger.debug("Best Gamma:    " + best.gamma);
      }
      if (worst != null) {
        logger.debug("Worst accuracy: " + worst.accuracy);
      }
    }

    CvParams best = getBestCvParams(cvParamsList);
    CvParams worst = getWorstcvParams(cvParamsList);
    if (best != null) {
      logger.debug("Best accuracy: " + best.accuracy);
      logger.debug("Best C:        " + best.c);
      logger.debug("Best Gamma:    " + best.gamma);

      c = best.c;
      gamma = best.gamma;
    } else {
      logger.error("Best CV parameters is null.");
    }
    if (worst != null) {
      logger.debug("Worst accuracy: " + worst.accuracy);
    }
  }