private void incrementOuterProducts(
      int thisOffset, int childOffset0, int childOffset1, double precision0, double precision1) {

    final double[][] outerProduct = wishartStatistics.getScaleMatrix();

    for (int k = 0; k < numData; k++) {

      for (int i = 0; i < dimTrait; i++) {

        // final double wChild0i = meanCache[childOffset0 + k * dimTrait + i] * precision0;
        // final double wChild1i = meanCache[childOffset1 + k * dimTrait + i] * precision1;
        final double wChild0i =
            cacheHelper.getCorrectedMeanCache()[childOffset0 + k * dimTrait + i] * precision0;
        final double wChild1i =
            cacheHelper.getCorrectedMeanCache()[childOffset1 + k * dimTrait + i] * precision1;

        for (int j = 0; j < dimTrait; j++) {

          // final double child0j = meanCache[childOffset0 + k * dimTrait + j];
          // final double child1j = meanCache[childOffset1 + k * dimTrait + j];
          final double child0j =
              cacheHelper.getCorrectedMeanCache()[childOffset0 + k * dimTrait + j];
          final double child1j =
              cacheHelper.getCorrectedMeanCache()[childOffset1 + k * dimTrait + j];

          outerProduct[i][j] += wChild0i * child0j;
          outerProduct[i][j] += wChild1i * child1j;

          // outerProduct[i][j] -= (wChild0i + wChild1i) * meanCache[thisOffset + k * dimTrait + j];
          outerProduct[i][j] -=
              (wChild0i + wChild1i) * cacheHelper.getMeanCache()[thisOffset + k * dimTrait + j];
        }
      }
    }
    wishartStatistics.incrementDf(1); // Peeled one node
  }
  private void preOrderTraverseSample(
      MultivariateTraitTree treeModel,
      NodeRef node,
      int parentIndex,
      double[][] treePrecision,
      double[][] treeVariance) {

    final int thisIndex = node.getNumber();

    if (treeModel.isRoot(node)) {
      // draw root

      double[] rootMean = new double[dimTrait];
      final int rootIndex = treeModel.getRoot().getNumber();
      double rootPrecision = lowerPrecisionCache[rootIndex];

      for (int datum = 0; datum < numData; datum++) {
        // System.arraycopy(meanCache, thisIndex * dim + datum * dimTrait, rootMean, 0, dimTrait);
        System.arraycopy(
            cacheHelper.getMeanCache(), thisIndex * dim + datum * dimTrait, rootMean, 0, dimTrait);

        double[][] variance =
            computeMarginalRootMeanAndVariance(
                rootMean, treePrecision, treeVariance, rootPrecision);

        double[] draw =
            MultivariateNormalDistribution.nextMultivariateNormalVariance(rootMean, variance);

        if (DEBUG_PREORDER) {
          Arrays.fill(draw, 1.0);
        }

        System.arraycopy(draw, 0, drawnStates, rootIndex * dim + datum * dimTrait, dimTrait);

        if (DEBUG) {
          System.err.println("Root mean: " + new Vector(rootMean));
          System.err.println("Root var : " + new Matrix(variance));
          System.err.println("Root draw: " + new Vector(draw));
        }
      }
    } else { // draw conditional on parentState

      if (!missingTraits.isCompletelyMissing(thisIndex)
          && !missingTraits.isPartiallyMissing(thisIndex)) {

        // System.arraycopy(meanCache, thisIndex * dim, drawnStates, thisIndex * dim, dim);
        System.arraycopy(
            cacheHelper.getMeanCache(), thisIndex * dim, drawnStates, thisIndex * dim, dim);

      } else {

        if (missingTraits.isPartiallyMissing(thisIndex)) {
          throw new RuntimeException("Partially missing values are not yet implemented");
        }
        // This code should work for sampling a missing tip trait as well, but needs testing

        // parent trait at drawnStates[parentOffset]
        double precisionToParent = 1.0 / getRescaledBranchLengthForPrecision(node);
        double precisionOfNode = lowerPrecisionCache[thisIndex];
        double totalPrecision = precisionOfNode + precisionToParent;

        double[] mean = Ay; // temporary storage
        double[][] var = tmpM; // temporary storage

        for (int datum = 0; datum < numData; datum++) {

          int parentOffset = parentIndex * dim + datum * dimTrait;
          int thisOffset = thisIndex * dim + datum * dimTrait;

          if (DEBUG) {
            double[] parentValue = new double[dimTrait];
            System.arraycopy(drawnStates, parentOffset, parentValue, 0, dimTrait);
            System.err.println("Parent draw: " + new Vector(parentValue));
            if (parentValue[0] != drawnStates[parentOffset]) {
              throw new RuntimeException("Error in setting indices");
            }
          }

          for (int i = 0; i < dimTrait; i++) {
            mean[i] =
                (drawnStates[parentOffset + i] * precisionToParent
                        //  + meanCache[thisOffset + i] * precisionOfNode) / totalPrecision;
                        + cacheHelper.getMeanCache()[thisOffset + i] * precisionOfNode)
                    / totalPrecision;
            for (int j = 0; j < dimTrait; j++) {
              var[i][j] = treeVariance[i][j] / totalPrecision;
            }
          }
          double[] draw = MultivariateNormalDistribution.nextMultivariateNormalVariance(mean, var);
          System.arraycopy(draw, 0, drawnStates, thisOffset, dimTrait);

          if (DEBUG) {
            System.err.println("Int prec: " + totalPrecision);
            System.err.println("Int mean: " + new Vector(mean));
            System.err.println("Int var : " + new Matrix(var));
            System.err.println("Int draw: " + new Vector(draw));
            System.err.println("");
          }
        }
      }
    }

    if (peel() && !treeModel.isExternal(node)) {
      preOrderTraverseSample(
          treeModel, treeModel.getChild(node, 0), thisIndex, treePrecision, treeVariance);
      preOrderTraverseSample(
          treeModel, treeModel.getChild(node, 1), thisIndex, treePrecision, treeVariance);
    }
  }
  public double calculateLogLikelihood() {

    double logLikelihood = 0;
    double[][] traitPrecision = diffusionModel.getPrecisionmatrix();
    double logDetTraitPrecision = Math.log(diffusionModel.getDeterminantPrecisionMatrix());
    double[] conditionalRootMean = tmp2;

    final boolean computeWishartStatistics = getComputeWishartSufficientStatistics();

    if (computeWishartStatistics) {
      //            if (wishartStatistics == null) {
      wishartStatistics = new WishartSufficientStatistics(dimTrait);
      //            } else {
      //                wishartStatistics.clear();
      //            }
    }

    // Use dynamic programming to compute conditional likelihoods at each internal node
    postOrderTraverse(
        treeModel,
        treeModel.getRoot(),
        traitPrecision,
        logDetTraitPrecision,
        computeWishartStatistics);

    if (DEBUG) {
      System.err.println("mean: " + new Vector(cacheHelper.getMeanCache()));
      System.err.println("correctedMean: " + new Vector(cacheHelper.getCorrectedMeanCache()));
      System.err.println("upre: " + new Vector(upperPrecisionCache));
      System.err.println("lpre: " + new Vector(lowerPrecisionCache));
      System.err.println("cach: " + new Vector(logRemainderDensityCache));
    }

    // Compute the contribution of each datum at the root
    final int rootIndex = treeModel.getRoot().getNumber();

    // Precision scalar of datum conditional on root
    double conditionalRootPrecision = lowerPrecisionCache[rootIndex];

    for (int datum = 0; datum < numData; datum++) {

      double thisLogLikelihood = 0;

      // Get conditional mean of datum conditional on root
      // System.arraycopy(meanCache, rootIndex * dim + datum * dimTrait, conditionalRootMean, 0,
      // dimTrait);
      System.arraycopy(
          cacheHelper.getMeanCache(),
          rootIndex * dim + datum * dimTrait,
          conditionalRootMean,
          0,
          dimTrait);

      if (DEBUG) {
        System.err.println("Datum #" + datum);
        System.err.println("root mean: " + new Vector(conditionalRootMean));
        System.err.println("root prec: " + conditionalRootPrecision);
        System.err.println("diffusion prec: " + new Matrix(traitPrecision));
      }

      // B = root prior precision
      // z = root prior mean
      // A = likelihood precision
      // y = likelihood mean

      // y'Ay
      double yAy =
          computeWeightedAverageAndSumOfSquares(
              conditionalRootMean,
              Ay,
              traitPrecision,
              dimTrait,
              conditionalRootPrecision); // Also fills in Ay

      if (conditionalRootPrecision != 0) {
        thisLogLikelihood +=
            -LOG_SQRT_2_PI * dimTrait
                + 0.5
                    * (logDetTraitPrecision + dimTrait * Math.log(conditionalRootPrecision) - yAy);
      }

      if (DEBUG) {
        double[][] T = new double[dimTrait][dimTrait];
        for (int i = 0; i < dimTrait; i++) {
          for (int j = 0; j < dimTrait; j++) {
            T[i][j] = traitPrecision[i][j] * conditionalRootPrecision;
          }
        }
        System.err.println("Conditional root MVN precision = \n" + new Matrix(T));
        System.err.println(
            "Conditional root MVN density = "
                + MultivariateNormalDistribution.logPdf(
                    conditionalRootMean,
                    new double[dimTrait],
                    T,
                    Math.log(MultivariateNormalDistribution.calculatePrecisionMatrixDeterminate(T)),
                    1.0));
      }

      if (integrateRoot) {
        // Integrate root trait out against rootPrior
        thisLogLikelihood +=
            integrateLogLikelihoodAtRoot(
                conditionalRootMean,
                Ay,
                tmpM,
                traitPrecision,
                conditionalRootPrecision); // Ay is destroyed
      }

      if (DEBUG) {
        System.err.println("yAy = " + yAy);
        System.err.println(
            "logLikelihood (before remainders) = "
                + thisLogLikelihood
                + " (should match conditional root MVN density when root not integrated out)");
      }

      logLikelihood += thisLogLikelihood;
    }

    logLikelihood += sumLogRemainders();
    if (DEBUG) {
      System.out.println("logLikelihood is " + logLikelihood);
    }

    if (DEBUG) { // Root trait is univariate!!!
      System.err.println("logLikelihood (final) = " + logLikelihood);
      //            checkViaLargeMatrixInversion();
    }

    if (DEBUG_PNAS) {
      checkLogLikelihood(
          logLikelihood,
          sumLogRemainders(),
          conditionalRootMean,
          conditionalRootPrecision,
          traitPrecision);
    }

    areStatesRedrawn = false; // Should redraw internal node states when needed
    return logLikelihood;
  }
  private void incrementRemainderDensities(
      double[][] precisionMatrix,
      double logDetPrecisionMatrix,
      int thisIndex,
      int thisOffset,
      int childOffset0,
      int childOffset1,
      double precision0,
      double precision1,
      double OUFactor0,
      double OUFactor1,
      boolean cacheOuterProducts) {

    final double remainderPrecision = precision0 * precision1 / (precision0 + precision1);

    if (cacheOuterProducts) {
      incrementOuterProducts(thisOffset, childOffset0, childOffset1, precision0, precision1);
    }

    for (int k = 0; k < numData; k++) {

      double childSS0 = 0;
      double childSS1 = 0;
      double crossSS = 0;

      for (int i = 0; i < dimTrait; i++) {

        // In case of no drift, getCorrectedMeanCache() simply returns mean cache
        // final double wChild0i = meanCache[childOffset0 + k * dimTrait + i] * precision0;
        final double wChild0i =
            cacheHelper.getCorrectedMeanCache()[childOffset0 + k * dimTrait + i] * precision0;
        // final double wChild1i = meanCache[childOffset1 + k * dimTrait + i] * precision1;
        final double wChild1i =
            cacheHelper.getCorrectedMeanCache()[childOffset1 + k * dimTrait + i] * precision1;

        for (int j = 0; j < dimTrait; j++) {

          // subtract "correction"
          // final double child0j = meanCache[childOffset0 + k * dimTrait + j];
          final double child0j =
              cacheHelper.getCorrectedMeanCache()[childOffset0 + k * dimTrait + j];
          // subtract "correction"
          // final double child1j = meanCache[childOffset1 + k * dimTrait + j];
          final double child1j =
              cacheHelper.getCorrectedMeanCache()[childOffset1 + k * dimTrait + j];

          childSS0 += wChild0i * precisionMatrix[i][j] * child0j;
          childSS1 += wChild1i * precisionMatrix[i][j] * child1j;

          // make sure meanCache in following is not "corrected"
          // crossSS += (wChild0i + wChild1i) * precisionMatrix[i][j] * meanCache[thisOffset + k *
          // dimTrait + j];
          crossSS +=
              (wChild0i + wChild1i)
                  * precisionMatrix[i][j]
                  * cacheHelper.getMeanCache()[thisOffset + k * dimTrait + j];
        }
      }

      logRemainderDensityCache[thisIndex] +=
          -dimTrait * LOG_SQRT_2_PI
              + 0.5 * (dimTrait * Math.log(remainderPrecision) + logDetPrecisionMatrix)
              - 0.5 * (childSS0 + childSS1 - crossSS)
              - dimTrait * (Math.log(OUFactor0) + Math.log(OUFactor1));
    }
  }
 public double[] getTipDataValues(int index) {
   double[] traitValue = new double[dim];
   System.arraycopy(cacheHelper.getMeanCache(), dim * index, traitValue, 0, dim);
   return traitValue;
 }