public static void compute_gradient2(int iteration) { double[][][] tmp_grad_h_hat_prime_s = new double[T - t0][n][K]; /* * compute * nti[t][i] = \sum_{j} { n_{ij} } * and * nti_h[t][j][k] = \sum_{i} { n_{ij}^{t} h_{ik}^{t} } */ double[][] nti = new double[T - t0][n]; double[][][] nti_h = new double[T - t0][n][K]; for (int t = 0; t < T - t0; t++) { double[][] G_t = GS.get(t); double[][] h_t = h_s.get(t); // h^{t} for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) { nti[t][i] += G_t[i][j]; for (int k = 0; k < K; k++) { nti_h[t][j][k] += G_t[i][j] * h_t[i][k]; } } } for (int t = 0; t < T - t0; t++) { double delta_t = delta_prime_s.get(t); double[][] h_t = h_s.get(t); // h^{t} double[][] h_hat_prime_t = h_hat_prime_s.get(t); // \hat{h}^{t} double[][] mu_hat_t = mu_hat_s.get(t); // \hat{\mu}^{t} double[][] mu_hat_prime_t = mu_hat_prime_s.get(t); // \hat{\mu}'^{t} double[][] h_prime_t = h_prime_s.get(t); if (t != 0) { Matrix a = new Matrix(AS.get(t - 1)); Matrix hprime_pre_t = new Matrix(h_prime_s.get(t - 1)); Matrix ave_neighbors = a.times(hprime_pre_t); double[][] G_pre_t = GS.get(t - 1); // G^{t-1} double[][] A_pre_t = AS.get(t - 1); // A^{t-1} double[][] h_pre_t = h_s.get(t - 1); // h^{t-1} double[][] mu_hat_prime_pre_t = mu_hat_prime_s.get(t - 1); // \hat{\mu}'^{t-1} [t] for (int s = 0; s < T - t0; s++) { double[][] grad_mu_hat_prime_t = grad_mu_hat_prime_s.get(t * (T - t0) + s); double[][] grad_mu_hat_prime_pre_t = grad_mu_hat_prime_s.get((t - 1) * (T - t0) + s); double[] h2delta2 = new double[n]; for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { h2delta2[i] += 0.5 * h_t[i][k] * h_t[i][k] * delta_t * delta_t; } /* compute weighted_exp for later use */ double[][][] weighted_exp_num = new double[K][n][n]; double[][] weighted_exp_den = new double[K][n]; double[][][] weighted_exp = new double[K][n][n]; for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) { double h_muhp = Operations.inner_product(h_t[j], mu_hat_prime_t[i], K); for (int k = 0; k < K; k++) { weighted_exp_num[k][i][j] = h_t[j][k] * Math.exp(h_muhp + h2delta2[j]); weighted_exp_den[k][j] += Math.exp(h_muhp + h2delta2[j]); } } for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) for (int k = 0; k < K; k++) { weighted_exp[k][i][j] = weighted_exp_num[k][i][j] / weighted_exp_den[k][j]; } /* compute sum_mu_hat_prime for later use */ double[] sum_mu_hat_prime = new double[K]; for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { sum_mu_hat_prime[k] += mu_hat_prime_pre_t[i][k]; } for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { /* first term */ double g1 = nti_h[t][i][k] * grad_mu_hat_prime_t[i][k]; tmp_grad_h_hat_prime_s[s][i][k] += g1; /* second term */ double g2 = 0; for (int j = 0; j < n; j++) { g2 -= nti[t][j] * weighted_exp[k][i][j] * grad_mu_hat_prime_t[i][k]; } tmp_grad_h_hat_prime_s[s][i][k] += g2; /* third term */ for (int j = 0; j < n; j++) if (G_pre_t[j][i] != 0) { // double g3 = ( h_t[j][k] - (1-lambda) * h_pre_t[j][k] - lambda * // A_pre_t[j][i] * sum_mu_hat_prime[k] ) double g3 = (h_t[j][k] - (1 - lambda) * h_pre_t[j][k] - lambda * A_pre_t[j][i] * mu_hat_prime_pre_t[i][k]) * lambda * A_pre_t[j][i] * grad_mu_hat_prime_pre_t[i][k] / (sigma * sigma); tmp_grad_h_hat_prime_s[s][j][k] += g3; // j instead of i! } } /* fourth term */ for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { double g4 = -(mu_hat_prime_t[i][k] - mu_hat_prime_pre_t[i][k]) * (grad_mu_hat_prime_t[i][k] - grad_mu_hat_prime_pre_t[i][k]) / (sigma * sigma); tmp_grad_h_hat_prime_s[s][i][k] += g4; } } } else { /* for (int s = 0; s < T-t0; s++) { double[] grad_mu_hat_prime_t = grad_mu_hat_prime_s.get(t * (T-t0) + s); for (int i = 0; i < n; i++) { // first term double g1 = nti_hp[t][i] * grad_mu_hat_prime_t[i]; tmp_grad_h_hat_prime_s[s][i] += g1; // second term double g2 = 0; for (int _j = 0; _j < NEG; _j++) { double weighted_exp_num = 0, weighted_exp_den = 0; int j = neg_samples.get(t)[i][_j]; double htj = h_t[j][0]; double muhti = mu_hat_t[i]; weighted_exp_num += htj * Math.exp(htj * muhti + 0.5 * htj * htj * delta_t * delta_t); for (int _k = 0; _k < NEG; _k++) { int k = neg_samples.get(t)[i][_k]; double muhtk = mu_hat_t[k]; weighted_exp_den += Math.exp(htj * muhtk + 0.5 * htj * htj * delta_t * delta_t); } g2 -= nti[t][j] * weighted_exp_num / weighted_exp_den * grad_mu_hat_prime_t[i]; } tmp_grad_h_hat_prime_s[s][i] += g2; } // fourth term (if any) if (s == t) for (int i = 0; i < n; i++) { double g4 = -h_hat_prime_t[i][0] / (sigma*sigma); tmp_grad_h_hat_prime_s[s][i] += g4; } } */ } } /* update global gradient */ for (int t = 0; t < T - t0; t++) { double[][] grad = new double[n][K]; for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { grad[i][k] = tmp_grad_h_hat_prime_s[t][i][k]; } grad_h_hat_prime_s.set(t, grad); } FileParser.output_2d(grad_h_hat_prime_s, "./grad/grad_prime_" + iteration + ".txt"); return; }
public static void compute_gradient1(int iteration) { double[][][] tmp_grad_h_hat_s = new double[T - t0][n][K]; for (int t = 0; t < T - t0; t++) { // System.out.println("compute gradient 1, t = " + t); double delta_t = delta_s.get(t); double[][] G_t = GS.get(t); double[][] h_prime_t = h_prime_s.get(t); double[][] mu_hat_t = mu_hat_s.get(t); if (t != 0) { double[][] mu_hat_pre_t = mu_hat_s.get(t - 1); Matrix a = new Matrix(AS.get(t - 1)); Matrix hprime_pre_t = new Matrix(h_prime_s.get(t - 1)); Matrix ave_neighbors = a.times(hprime_pre_t); /* TODO: check whether we can save computation by comparing s and t */ for (int s = 0; s < T - t0; s++) { double[][] grad_hat_t = grad_mu_hat_s.get(t * (T - t0) + s); double[][] grad_hat_pre_t = grad_mu_hat_s.get((t - 1) * (T - t0) + s); double[] hp2delta2 = new double[n]; for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { hp2delta2[i] += 0.5 * h_prime_t[i][k] * h_prime_t[i][k] * delta_t * delta_t; } for (int i = 0; i < n; i++) { /* first term */ double[] weighted_exp_num = new double[K]; double weighted_exp_den = 0; for (int l = 0; l < n; l++) { double hp_muh = Operations.inner_product(h_prime_t[l], mu_hat_t[i], K); double e = Math.exp(hp_muh + hp2delta2[l]); if (Double.isNaN(e)) { /* check if e explodes */ System.out.println("ERROR2"); Scanner sc = new Scanner(System.in); int gu; gu = sc.nextInt(); } for (int k = 0; k < K; k++) { weighted_exp_num[k] += h_prime_t[l][k] * e; weighted_exp_den += e; } } for (int j = 0; j < n; j++) for (int k = 0; k < K; k++) { double weighted_exp = weighted_exp_num[k] / weighted_exp_den; double gi1 = G_t[i][j] * grad_hat_t[i][k] * (h_prime_t[j][k] - weighted_exp); tmp_grad_h_hat_s[s][i][k] += gi1; } /* second term */ for (int k = 0; k < K; k++) { double gi2 = -(mu_hat_t[i][k] - (1 - lambda) * mu_hat_pre_t[i][k] - lambda * ave_neighbors.get(i, k)) * (grad_hat_t[i][k] - (1 - lambda) * grad_hat_pre_t[i][k]) / (sigma * sigma); tmp_grad_h_hat_s[s][i][k] += gi2; } } } } else { /* no such term (t=0) in ELBO */ /* for (int s = 0; s < T-t0; s++) { double[] grad_hat_t = grad_mu_hat_s.get(t * (T-t0) + s); for (int i = 0; i < n; i++) { double n_it = 0; for (int j = 0; j < n; j++) n_it += G_t[i][j]; // first term double gi1 = -mu_hat_t[i] * grad_hat_t[i] / (sigma * sigma); tmp_grad_h_hat_s[s][i] += gi1; // second term double gi2 = 0; double weighted_exp_num = 0, weighted_exp_den = 0; for (int j = 0; j < NEG; j++) { int l = neg_samples.get(t)[i][j]; double hpl = h_prime_t[l][0]; double muit = mu_hat_t[i]; double e = Math.exp(hpl * muit + 0.5 * hpl * hpl * delta_t * delta_t); // TODO: check if e explodes if (Double.isNaN(e)) { System.out.println("ERROR3"); Scanner sc = new Scanner(System.in); int gu; gu = sc.nextInt(); } weighted_exp_num += hpl * e; weighted_exp_den += e; } double weighted_exp = weighted_exp_num / weighted_exp_den; for (int j = 0; j < n; j++) { gi2 += G_t[i][j] * grad_hat_t[i] * (h_prime_t[j][0] - weighted_exp); } tmp_grad_h_hat_s[s][i] += gi2; } } */ } /* end if-else */ } /* update global gradient */ for (int t = 0; t < T - t0; t++) { double[][] grad = new double[n][K]; for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { grad[i][k] = tmp_grad_h_hat_s[t][i][k]; } grad_h_hat_s.set(t, grad); } FileParser.output_2d(grad_h_hat_s, "./grad/grad_" + iteration + ".txt"); return; }
/** compute_objective2: return the lower bound when h is fixed */ public static double compute_objective2() { double res = 0; for (int t = 0; t < T - t0; t++) { if (t != 0) { double[][] G_t = GS.get(t); double[][] G_t_pre = GS.get(t - 1); double[][] h_t = h_s.get(t); double[][] h_pre_t = h_s.get(t - 1); double[][] mu_hat_prime_t = mu_hat_prime_s.get(t); double[][] mu_hat_prime_pre_t = mu_hat_prime_s.get(t - 1); double delta_t = delta_s.get(t); double[][] a_pre = AS.get(t - 1); double[][] ave_neighbors = new double[n][K]; for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) if (G_t_pre[i][j] != 0) { for (int k = 0; k < K; k++) { ave_neighbors[i][k] += a_pre[i][j] * mu_hat_prime_pre_t[j][k]; } } for (int i = 0; i < n; i++) { /* first term */ double h2delta2 = 0; for (int k = 0; k < K; k++) { h2delta2 += 0.5 * h_t[i][k] * h_t[i][k] * delta_t * delta_t; } List<Double> powers = new ArrayList<Double>(); for (int l = 0; l < n; l++) { double h_muhp = Operations.inner_product(h_t[i], mu_hat_prime_t[l], K); powers.add(h_muhp + h2delta2); } double lse = log_sum_exp(powers); for (int j = 0; j < n; j++) if (G_t[i][j] != 0) { double h_muhp = Operations.inner_product(h_t[i], mu_hat_prime_t[j], K); res += G_t[i][j] * (h_muhp - lse); } /* second term */ for (int k = 0; k < K; k++) { double diff = h_t[i][k] - (1 - lambda) * h_pre_t[i][k] - lambda * ave_neighbors[i][k]; res -= 0.5 * diff * diff / (sigma * sigma); } /* third term */ for (int k = 0; k < K; k++) { double diff_3 = mu_hat_prime_t[i][k] - mu_hat_prime_pre_t[i][k]; res -= 0.5 * diff_3 * diff_3 / (sigma * sigma); } } } else { /* double[][] G_t = GS.get(t); double[][] h_t = h_s.get(t); double[] mu_hat_prime_t = mu_hat_prime_s.get(t); double delta_t = delta_s.get(t); int[][] neg_sam_t = neg_samples.get(t); for (int i = 0; i < n; i++) { // first term for (int j = 0; j < n; j++) if (G_t[i][j] != 0) { List<Double> powers = new ArrayList<Double>(); for (int _l = 0; _l < NEG; _l++) { int l = neg_sam_t[i][_l]; powers.add(h_t[i][0] * mu_hat_prime_t[l] + 0.5 * h_t[i][0] * h_t[i][0] * delta_t * delta_t); } double lse = log_sum_exp(powers); res += G_t[i][j] * (h_t[i][0] * mu_hat_prime_t[j] - lse); } } */ } } return res; }
/** compute_objective1: return the lower bound when h' is fixed */ public static double compute_objective1() { double res = 0; for (int t = 0; t < T - t0; t++) { if (t != 0) { double[][] G_t = GS.get(t); double[][] h_prime_t = h_prime_s.get(t); double[][] h_prime_pre_t = h_prime_s.get(t - 1); double[][] mu_hat_t = mu_hat_s.get(t); double[][] mu_hat_pre_t = mu_hat_s.get(t - 1); double delta_t = delta_s.get(t); Matrix a = new Matrix(AS.get(t - 1)); Matrix hprime_pre_t = new Matrix(h_prime_s.get(t - 1)); Matrix ave_neighbors = a.times(hprime_pre_t); double[] hp2delta2 = new double[n]; for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { hp2delta2[i] += 0.5 * h_prime_t[i][k] * h_prime_t[i][k] * delta_t * delta_t; } for (int i = 0; i < n; i++) { /* first term */ List<Double> powers = new ArrayList<Double>(); for (int l = 0; l < n; l++) { double hp_muh = Operations.inner_product(h_prime_t[l], mu_hat_t[i], K); powers.add(hp_muh + hp2delta2[l]); } double lse = log_sum_exp(powers); for (int j = 0; j < n; j++) if (G_t[i][j] != 0) { double hp_muh = Operations.inner_product(h_prime_t[j], mu_hat_t[i], K); res += G_t[i][j] * (hp_muh - lse); } /* second term */ for (int k = 0; k < K; k++) { double diff = mu_hat_t[i][k] - (1 - lambda) * mu_hat_pre_t[i][k] - lambda * ave_neighbors.get(i, k); res -= 0.5 * diff * diff / (sigma * sigma); } } } else { /* double[][] G_t = GS.get(t); double[][] h_prime_t = h_prime_s.get(t); double[] mu_hat_t = mu_hat_s.get(t); double delta_t = delta_s.get(t); int[][] neg_sam_t = neg_samples.get(t); for (int i = 0; i < n; i++) { // first term for (int j = 0; j < n; j++) if (G_t[i][j] != 0) { List<Double> powers = new ArrayList<Double>(); for (int _l = 0; _l < NEG; _l++) { int l = neg_sam_t[i][_l]; powers.add(h_prime_t[l][0] * mu_hat_t[i] + 0.5 * h_prime_t[l][0] * h_prime_t[l][0] * delta_t * delta_t); } double lse = log_sum_exp(powers); res += G_t[i][j] * (h_prime_t[j][0] * mu_hat_t[i] - lse); } } */ } } return res; }