Пример #1
0
 public double getObjective(
     AbstractStochasticCachingDiffUpdateFunction function,
     double[] w,
     double wscale,
     int[] sample) {
   double wnorm = getNorm(w) * wscale * wscale;
   double obj = function.valueAt(w, wscale, sample);
   // Calculate objective with L2 regularization
   return obj + 0.5 * sample.length * lambda * wnorm;
 }
Пример #2
0
 /**
  * Finds a good learning rate to start with. eta = 1/(lambda*(t0+t)) - we find good t0
  *
  * @param function
  * @param initial
  * @param sampleSize
  * @param seta
  */
 public double tune(
     AbstractStochasticCachingDiffUpdateFunction function,
     double[] initial,
     int sampleSize,
     double seta) {
   Timing timer = new Timing();
   int[] sample = function.getSample(sampleSize);
   double sobj = getObjective(function, initial, 1, sample);
   double besteta = 1;
   double bestobj = sobj;
   double eta = seta;
   int totest = 10;
   double factor = 2;
   boolean phase2 = false;
   while (totest > 0 || !phase2) {
     double obj = tryEta(function, initial, sample, eta);
     boolean okay = (obj < sobj);
     sayln("  Trying eta=" + eta + "  obj=" + obj + ((okay) ? "(possible)" : "(too large)"));
     if (okay) {
       totest -= 1;
       if (obj < bestobj) {
         bestobj = obj;
         besteta = eta;
       }
     }
     if (!phase2) {
       if (okay) {
         eta = eta * factor;
       } else {
         phase2 = true;
         eta = seta;
       }
     }
     if (phase2) {
       eta = eta / factor;
     }
   }
   // take it on the safe side (implicit regularization)
   besteta /= factor;
   // determine t
   t0 = (int) (1 / (besteta * lambda));
   sayln("  Taking eta=" + besteta + " t0=" + t0);
   sayln("  Tuning completed in: " + Timing.toSecondsString(timer.report()) + " s");
   return besteta;
 }
Пример #3
0
 public double tryEta(
     AbstractStochasticCachingDiffUpdateFunction function,
     double[] initial,
     int[] sample,
     double eta) {
   int numBatches = sample.length / bSize;
   double[] w = new double[initial.length];
   double wscale = 1;
   System.arraycopy(initial, 0, w, 0, w.length);
   int[] sampleBatch = new int[bSize];
   int sampleIndex = 0;
   for (int batch = 0; batch < numBatches; batch++) {
     for (int i = 0; i < bSize; i++) {
       sampleBatch[i] = sample[(sampleIndex + i) % sample.length];
     }
     sampleIndex += bSize;
     double gain = eta / wscale;
     function.calculateStochasticUpdate(w, wscale, sampleBatch, gain);
     wscale *= (1 - eta * lambda * bSize);
   }
   double obj = getObjective(function, w, wscale, sample);
   return obj;
 }
Пример #4
0
  @Override
  public double[] minimize(
      Function f, double functionTolerance, double[] initial, int maxIterations) {
    if (!(f instanceof AbstractStochasticCachingDiffUpdateFunction)) {
      throw new UnsupportedOperationException();
    }
    AbstractStochasticCachingDiffUpdateFunction function =
        (AbstractStochasticCachingDiffUpdateFunction) f;
    if (function instanceof LogConditionalObjectiveFunction) {
      if (((LogConditionalObjectiveFunction) function).parallelGradientCalculation) {
        System.err.println(
            "\n*********\nNoting that HogWild optimization requested.\nSetting batch size = data size to minimize thread creation overhead.\nResults *should* be identical on sparse problems.\nDisable parallelGradientComputation flag in LogConditionalObjectiveFunction, or run with -threads 1 to disable.\nAlso can use another Minimizer if parallel computation is desired, but HogWild isn't delivering good results.\n*********\n");
        bSize = function.dataDimension();
      }
    }
    int totalSamples = function.dataDimension();
    int tuneSampleSize = Math.min(totalSamples, tuningSamples);
    if (tuneSampleSize < tuningSamples) {
      System.err.println(
          "WARNING: Total number of samples="
              + totalSamples
              + " is smaller than requested tuning sample size="
              + tuningSamples
              + "!!!");
    }
    lambda = 1.0 / (sigma * totalSamples);
    sayln("Using sigma=" + sigma + " lambda=" + lambda + " tuning sample size " + tuneSampleSize);
    // tune(function, initial, tuneSampleSize, 0.1);
    t0 = (int) (1 / (0.1 * lambda));

    x = new double[initial.length];
    System.arraycopy(initial, 0, x, 0, x.length);
    xscale = 1;
    xnorm = getNorm(x);
    int numBatches = totalSamples / bSize;

    init(function);

    boolean have_max = (maxIterations > 0 || numPasses > 0);

    if (!have_max) {
      throw new UnsupportedOperationException(
          "No maximum number of iterations has been specified.");
    } else {
      maxIterations = Math.max(maxIterations, numPasses) * numBatches;
    }

    sayln("       Batch size of: " + bSize);
    sayln("       Data dimension of: " + totalSamples);
    sayln("       Batches per pass through data:  " + numBatches);
    sayln("       Number of passes is = " + numPasses);
    sayln("       Max iterations is = " + maxIterations);

    // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    //            Loop
    // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    Timing total = new Timing();
    Timing current = new Timing();
    total.start();
    current.start();
    int t = t0;
    int iters = 0;
    for (int pass = 0; pass < numPasses; pass++) {
      boolean doEval = (pass > 0 && evaluateIters > 0 && pass % evaluateIters == 0);
      if (doEval) {
        rescale();
        doEvaluation(x);
      }

      double totalValue = 0;
      double lastValue = 0;
      say("Iter: " + iters + " pass " + pass + " batch 1 ... ");
      for (int batch = 0; batch < numBatches; batch++) {
        iters++;

        // Get the next X
        double eta = 1 / (lambda * t);
        double gain = eta / xscale;
        lastValue = function.calculateStochasticUpdate(x, xscale, bSize, gain);
        totalValue += lastValue;
        // weight decay (for L2 regularization)
        xscale *= (1 - eta * lambda * bSize);
        t += bSize;
      }
      if (xscale < 1e-6) {
        rescale();
      }
      try {
        ArrayMath.assertFinite(x, "x");
      } catch (ArrayMath.InvalidElementException e) {
        System.err.println(e.toString());
        for (int i = 0; i < x.length; i++) {
          x[i] = Double.NaN;
        }
        break;
      }
      xnorm = getNorm(x) * xscale * xscale;
      // Calculate loss based on L2 regularization
      double loss = totalValue + 0.5 * xnorm * lambda * totalSamples;
      say(String.valueOf(numBatches));
      say("[" + (total.report()) / 1000.0 + " s ");
      say("{" + (current.restart() / 1000.0) + " s}] ");
      sayln(" " + lastValue + ' ' + totalValue + ' ' + loss);

      if (iters >= maxIterations) {
        sayln("Stochastic Optimization complete.  Stopped after max iterations");
        break;
      }

      if (total.report() >= maxTime) {
        sayln("Stochastic Optimization complete.  Stopped after max time");
        break;
      }
    }
    rescale();

    if (evaluateIters > 0) {
      // do final evaluation
      doEvaluation(x);
    }

    sayln("Completed in: " + Timing.toSecondsString(total.report()) + " s");

    return x;
  }