@Override public void train(List<TrainingSample<T>> l) { long tim = System.currentTimeMillis(); debug.println( 2, "training on " + listOfKernels.size() + " kernels and " + l.size() + " examples"); // 1. init kernels ArrayList<SimpleCacheKernel<T>> kernels = new ArrayList<SimpleCacheKernel<T>>(); ArrayList<Double> weights = new ArrayList<Double>(); // normalize to cst trace and init weights to 1/N for (int i = 0; i < listOfKernels.size(); i++) { SimpleCacheKernel<T> sck = new SimpleCacheKernel<T>(listOfKernels.get(i), l); sck.setName(listOfKernels.get(i).toString()); double[][] matrix = sck.getKernelMatrix(l); // compute trace double trace = 0.; for (int x = 0; x < matrix.length; x++) { trace += matrix[x][x]; } // divide by trace for (int x = 0; x < matrix.length; x++) for (int y = x; y < matrix.length; y++) { matrix[x][y] *= matrix.length / (double) trace; matrix[y][x] = matrix[x][y]; } kernels.add(sck); weights.add(Math.pow(1 / (double) listOfKernels.size(), 1 / (double) p_norm)); debug.println(3, "kernel : " + sck + " weight : " + weights.get(i)); } // 1 train first svm ThreadedSumKernel<T> tsk = new ThreadedSumKernel<T>(); for (int i = 0; i < kernels.size(); i++) tsk.addKernel(kernels.get(i), weights.get(i)); if (svm == null) svm = new SMOSVM<T>(null); svm.setKernel(tsk); svm.setC(C); svm.train(l); // 2. big loop double gap = 0; long max = 100000; // less than 10k iterations do { debug.println(3, "weights : " + weights); // compute sum kernel tsk = new ThreadedSumKernel<T>(); for (int i = 0; i < kernels.size(); i++) tsk.addKernel(kernels.get(i), weights.get(i)); // train svm svm.setKernel(tsk); svm.train(l); // compute sum of example weights and gradient direction double suma = computeSumAlpha(); double[] grad = computeGradBeta(kernels, l); // perform one step double objEvol = performMKLStep(suma, grad, kernels, weights, l); if (objEvol < 0) { debug.println(1, "Error, performMKLStep return wrong value"); System.exit(0); ; } gap = 1 - objEvol; // compute norm double norm = 0; for (int i = 0; i < weights.size(); i++) norm += Math.pow(weights.get(i), p_norm); norm = Math.pow(norm, -1 / (double) p_norm); debug.println(1, "objective_gap : " + gap + " norm : " + norm); max--; } while (gap >= stopGap && max > 0); // 3. save weights listOfKernelWeights.clear(); listOfKernelWeights.addAll(weights); // 4. retrain svm // compute sum kernel tsk = new ThreadedSumKernel<T>(); for (int i = 0; i < kernels.size(); i++) tsk.addKernel(listOfKernels.get(i), listOfKernelWeights.get(i)); // train svm svm.setKernel(tsk); svm.train(l); // 5. save examples weights listOfExamples = new ArrayList<TrainingSample<T>>(); listOfExamples.addAll(l); listOfExampleWeights.clear(); for (double d : svm.getAlphas()) listOfExampleWeights.add(d); debug.println(1, "MKL trained in " + (System.currentTimeMillis() - tim) + " milis."); }