示例#1
0
 public final double loss(double u, double a, Loss loss) {
   assert loss.isForNumeric() : "Loss function " + loss + " not applicable to numerics";
   switch (loss) {
     case Quadratic:
       return (u - a) * (u - a);
     case Absolute:
       return Math.abs(u - a);
     case Huber:
       return Math.abs(u - a) <= 1 ? 0.5 * (u - a) * (u - a) : Math.abs(u - a) - 0.5;
     case Poisson:
       assert a >= 0 : "Poisson loss L(u,a) requires variable a >= 0";
       return Math.exp(u)
           + (a == 0 ? 0 : -a * u + a * Math.log(a) - a); // Since \lim_{a->0} a*log(a) = 0
     case Hinge:
       // return Math.max(1-a*u,0);
       return Math.max(1 - (a == 0 ? -u : u), 0); // Booleans are coded {0,1} instead of {-1,1}
     case Logistic:
       // return Math.log(1 + Math.exp(-a * u));
       return Math.log(
           1 + Math.exp(a == 0 ? u : -u)); // Booleans are coded {0,1} instead of {-1,1}
     case Periodic:
       return 1 - Math.cos((a - u) * (2 * Math.PI) / _period);
     default:
       throw new RuntimeException("Unknown loss function " + loss);
   }
 }
示例#2
0
    public final double regularize(double[] u, Regularizer regularization) {
      if (u == null) return 0;
      double ureg = 0;

      switch (regularization) {
        case None:
          return 0;
        case Quadratic:
          for (int i = 0; i < u.length; i++) ureg += u[i] * u[i];
          return ureg;
        case L2:
          for (int i = 0; i < u.length; i++) ureg += u[i] * u[i];
          return Math.sqrt(ureg);
        case L1:
          for (int i = 0; i < u.length; i++) ureg += Math.abs(u[i]);
          return ureg;
        case NonNegative:
          for (int i = 0; i < u.length; i++) {
            if (u[i] < 0) return Double.POSITIVE_INFINITY;
          }
          return 0;
        case OneSparse:
          int card = 0;
          for (int i = 0; i < u.length; i++) {
            if (u[i] < 0) return Double.POSITIVE_INFINITY;
            else if (u[i] > 0) card++;
          }
          return card == 1 ? 0 : Double.POSITIVE_INFINITY;
        case UnitOneSparse:
          int ones = 0, zeros = 0;
          for (int i = 0; i < u.length; i++) {
            if (u[i] == 1) ones++;
            else if (u[i] == 0) zeros++;
            else return Double.POSITIVE_INFINITY;
          }
          return ones == 1 && zeros == u.length - 1 ? 0 : Double.POSITIVE_INFINITY;
        case Simplex:
          double sum = 0, absum = 0;
          for (int i = 0; i < u.length; i++) {
            if (u[i] < 0) return Double.POSITIVE_INFINITY;
            else {
              sum += u[i];
              absum += Math.abs(u[i]);
            }
          }
          return MathUtils.equalsWithinRecSumErr(sum, 1.0, u.length, absum)
              ? 0
              : Double.POSITIVE_INFINITY;
        default:
          throw new RuntimeException("Unknown regularization function " + regularization);
      }
    }
示例#3
0
 public final double lgrad(double u, double a, Loss loss) {
   assert loss.isForNumeric() : "Loss function " + loss + " not applicable to numerics";
   switch (loss) {
     case Quadratic:
       return 2 * (u - a);
     case Absolute:
       return Math.signum(u - a);
     case Huber:
       return Math.abs(u - a) <= 1 ? u - a : Math.signum(u - a);
     case Poisson:
       assert a >= 0 : "Poisson loss L(u,a) requires variable a >= 0";
       return Math.exp(u) - a;
     case Hinge:
       // return a*u <= 1 ? -a : 0;
       return a == 0
           ? (-u <= 1 ? 1 : 0)
           : (u <= 1 ? -1 : 0); // Booleans are coded as {0,1} instead of {-1,1}
     case Logistic:
       // return -a/(1+Math.exp(a*u));
       return a == 0
           ? 1 / (1 + Math.exp(-u))
           : -1 / (1 + Math.exp(u)); // Booleans are coded as {0,1} instead of {-1,1}
     case Periodic:
       return ((2 * Math.PI) / _period) * Math.sin((a - u) * (2 * Math.PI) / _period);
     default:
       throw new RuntimeException("Unknown loss function " + loss);
   }
 }