예제 #1
0
 public static int mimpute(double[] u, Loss multi_loss) {
   assert multi_loss.isForCategorical()
       : "Loss function " + multi_loss + " not applicable to categoricals";
   switch (multi_loss) {
     case Categorical:
     case Ordinal:
       double[] cand = new double[u.length];
       for (int a = 0; a < cand.length; a++) cand[a] = mloss(u, a, multi_loss);
       return ArrayUtils.minIndex(cand);
     default:
       throw new RuntimeException("Unknown multidimensional loss function " + multi_loss);
   }
 }
예제 #2
0
    public static double[] mlgrad(double[] u, int a, Loss multi_loss) {
      assert multi_loss.isForCategorical()
          : "Loss function " + multi_loss + " not applicable to categoricals";
      if (a < 0 || a > u.length - 1)
        throw new IllegalArgumentException(
            "Index must be between 0 and " + String.valueOf(u.length - 1));

      double[] grad = new double[u.length];
      switch (multi_loss) {
        case Categorical:
          for (int i = 0; i < u.length; i++) grad[i] = (1 + u[i] > 0) ? 1 : 0;
          grad[a] = (1 - u[a] > 0) ? -1 : 0;
          return grad;
        case Ordinal:
          for (int i = 0; i < u.length - 1; i++) grad[i] = (a > i && 1 - u[i] > 0) ? -1 : 0;
          return grad;
        default:
          throw new RuntimeException("Unknown multidimensional loss function " + multi_loss);
      }
    }
예제 #3
0
    public static double mloss(double[] u, int a, Loss multi_loss) {
      assert multi_loss.isForCategorical()
          : "Loss function " + multi_loss + " not applicable to categoricals";
      if (a < 0 || a > u.length - 1)
        throw new IllegalArgumentException(
            "Index must be between 0 and " + String.valueOf(u.length - 1));

      double sum = 0;
      switch (multi_loss) {
        case Categorical:
          for (int i = 0; i < u.length; i++) sum += Math.max(1 + u[i], 0);
          sum += Math.max(1 - u[a], 0) - Math.max(1 + u[a], 0);
          return sum;
        case Ordinal:
          for (int i = 0; i < u.length - 1; i++) sum += Math.max(a > i ? 1 - u[i] : 1, 0);
          return sum;
        default:
          throw new RuntimeException("Unknown multidimensional loss function " + multi_loss);
      }
    }