private Matrix SqrtSPKF(Matrix PDash) { // Only works with a symmetric Positive Definite matrix // [V,D] = eig(A) // S = V*diag(sqrt(diag(D)))*V' // // //PDash in // System.out.println("PDash: "); // PDash.print(3, 2); // Matrix NegKeeper = Matrix.identity(9, 9); // //Testing Forced Compliance // for(int i = 0; i< 9; i++){ // if (PDash.get(i, i)< 0){ // NegKeeper.set(i,i,-1); // PDash.set(i, i, -PDash.get(i, i)); // } // } EigenvalueDecomposition eig = PDash.eig(); Matrix V = eig.getV(); Matrix D = eig.getD(); int iDSize = D.getRowDimension(); for (int i = 0; i < iDSize; i++) { D.set(i, i, Math.sqrt(D.get(i, i))); } Matrix S = V.times(D).times(V.inverse()); // S = S.times(NegKeeper); return S; }
/** toy example */ public static void test2() { int N = 500; double[][] m1 = new double[N][N]; double[][] m2 = new double[N][N]; double[][] m3 = new double[N][N]; // init Random rand = new Random(); for (int i = 0; i < N; i++) for (int j = 0; j < N; j++) { m1[i][j] = 10 * (rand.nextDouble() - 0.2); m2[i][j] = 20 * (rand.nextDouble() - 0.8); } // inverse System.out.println("Start"); Matrix mat1 = new Matrix(m1); Matrix mat2 = mat1.inverse(); Matrix mat3 = mat1.times(mat2); double[][] m4 = mat3.getArray(); /* for (int i = 0; i < m4.length; i++) { int ss = 10; for (int j = 0; j < ss; j++) { System.out.printf("%f ", m4[i][j]); } System.out.print("\n"); } */ System.out.println("Done"); /* // matrix * System.out.println("Start"); for (int i = 0; i < N; i++) for (int j = 0; j < N; j++) { double cell = 0; for (int k = 0; k < N; k++) cell += m1[i][k] * m2[k][j]; // System.out.printf("%f ", cell); m3[i][j] = cell; } System.out.println("Done"); */ }
private void RunSPKF() { // SPKF Steps: // 1) Generate Test Points // 2) Propagate Test Points // 3) Compute Predicted Mean and Covariance // 4) Compute Measurements // 5) Compute Innovations and Cross Covariance // 6) Compute corrections and update // Line up initial variables from the controller! Double dAlpha = dGreek.get(0); Double dBeta = dGreek.get(1); cController.setAlpha(dAlpha); cController.setBeta(dBeta); cController.setKappa(dGreek.get(2)); Double dGamma = cController.getGamma(); Double dLambda = cController.getLambda(); // // DEBUG - Print the Greeks // System.out.println("Greeks!"); // System.out.println("Alpha - " + dAlpha); // System.out.println("Beta - " + dBeta); // System.out.println("Kappa - " + dGreek.get(2)); // System.out.println("Lambda - " + dLambda); // System.out.println("Gamma - " + dGamma); // Let's get started: // Step 1: Generate Test Points Vector<Matrix> Chi = new Vector<Matrix>(); Vector<Matrix> UpChi = new Vector<Matrix>(); Vector<Matrix> UpY = new Vector<Matrix>(); Matrix UpPx = new Matrix(3, 3, 0.0); Matrix UpPy = new Matrix(3, 3, 0.0); Matrix UpPxy = new Matrix(3, 3, 0.0); Matrix K; Vector<Double> wc = new Vector<Double>(); Vector<Double> wm = new Vector<Double>(); Chi.add(X); // Add Chi_0 - the current state estimate (X, Y, Z) // Big P Matrix is LxL diagonal Matrix SqrtP = SqrtSPKF(P); SqrtP = SqrtP.times(dGamma); // Set up Sigma Points for (int i = 0; i <= 8; i++) { Matrix tempVec = SqrtP.getMatrix(0, 8, i, i); Matrix tempX = X; Matrix tempPlus = tempX.plus(tempVec); // System.out.println("TempPlus"); // tempPlus.print(3, 2); Matrix tempMinu = tempX.minus(tempVec); // System.out.println("TempMinus"); // tempMinu.print(3, 2); // tempX = X.copy(); // tempX.setMatrix(i, i, 0, 2, tempPlus); Chi.add(tempPlus); // tempX = X.copy(); // tempX.setMatrix(i, i, 0, 2, tempMinu); Chi.add(tempMinu); } // DEBUG Print the lines inside the Chi Matrix (2L x L) // for (int i = 0; i<=(2*L); i++){ // System.out.println("Chi Matrix Set: "+i); // Chi.get(i).print(5, 2); // } // Generate weights Double WeightZero = (dLambda / (L + dLambda)); Double OtherWeight = (1 / (2 * (L + dLambda))); Double TotalWeight = WeightZero; wm.add(WeightZero); wc.add(WeightZero + (1 - (dAlpha * dAlpha) + dBeta)); for (int i = 1; i <= (2 * L); i++) { TotalWeight = TotalWeight + OtherWeight; wm.add(OtherWeight); wc.add(OtherWeight); } // Weights MUST BE 1 in total for (int i = 0; i <= (2 * L); i++) { wm.set(i, wm.get(i) / TotalWeight); wc.set(i, wc.get(i) / TotalWeight); } // //DEBUG Print the weights // System.out.println("Total Weight:"); // System.out.println(TotalWeight); // for (int i = 0; i<=(2*L); i++){ // System.out.println("Weight M for "+i+" Entry"); // System.out.println(wm.get(i)); // System.out.println("Weight C for "+i+" Entry"); // System.out.println(wc.get(i)); // } // Step 2: Propagate Test Points // This will also handle computing the mean Double ux = dControl.elementAt(0); Double uy = dControl.elementAt(1); Double uz = dControl.elementAt(2); Matrix XhatMean = new Matrix(3, 1, 0.0); for (int i = 0; i < Chi.size(); i++) { Matrix ChiOne = Chi.get(i); Matrix Chixminus = new Matrix(3, 1, 0.0); Double Xhat = ChiOne.get(0, 0); Double Yhat = ChiOne.get(1, 0); Double Zhat = ChiOne.get(2, 0); Double Xerr = ChiOne.get(3, 0); Double Yerr = ChiOne.get(4, 0); Double Zerr = ChiOne.get(5, 0); Xhat = Xhat + ux + Xerr; Yhat = Yhat + uy + Yerr; Zhat = Zhat + uz + Zerr; Chixminus.set(0, 0, Xhat); Chixminus.set(1, 0, Yhat); Chixminus.set(2, 0, Zhat); // System.out.println("ChixMinus:"); // Chixminus.print(3, 2); UpChi.add(Chixminus); XhatMean = XhatMean.plus(Chixminus.times(wm.get(i))); } // Mean is right! // System.out.println("XhatMean: "); // XhatMean.print(3, 2); // Step 3: Compute Predicted Mean and Covariance // Welp, we already solved the mean - let's do the covariance now for (int i = 0; i <= (2 * L); i++) { Matrix tempP = UpChi.get(i).minus(XhatMean); Matrix tempPw = tempP.times(wc.get(i)); tempP = tempPw.times(tempP.transpose()); UpPx = UpPx.plus(tempP); } // New Steps! // Step 4: Compute Measurements! (and Y mean!) Matrix YhatMean = new Matrix(3, 1, 0.0); for (int i = 0; i <= (2 * L); i++) { Matrix ChiOne = Chi.get(i); Matrix Chiyminus = new Matrix(3, 1, 0.0); Double Xhat = UpChi.get(i).get(0, 0); Double Yhat = UpChi.get(i).get(1, 0); Double Zhat = UpChi.get(i).get(2, 0); Double Xerr = ChiOne.get(6, 0); Double Yerr = ChiOne.get(7, 0); Double Zerr = ChiOne.get(8, 0); Xhat = Xhat + Xerr; Yhat = Yhat + Yerr; Zhat = Zhat + Zerr; Chiyminus.set(0, 0, Xhat); Chiyminus.set(1, 0, Yhat); Chiyminus.set(2, 0, Zhat); UpY.add(Chiyminus); YhatMean = YhatMean.plus(Chiyminus.times(wm.get(i))); } // // Welp, we already solved the mean - let's do the covariances // now // System.out.println("XHatMean and YHatMean = "); // XhatMean.print(3, 2); // YhatMean.print(3, 2); for (int i = 0; i <= (2 * L); i++) { Matrix tempPx = UpChi.get(i).minus(XhatMean); Matrix tempPy = UpY.get(i).minus(YhatMean); // System.out.println("ChiX - XhatMean and ChiY-YhatMean"); // tempPx.print(3, 2); // tempPy.print(3, 2); Matrix tempPxw = tempPx.times(wc.get(i)); Matrix tempPyw = tempPy.times(wc.get(i)); tempPx = tempPxw.times(tempPy.transpose()); tempPy = tempPyw.times(tempPy.transpose()); UpPy = UpPy.plus(tempPy); UpPxy = UpPxy.plus(tempPx); } // Step 6: Compute Corrections and Update // Compute Kalman Gain! // System.out.println("Updated Px"); // UpPx.print(5, 2); // System.out.println("Updated Py"); // UpPy.print(5, 2); // System.out.println("Updated Pxy"); // UpPxy.print(5, 2); K = UpPxy.times(UpPy.inverse()); // System.out.println("Kalman"); // K.print(5, 2); Matrix Mea = new Matrix(3, 1, 0.0); Mea.set(0, 0, dMeasure.get(0)); Mea.set(1, 0, dMeasure.get(1)); Mea.set(2, 0, dMeasure.get(2)); Matrix Out = K.times(Mea.minus(YhatMean)); Out = Out.plus(XhatMean); // System.out.println("Out:"); // Out.print(3, 2); Matrix Px = UpPx.minus(K.times(UpPy.times(K.transpose()))); // Update Stuff! // Push the P to the controller Matrix OutP = P.copy(); OutP.setMatrix(0, 2, 0, 2, Px); X.setMatrix(0, 2, 0, 0, Out); Residual = XhatMean.minus(Out); cController.inputState(OutP, Residual); // cController.setL(L); cController.startProcess(); while (!cController.finishedProcess()) { try { Thread.sleep(10); } catch (InterruptedException e) { e.printStackTrace(); } } // System.out.println("Post Greeks: " + cController.getAlpha() + " , // "+ cController.getBeta()); dGreek.set(0, cController.getAlpha()); dGreek.set(1, cController.getBeta()); dGreek.set(2, cController.getKappa()); P = cController.getP(); // System.out.println("P is post Process:"); // P.print(3, 2); StepDone = true; }
public static void main(String argv[]) { Matrix A, B, C, Z, O, I, R, S, X, SUB, M, T, SQ, DEF, SOL; // Uncomment this to test IO in a different locale. // Locale.setDefault(Locale.GERMAN); int errorCount = 0; int warningCount = 0; double tmp, s; double[] columnwise = {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}; double[] rowwise = {1., 4., 7., 10., 2., 5., 8., 11., 3., 6., 9., 12.}; double[][] avals = {{1., 4., 7., 10.}, {2., 5., 8., 11.}, {3., 6., 9., 12.}}; double[][] rankdef = avals; double[][] tvals = {{1., 2., 3.}, {4., 5., 6.}, {7., 8., 9.}, {10., 11., 12.}}; double[][] subavals = {{5., 8., 11.}, {6., 9., 12.}}; double[][] rvals = {{1., 4., 7.}, {2., 5., 8., 11.}, {3., 6., 9., 12.}}; double[][] pvals = {{4., 1., 1.}, {1., 2., 3.}, {1., 3., 6.}}; double[][] ivals = {{1., 0., 0., 0.}, {0., 1., 0., 0.}, {0., 0., 1., 0.}}; double[][] evals = { {0., 1., 0., 0.}, {1., 0., 2.e-7, 0.}, {0., -2.e-7, 0., 1.}, {0., 0., 1., 0.} }; double[][] square = {{166., 188., 210.}, {188., 214., 240.}, {210., 240., 270.}}; double[][] sqSolution = {{13.}, {15.}}; double[][] condmat = {{1., 3.}, {7., 9.}}; double[][] badeigs = { {0, 0, 0, 0, 0}, {0, 0, 0, 0, 1}, {0, 0, 0, 1, 0}, {1, 1, 0, 0, 1}, {1, 0, 1, 0, 1} }; int rows = 3, cols = 4; int invalidld = 5; /* should trigger bad shape for construction with val */ int raggedr = 0; /* (raggedr,raggedc) should be out of bounds in ragged array */ int raggedc = 4; int validld = 3; /* leading dimension of intended test Matrices */ int nonconformld = 4; /* leading dimension which is valid, but nonconforming */ int ib = 1, ie = 2, jb = 1, je = 3; /* index ranges for sub Matrix */ int[] rowindexset = {1, 2}; int[] badrowindexset = {1, 3}; int[] columnindexset = {1, 2, 3}; int[] badcolumnindexset = {1, 2, 4}; double columnsummax = 33.; double rowsummax = 30.; double sumofdiagonals = 15; double sumofsquares = 650; /** * Constructors and constructor-like methods: double[], int double[][] int, int int, int, double * int, int, double[][] constructWithCopy(double[][]) random(int,int) identity(int) */ print("\nTesting constructors and constructor-like methods...\n"); try { /** check that exception is thrown in packed constructor with invalid length * */ A = new Matrix(columnwise, invalidld); errorCount = try_failure( errorCount, "Catch invalid length in packed constructor... ", "exception not thrown for invalid input"); } catch (IllegalArgumentException e) { try_success("Catch invalid length in packed constructor... ", e.getMessage()); } try { /** check that exception is thrown in default constructor if input array is 'ragged' * */ A = new Matrix(rvals); tmp = A.get(raggedr, raggedc); } catch (IllegalArgumentException e) { try_success("Catch ragged input to default constructor... ", e.getMessage()); } catch (java.lang.ArrayIndexOutOfBoundsException e) { errorCount = try_failure( errorCount, "Catch ragged input to constructor... ", "exception not thrown in construction...ArrayIndexOutOfBoundsException thrown later"); } try { /** check that exception is thrown in constructWithCopy if input array is 'ragged' * */ A = Matrix.constructWithCopy(rvals); tmp = A.get(raggedr, raggedc); } catch (IllegalArgumentException e) { try_success("Catch ragged input to constructWithCopy... ", e.getMessage()); } catch (java.lang.ArrayIndexOutOfBoundsException e) { errorCount = try_failure( errorCount, "Catch ragged input to constructWithCopy... ", "exception not thrown in construction...ArrayIndexOutOfBoundsException thrown later"); } A = new Matrix(columnwise, validld); B = new Matrix(avals); tmp = B.get(0, 0); avals[0][0] = 0.0; C = B.minus(A); avals[0][0] = tmp; B = Matrix.constructWithCopy(avals); tmp = B.get(0, 0); avals[0][0] = 0.0; if ((tmp - B.get(0, 0)) != 0.0) { /** check that constructWithCopy behaves properly * */ errorCount = try_failure( errorCount, "constructWithCopy... ", "copy not effected... data visible outside"); } else { try_success("constructWithCopy... ", ""); } avals[0][0] = columnwise[0]; I = new Matrix(ivals); try { check(I, Matrix.identity(3, 4)); try_success("identity... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "identity... ", "identity Matrix not successfully created"); } /** * Access Methods: getColumnDimension() getRowDimension() getArray() getArrayCopy() * getColumnPackedCopy() getRowPackedCopy() get(int,int) getMatrix(int,int,int,int) * getMatrix(int,int,int[]) getMatrix(int[],int,int) getMatrix(int[],int[]) set(int,int,double) * setMatrix(int,int,int,int,Matrix) setMatrix(int,int,int[],Matrix) * setMatrix(int[],int,int,Matrix) setMatrix(int[],int[],Matrix) */ print("\nTesting access methods...\n"); /** Various get methods: */ B = new Matrix(avals); if (B.getRowDimension() != rows) { errorCount = try_failure(errorCount, "getRowDimension... ", ""); } else { try_success("getRowDimension... ", ""); } if (B.getColumnDimension() != cols) { errorCount = try_failure(errorCount, "getColumnDimension... ", ""); } else { try_success("getColumnDimension... ", ""); } B = new Matrix(avals); double[][] barray = B.getArray(); if (barray != avals) { errorCount = try_failure(errorCount, "getArray... ", ""); } else { try_success("getArray... ", ""); } barray = B.getArrayCopy(); if (barray == avals) { errorCount = try_failure(errorCount, "getArrayCopy... ", "data not (deep) copied"); } try { check(barray, avals); try_success("getArrayCopy... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "getArrayCopy... ", "data not successfully (deep) copied"); } double[] bpacked = B.getColumnPackedCopy(); try { check(bpacked, columnwise); try_success("getColumnPackedCopy... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "getColumnPackedCopy... ", "data not successfully (deep) copied by columns"); } bpacked = B.getRowPackedCopy(); try { check(bpacked, rowwise); try_success("getRowPackedCopy... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "getRowPackedCopy... ", "data not successfully (deep) copied by rows"); } try { tmp = B.get(B.getRowDimension(), B.getColumnDimension() - 1); errorCount = try_failure( errorCount, "get(int,int)... ", "OutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e) { try { tmp = B.get(B.getRowDimension() - 1, B.getColumnDimension()); errorCount = try_failure( errorCount, "get(int,int)... ", "OutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e1) { try_success("get(int,int)... OutofBoundsException... ", ""); } } catch (java.lang.IllegalArgumentException e1) { errorCount = try_failure( errorCount, "get(int,int)... ", "OutOfBoundsException expected but not thrown"); } try { if (B.get(B.getRowDimension() - 1, B.getColumnDimension() - 1) != avals[B.getRowDimension() - 1][B.getColumnDimension() - 1]) { errorCount = try_failure( errorCount, "get(int,int)... ", "Matrix entry (i,j) not successfully retreived"); } else { try_success("get(int,int)... ", ""); } } catch (java.lang.ArrayIndexOutOfBoundsException e) { errorCount = try_failure(errorCount, "get(int,int)... ", "Unexpected ArrayIndexOutOfBoundsException"); } SUB = new Matrix(subavals); try { M = B.getMatrix(ib, ie + B.getRowDimension() + 1, jb, je); errorCount = try_failure( errorCount, "getMatrix(int,int,int,int)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e) { try { M = B.getMatrix(ib, ie, jb, je + B.getColumnDimension() + 1); errorCount = try_failure( errorCount, "getMatrix(int,int,int,int)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e1) { try_success("getMatrix(int,int,int,int)... ArrayIndexOutOfBoundsException... ", ""); } } catch (java.lang.IllegalArgumentException e1) { errorCount = try_failure( errorCount, "getMatrix(int,int,int,int)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } try { M = B.getMatrix(ib, ie, jb, je); try { check(SUB, M); try_success("getMatrix(int,int,int,int)... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "getMatrix(int,int,int,int)... ", "submatrix not successfully retreived"); } } catch (java.lang.ArrayIndexOutOfBoundsException e) { errorCount = try_failure( errorCount, "getMatrix(int,int,int,int)... ", "Unexpected ArrayIndexOutOfBoundsException"); } try { M = B.getMatrix(ib, ie, badcolumnindexset); errorCount = try_failure( errorCount, "getMatrix(int,int,int[])... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e) { try { M = B.getMatrix(ib, ie + B.getRowDimension() + 1, columnindexset); errorCount = try_failure( errorCount, "getMatrix(int,int,int[])... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e1) { try_success("getMatrix(int,int,int[])... ArrayIndexOutOfBoundsException... ", ""); } } catch (java.lang.IllegalArgumentException e1) { errorCount = try_failure( errorCount, "getMatrix(int,int,int[])... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } try { M = B.getMatrix(ib, ie, columnindexset); try { check(SUB, M); try_success("getMatrix(int,int,int[])... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "getMatrix(int,int,int[])... ", "submatrix not successfully retreived"); } } catch (java.lang.ArrayIndexOutOfBoundsException e) { errorCount = try_failure( errorCount, "getMatrix(int,int,int[])... ", "Unexpected ArrayIndexOutOfBoundsException"); } try { M = B.getMatrix(badrowindexset, jb, je); errorCount = try_failure( errorCount, "getMatrix(int[],int,int)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e) { try { M = B.getMatrix(rowindexset, jb, je + B.getColumnDimension() + 1); errorCount = try_failure( errorCount, "getMatrix(int[],int,int)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e1) { try_success("getMatrix(int[],int,int)... ArrayIndexOutOfBoundsException... ", ""); } } catch (java.lang.IllegalArgumentException e1) { errorCount = try_failure( errorCount, "getMatrix(int[],int,int)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } try { M = B.getMatrix(rowindexset, jb, je); try { check(SUB, M); try_success("getMatrix(int[],int,int)... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "getMatrix(int[],int,int)... ", "submatrix not successfully retreived"); } } catch (java.lang.ArrayIndexOutOfBoundsException e) { errorCount = try_failure( errorCount, "getMatrix(int[],int,int)... ", "Unexpected ArrayIndexOutOfBoundsException"); } try { M = B.getMatrix(badrowindexset, columnindexset); errorCount = try_failure( errorCount, "getMatrix(int[],int[])... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e) { try { M = B.getMatrix(rowindexset, badcolumnindexset); errorCount = try_failure( errorCount, "getMatrix(int[],int[])... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e1) { try_success("getMatrix(int[],int[])... ArrayIndexOutOfBoundsException... ", ""); } } catch (java.lang.IllegalArgumentException e1) { errorCount = try_failure( errorCount, "getMatrix(int[],int[])... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } try { M = B.getMatrix(rowindexset, columnindexset); try { check(SUB, M); try_success("getMatrix(int[],int[])... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "getMatrix(int[],int[])... ", "submatrix not successfully retreived"); } } catch (java.lang.ArrayIndexOutOfBoundsException e) { errorCount = try_failure( errorCount, "getMatrix(int[],int[])... ", "Unexpected ArrayIndexOutOfBoundsException"); } /** Various set methods: */ try { B.set(B.getRowDimension(), B.getColumnDimension() - 1, 0.); errorCount = try_failure( errorCount, "set(int,int,double)... ", "OutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e) { try { B.set(B.getRowDimension() - 1, B.getColumnDimension(), 0.); errorCount = try_failure( errorCount, "set(int,int,double)... ", "OutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e1) { try_success("set(int,int,double)... OutofBoundsException... ", ""); } } catch (java.lang.IllegalArgumentException e1) { errorCount = try_failure( errorCount, "set(int,int,double)... ", "OutOfBoundsException expected but not thrown"); } try { B.set(ib, jb, 0.); tmp = B.get(ib, jb); try { check(tmp, 0.); try_success("set(int,int,double)... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "set(int,int,double)... ", "Matrix element not successfully set"); } } catch (java.lang.ArrayIndexOutOfBoundsException e1) { errorCount = try_failure( errorCount, "set(int,int,double)... ", "Unexpected ArrayIndexOutOfBoundsException"); } M = new Matrix(2, 3, 0.); try { B.setMatrix(ib, ie + B.getRowDimension() + 1, jb, je, M); errorCount = try_failure( errorCount, "setMatrix(int,int,int,int,Matrix)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e) { try { B.setMatrix(ib, ie, jb, je + B.getColumnDimension() + 1, M); errorCount = try_failure( errorCount, "setMatrix(int,int,int,int,Matrix)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e1) { try_success("setMatrix(int,int,int,int,Matrix)... ArrayIndexOutOfBoundsException... ", ""); } } catch (java.lang.IllegalArgumentException e1) { errorCount = try_failure( errorCount, "setMatrix(int,int,int,int,Matrix)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } try { B.setMatrix(ib, ie, jb, je, M); try { check(M.minus(B.getMatrix(ib, ie, jb, je)), M); try_success("setMatrix(int,int,int,int,Matrix)... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "setMatrix(int,int,int,int,Matrix)... ", "submatrix not successfully set"); } B.setMatrix(ib, ie, jb, je, SUB); } catch (java.lang.ArrayIndexOutOfBoundsException e1) { errorCount = try_failure( errorCount, "setMatrix(int,int,int,int,Matrix)... ", "Unexpected ArrayIndexOutOfBoundsException"); } try { B.setMatrix(ib, ie + B.getRowDimension() + 1, columnindexset, M); errorCount = try_failure( errorCount, "setMatrix(int,int,int[],Matrix)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e) { try { B.setMatrix(ib, ie, badcolumnindexset, M); errorCount = try_failure( errorCount, "setMatrix(int,int,int[],Matrix)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e1) { try_success("setMatrix(int,int,int[],Matrix)... ArrayIndexOutOfBoundsException... ", ""); } } catch (java.lang.IllegalArgumentException e1) { errorCount = try_failure( errorCount, "setMatrix(int,int,int[],Matrix)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } try { B.setMatrix(ib, ie, columnindexset, M); try { check(M.minus(B.getMatrix(ib, ie, columnindexset)), M); try_success("setMatrix(int,int,int[],Matrix)... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "setMatrix(int,int,int[],Matrix)... ", "submatrix not successfully set"); } B.setMatrix(ib, ie, jb, je, SUB); } catch (java.lang.ArrayIndexOutOfBoundsException e1) { errorCount = try_failure( errorCount, "setMatrix(int,int,int[],Matrix)... ", "Unexpected ArrayIndexOutOfBoundsException"); } try { B.setMatrix(rowindexset, jb, je + B.getColumnDimension() + 1, M); errorCount = try_failure( errorCount, "setMatrix(int[],int,int,Matrix)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e) { try { B.setMatrix(badrowindexset, jb, je, M); errorCount = try_failure( errorCount, "setMatrix(int[],int,int,Matrix)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e1) { try_success("setMatrix(int[],int,int,Matrix)... ArrayIndexOutOfBoundsException... ", ""); } } catch (java.lang.IllegalArgumentException e1) { errorCount = try_failure( errorCount, "setMatrix(int[],int,int,Matrix)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } try { B.setMatrix(rowindexset, jb, je, M); try { check(M.minus(B.getMatrix(rowindexset, jb, je)), M); try_success("setMatrix(int[],int,int,Matrix)... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "setMatrix(int[],int,int,Matrix)... ", "submatrix not successfully set"); } B.setMatrix(ib, ie, jb, je, SUB); } catch (java.lang.ArrayIndexOutOfBoundsException e1) { errorCount = try_failure( errorCount, "setMatrix(int[],int,int,Matrix)... ", "Unexpected ArrayIndexOutOfBoundsException"); } try { B.setMatrix(rowindexset, badcolumnindexset, M); errorCount = try_failure( errorCount, "setMatrix(int[],int[],Matrix)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e) { try { B.setMatrix(badrowindexset, columnindexset, M); errorCount = try_failure( errorCount, "setMatrix(int[],int[],Matrix)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } catch (java.lang.ArrayIndexOutOfBoundsException e1) { try_success("setMatrix(int[],int[],Matrix)... ArrayIndexOutOfBoundsException... ", ""); } } catch (java.lang.IllegalArgumentException e1) { errorCount = try_failure( errorCount, "setMatrix(int[],int[],Matrix)... ", "ArrayIndexOutOfBoundsException expected but not thrown"); } try { B.setMatrix(rowindexset, columnindexset, M); try { check(M.minus(B.getMatrix(rowindexset, columnindexset)), M); try_success("setMatrix(int[],int[],Matrix)... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "setMatrix(int[],int[],Matrix)... ", "submatrix not successfully set"); } } catch (java.lang.ArrayIndexOutOfBoundsException e1) { errorCount = try_failure( errorCount, "setMatrix(int[],int[],Matrix)... ", "Unexpected ArrayIndexOutOfBoundsException"); } /** * Array-like methods: minus minusEquals plus plusEquals arrayLeftDivide arrayLeftDivideEquals * arrayRightDivide arrayRightDivideEquals arrayTimes arrayTimesEquals uminus */ print("\nTesting array-like methods...\n"); S = new Matrix(columnwise, nonconformld); R = Matrix.random(A.getRowDimension(), A.getColumnDimension()); A = R; try { S = A.minus(S); errorCount = try_failure(errorCount, "minus conformance check... ", "nonconformance not raised"); } catch (IllegalArgumentException e) { try_success("minus conformance check... ", ""); } if (A.minus(R).norm1() != 0.) { errorCount = try_failure( errorCount, "minus... ", "(difference of identical Matrices is nonzero,\nSubsequent use of minus should be suspect)"); } else { try_success("minus... ", ""); } A = R.copy(); A.minusEquals(R); Z = new Matrix(A.getRowDimension(), A.getColumnDimension()); try { A.minusEquals(S); errorCount = try_failure(errorCount, "minusEquals conformance check... ", "nonconformance not raised"); } catch (IllegalArgumentException e) { try_success("minusEquals conformance check... ", ""); } if (A.minus(Z).norm1() != 0.) { errorCount = try_failure( errorCount, "minusEquals... ", "(difference of identical Matrices is nonzero,\nSubsequent use of minus should be suspect)"); } else { try_success("minusEquals... ", ""); } A = R.copy(); B = Matrix.random(A.getRowDimension(), A.getColumnDimension()); C = A.minus(B); try { S = A.plus(S); errorCount = try_failure(errorCount, "plus conformance check... ", "nonconformance not raised"); } catch (IllegalArgumentException e) { try_success("plus conformance check... ", ""); } try { check(C.plus(B), A); try_success("plus... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "plus... ", "(C = A - B, but C + B != A)"); } C = A.minus(B); C.plusEquals(B); try { A.plusEquals(S); errorCount = try_failure(errorCount, "plusEquals conformance check... ", "nonconformance not raised"); } catch (IllegalArgumentException e) { try_success("plusEquals conformance check... ", ""); } try { check(C, A); try_success("plusEquals... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "plusEquals... ", "(C = A - B, but C = C + B != A)"); } A = R.uminus(); try { check(A.plus(R), Z); try_success("uminus... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "uminus... ", "(-A + A != zeros)"); } A = R.copy(); O = new Matrix(A.getRowDimension(), A.getColumnDimension(), 1.0); C = A.arrayLeftDivide(R); try { S = A.arrayLeftDivide(S); errorCount = try_failure( errorCount, "arrayLeftDivide conformance check... ", "nonconformance not raised"); } catch (IllegalArgumentException e) { try_success("arrayLeftDivide conformance check... ", ""); } try { check(C, O); try_success("arrayLeftDivide... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "arrayLeftDivide... ", "(M.\\M != ones)"); } try { A.arrayLeftDivideEquals(S); errorCount = try_failure( errorCount, "arrayLeftDivideEquals conformance check... ", "nonconformance not raised"); } catch (IllegalArgumentException e) { try_success("arrayLeftDivideEquals conformance check... ", ""); } A.arrayLeftDivideEquals(R); try { check(A, O); try_success("arrayLeftDivideEquals... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "arrayLeftDivideEquals... ", "(M.\\M != ones)"); } A = R.copy(); try { A.arrayRightDivide(S); errorCount = try_failure( errorCount, "arrayRightDivide conformance check... ", "nonconformance not raised"); } catch (IllegalArgumentException e) { try_success("arrayRightDivide conformance check... ", ""); } C = A.arrayRightDivide(R); try { check(C, O); try_success("arrayRightDivide... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "arrayRightDivide... ", "(M./M != ones)"); } try { A.arrayRightDivideEquals(S); errorCount = try_failure( errorCount, "arrayRightDivideEquals conformance check... ", "nonconformance not raised"); } catch (IllegalArgumentException e) { try_success("arrayRightDivideEquals conformance check... ", ""); } A.arrayRightDivideEquals(R); try { check(A, O); try_success("arrayRightDivideEquals... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "arrayRightDivideEquals... ", "(M./M != ones)"); } A = R.copy(); B = Matrix.random(A.getRowDimension(), A.getColumnDimension()); try { S = A.arrayTimes(S); errorCount = try_failure(errorCount, "arrayTimes conformance check... ", "nonconformance not raised"); } catch (IllegalArgumentException e) { try_success("arrayTimes conformance check... ", ""); } C = A.arrayTimes(B); try { check(C.arrayRightDivideEquals(B), A); try_success("arrayTimes... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "arrayTimes... ", "(A = R, C = A.*B, but C./B != A)"); } try { A.arrayTimesEquals(S); errorCount = try_failure( errorCount, "arrayTimesEquals conformance check... ", "nonconformance not raised"); } catch (IllegalArgumentException e) { try_success("arrayTimesEquals conformance check... ", ""); } A.arrayTimesEquals(B); try { check(A.arrayRightDivideEquals(B), R); try_success("arrayTimesEquals... ", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "arrayTimesEquals... ", "(A = R, A = A.*B, but A./B != R)"); } /** I/O methods: read print serializable: writeObject readObject */ print("\nTesting I/O methods...\n"); try { DecimalFormat fmt = new DecimalFormat("0.0000E00"); fmt.setDecimalFormatSymbols(new DecimalFormatSymbols(Locale.US)); PrintWriter FILE = new PrintWriter(new FileOutputStream("JamaTestMatrix.out")); A.print(FILE, fmt, 10); FILE.close(); R = Matrix.read(new BufferedReader(new FileReader("JamaTestMatrix.out"))); if (A.minus(R).norm1() < .001) { try_success("print()/read()...", ""); } else { errorCount = try_failure( errorCount, "print()/read()...", "Matrix read from file does not match Matrix printed to file"); } } catch (java.io.IOException ioe) { warningCount = try_warning( warningCount, "print()/read()...", "unexpected I/O error, unable to run print/read test; check write permission in current directory and retry"); } catch (Exception e) { try { e.printStackTrace(System.out); warningCount = try_warning( warningCount, "print()/read()...", "Formatting error... will try JDK1.1 reformulation..."); DecimalFormat fmt = new DecimalFormat("0.0000"); PrintWriter FILE = new PrintWriter(new FileOutputStream("JamaTestMatrix.out")); A.print(FILE, fmt, 10); FILE.close(); R = Matrix.read(new BufferedReader(new FileReader("JamaTestMatrix.out"))); if (A.minus(R).norm1() < .001) { try_success("print()/read()...", ""); } else { errorCount = try_failure( errorCount, "print()/read() (2nd attempt) ...", "Matrix read from file does not match Matrix printed to file"); } } catch (java.io.IOException ioe) { warningCount = try_warning( warningCount, "print()/read()...", "unexpected I/O error, unable to run print/read test; check write permission in current directory and retry"); } } R = Matrix.random(A.getRowDimension(), A.getColumnDimension()); String tmpname = "TMPMATRIX.serial"; try { @SuppressWarnings("resource") ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(tmpname)); out.writeObject(R); @SuppressWarnings("resource") ObjectInputStream sin = new ObjectInputStream(new FileInputStream(tmpname)); A = (Matrix) sin.readObject(); try { check(A, R); try_success("writeObject(Matrix)/readObject(Matrix)...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "writeObject(Matrix)/readObject(Matrix)...", "Matrix not serialized correctly"); } } catch (java.io.IOException ioe) { warningCount = try_warning( warningCount, "writeObject()/readObject()...", "unexpected I/O error, unable to run serialization test; check write permission in current directory and retry"); } catch (Exception e) { errorCount = try_failure( errorCount, "writeObject(Matrix)/readObject(Matrix)...", "unexpected error in serialization test"); } /** * LA methods: transpose times cond rank det trace norm1 norm2 normF normInf solve * solveTranspose inverse chol eig lu qr svd */ print("\nTesting linear algebra methods...\n"); A = new Matrix(columnwise, 3); T = new Matrix(tvals); T = A.transpose(); try { check(A.transpose(), T); try_success("transpose...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "transpose()...", "transpose unsuccessful"); } A.transpose(); try { check(A.norm1(), columnsummax); try_success("norm1...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "norm1()...", "incorrect norm calculation"); } try { check(A.normInf(), rowsummax); try_success("normInf()...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "normInf()...", "incorrect norm calculation"); } try { check(A.normF(), Math.sqrt(sumofsquares)); try_success("normF...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "normF()...", "incorrect norm calculation"); } try { check(A.trace(), sumofdiagonals); try_success("trace()...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "trace()...", "incorrect trace calculation"); } try { check(A.getMatrix(0, A.getRowDimension() - 1, 0, A.getRowDimension() - 1).det(), 0.); try_success("det()...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "det()...", "incorrect determinant calculation"); } SQ = new Matrix(square); try { check(A.times(A.transpose()), SQ); try_success("times(Matrix)...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "times(Matrix)...", "incorrect Matrix-Matrix product calculation"); } try { check(A.times(0.), Z); try_success("times(double)...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "times(double)...", "incorrect Matrix-scalar product calculation"); } A = new Matrix(columnwise, 4); QRDecomposition QR = A.qr(); R = QR.getR(); try { check(A, QR.getQ().times(R)); try_success("QRDecomposition...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "QRDecomposition...", "incorrect QR decomposition calculation"); } SingularValueDecomposition SVD = A.svd(); try { check(A, SVD.getU().times(SVD.getS().times(SVD.getV().transpose()))); try_success("SingularValueDecomposition...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "SingularValueDecomposition...", "incorrect singular value decomposition calculation"); } DEF = new Matrix(rankdef); try { check(DEF.rank(), Math.min(DEF.getRowDimension(), DEF.getColumnDimension()) - 1); try_success("rank()...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "rank()...", "incorrect rank calculation"); } B = new Matrix(condmat); SVD = B.svd(); double[] singularvalues = SVD.getSingularValues(); try { check( B.cond(), singularvalues[0] / singularvalues[Math.min(B.getRowDimension(), B.getColumnDimension()) - 1]); try_success("cond()...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "cond()...", "incorrect condition number calculation"); } int n = A.getColumnDimension(); A = A.getMatrix(0, n - 1, 0, n - 1); A.set(0, 0, 0.); LUDecomposition LU = A.lu(); try { check(A.getMatrix(LU.getPivot(), 0, n - 1), LU.getL().times(LU.getU())); try_success("LUDecomposition...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "LUDecomposition...", "incorrect LU decomposition calculation"); } X = A.inverse(); try { check(A.times(X), Matrix.identity(3, 3)); try_success("inverse()...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "inverse()...", "incorrect inverse calculation"); } O = new Matrix(SUB.getRowDimension(), 1, 1.0); SOL = new Matrix(sqSolution); SQ = SUB.getMatrix(0, SUB.getRowDimension() - 1, 0, SUB.getRowDimension() - 1); try { check(SQ.solve(SOL), O); try_success("solve()...", ""); } catch (java.lang.IllegalArgumentException e1) { errorCount = try_failure(errorCount, "solve()...", e1.getMessage()); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "solve()...", e.getMessage()); } A = new Matrix(pvals); CholeskyDecomposition Chol = A.chol(); Matrix L = Chol.getL(); try { check(A, L.times(L.transpose())); try_success("CholeskyDecomposition...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "CholeskyDecomposition...", "incorrect Cholesky decomposition calculation"); } X = Chol.solve(Matrix.identity(3, 3)); try { check(A.times(X), Matrix.identity(3, 3)); try_success("CholeskyDecomposition solve()...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "CholeskyDecomposition solve()...", "incorrect Choleskydecomposition solve calculation"); } EigenvalueDecomposition Eig = A.eig(); Matrix D = Eig.getD(); Matrix V = Eig.getV(); try { check(A.times(V), V.times(D)); try_success("EigenvalueDecomposition (symmetric)...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "EigenvalueDecomposition (symmetric)...", "incorrect symmetric Eigenvalue decomposition calculation"); } A = new Matrix(evals); Eig = A.eig(); D = Eig.getD(); V = Eig.getV(); try { check(A.times(V), V.times(D)); try_success("EigenvalueDecomposition (nonsymmetric)...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure( errorCount, "EigenvalueDecomposition (nonsymmetric)...", "incorrect nonsymmetric Eigenvalue decomposition calculation"); } try { print("\nTesting Eigenvalue; If this hangs, we've failed\n"); Matrix bA = new Matrix(badeigs); EigenvalueDecomposition bEig = bA.eig(); try_success("EigenvalueDecomposition (hang)...", ""); } catch (java.lang.RuntimeException e) { errorCount = try_failure(errorCount, "EigenvalueDecomposition (hang)...", "incorrect termination"); } print("\nTestMatrix completed.\n"); print("Total errors reported: " + Integer.toString(errorCount) + "\n"); print("Total warnings reported: " + Integer.toString(warningCount) + "\n"); }
/* * backward pass 1: update * (1) \hat{mu} (mu_hat_s) * (2) \hat{grad_mu} (grad_mu_hat_s) * (3) \hat{V} (v_hat_s) */ public static void backward1(boolean update_grad) { for (int t1 = T - 1; t1 > t0; t1--) { int t = t1 - t0; // System.out.println("backward 1;\tt = " + t1); if (t != T - 1 - t0) { double V_pre_t = v_s.get(t - 1); // V^{t-1} double V_hat_t = v_hat_s.get(t); // \hat{V}^{t} double[][] mu_pre_t = mu_s.get(t - 1); // \mu^{t-1} double[][] mu_hat_t = mu_hat_s.get(t); // \hat{\mu}^{t} [t-1] Matrix A_pre_t = new Matrix(AS.get(t - 1)); // A^{t-1} Matrix hprime_pre_t = new Matrix(h_prime_s.get(t - 1)); // h'^{t-1} Matrix ave_neighbors = A_pre_t.times(hprime_pre_t); // n * 1 /* calculate \hat{\mu} at time t-1 */ double factor_1 = (1 - lambda) * V_pre_t / (sigma * sigma + (1 - lambda) * (1 - lambda) * V_pre_t); double factor_2 = (sigma * sigma) / (sigma * sigma + (1 - lambda) * (1 - lambda) * V_pre_t); double[][] mu_hat_pre_t = new double[n][K]; for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { mu_hat_pre_t[i][k] = factor_1 * (mu_hat_t[i][k] - lambda * ave_neighbors.get(i, k)) + factor_2 * mu_pre_t[i][k]; } /* calculate \hat{V} at time t-1 */ double V_hat_pre_t = V_pre_t + factor_1 * factor_1 * (V_hat_t - (1 - lambda) * (1 - lambda) * V_pre_t - (sigma * sigma)); /* update \mu and V */ mu_hat_s.set(t - 1, mu_hat_pre_t); v_hat_s.set(t - 1, V_hat_pre_t); /* calculate and update grad_mu_hat at time t-1 */ if (update_grad) for (int s = 0; s < T - t0; s++) { double[][] grad_hat_t_s = grad_mu_hat_s.get(t * (T - t0) + s); double[][] grad_pre_t_s = grad_mu_s.get((t - 1) * (T - t0) + s); double[][] grad_hat_pre_t_s = new double[n][K]; for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { grad_hat_pre_t_s[i][k] = factor_1 * grad_hat_t_s[i][k] + factor_2 * grad_pre_t_s[i][k]; } grad_mu_hat_s.set((t - 1) * (T - t0) + s, grad_hat_pre_t_s); } } else { /* * initial condition for backward pass: * (1) \hat{mu}^{T} = mu^{T} * (2) \hat{V}^{T} = V^{T} * (3) \hat{grad_mu}^{T/s} = grad_mu^{T/s}, \forall s */ mu_hat_s.set(t, mu_s.get(t)); v_hat_s.set(t, v_s.get(t)); if (update_grad) for (int s = 0; s < T - t0; s++) { grad_mu_hat_s.set(t * (T - t0) + s, grad_mu_s.get(t * (T - t0) + s)); } } // Scanner sc = new Scanner(System.in); // int gu; gu = sc.nextInt(); /* end for each t */ } }
/** * forward pass 1: update intrinsic features (1) mu (mu_s) (2) grad_mu (grad_mu_s) (3) variance V * (v_s) */ public static void forward1(boolean update_grad, int iter) { /* if (iter == 4) { int t = 15; double[][] h_t = new double[n][K]; double[][] h_hat_t = new double[n][K]; double[][] h_prime_t = new double[n][K]; double[][] h_hat_prime_t = new double[n][K]; for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { h_t[i][k] = h_s.get(t-1)[i][k]; h_hat_t[i][k] = h_hat_s.get(t-1)[i][k]; h_prime_t[i][k] = h_prime_s.get(t-1)[i][k]; h_hat_prime_t[i][k] = h_hat_prime_s.get(t-1)[i][k]; } h_s.set(t, h_t); h_hat_s.set(t, h_hat_t); h_prime_s.set(t, h_prime_t); h_hat_prime_s.set(t, h_hat_prime_t); } */ for (int t = 0; t < T - t0; t++) { // System.out.println("forward 1;\tt = " + t1); if (t != 0) { double delta_t = delta_s.get(t); // delta_t double[][] h_hat_t = h_hat_s.get(t); // \hat{h}^t [t] double[][] mu_pre_t = mu_s.get(t - 1); // mu^{t-1} (N*1) double V_pre_t = v_s.get(t - 1); // V^{t-1} Matrix a = new Matrix(AS.get(t - 1)); // A^{t-1} Matrix hprime_pre_t = new Matrix(h_prime_s.get(t - 1)); // h'^{t-1} Matrix ave_neighbors = a.times(hprime_pre_t); /* calculate \mu */ double[][] mu_t = new double[n][K]; double factor_1 = (delta_t * delta_t) / (delta_t * delta_t + sigma * sigma + (1 - lambda) * (1 - lambda) * V_pre_t); double factor_2 = (sigma * sigma + (1 - lambda) * (1 - lambda) * V_pre_t) / (delta_t * delta_t + sigma * sigma + (1 - lambda) * (1 - lambda) * V_pre_t); for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { mu_t[i][k] = factor_1 * ((1 - lambda) * mu_pre_t[i][k] + lambda * ave_neighbors.get(i, k)) + factor_2 * h_hat_t[i][k]; } /* calculate V */ double V_t = factor_2 * delta_t * delta_t; /* update \mu and V */ mu_s.set(t, mu_t); v_s.set(t, V_t); /* calculate and update grad_mu */ if (update_grad) for (int s = 0; s < T - t0; s++) { double[][] grad_pre_t_s = grad_mu_s.get((t - 1) * (T - t0) + s); double[][] grad_t_s = new double[n][K]; for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { grad_t_s[i][k] = factor_1 * (1 - lambda) * grad_pre_t_s[i][k]; if (t == s) { grad_t_s[i][k] += factor_2; } } grad_mu_s.set(t * (T - t0) + s, grad_t_s); } } else { /* mu, V: random init (keep unchanged) */ /* grad_mu: set to 0 (keep unchanged) */ } // Scanner sc = new Scanner(System.in); // int gu; gu = sc.nextInt(); /* end for each t */ } }
public static void compute_gradient2(int iteration) { double[][][] tmp_grad_h_hat_prime_s = new double[T - t0][n][K]; /* * compute * nti[t][i] = \sum_{j} { n_{ij} } * and * nti_h[t][j][k] = \sum_{i} { n_{ij}^{t} h_{ik}^{t} } */ double[][] nti = new double[T - t0][n]; double[][][] nti_h = new double[T - t0][n][K]; for (int t = 0; t < T - t0; t++) { double[][] G_t = GS.get(t); double[][] h_t = h_s.get(t); // h^{t} for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) { nti[t][i] += G_t[i][j]; for (int k = 0; k < K; k++) { nti_h[t][j][k] += G_t[i][j] * h_t[i][k]; } } } for (int t = 0; t < T - t0; t++) { double delta_t = delta_prime_s.get(t); double[][] h_t = h_s.get(t); // h^{t} double[][] h_hat_prime_t = h_hat_prime_s.get(t); // \hat{h}^{t} double[][] mu_hat_t = mu_hat_s.get(t); // \hat{\mu}^{t} double[][] mu_hat_prime_t = mu_hat_prime_s.get(t); // \hat{\mu}'^{t} double[][] h_prime_t = h_prime_s.get(t); if (t != 0) { Matrix a = new Matrix(AS.get(t - 1)); Matrix hprime_pre_t = new Matrix(h_prime_s.get(t - 1)); Matrix ave_neighbors = a.times(hprime_pre_t); double[][] G_pre_t = GS.get(t - 1); // G^{t-1} double[][] A_pre_t = AS.get(t - 1); // A^{t-1} double[][] h_pre_t = h_s.get(t - 1); // h^{t-1} double[][] mu_hat_prime_pre_t = mu_hat_prime_s.get(t - 1); // \hat{\mu}'^{t-1} [t] for (int s = 0; s < T - t0; s++) { double[][] grad_mu_hat_prime_t = grad_mu_hat_prime_s.get(t * (T - t0) + s); double[][] grad_mu_hat_prime_pre_t = grad_mu_hat_prime_s.get((t - 1) * (T - t0) + s); double[] h2delta2 = new double[n]; for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { h2delta2[i] += 0.5 * h_t[i][k] * h_t[i][k] * delta_t * delta_t; } /* compute weighted_exp for later use */ double[][][] weighted_exp_num = new double[K][n][n]; double[][] weighted_exp_den = new double[K][n]; double[][][] weighted_exp = new double[K][n][n]; for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) { double h_muhp = Operations.inner_product(h_t[j], mu_hat_prime_t[i], K); for (int k = 0; k < K; k++) { weighted_exp_num[k][i][j] = h_t[j][k] * Math.exp(h_muhp + h2delta2[j]); weighted_exp_den[k][j] += Math.exp(h_muhp + h2delta2[j]); } } for (int i = 0; i < n; i++) for (int j = 0; j < n; j++) for (int k = 0; k < K; k++) { weighted_exp[k][i][j] = weighted_exp_num[k][i][j] / weighted_exp_den[k][j]; } /* compute sum_mu_hat_prime for later use */ double[] sum_mu_hat_prime = new double[K]; for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { sum_mu_hat_prime[k] += mu_hat_prime_pre_t[i][k]; } for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { /* first term */ double g1 = nti_h[t][i][k] * grad_mu_hat_prime_t[i][k]; tmp_grad_h_hat_prime_s[s][i][k] += g1; /* second term */ double g2 = 0; for (int j = 0; j < n; j++) { g2 -= nti[t][j] * weighted_exp[k][i][j] * grad_mu_hat_prime_t[i][k]; } tmp_grad_h_hat_prime_s[s][i][k] += g2; /* third term */ for (int j = 0; j < n; j++) if (G_pre_t[j][i] != 0) { // double g3 = ( h_t[j][k] - (1-lambda) * h_pre_t[j][k] - lambda * // A_pre_t[j][i] * sum_mu_hat_prime[k] ) double g3 = (h_t[j][k] - (1 - lambda) * h_pre_t[j][k] - lambda * A_pre_t[j][i] * mu_hat_prime_pre_t[i][k]) * lambda * A_pre_t[j][i] * grad_mu_hat_prime_pre_t[i][k] / (sigma * sigma); tmp_grad_h_hat_prime_s[s][j][k] += g3; // j instead of i! } } /* fourth term */ for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { double g4 = -(mu_hat_prime_t[i][k] - mu_hat_prime_pre_t[i][k]) * (grad_mu_hat_prime_t[i][k] - grad_mu_hat_prime_pre_t[i][k]) / (sigma * sigma); tmp_grad_h_hat_prime_s[s][i][k] += g4; } } } else { /* for (int s = 0; s < T-t0; s++) { double[] grad_mu_hat_prime_t = grad_mu_hat_prime_s.get(t * (T-t0) + s); for (int i = 0; i < n; i++) { // first term double g1 = nti_hp[t][i] * grad_mu_hat_prime_t[i]; tmp_grad_h_hat_prime_s[s][i] += g1; // second term double g2 = 0; for (int _j = 0; _j < NEG; _j++) { double weighted_exp_num = 0, weighted_exp_den = 0; int j = neg_samples.get(t)[i][_j]; double htj = h_t[j][0]; double muhti = mu_hat_t[i]; weighted_exp_num += htj * Math.exp(htj * muhti + 0.5 * htj * htj * delta_t * delta_t); for (int _k = 0; _k < NEG; _k++) { int k = neg_samples.get(t)[i][_k]; double muhtk = mu_hat_t[k]; weighted_exp_den += Math.exp(htj * muhtk + 0.5 * htj * htj * delta_t * delta_t); } g2 -= nti[t][j] * weighted_exp_num / weighted_exp_den * grad_mu_hat_prime_t[i]; } tmp_grad_h_hat_prime_s[s][i] += g2; } // fourth term (if any) if (s == t) for (int i = 0; i < n; i++) { double g4 = -h_hat_prime_t[i][0] / (sigma*sigma); tmp_grad_h_hat_prime_s[s][i] += g4; } } */ } } /* update global gradient */ for (int t = 0; t < T - t0; t++) { double[][] grad = new double[n][K]; for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { grad[i][k] = tmp_grad_h_hat_prime_s[t][i][k]; } grad_h_hat_prime_s.set(t, grad); } FileParser.output_2d(grad_h_hat_prime_s, "./grad/grad_prime_" + iteration + ".txt"); return; }
public static void compute_gradient1(int iteration) { double[][][] tmp_grad_h_hat_s = new double[T - t0][n][K]; for (int t = 0; t < T - t0; t++) { // System.out.println("compute gradient 1, t = " + t); double delta_t = delta_s.get(t); double[][] G_t = GS.get(t); double[][] h_prime_t = h_prime_s.get(t); double[][] mu_hat_t = mu_hat_s.get(t); if (t != 0) { double[][] mu_hat_pre_t = mu_hat_s.get(t - 1); Matrix a = new Matrix(AS.get(t - 1)); Matrix hprime_pre_t = new Matrix(h_prime_s.get(t - 1)); Matrix ave_neighbors = a.times(hprime_pre_t); /* TODO: check whether we can save computation by comparing s and t */ for (int s = 0; s < T - t0; s++) { double[][] grad_hat_t = grad_mu_hat_s.get(t * (T - t0) + s); double[][] grad_hat_pre_t = grad_mu_hat_s.get((t - 1) * (T - t0) + s); double[] hp2delta2 = new double[n]; for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { hp2delta2[i] += 0.5 * h_prime_t[i][k] * h_prime_t[i][k] * delta_t * delta_t; } for (int i = 0; i < n; i++) { /* first term */ double[] weighted_exp_num = new double[K]; double weighted_exp_den = 0; for (int l = 0; l < n; l++) { double hp_muh = Operations.inner_product(h_prime_t[l], mu_hat_t[i], K); double e = Math.exp(hp_muh + hp2delta2[l]); if (Double.isNaN(e)) { /* check if e explodes */ System.out.println("ERROR2"); Scanner sc = new Scanner(System.in); int gu; gu = sc.nextInt(); } for (int k = 0; k < K; k++) { weighted_exp_num[k] += h_prime_t[l][k] * e; weighted_exp_den += e; } } for (int j = 0; j < n; j++) for (int k = 0; k < K; k++) { double weighted_exp = weighted_exp_num[k] / weighted_exp_den; double gi1 = G_t[i][j] * grad_hat_t[i][k] * (h_prime_t[j][k] - weighted_exp); tmp_grad_h_hat_s[s][i][k] += gi1; } /* second term */ for (int k = 0; k < K; k++) { double gi2 = -(mu_hat_t[i][k] - (1 - lambda) * mu_hat_pre_t[i][k] - lambda * ave_neighbors.get(i, k)) * (grad_hat_t[i][k] - (1 - lambda) * grad_hat_pre_t[i][k]) / (sigma * sigma); tmp_grad_h_hat_s[s][i][k] += gi2; } } } } else { /* no such term (t=0) in ELBO */ /* for (int s = 0; s < T-t0; s++) { double[] grad_hat_t = grad_mu_hat_s.get(t * (T-t0) + s); for (int i = 0; i < n; i++) { double n_it = 0; for (int j = 0; j < n; j++) n_it += G_t[i][j]; // first term double gi1 = -mu_hat_t[i] * grad_hat_t[i] / (sigma * sigma); tmp_grad_h_hat_s[s][i] += gi1; // second term double gi2 = 0; double weighted_exp_num = 0, weighted_exp_den = 0; for (int j = 0; j < NEG; j++) { int l = neg_samples.get(t)[i][j]; double hpl = h_prime_t[l][0]; double muit = mu_hat_t[i]; double e = Math.exp(hpl * muit + 0.5 * hpl * hpl * delta_t * delta_t); // TODO: check if e explodes if (Double.isNaN(e)) { System.out.println("ERROR3"); Scanner sc = new Scanner(System.in); int gu; gu = sc.nextInt(); } weighted_exp_num += hpl * e; weighted_exp_den += e; } double weighted_exp = weighted_exp_num / weighted_exp_den; for (int j = 0; j < n; j++) { gi2 += G_t[i][j] * grad_hat_t[i] * (h_prime_t[j][0] - weighted_exp); } tmp_grad_h_hat_s[s][i] += gi2; } } */ } /* end if-else */ } /* update global gradient */ for (int t = 0; t < T - t0; t++) { double[][] grad = new double[n][K]; for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { grad[i][k] = tmp_grad_h_hat_s[t][i][k]; } grad_h_hat_s.set(t, grad); } FileParser.output_2d(grad_h_hat_s, "./grad/grad_" + iteration + ".txt"); return; }
/** compute_objective1: return the lower bound when h' is fixed */ public static double compute_objective1() { double res = 0; for (int t = 0; t < T - t0; t++) { if (t != 0) { double[][] G_t = GS.get(t); double[][] h_prime_t = h_prime_s.get(t); double[][] h_prime_pre_t = h_prime_s.get(t - 1); double[][] mu_hat_t = mu_hat_s.get(t); double[][] mu_hat_pre_t = mu_hat_s.get(t - 1); double delta_t = delta_s.get(t); Matrix a = new Matrix(AS.get(t - 1)); Matrix hprime_pre_t = new Matrix(h_prime_s.get(t - 1)); Matrix ave_neighbors = a.times(hprime_pre_t); double[] hp2delta2 = new double[n]; for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) { hp2delta2[i] += 0.5 * h_prime_t[i][k] * h_prime_t[i][k] * delta_t * delta_t; } for (int i = 0; i < n; i++) { /* first term */ List<Double> powers = new ArrayList<Double>(); for (int l = 0; l < n; l++) { double hp_muh = Operations.inner_product(h_prime_t[l], mu_hat_t[i], K); powers.add(hp_muh + hp2delta2[l]); } double lse = log_sum_exp(powers); for (int j = 0; j < n; j++) if (G_t[i][j] != 0) { double hp_muh = Operations.inner_product(h_prime_t[j], mu_hat_t[i], K); res += G_t[i][j] * (hp_muh - lse); } /* second term */ for (int k = 0; k < K; k++) { double diff = mu_hat_t[i][k] - (1 - lambda) * mu_hat_pre_t[i][k] - lambda * ave_neighbors.get(i, k); res -= 0.5 * diff * diff / (sigma * sigma); } } } else { /* double[][] G_t = GS.get(t); double[][] h_prime_t = h_prime_s.get(t); double[] mu_hat_t = mu_hat_s.get(t); double delta_t = delta_s.get(t); int[][] neg_sam_t = neg_samples.get(t); for (int i = 0; i < n; i++) { // first term for (int j = 0; j < n; j++) if (G_t[i][j] != 0) { List<Double> powers = new ArrayList<Double>(); for (int _l = 0; _l < NEG; _l++) { int l = neg_sam_t[i][_l]; powers.add(h_prime_t[l][0] * mu_hat_t[i] + 0.5 * h_prime_t[l][0] * h_prime_t[l][0] * delta_t * delta_t); } double lse = log_sum_exp(powers); res += G_t[i][j] * (h_prime_t[j][0] * mu_hat_t[i] - lse); } } */ } } return res; }