Ejemplo n.º 1
0
  @Test
  public void test() {
    int i, j, k, M = 1000, N = 2;
    double a, b, lambda, omega = 1.3, test, yy;
    double[] x = new double[2];
    double[] y = new double[M];
    double[][] data = new double[M][N];
    boolean localflag, globalflag = false;

    // Test Svm
    System.out.println("Testing Svm");

    // Create two disjoint sets of points
    Ran myran = new Ran(17);
    for (i = 0; i < M / 2; i++) {
      y[i] = 1.0;
      a = myran.doub();
      b = 2.0 * myran.doub() - 1.0;
      data[i][0] = 1.0 + (a - b);
      data[i][1] = 1.0 + (a + b);
    }

    for (i = M / 2; i < M; i++) {
      y[i] = -1.0;
      a = myran.doub();
      b = 2.0 * myran.doub() - 1.0;
      data[i][0] = -1.0 - (a - b);
      data[i][1] = -1.0 - (a + b);
    }

    // Linear kernel
    Svmlinkernel linkernel = new Svmlinkernel(data, y);
    Svm linsvm = new Svm(linkernel);
    lambda = 10;
    k = 0;
    do {
      test = linsvm.relax(lambda, omega);
      //      System.out.printf(test);
      k++;
    } while (test > 1.e-3 && k < 100);
    int nerror = 0;
    for (i = 0; i < M; i++) {
      //      if (i%10 == 0) System.out.printf((y[i]==1.0) << " %f\n", (linsvm.predict(i) >= 1.0));
      //      if ((y[i] == 1.0) != (linsvm.predict(i) >= 1.0))
      //        System.out.printf(data[i][0] << " %f\n", data[i][1] << " %f\n", y[i] << " %f\n",
      // linsvm.predict(i));
      nerror += ((y[i] == 1.0) != (linsvm.predict(i) >= 0.0) ? 1 : 0);
    }
    System.out.printf("Errors: %d\n", nerror);

    // Polynomial kernel
    Svmpolykernel polykernel = new Svmpolykernel(data, y, 1.0, 1.0, 2.0);
    Svm polysvm = new Svm(polykernel);
    lambda = 10;
    k = 0;
    do {
      test = polysvm.relax(lambda, omega);
      //      System.out.printf(test);
      k++;
    } while (test > 1.e-3 && k < 100);
    nerror = 0;
    for (i = 0; i < M; i++) {
      nerror += ((y[i] == 1.0) != (polysvm.predict(i) >= 0.0) ? 1 : 0);
    }
    System.out.printf("Errors: %d\n", nerror);

    // Gaussian kernel
    Svmgausskernel gausskernel = new Svmgausskernel(data, y, 1.0);
    Svm gausssvm = new Svm(gausskernel);
    lambda = 10;
    k = 0;
    do {
      test = gausssvm.relax(lambda, omega);
      //      System.out.printf(test);
      k++;
    } while (test > 1.e-3 && k < 100);
    nerror = 0;
    for (i = 0; i < M; i++) {
      nerror += ((y[i] == 1.0) != (gausssvm.predict(i) >= 0.0) ? 1 : 0);
    }
    System.out.printf("Errors: %d\n", nerror);

    // Need to add tests for harder test case and resolve issue that the two
    // support vectors give an erroneous indication for two of the kernels above

    // Example similar to the book
    Normaldev ndev = new Normaldev(0.0, 0.5, 17);
    for (j = 0; j < 4; j++) { // Four quadrants
      for (i = 0; i < M / 4; i++) {
        k = (M / 4) * j + i;
        if (j == 0) {
          y[k] = 1.0;
          data[k][0] = 1.0 + ndev.dev();
          data[k][1] = 1.0 + ndev.dev();
        } else if (j == 1) {
          y[k] = -1.0;
          data[k][0] = -1.0 + ndev.dev();
          data[k][1] = 1.0 + ndev.dev();
        } else if (j == 2) {
          y[k] = 1.0;
          data[k][0] = -1.0 + ndev.dev();
          data[k][1] = -1.0 + ndev.dev();
        } else {
          y[k] = -1.0;
          data[k][0] = 1.0 + ndev.dev();
          data[k][1] = -1.0 + ndev.dev();
        }
      }
    }

    // Linear kernel
    Svmlinkernel linkernel2 = new Svmlinkernel(data, y);
    Svm linsvm2 = new Svm(linkernel2);
    System.out.printf("Errors: ");
    for (lambda = 0.001; lambda < 10000; lambda *= 10) {
      k = 0;
      do {
        test = linsvm2.relax(lambda, omega);
        //        System.out.printf(test);
        k++;
      } while (test > 1.e-3 && k < 100);
      nerror = 0;
      for (i = 0; i < M; i++) {
        nerror += ((y[i] == 1.0) != (linsvm2.predict(i) >= 0.0) ? 1 : 0);
      }
      System.out.printf("%d ", nerror);
      // Test new data
      nerror = 0;
      for (j = 0; j < 4; j++) { // Four quadrants
        for (i = 0; i < M / 4; i++) {
          if (j == 0) {
            yy = 1.0;
            x[0] = 1.0 + ndev.dev();
            x[1] = 1.0 + ndev.dev();
          } else if (j == 1) {
            yy = -1.0;
            x[0] = -1.0 + ndev.dev();
            x[1] = 1.0 + ndev.dev();
          } else if (j == 2) {
            yy = 1.0;
            x[0] = -1.0 + ndev.dev();
            x[1] = -1.0 + ndev.dev();
          } else {
            yy = -1.0;
            x[0] = 1.0 + ndev.dev();
            x[1] = -1.0 + ndev.dev();
          }
          nerror += ((yy == 1.0) != (linsvm2.predict(x) >= 0.0) ? 1 : 0);
        }
      }
      System.out.printf("%d    ", nerror);
    }
    System.out.println();

    // Polynomial kernel
    Svmpolykernel polykernel2 = new Svmpolykernel(data, y, 1.0, 1.0, 4.0);
    Svm polysvm2 = new Svm(polykernel2);
    System.out.printf("Errors: ");
    for (lambda = 0.001; lambda < 10000; lambda *= 10) {
      k = 0;
      do {
        test = polysvm2.relax(lambda, omega);
        //        System.out.printf(test);
        k++;
      } while (test > 1.e-3 && k < 100);
      // Test training set
      nerror = 0;
      for (i = 0; i < M; i++) {
        nerror += ((y[i] == 1.0) != (polysvm2.predict(i) >= 0.0) ? 1 : 0);
      }
      System.out.printf("%d ", nerror);
      // Test new data
      nerror = 0;
      for (j = 0; j < 4; j++) { // Four quadrants
        for (i = 0; i < M / 4; i++) {
          if (j == 0) {
            yy = 1.0;
            x[0] = 1.0 + ndev.dev();
            x[1] = 1.0 + ndev.dev();
          } else if (j == 1) {
            yy = -1.0;
            x[0] = -1.0 + ndev.dev();
            x[1] = 1.0 + ndev.dev();
          } else if (j == 2) {
            yy = 1.0;
            x[0] = -1.0 + ndev.dev();
            x[1] = -1.0 + ndev.dev();
          } else {
            yy = -1.0;
            x[0] = 1.0 + ndev.dev();
            x[1] = -1.0 + ndev.dev();
          }
          nerror += ((yy == 1.0) != (polysvm2.predict(x) >= 0.0) ? 1 : 0);
        }
      }
      System.out.printf("%d    ", nerror);
    }
    System.out.println();

    // Gaussian kernel
    Svmgausskernel gausskernel2 = new Svmgausskernel(data, y, 1.0);
    Svm gausssvm2 = new Svm(gausskernel2);
    System.out.printf("Errors: ");
    for (lambda = 0.001; lambda < 10000; lambda *= 10) {
      k = 0;
      do {
        test = gausssvm2.relax(lambda, omega);
        //        System.out.printf(test);
        k++;
      } while (test > 1.e-3 && k < 100);
      nerror = 0;
      for (i = 0; i < M; i++) {
        nerror += ((y[i] == 1.0) != (gausssvm2.predict(i) >= 0.0) ? 1 : 0);
      }
      System.out.printf("%d ", nerror);
      // Test new data
      nerror = 0;
      for (j = 0; j < 4; j++) { // Four quadrants
        for (i = 0; i < M / 4; i++) {
          if (j == 0) {
            yy = 1.0;
            x[0] = 1.0 + ndev.dev();
            x[1] = 1.0 + ndev.dev();
          } else if (j == 1) {
            yy = -1.0;
            x[0] = -1.0 + ndev.dev();
            x[1] = 1.0 + ndev.dev();
          } else if (j == 2) {
            yy = 1.0;
            x[0] = -1.0 + ndev.dev();
            x[1] = -1.0 + ndev.dev();
          } else {
            yy = -1.0;
            x[0] = 1.0 + ndev.dev();
            x[1] = -1.0 + ndev.dev();
          }
          nerror += ((yy == 1.0) != (gausssvm2.predict(x) >= 0.0) ? 1 : 0);
        }
      }
      System.out.printf("%d    ", nerror);
    }
    System.out.println();

    // Test the algorithm on test data after learning
    // Do a scan over lambda to find best value

    localflag = false;
    globalflag = globalflag || localflag;
    if (localflag) {
      fail("*** Svm: *************************");
    }

    if (globalflag) System.out.println("Failed\n");
    else System.out.println("Passed\n");
  }
