/**
   * Important method used in compute CCA
   *
   * @param X
   * @return
   */
  public static FlexCompRowMatrix computeSparseInverseSqRoot(FlexCompRowMatrix X) {

    FlexCompRowMatrix diagInvEntries = new FlexCompRowMatrix(X.numRows(), X.numColumns());

    System.out.println("++Beginning Sparse Inverse Sq. Root++");

    for (MatrixEntry e : X) {
      if (e.row() == e.column() && e.get() != 0) {
        diagInvEntries.set(e.row(), e.column(), 1 / Math.sqrt(e.get()));
      }
      if (e.row() == e.column() && e.get() == 0) {
        diagInvEntries.set(e.row(), e.column(), 10000); // Some large
        // value

      }
    }

    System.out.println("++Finished Sparse Inverse Sq. Root++");

    return diagInvEntries;
  }
  /**
   * Computing CCA
   *
   * @param xty - x = \Psi and y= \Phi
   * @param ytx
   * @param yty
   * @param xtx
   * @param svdTC
   * @param _cpcaR2
   * @param twoStageFlag
   * @return
   */
  private static void computeCCA2(
      FlexCompRowMatrix xty,
      FlexCompRowMatrix ytx,
      FlexCompRowMatrix yty,
      FlexCompRowMatrix xtx,
      SVDTemplates1 svdTC,
      ContextPCARepresentation _cpcaR2,
      int twoStageFlag,
      int hiddenStates,
      String directoryName) {

    System.out.println("+++Entering CCA Compute Function+++");
    DenseDoubleMatrix2D phiLCOLT, phiRCOLT;

    // remember x is Psi, i.e. the outside feature matrix and hence the
    // dimensionality here is dprime \times k
    System.out.println("***Creating the dense matrix, Memory Consuming Step****");

    /* Total memory currently in use by the JVM */
    System.out.println(
        "Total memory (bytes) currently used: " + Runtime.getRuntime().totalMemory());

    phiLCOLT = new DenseDoubleMatrix2D(xtx.numRows(), hiddenStates);
    /*
     * The below matrix dimensionality is d \times k
     */
    phiRCOLT = new DenseDoubleMatrix2D(yty.numRows(), hiddenStates);

    System.out.println("****Memory Consuming Step Done, Loaded two huge matrices in Memory****");

    /* Total memory currently in use by the JVM */
    System.out.println(
        "Total memory (bytes) used currently by JVM: " + Runtime.getRuntime().totalMemory());
    /*
     * dprime \times d
     */
    FlexCompRowMatrix auxMat1 = new FlexCompRowMatrix(xtx.numRows(), xty.numColumns());
    /*
     * d \times dprime
     */
    FlexCompRowMatrix auxMat2 = new FlexCompRowMatrix(yty.numRows(), ytx.numColumns());
    /*
     * dprime \times d
     */
    FlexCompRowMatrix auxMat3 = new FlexCompRowMatrix(auxMat1.numRows(), auxMat1.numColumns());
    /*
     * d \times dprime
     */
    FlexCompRowMatrix auxMat4 = new FlexCompRowMatrix(auxMat2.numRows(), auxMat2.numColumns());

    // d in our case, the dimensionality of the inside feature matrix
    int dim1 = ytx.numRows();
    // dprime in our case, the dimensionality of the outside feature matrix
    int dim2 = xty.numRows();

    System.out.println("+++Initialized auxiliary matrices+++");

    /*
     * Calculating C_{xx}^{-1|2} C_{xy}
     */
    auxMat1 =
        MatrixFormatConversion.multLargeSparseMatricesJEIGEN(computeSparseInverseSqRoot(xtx), xty);

    /*
     * Multiplying auxMat1 with C_{yy}^{-1|2}
     */
    auxMat3 =
        MatrixFormatConversion.multLargeSparseMatricesJEIGEN(
            auxMat1, computeSparseInverseSqRoot(yty));

    System.out.println("+++Computed 1 inverse+++");

    // (svdTC.computeSparseInverse(yty)).zMult(ytx, auxMat2);

    /*
     * C_{yy}^{-1|2}.C_{yx}
     */
    auxMat2 =
        MatrixFormatConversion.multLargeSparseMatricesJEIGEN(
            (svdTC.computeSparseInverseSqRoot(yty)), ytx);

    /*
     * Multiplying auxMat2 with C_{xx}^{-1|2}
     */
    auxMat4 =
        MatrixFormatConversion.multLargeSparseMatricesJEIGEN(
            auxMat2, svdTC.computeSparseInverseSqRoot(xtx));

    System.out.println("+++Computed Inverses+++");

    // auxMat1.zMult(auxMat2,auxMat3);

    System.out.println("+++Entering SVD computation+++");

    /*
     * Unnormalized Z projection matrix i.e. the Outside Projection Matrix,
     * but Unnormalized
     */
    phiLCSU =
        svdTC.computeSVD_Tropp(
            MatrixFormatConversion.createSparseMatrixCOLT(auxMat3),
            getOmegaMatrix(auxMat3.numColumns(), hiddenStates),
            dim1);

    s = svdTC.getSingularVals();

    /*
     * Write singular values to a file, just to see what's going on in here
     */
    VSMUtil.writeSingularValuesSem(s, "NNS");

    // phiL=phiLCSU;

    MatrixFormatConversion.createSparseMatrixCOLT((svdTC.computeSparseInverseSqRoot(xtx)))
        .zMult(MatrixFormatConversion.createDenseMatrixCOLT(phiLCSU), phiLCOLT);

    /*
     * This is the actual Outside projection TODO, check whether this is
     * actually the Outside projection. We get this by performing SVD on
     * C_{xx}^{-1|2}.C_{XY}.C{YY}^{-1|2}, where x is the outside feature
     * matrix (\Psi) and y is the inside feature matrix (\Phi) dprime \times
     * k
     */

    /* Total memory currently in use by the JVM */
    System.out.println(
        "Total memory (bytes) currently used: " + Runtime.getRuntime().totalMemory());

    phiL = MatrixFormatConversion.createDenseMatrixJAMA(phiLCOLT);

    /*
     * Unormalized Y projection matrix
     */
    phiRCSU =
        svdTC.computeSVD_Tropp(
            MatrixFormatConversion.createSparseMatrixCOLT(auxMat4),
            getOmegaMatrix(auxMat4.numColumns(), hiddenStates),
            dim2);

    MatrixFormatConversion.createSparseMatrixCOLT((svdTC.computeSparseInverseSqRoot(yty)))
        .zMult(MatrixFormatConversion.createDenseMatrixCOLT(phiRCSU), phiRCOLT);

    /*
     * THe inside projection matrix for the node
     */
    // 700000 \times 200

    /* Total memory currently in use by the JVM */
    System.out.println(
        "Total memory (bytes) currently used: " + Runtime.getRuntime().totalMemory());

    phiR = MatrixFormatConversion.createDenseMatrixJAMA(phiRCOLT);

    /*
     * Serialize PhiR and PhiL
     */
    System.out.println("***Serializing***");
    serializeCCAVariantsRun(directoryName);

    System.out.println("Freeing up the memory");
    phiLCOLT = null;
    phiRCOLT = null;
    phiL = null;
    phiLCSU = null;
    phiRCSU = null;
    phiR = null;

    /* Total memory currently in use by the JVM */
    System.out.println(
        "Total memory (bytes) currently used: " + Runtime.getRuntime().totalMemory());
  }