Ejemplo n.º 1
0
  public static void main(String[] args) {
    GoldenSectionLineSearch min = new GoldenSectionLineSearch(true, 0.00001, 0.001, 121.0);
    Function<Double, Double> f1 =
        new Function<Double, Double>() {
          public Double apply(Double x) {
            return Math.log(x * x - x + 1);
          }
        };
    System.out.println(min.minimize(f1));
    System.out.println();

    min = new GoldenSectionLineSearch(false, 0.00001, 0.0, 1.0);
    Function<Double, Double> f2 =
        new Function<Double, Double>() {
          public Double apply(Double x) {
            // this function used to fail in Galen's version; min should be 0.2
            // return - x * (2 * x - 1) * (x - 0.8);
            // this function fails if you don't find an initial bracketing
            return x < 0.1 ? 0.0 : (x > 0.2 ? 0.0 : (x - 0.1) * (x - 0.2));
            // return - Math.sin(x * Math.PI);
            // return -(3 + 6 * x - 4 * x * x);
          }
        };

    System.out.println(min.minimize(f2));
  } // end main
  private void tuneSigma(final int[][] data, final int[] labels) {

    Function<Double, Double> CVSigmaToPerplexity =
        trialSigma -> {
          double score = 0.0;
          double sumScore = 0.0;
          int foldSize, nbCV;
          System.err.println("Trying sigma = " + trialSigma);
          // test if enough training data
          if (data.length >= folds) {
            foldSize = data.length / folds;
            nbCV = folds;
          } else { // leave-one-out
            foldSize = 1;
            nbCV = data.length;
          }

          for (int j = 0; j < nbCV; j++) {
            // System.out.println("CV j: "+ j);
            int testMin = j * foldSize;
            int testMax = testMin + foldSize;

            LinearClassifier<L, F> c =
                new LinearClassifier<>(
                    weights(data, labels, testMin, testMax, trialSigma, foldSize),
                    featureIndex,
                    labelIndex);
            for (int i = testMin; i < testMax; i++) {
              // System.out.println("test i: "+ i + " "+ new
              // BasicDatum(featureIndex.objects(data[i])));
              score -=
                  c.logProbabilityOf(new BasicDatum<>(featureIndex.objects(data[i])))
                      .getCount(labelIndex.get(labels[i]));
            }
            // System.err.printf("%d: %8g%n", j, score);
            sumScore += score;
          }
          System.err.printf(": %8g%n", sumScore);
          return sumScore;
        };

    GoldenSectionLineSearch gsls = new GoldenSectionLineSearch(true);
    sigma = gsls.minimize(CVSigmaToPerplexity, 0.01, 0.0001, 2.0);
    System.out.println("Sigma used: " + sigma);
  }