@Override public Label predict(Instance instance) { Label l = null; if (instance.getLabel() instanceof ClassificationLabel || instance.getLabel() == null) { // ----------------- declare variables ------------------ double lambda = 0.0; RealVector x_instance = new ArrayRealVector(matrixX.getColumnDimension(), 0); double result = 0.0; // -------------------------- initialize xi ------------------------- for (int idx = 0; idx < matrixX.getColumnDimension(); idx++) { x_instance.setEntry(idx, instance.getFeatureVector().get(idx + 1)); } // ------------------ get lambda ----------------------- for (int j = 0; j < alpha.getDimension(); j++) { lambda += alpha.getEntry(j) * kernelFunction(matrixX.getRowVector(j), x_instance); } // ----------------- make prediction ----------------- Sigmoid g = new Sigmoid(); // helper function result = g.value(lambda); l = new ClassificationLabel(result < 0.5 ? 0 : 1); } else { System.out.println("label type error!"); } return l; }
@Override public void train(List<Instance> instances) { // ------------------------ initialize rows and columns --------------------- int rows = instances.size(); int columns = 0; // get max columns for (Instance i : instances) { int localColumns = Collections.max(i.getFeatureVector().getFeatureMap().keySet()); if (localColumns > columns) columns = localColumns; } // ------------------------ initialize alpha vector ----------------------- alpha = new ArrayRealVector(rows, 0); // ------------------------ initialize base X and Y for use -------------------------- double[][] X = new double[rows][columns]; double[] Y = new double[rows]; for (int i = 0; i < rows; i++) { Y[i] = ((ClassificationLabel) instances.get(i).getLabel()).getLabelValue(); for (int j = 0; j < columns; j++) { X[i][j] = instances.get(i).getFeatureVector().get(j + 1); } } // ---------------------- gram matrix ------------------- matrixX = new Array2DRowRealMatrix(X); RealMatrix gram = new Array2DRowRealMatrix(rows, rows); for (int i = 0; i < rows; i++) { for (int j = 0; j < rows; j++) { gram.setEntry(i, j, kernelFunction(matrixX.getRowVector(i), matrixX.getRowVector(j))); } } // ---------------------- gradient ascent -------------------------- Sigmoid g = new Sigmoid(); // helper function System.out.println("Training start..."); System.out.println( "Learning rate: " + _learning_rate + " Training times: " + _training_iterations); for (int idx = 0; idx < _training_iterations; idx++) { System.out.println("Training iteration: " + (idx + 1)); for (int k = 0; k < rows; k++) { double gradient_ascent = 0.0; RealVector alpha_gram = gram.operate(alpha); for (int i = 0; i < rows; i++) { double lambda = alpha_gram.getEntry(i); double kernel = gram.getEntry(i, k); gradient_ascent = gradient_ascent + Y[i] * g.value(-lambda) * kernel + (1 - Y[i]) * g.value(lambda) * (-kernel); } alpha.setEntry(k, alpha.getEntry(k) + _learning_rate * gradient_ascent); } } System.out.println("Training done!"); }