public static int mimpute(double[] u, Loss multi_loss) { assert multi_loss.isForCategorical() : "Loss function " + multi_loss + " not applicable to categoricals"; switch (multi_loss) { case Categorical: case Ordinal: double[] cand = new double[u.length]; for (int a = 0; a < cand.length; a++) cand[a] = mloss(u, a, multi_loss); return ArrayUtils.minIndex(cand); default: throw new RuntimeException("Unknown multidimensional loss function " + multi_loss); } }
public static double[] mlgrad(double[] u, int a, Loss multi_loss) { assert multi_loss.isForCategorical() : "Loss function " + multi_loss + " not applicable to categoricals"; if (a < 0 || a > u.length - 1) throw new IllegalArgumentException( "Index must be between 0 and " + String.valueOf(u.length - 1)); double[] grad = new double[u.length]; switch (multi_loss) { case Categorical: for (int i = 0; i < u.length; i++) grad[i] = (1 + u[i] > 0) ? 1 : 0; grad[a] = (1 - u[a] > 0) ? -1 : 0; return grad; case Ordinal: for (int i = 0; i < u.length - 1; i++) grad[i] = (a > i && 1 - u[i] > 0) ? -1 : 0; return grad; default: throw new RuntimeException("Unknown multidimensional loss function " + multi_loss); } }
public static double mloss(double[] u, int a, Loss multi_loss) { assert multi_loss.isForCategorical() : "Loss function " + multi_loss + " not applicable to categoricals"; if (a < 0 || a > u.length - 1) throw new IllegalArgumentException( "Index must be between 0 and " + String.valueOf(u.length - 1)); double sum = 0; switch (multi_loss) { case Categorical: for (int i = 0; i < u.length; i++) sum += Math.max(1 + u[i], 0); sum += Math.max(1 - u[a], 0) - Math.max(1 + u[a], 0); return sum; case Ordinal: for (int i = 0; i < u.length - 1; i++) sum += Math.max(a > i ? 1 - u[i] : 1, 0); return sum; default: throw new RuntimeException("Unknown multidimensional loss function " + multi_loss); } }