@Override
  public int computeResiduals(AssociatedPair p, double[] residuals, int index) {

    GeometryMath_F64.mult(H, p.p1, temp);

    double top1 = error1(p.p1.x, p.p1.y, p.p2.x, p.p2.y);
    double top2 = error2(p.p1.x, p.p1.y, p.p2.x, p.p2.y);

    computeJacobian(p.p1, p.p2);
    // JJ = J*J'
    CommonOps.multTransB(J, J, JJ);

    // solve JJ'*x = -e
    e.data[0] = -top1;
    e.data[1] = -top2;

    if (solver.setA(JJ)) {
      solver.solve(e, x);
      // -J'(J*J')^-1*e
      CommonOps.multTransA(J, x, error);
      residuals[index++] = error.data[0];
      residuals[index++] = error.data[1];
      residuals[index++] = error.data[2];
      residuals[index++] = error.data[3];
    } else {
      residuals[index++] = 0;
      residuals[index++] = 0;
      residuals[index++] = 0;
      residuals[index++] = 0;
    }

    return index;
  }
Exemplo n.º 2
0
Arquivo: CPM.java Projeto: fpl/s1tbx
  private void estimateCPM() {

    logger.info("Start EJML Estimation");

    numIterations = 0;
    boolean estimationDone = false;

    DenseMatrix64F eL_hat = null;
    DenseMatrix64F eP_hat = null;
    DenseMatrix64F rhsL = null;
    DenseMatrix64F rhsP = null;

    // normalize master coordinates for stability -- only master!
    TDoubleArrayList yMasterNorm = new TDoubleArrayList();
    TDoubleArrayList xMasterNorm = new TDoubleArrayList();
    for (int i = 0; i < yMaster.size(); i++) {
      yMasterNorm.add(PolyUtils.normalize2(yMaster.getQuick(i), normWin.linelo, normWin.linehi));
      xMasterNorm.add(PolyUtils.normalize2(xMaster.getQuick(i), normWin.pixlo, normWin.pixhi));
    }

    // helper variables
    int winL;
    int winP;
    int maxWSum_idx = 0;

    while (!estimationDone) {

      String codeBlockMessage = "LS ESTIMATION PROCEDURE";
      StopWatch stopWatch = new StopWatch();
      StopWatch clock = new StopWatch();
      clock.start();
      stopWatch.setTag(codeBlockMessage);
      stopWatch.start();

      logger.info("Start iteration: {}" + numIterations);

      /** Remove identified outlier from previous estimation */
      if (numIterations != 0) {
        logger.info(
            "Removing observation {}, idxList {},  from observation vector."
                + index.getQuick(maxWSum_idx)
                + maxWSum_idx);
        index.removeAt(maxWSum_idx);
        yMasterNorm.removeAt(maxWSum_idx);
        xMasterNorm.removeAt(maxWSum_idx);
        yOffset.removeAt(maxWSum_idx);
        xOffset.removeAt(maxWSum_idx);

        // only for outlier removal
        yMaster.removeAt(maxWSum_idx);
        xMaster.removeAt(maxWSum_idx);
        ySlave.removeAt(maxWSum_idx);
        xSlave.removeAt(maxWSum_idx);
        coherence.removeAt(maxWSum_idx);

        // also take care of slave pins
        slaveGCPList.remove(maxWSum_idx);

        //                if (demRefinement) {
        //                    ySlaveGeometry.removeAt(maxWSum_idx);
        //                    xSlaveGeometry.removeAt(maxWSum_idx);
        //                }

      }

      /** Check redundancy */
      numObservations = index.size(); // Number of points > threshold
      if (numObservations < numUnknowns) {
        logger.severe(
            "coregpm: Number of windows > threshold is smaller than parameters solved for.");
        throw new ArithmeticException(
            "coregpm: Number of windows > threshold is smaller than parameters solved for.");
      }

      // work with normalized values
      DenseMatrix64F A =
          new DenseMatrix64F(
              SystemOfEquations.constructDesignMatrix_loop(
                  yMasterNorm.toArray(), xMasterNorm.toArray(), cpmDegree));

      logger.info("TIME FOR SETUP of SYSTEM : {}" + stopWatch.lap("setup"));

      RowD1Matrix64F Qy_1; // vector
      double meanValue;
      switch (cpmWeight) {
        case "linear":
          logger.info("Using sqrt(coherence) as weights");
          Qy_1 = DenseMatrix64F.wrap(numObservations, 1, coherence.toArray());
          // Normalize weights to avoid influence on estimated var.factor
          logger.info("Normalizing covariance matrix for LS estimation");
          meanValue = CommonOps.elementSum(Qy_1) / numObservations;
          CommonOps.divide(meanValue, Qy_1); // normalize vector
          break;
        case "quadratic":
          logger.info("Using coherence as weights.");
          Qy_1 = DenseMatrix64F.wrap(numObservations, 1, coherence.toArray());
          CommonOps.elementMult(Qy_1, Qy_1);
          // Normalize weights to avoid influence on estimated var.factor
          meanValue = CommonOps.elementSum(Qy_1) / numObservations;
          logger.info("Normalizing covariance matrix for LS estimation.");
          CommonOps.divide(meanValue, Qy_1); // normalize vector
          break;
        case "bamler":
          // TODO: see Bamler papers IGARSS 2000 and 2004
          logger.warning("Bamler weighting method NOT IMPLEMENTED, falling back to None.");
          Qy_1 = onesEJML(numObservations);
          break;
        case "none":
          logger.info("No weighting.");
          Qy_1 = onesEJML(numObservations);
          break;
        default:
          Qy_1 = onesEJML(numObservations);
          break;
      }

      logger.info("TIME FOR SETUP of VC diag matrix: {}" + stopWatch.lap("diag VC matrix"));

      /** tempMatrix_1 matrices */
      final DenseMatrix64F yL_matrix = DenseMatrix64F.wrap(numObservations, 1, yOffset.toArray());
      final DenseMatrix64F yP_matrix = DenseMatrix64F.wrap(numObservations, 1, xOffset.toArray());
      logger.info("TIME FOR SETUP of TEMP MATRICES: {}" + stopWatch.lap("Temp matrices"));

      /** normal matrix */
      final DenseMatrix64F N =
          new DenseMatrix64F(numUnknowns, numUnknowns); // = A_transpose.mmul(Qy_1_diag.mmul(A));

      /*
                  // fork/join parallel implementation
                  RowD1Matrix64F result = A.copy();
                  DiagXMat dd = new DiagXMat(Qy_1, A, 0, A.numRows, result);
                  ForkJoinPool pool = new ForkJoinPool();
                  pool.invoke(dd);
                  CommonOps.multAddTransA(A, dd.result, N);
      */

      CommonOps.multAddTransA(A, diagxmat(Qy_1, A), N);
      DenseMatrix64F Qx_hat = N.copy();

      logger.info("TIME FOR SETUP of NORMAL MATRIX: {}" + stopWatch.lap("Normal matrix"));

      /** right hand sides */
      // azimuth
      rhsL = new DenseMatrix64F(numUnknowns, 1); // A_transpose.mmul(Qy_1_diag.mmul(yL_matrix));
      CommonOps.multAddTransA(1d, A, diagxmat(Qy_1, yL_matrix), rhsL);
      // range
      rhsP = new DenseMatrix64F(numUnknowns, 1); // A_transpose.mmul(Qy_1_diag.mmul(yP_matrix));
      CommonOps.multAddTransA(1d, A, diagxmat(Qy_1, yP_matrix), rhsP);
      logger.info("TIME FOR SETUP of RightHand Side: {}" + stopWatch.lap("Right-hand-side"));

      LinearSolver<DenseMatrix64F> solver = LinearSolverFactory.leastSquares(100, 100);
      /** compute solution */
      if (!solver.setA(Qx_hat)) {
        throw new IllegalArgumentException("Singular Matrix");
      }
      solver.solve(rhsL, rhsL);
      solver.solve(rhsP, rhsP);
      logger.info("TIME FOR SOLVING of System: {}" + stopWatch.lap("Solving System"));

      /** inverting of Qx_hat for stability check */
      solver.invert(Qx_hat);

      logger.info("TIME FOR INVERSION OF N: {}" + stopWatch.lap("Inversion of N"));

      /** test inversion and check stability: max(abs([N*inv(N) - E)) ?= 0 */
      DenseMatrix64F tempMatrix_1 = new DenseMatrix64F(N.numRows, N.numCols);
      CommonOps.mult(N, Qx_hat, tempMatrix_1);
      CommonOps.subEquals(
          tempMatrix_1, CommonOps.identity(tempMatrix_1.numRows, tempMatrix_1.numCols));
      double maxDeviation = CommonOps.elementMaxAbs(tempMatrix_1);
      if (maxDeviation > .01) {
        logger.severe(
            "COREGPM: maximum deviation N*inv(N) from unity = {}. This is larger than 0.01"
                + maxDeviation);
        throw new IllegalStateException("COREGPM: maximum deviation N*inv(N) from unity)");
      } else if (maxDeviation > .001) {
        logger.warning(
            "COREGPM: maximum deviation N*inv(N) from unity = {}. This is between 0.01 and 0.001"
                + maxDeviation);
      }
      logger.info("TIME FOR STABILITY CHECK: {}" + stopWatch.lap("Stability Check"));

      logger.info("Coeffs in Azimuth direction: {}" + rhsL.toString());
      logger.info("Coeffs in Range direction: {}" + rhsP.toString());
      logger.info("Max Deviation: {}" + maxDeviation);
      logger.info("System Quality: {}" + solver.quality());

      /** some other stuff if the scale is okay */
      DenseMatrix64F Qe_hat = new DenseMatrix64F(numObservations, numObservations);
      DenseMatrix64F tempMatrix_2 = new DenseMatrix64F(numObservations, numUnknowns);

      CommonOps.mult(A, Qx_hat, tempMatrix_2);
      CommonOps.multTransB(-1, tempMatrix_2, A, Qe_hat);
      scaleInputDiag(Qe_hat, Qy_1);

      // solution: Azimuth
      DenseMatrix64F yL_hat = new DenseMatrix64F(numObservations, 1);
      eL_hat = new DenseMatrix64F(numObservations, 1);
      CommonOps.mult(A, rhsL, yL_hat);
      CommonOps.sub(yL_matrix, yL_hat, eL_hat);

      // solution: Range
      DenseMatrix64F yP_hat = new DenseMatrix64F(numObservations, 1);
      eP_hat = new DenseMatrix64F(numObservations, 1);
      CommonOps.mult(A, rhsP, yP_hat);
      CommonOps.sub(yP_matrix, yP_hat, eP_hat);

      logger.info("TIME FOR DATA preparation for TESTING: {}" + stopWatch.lap("Testing Setup"));

      /** overal model test (variance factor) */
      double overAllModelTest_L = 0;
      double overAllModelTest_P = 0;

      for (int i = 0; i < numObservations; i++) {
        overAllModelTest_L += FastMath.pow(eL_hat.get(i), 2) * Qy_1.get(i);
        overAllModelTest_P += FastMath.pow(eP_hat.get(i), 2) * Qy_1.get(i);
      }

      overAllModelTest_L =
          (overAllModelTest_L / FastMath.pow(SIGMA_L, 2)) / (numObservations - numUnknowns);
      overAllModelTest_P =
          (overAllModelTest_P / FastMath.pow(SIGMA_P, 2)) / (numObservations - numUnknowns);

      logger.info("Overall Model Test Lines: {}" + overAllModelTest_L);
      logger.info("Overall Model Test Pixels: {}" + overAllModelTest_P);

      logger.info("TIME FOR OMT: {}" + stopWatch.lap("OMT"));

      /** ---------------------- DATASNOPING ----------------------------------- * */
      /** Assumed Qy diag */

      /** initialize */
      DenseMatrix64F wTest_L = new DenseMatrix64F(numObservations, 1);
      DenseMatrix64F wTest_P = new DenseMatrix64F(numObservations, 1);

      for (int i = 0; i < numObservations; i++) {
        wTest_L.set(i, eL_hat.get(i) / (Math.sqrt(Qe_hat.get(i, i)) * SIGMA_L));
        wTest_P.set(i, eP_hat.get(i) / (Math.sqrt(Qe_hat.get(i, i)) * SIGMA_P));
      }

      /** find maxima's */
      // azimuth
      winL = absArgmax(wTest_L);
      double maxWinL = Math.abs(wTest_L.get(winL));
      logger.info(
          "maximum wtest statistic azimuth = {} for window number: {} "
              + maxWinL
              + index.getQuick(winL));

      // range
      winP = absArgmax(wTest_P);
      double maxWinP = Math.abs(wTest_P.get(winP));
      logger.info(
          "maximum wtest statistic range = {} for window number: {} "
              + maxWinP
              + index.getQuick(winP));

      /** use summed wTest in Azimuth and Range direction for outlier detection */
      DenseMatrix64F wTestSum = new DenseMatrix64F(numObservations);
      for (int i = 0; i < numObservations; i++) {
        wTestSum.set(i, FastMath.pow(wTest_L.get(i), 2) + FastMath.pow(wTest_P.get(i), 2));
      }

      maxWSum_idx = absArgmax(wTest_P);
      double maxWSum = wTest_P.get(winP);
      logger.info(
          "Detected outlier: summed sqr.wtest = {}; observation: {}"
              + maxWSum
              + index.getQuick(maxWSum_idx));

      /** Test if we are estimationDone yet */
      // check on number of observations
      if (numObservations <= numUnknowns) {
        logger.warning("NO redundancy!  Exiting iterations.");
        estimationDone = true; // cannot remove more than this
      }

      // check on test k_alpha
      if (Math.max(maxWinL, maxWinP) <= criticalValue) {
        // all tests accepted?
        logger.info("All outlier tests accepted! (final solution computed)");
        estimationDone = true;
      }

      if (numIterations >= maxIterations) {
        logger.info("max. number of iterations reached (exiting loop).");
        estimationDone = true; // we reached max. (or no max_iter specified)
      }

      /** Only warn if last iteration has been estimationDone */
      if (estimationDone) {
        if (overAllModelTest_L > 10) {
          logger.warning(
              "COREGPM: Overall Model Test, Lines = {} is larger than 10. (Suggest model or a priori sigma not correct.)"
                  + overAllModelTest_L);
        }
        if (overAllModelTest_P > 10) {
          logger.warning(
              "COREGPM: Overall Model Test, Pixels = {} is larger than 10. (Suggest model or a priori sigma not correct.)"
                  + overAllModelTest_P);
        }

        /** if a priori sigma is correct, max wtest should be something like 1.96 */
        if (Math.max(maxWinL, maxWinP) > 200.0) {
          logger.warning(
              "Recommendation: remove window number: {} and re-run step COREGPM.  max. wtest is: {}."
                  + index.get(winL)
                  + Math.max(maxWinL, maxWinP));
        }
      }

      logger.info("TIME FOR wTestStatistics: {}" + stopWatch.lap("WTEST"));
      logger.info("Total Estimation TIME: {}" + clock.getElapsedTime());

      numIterations++; // update counter here!
    } // only warn when iterating

    yError = eL_hat.getData();
    xError = eP_hat.getData();

    yCoef = rhsL.getData();
    xCoef = rhsP.getData();
  }