コード例 #1
0
  @Override
  public void train(List<TrainingSample<T>> l) {

    long tim = System.currentTimeMillis();
    debug.println(
        2, "training on " + listOfKernels.size() + " kernels and " + l.size() + " examples");

    // 1. init kernels
    ArrayList<SimpleCacheKernel<T>> kernels = new ArrayList<SimpleCacheKernel<T>>();
    ArrayList<Double> weights = new ArrayList<Double>();

    // normalize to cst trace and init weights to 1/N
    for (int i = 0; i < listOfKernels.size(); i++) {
      SimpleCacheKernel<T> sck = new SimpleCacheKernel<T>(listOfKernels.get(i), l);
      sck.setName(listOfKernels.get(i).toString());
      double[][] matrix = sck.getKernelMatrix(l);
      // compute trace
      double trace = 0.;
      for (int x = 0; x < matrix.length; x++) {
        trace += matrix[x][x];
      }
      // divide by trace
      for (int x = 0; x < matrix.length; x++)
        for (int y = x; y < matrix.length; y++) {
          matrix[x][y] *= matrix.length / (double) trace;
          matrix[y][x] = matrix[x][y];
        }
      kernels.add(sck);
      weights.add(Math.pow(1 / (double) listOfKernels.size(), 1 / (double) p_norm));
      debug.println(3, "kernel : " + sck + " weight : " + weights.get(i));
    }

    // 1 train first svm
    ThreadedSumKernel<T> tsk = new ThreadedSumKernel<T>();
    for (int i = 0; i < kernels.size(); i++) tsk.addKernel(kernels.get(i), weights.get(i));
    if (svm == null) svm = new SMOSVM<T>(null);
    svm.setKernel(tsk);
    svm.setC(C);
    svm.train(l);

    // 2. big loop
    double gap = 0;
    long max = 100000; // less than 10k iterations
    do {
      debug.println(3, "weights : " + weights);
      // compute sum kernel
      tsk = new ThreadedSumKernel<T>();
      for (int i = 0; i < kernels.size(); i++) tsk.addKernel(kernels.get(i), weights.get(i));

      // train svm
      svm.setKernel(tsk);
      svm.train(l);

      // compute sum of example weights and gradient direction
      double suma = computeSumAlpha();
      double[] grad = computeGradBeta(kernels, l);

      // perform one step
      double objEvol = performMKLStep(suma, grad, kernels, weights, l);

      if (objEvol < 0) {
        debug.println(1, "Error, performMKLStep return wrong value");
        System.exit(0);
        ;
      }
      gap = 1 - objEvol;

      // compute norm
      double norm = 0;
      for (int i = 0; i < weights.size(); i++) norm += Math.pow(weights.get(i), p_norm);
      norm = Math.pow(norm, -1 / (double) p_norm);

      debug.println(1, "objective_gap : " + gap + " norm : " + norm);
      max--;
    } while (gap >= stopGap && max > 0);

    // 3. save weights
    listOfKernelWeights.clear();
    listOfKernelWeights.addAll(weights);

    // 4. retrain svm
    // compute sum kernel
    tsk = new ThreadedSumKernel<T>();
    for (int i = 0; i < kernels.size(); i++)
      tsk.addKernel(listOfKernels.get(i), listOfKernelWeights.get(i));
    // train svm
    svm.setKernel(tsk);
    svm.train(l);

    // 5. save examples weights
    listOfExamples = new ArrayList<TrainingSample<T>>();
    listOfExamples.addAll(l);
    listOfExampleWeights.clear();
    for (double d : svm.getAlphas()) listOfExampleWeights.add(d);

    debug.println(1, "MKL trained in " + (System.currentTimeMillis() - tim) + " milis.");
  }