private void preOrderTraverseSample(
      MultivariateTraitTree treeModel,
      NodeRef node,
      int parentIndex,
      double[][] treePrecision,
      double[][] treeVariance) {

    final int thisIndex = node.getNumber();

    if (treeModel.isRoot(node)) {
      // draw root

      double[] rootMean = new double[dimTrait];
      final int rootIndex = treeModel.getRoot().getNumber();
      double rootPrecision = lowerPrecisionCache[rootIndex];

      for (int datum = 0; datum < numData; datum++) {
        // System.arraycopy(meanCache, thisIndex * dim + datum * dimTrait, rootMean, 0, dimTrait);
        System.arraycopy(
            cacheHelper.getMeanCache(), thisIndex * dim + datum * dimTrait, rootMean, 0, dimTrait);

        double[][] variance =
            computeMarginalRootMeanAndVariance(
                rootMean, treePrecision, treeVariance, rootPrecision);

        double[] draw =
            MultivariateNormalDistribution.nextMultivariateNormalVariance(rootMean, variance);

        if (DEBUG_PREORDER) {
          Arrays.fill(draw, 1.0);
        }

        System.arraycopy(draw, 0, drawnStates, rootIndex * dim + datum * dimTrait, dimTrait);

        if (DEBUG) {
          System.err.println("Root mean: " + new Vector(rootMean));
          System.err.println("Root var : " + new Matrix(variance));
          System.err.println("Root draw: " + new Vector(draw));
        }
      }
    } else { // draw conditional on parentState

      if (!missingTraits.isCompletelyMissing(thisIndex)
          && !missingTraits.isPartiallyMissing(thisIndex)) {

        // System.arraycopy(meanCache, thisIndex * dim, drawnStates, thisIndex * dim, dim);
        System.arraycopy(
            cacheHelper.getMeanCache(), thisIndex * dim, drawnStates, thisIndex * dim, dim);

      } else {

        if (missingTraits.isPartiallyMissing(thisIndex)) {
          throw new RuntimeException("Partially missing values are not yet implemented");
        }
        // This code should work for sampling a missing tip trait as well, but needs testing

        // parent trait at drawnStates[parentOffset]
        double precisionToParent = 1.0 / getRescaledBranchLengthForPrecision(node);
        double precisionOfNode = lowerPrecisionCache[thisIndex];
        double totalPrecision = precisionOfNode + precisionToParent;

        double[] mean = Ay; // temporary storage
        double[][] var = tmpM; // temporary storage

        for (int datum = 0; datum < numData; datum++) {

          int parentOffset = parentIndex * dim + datum * dimTrait;
          int thisOffset = thisIndex * dim + datum * dimTrait;

          if (DEBUG) {
            double[] parentValue = new double[dimTrait];
            System.arraycopy(drawnStates, parentOffset, parentValue, 0, dimTrait);
            System.err.println("Parent draw: " + new Vector(parentValue));
            if (parentValue[0] != drawnStates[parentOffset]) {
              throw new RuntimeException("Error in setting indices");
            }
          }

          for (int i = 0; i < dimTrait; i++) {
            mean[i] =
                (drawnStates[parentOffset + i] * precisionToParent
                        //  + meanCache[thisOffset + i] * precisionOfNode) / totalPrecision;
                        + cacheHelper.getMeanCache()[thisOffset + i] * precisionOfNode)
                    / totalPrecision;
            for (int j = 0; j < dimTrait; j++) {
              var[i][j] = treeVariance[i][j] / totalPrecision;
            }
          }
          double[] draw = MultivariateNormalDistribution.nextMultivariateNormalVariance(mean, var);
          System.arraycopy(draw, 0, drawnStates, thisOffset, dimTrait);

          if (DEBUG) {
            System.err.println("Int prec: " + totalPrecision);
            System.err.println("Int mean: " + new Vector(mean));
            System.err.println("Int var : " + new Matrix(var));
            System.err.println("Int draw: " + new Vector(draw));
            System.err.println("");
          }
        }
      }
    }

    if (peel() && !treeModel.isExternal(node)) {
      preOrderTraverseSample(
          treeModel, treeModel.getChild(node, 0), thisIndex, treePrecision, treeVariance);
      preOrderTraverseSample(
          treeModel, treeModel.getChild(node, 1), thisIndex, treePrecision, treeVariance);
    }
  }