/** calcul du gradient en chaque beta */
  private double[] computeGradBeta(
      ArrayList<SimpleCacheKernel<T>> kernels, List<TrainingSample<T>> l) {
    double grad[] = new double[kernels.size()];

    for (int i = 0; i < kernels.size(); i++) {
      double matrix[][] = kernels.get(i).getKernelMatrix(l);
      double a[] = svm.getAlphas();

      for (int x = 0; x < matrix.length; x++) {
        int l1 = l.get(x).label;
        for (int y = 0; y < matrix.length; y++) {
          int l2 = l.get(y).label;
          grad[i] += 0.5 * l1 * l2 * a[x] * a[y] * matrix[x][y];
        }
      }
    }

    debug.print(3, "gradDir : " + Arrays.toString(grad));

    return grad;
  }
  @Override
  public void train(List<TrainingSample<T>> l) {

    long tim = System.currentTimeMillis();
    debug.println(
        2, "training on " + listOfKernels.size() + " kernels and " + l.size() + " examples");

    // 1. init kernels
    ArrayList<SimpleCacheKernel<T>> kernels = new ArrayList<SimpleCacheKernel<T>>();
    ArrayList<Double> weights = new ArrayList<Double>();

    // normalize to cst trace and init weights to 1/N
    for (int i = 0; i < listOfKernels.size(); i++) {
      SimpleCacheKernel<T> sck = new SimpleCacheKernel<T>(listOfKernels.get(i), l);
      sck.setName(listOfKernels.get(i).toString());
      double[][] matrix = sck.getKernelMatrix(l);
      // compute trace
      double trace = 0.;
      for (int x = 0; x < matrix.length; x++) {
        trace += matrix[x][x];
      }
      // divide by trace
      for (int x = 0; x < matrix.length; x++)
        for (int y = x; y < matrix.length; y++) {
          matrix[x][y] *= matrix.length / (double) trace;
          matrix[y][x] = matrix[x][y];
        }
      kernels.add(sck);
      weights.add(Math.pow(1 / (double) listOfKernels.size(), 1 / (double) p_norm));
      debug.println(3, "kernel : " + sck + " weight : " + weights.get(i));
    }

    // 1 train first svm
    ThreadedSumKernel<T> tsk = new ThreadedSumKernel<T>();
    for (int i = 0; i < kernels.size(); i++) tsk.addKernel(kernels.get(i), weights.get(i));
    if (svm == null) svm = new SMOSVM<T>(null);
    svm.setKernel(tsk);
    svm.setC(C);
    svm.train(l);

    // 2. big loop
    double gap = 0;
    long max = 100000; // less than 10k iterations
    do {
      debug.println(3, "weights : " + weights);
      // compute sum kernel
      tsk = new ThreadedSumKernel<T>();
      for (int i = 0; i < kernels.size(); i++) tsk.addKernel(kernels.get(i), weights.get(i));

      // train svm
      svm.setKernel(tsk);
      svm.train(l);

      // compute sum of example weights and gradient direction
      double suma = computeSumAlpha();
      double[] grad = computeGradBeta(kernels, l);

      // perform one step
      double objEvol = performMKLStep(suma, grad, kernels, weights, l);

      if (objEvol < 0) {
        debug.println(1, "Error, performMKLStep return wrong value");
        System.exit(0);
        ;
      }
      gap = 1 - objEvol;

      // compute norm
      double norm = 0;
      for (int i = 0; i < weights.size(); i++) norm += Math.pow(weights.get(i), p_norm);
      norm = Math.pow(norm, -1 / (double) p_norm);

      debug.println(1, "objective_gap : " + gap + " norm : " + norm);
      max--;
    } while (gap >= stopGap && max > 0);

    // 3. save weights
    listOfKernelWeights.clear();
    listOfKernelWeights.addAll(weights);

    // 4. retrain svm
    // compute sum kernel
    tsk = new ThreadedSumKernel<T>();
    for (int i = 0; i < kernels.size(); i++)
      tsk.addKernel(listOfKernels.get(i), listOfKernelWeights.get(i));
    // train svm
    svm.setKernel(tsk);
    svm.train(l);

    // 5. save examples weights
    listOfExamples = new ArrayList<TrainingSample<T>>();
    listOfExamples.addAll(l);
    listOfExampleWeights.clear();
    for (double d : svm.getAlphas()) listOfExampleWeights.add(d);

    debug.println(1, "MKL trained in " + (System.currentTimeMillis() - tim) + " milis.");
  }
  private double performMKLStep(
      double suma,
      double[] grad,
      ArrayList<SimpleCacheKernel<T>> kernels,
      ArrayList<Double> weights,
      List<TrainingSample<T>> l) {
    debug.print(2, ".");
    // compute objective function
    double oldObjective = +suma;
    for (int i = 0; i < grad.length; i++) {
      oldObjective -= weights.get(i) * grad[i];
    }
    debug.println(3, "oldObjective : " + oldObjective + " sumAlpha : " + suma);

    // compute optimal step
    double newBeta[] = new double[grad.length];

    for (int i = 0; i < grad.length; i++) {
      if (grad[i] >= 0 && weights.get(i) >= 0) {
        newBeta[i] = grad[i] * weights.get(i) * weights.get(i) / p_norm;
        newBeta[i] = Math.pow(newBeta[i], 1 / ((double) 1 + p_norm));
      } else newBeta[i] = 0;
    }

    // normalize
    double norm = 0;
    for (int i = 0; i < newBeta.length; i++) norm += Math.pow(newBeta[i], p_norm);
    norm = Math.pow(norm, -1 / (double) p_norm);
    if (norm < 0) {
      debug.println(1, "Error normalization, norm < 0");
      return -1;
    }
    for (int i = 0; i < newBeta.length; i++) newBeta[i] *= norm;

    // regularize and renormalize
    double R = 0;
    for (int i = 0; i < kernels.size(); i++) R += Math.pow(weights.get(i) - newBeta[i], 2);
    R = Math.sqrt(R / (double) p_norm) * eps_regul;
    if (R < 0) {
      debug.println(1, "Error regularization, R < 0");
      return -1;
    }
    norm = 0;
    for (int i = 0; i < kernels.size(); i++) {
      newBeta[i] += R;
      if (newBeta[i] < num_cleaning) newBeta[i] = 0;
      norm += Math.pow(newBeta[i], p_norm);
    }
    norm = Math.pow(norm, -1 / (double) p_norm);
    if (norm < 0) {
      debug.println(1, "Error normalization, norm < 0");
      return -1;
    }
    for (int i = 0; i < newBeta.length; i++) newBeta[i] *= norm;

    // store new weights
    for (int i = 0; i < weights.size(); i++) weights.set(i, newBeta[i]);

    // compute objective function
    double objective = +suma;
    for (int i = 0; i < grad.length; i++) {
      objective -= weights.get(i) * grad[i];
    }
    debug.println(3, "objective : " + objective + " sumAlpha : " + suma);

    // return objective evolution
    return objective / oldObjective;
  }