Ejemplo n.º 2
0
  @Test
  public void test() {
    int i, j, N = 100, M = 10;
    double pi = acos(-1.0), sumx2, sa = 0, sb = 0, sbeps;
    double[] x = new double[N], y = new double[N], yy = new double[N], sig = new double[N];
    boolean localflag, globalflag = false;

    // Test Fitab
    System.out.println("Testing Fitab");

    Ran myran = new Ran(17);
    sumx2 = 0;
    for (i = 0; i < N; i++) {
      x[i] = 10.0 * myran.doub();
      y[i] = sqrt(2.0) + pi * x[i];
      sig[i] = 1.0;
      sumx2 += SQR(x[i]);
    }

    Fitab fit1 = new Fitab(x, y, sig); // Perfect fit, no noise
    sbeps = 1.e-12;
    localflag = abs(fit1.a - sqrt(2.0)) > sbeps;
    globalflag = globalflag || localflag;
    if (localflag) {
      fail("*** Fitab: Fitted constant term a has incorrect value");
    }

    localflag = abs(fit1.b - pi) > sbeps;
    globalflag = globalflag || localflag;
    if (localflag) {
      fail("*** Fitab: Fitted slope b has incorrect value");
    }

    localflag = fit1.chi2 > sbeps;
    globalflag = globalflag || localflag;
    if (localflag) {
      fail("*** Fitab: Chi^2 not zero for perfect linear data");
    }

    localflag = abs(fit1.q - 1.0) > sbeps;
    globalflag = globalflag || localflag;
    if (localflag) {
      fail("*** Fitab: Probability not 1.0 for perfect linear data");
    }

    localflag = abs(fit1.siga / fit1.sigb - sqrt(sumx2 / N)) > sbeps;
    globalflag = globalflag || localflag;
    if (localflag) {
      fail("*** Fitab: Ratio of siga/sigb incorrect for special case");
    }

    // Test 2
    for (j = 0; j < M; j++) {
      for (i = 0; i < N; i++) {
        sig[i] = 0.1 * (j + 1); // 0.1*(j+1)*sqrt(y[i]);
      }
      Fitab fit = new Fitab(x, y, sig);

      localflag = abs(fit.a - sqrt(2.0)) > sbeps;
      globalflag = globalflag || localflag;
      if (localflag) {
        fail("*** Fitab,Test2: Fitted constant term a has incorrect value");
      }

      localflag = abs(fit.b - pi) > sbeps;
      globalflag = globalflag || localflag;
      if (localflag) {
        fail("*** Fitab,Test2: Fitted slope b has incorrect value");
      }

      localflag = fit.chi2 > sbeps;
      globalflag = globalflag || localflag;
      if (localflag) {
        fail("*** Fitab,Test2: Chi^2 not zero for perfect linear data");
      }

      localflag = abs(fit.q - 1.0) > sbeps;
      globalflag = globalflag || localflag;
      if (localflag) {
        fail("*** Fitab,Test2: Probability not 1.0 for perfect linear data");
      }

      if (j == 0) {
        sa = fit.siga;
        sb = fit.sigb;
      } else {
        localflag = (fit.siga / sa - (j + 1)) > sbeps;
        globalflag = globalflag || localflag;
        if (localflag) {
          fail("*** Fitab,Test2: siga did not scale properly with data errors");
        }

        localflag = (fit.sigb / sb - (j + 1)) > sbeps;
        globalflag = globalflag || localflag;
        if (localflag) {
          fail("*** Fitab,Test2: sigb did not scale properly with data errors");
        }
      }
    }

    // Test 3
    Normaldev ndev = new Normaldev(0.0, 1.0, 17);
    for (j = 0; j < M; j++) {
      for (i = 0; i < N; i++) {
        yy[i] = y[i] + ndev.dev();
        sig[i] = 1.0;
      }
      Fitab fit3 = new Fitab(x, yy, sig);

      //      System.out.printf(fit3.a << " %f\n", fit3.b);
      //      System.out.printf(fit3.siga << " %f\n", fit3.sigb);
      //      System.out.printf(fit3.chi2 << " %f\n", fit3.q << endl);

      localflag = abs(fit3.a - sqrt(2.0)) > 3.0 * fit3.siga;
      globalflag = globalflag || localflag;
      if (localflag) {
        fail("*** Fitab,Test3: Fitted constant term a, or error siga, may be incorrect");
      }

      localflag = abs(fit3.b - pi) > 3.0 * fit3.sigb;
      globalflag = globalflag || localflag;
      if (localflag) {
        fail("*** Fitab,Test3: Fitted slope b, or error sigb, may be incorrect");
      }

      localflag = fit3.chi2 > 1.3 * N;
      globalflag = globalflag || localflag;
      if (localflag) {
        fail("*** Fitab,Test3: Chi^2 is unexpectedly high");
      }

      localflag = abs(fit3.q) < 0.1;
      globalflag = globalflag || localflag;
      if (localflag) {
        fail("*** Fitab,Test3: Probability q suggests a possibly bad fit");
      }
    }

    if (globalflag) System.out.println("Failed\n");
    else System.out.println("Passed\n");
  }