コード例 #1
0
  /** calcul du gradient en chaque beta */
  private double[] computeGradBeta(
      ArrayList<SimpleCacheKernel<T>> kernels, List<TrainingSample<T>> l) {
    double grad[] = new double[kernels.size()];

    for (int i = 0; i < kernels.size(); i++) {
      double matrix[][] = kernels.get(i).getKernelMatrix(l);
      double a[] = svm.getAlphas();

      for (int x = 0; x < matrix.length; x++) {
        int l1 = l.get(x).label;
        for (int y = 0; y < matrix.length; y++) {
          int l2 = l.get(y).label;
          grad[i] += 0.5 * l1 * l2 * a[x] * a[y] * matrix[x][y];
        }
      }
    }

    debug.print(3, "gradDir : " + Arrays.toString(grad));

    return grad;
  }
コード例 #2
0
  @Override
  public void train(List<TrainingSample<T>> l) {

    long tim = System.currentTimeMillis();
    debug.println(
        2, "training on " + listOfKernels.size() + " kernels and " + l.size() + " examples");

    // 1. init kernels
    ArrayList<SimpleCacheKernel<T>> kernels = new ArrayList<SimpleCacheKernel<T>>();
    ArrayList<Double> weights = new ArrayList<Double>();

    // normalize to cst trace and init weights to 1/N
    for (int i = 0; i < listOfKernels.size(); i++) {
      SimpleCacheKernel<T> sck = new SimpleCacheKernel<T>(listOfKernels.get(i), l);
      sck.setName(listOfKernels.get(i).toString());
      double[][] matrix = sck.getKernelMatrix(l);
      // compute trace
      double trace = 0.;
      for (int x = 0; x < matrix.length; x++) {
        trace += matrix[x][x];
      }
      // divide by trace
      for (int x = 0; x < matrix.length; x++)
        for (int y = x; y < matrix.length; y++) {
          matrix[x][y] *= matrix.length / (double) trace;
          matrix[y][x] = matrix[x][y];
        }
      kernels.add(sck);
      weights.add(Math.pow(1 / (double) listOfKernels.size(), 1 / (double) p_norm));
      debug.println(3, "kernel : " + sck + " weight : " + weights.get(i));
    }

    // 1 train first svm
    ThreadedSumKernel<T> tsk = new ThreadedSumKernel<T>();
    for (int i = 0; i < kernels.size(); i++) tsk.addKernel(kernels.get(i), weights.get(i));
    if (svm == null) svm = new SMOSVM<T>(null);
    svm.setKernel(tsk);
    svm.setC(C);
    svm.train(l);

    // 2. big loop
    double gap = 0;
    long max = 100000; // less than 10k iterations
    do {
      debug.println(3, "weights : " + weights);
      // compute sum kernel
      tsk = new ThreadedSumKernel<T>();
      for (int i = 0; i < kernels.size(); i++) tsk.addKernel(kernels.get(i), weights.get(i));

      // train svm
      svm.setKernel(tsk);
      svm.train(l);

      // compute sum of example weights and gradient direction
      double suma = computeSumAlpha();
      double[] grad = computeGradBeta(kernels, l);

      // perform one step
      double objEvol = performMKLStep(suma, grad, kernels, weights, l);

      if (objEvol < 0) {
        debug.println(1, "Error, performMKLStep return wrong value");
        System.exit(0);
        ;
      }
      gap = 1 - objEvol;

      // compute norm
      double norm = 0;
      for (int i = 0; i < weights.size(); i++) norm += Math.pow(weights.get(i), p_norm);
      norm = Math.pow(norm, -1 / (double) p_norm);

      debug.println(1, "objective_gap : " + gap + " norm : " + norm);
      max--;
    } while (gap >= stopGap && max > 0);

    // 3. save weights
    listOfKernelWeights.clear();
    listOfKernelWeights.addAll(weights);

    // 4. retrain svm
    // compute sum kernel
    tsk = new ThreadedSumKernel<T>();
    for (int i = 0; i < kernels.size(); i++)
      tsk.addKernel(listOfKernels.get(i), listOfKernelWeights.get(i));
    // train svm
    svm.setKernel(tsk);
    svm.train(l);

    // 5. save examples weights
    listOfExamples = new ArrayList<TrainingSample<T>>();
    listOfExamples.addAll(l);
    listOfExampleWeights.clear();
    for (double d : svm.getAlphas()) listOfExampleWeights.add(d);

    debug.println(1, "MKL trained in " + (System.currentTimeMillis() - tim) + " milis.");
  }
コード例 #3
0
 @Override
 public Kernel<T> getKernel() {
   return svm.getKernel();
 }
コード例 #4
0
 /* (non-Javadoc)
  * @see fr.lip6.jkernelmachines.classifier.KernelSVM#getAlphas()
  */
 @Override
 public double[] getAlphas() {
   return svm.getAlphas();
 }
コード例 #5
0
  @Override
  public double valueOf(T e) {

    return svm.valueOf(e);
  }
コード例 #6
0
 /** compute the sum of examples weights */
 private double computeSumAlpha() {
   double sum = 0;
   double[] a = svm.getAlphas();
   for (double d : a) sum += Math.abs(d);
   return sum;
 }