Exemple #1
0
 public void backwardIBD(NodeRef node) {
   int stateCount = substitutionModel.getStateCount();
   if (node == null) {
     node = treeModel.getRoot();
     int nodeId = node.getNumber();
     for (int state = 0; state < stateCount; ++state) {
       ibdBackward[nodeId][state] = 0;
     }
   }
   getDiagonalRates(diag);
   int childCount = treeModel.getChildCount(node);
   int nodeId = node.getNumber();
   for (int child = 0; child < childCount; ++child) {
     NodeRef childNode = treeModel.getChild(node, child);
     int childNodeId = childNode.getNumber();
     double branchTime =
         branchRateModel.getBranchRate(treeModel, childNode)
             * (treeModel.getNodeHeight(node) - treeModel.getNodeHeight(childNode));
     for (int state = 0; state < stateCount; ++state) {
       ibdBackward[childNodeId][state] = ibdBackward[nodeId][state];
       for (int sibling = 0; sibling < childCount; ++sibling) {
         if (sibling != child) {
           int siblingId = treeModel.getChild(node, sibling).getNumber();
           ibdBackward[childNodeId][state] += ibdForward[siblingId][state];
         }
       }
       ibdBackward[childNodeId][state] *= Math.exp(-diag[state] * branchTime);
     }
   }
   for (int child = 0; child < childCount; ++child) {
     NodeRef childNode = treeModel.getChild(node, child);
     backwardIBD(childNode);
   }
 }
 private void computeNodeToRestrictionMap() {
   Arrays.fill(partialsMap, null);
   for (Set<String> taxonNames : partialsRestrictions.keySet()) {
     NodeRef node = Tree.Utils.getCommonAncestorNode(treeModel, taxonNames);
     partialsMap[node.getNumber()] = partialsRestrictions.get(taxonNames);
   }
 }
 @Override
 public List<Integer> getAncestors(int i, boolean b) {
   List<Integer> ancestors = new ArrayList<Integer>();
   if (b) ancestors.add(i);
   for (NodeRef n = tree.getParent(tree.getNode(i)); n != null; n = tree.getParent(n))
     ancestors.add(n.getNumber());
   return ancestors;
 }
 @Override
 public List<Integer> getDescendantLeaves(int i, boolean b) {
   List<Integer> descendants = new ArrayList<Integer>();
   if (!b) descendants.add(i);
   for (NodeRef n : Tree.Utils.getExternalNodes(tree, tree.getNode(i)))
     descendants.add(n.getNumber());
   return descendants;
 }
 @Override
 public int getSibling(int i) {
   NodeRef n = tree.getNode(i);
   if (tree.isRoot(n)) return RootedTree.NULL;
   NodeRef p = tree.getParent(n);
   int c1 = tree.getChild(p, 0).getNumber();
   int c2 = tree.getChild(p, 1).getNumber();
   return n.getNumber() == c2 ? c1 : c2;
 }
 @Override
 public void replaceSlidableChildren(NodeRef node, NodeRef lft, NodeRef rgt) {
   int nn = node.getNumber();
   int lftn = lft.getNumber();
   int rgtn = rgt.getNumber();
   assert pionodes[nn].lft >= 0;
   pionodes[nn].lft = lftn;
   pionodes[nn].rgt = rgtn;
   pionodes[lftn].anc = pionodes[nn].nodeNumber;
   pionodes[rgtn].anc = pionodes[nn].nodeNumber;
 }
  public double[][] simulateLocations() {
    NodeRef root = m_tree.getRoot();

    // assume uniform
    double[][] latLongs = new double[m_tree.getNodeCount()][2];
    double rootLat = MathUtils.nextDouble() * (maxLat - minLat) + minLat;
    double rootLong = MathUtils.nextDouble() * (maxLong - minLong) + minLong;
    int rootNum = root.getNumber();
    latLongs[rootNum][LATITUDE_INDEX] = rootLat;
    latLongs[rootNum][LONGITUDE_INDEX] = rootLong;
    traverse(root, latLongs[rootNum], latLongs);

    return latLongs;
  }
  /** Convert an alignment to a pattern */
  public ArrayList simulateGeoAttr() {
    double[][] locations = simulateLocations();

    ArrayList<Parameter> locationList = new ArrayList<Parameter>();
    for (int i = 0; i < m_tree.getExternalNodeCount(); i++) {
      NodeRef node = m_tree.getNode(i);
      String taxaName = m_tree.getTaxon(node.getNumber()).getId();
      Parameter location = new Parameter.Default(locations[i]);
      System.out.println(
          "taxon: " + taxaName + ", lat: " + locations[i][0] + ", long: " + locations[i][1]);
      locationList.add(location);
    }

    return locationList;
  }
  private double traitCachedLogLikelihood(double[] parentTrait, NodeRef node) {

    double logL = 0.0;
    double[] childTrait = null;
    final int nodeNumber = node.getNumber();

    if (!treeModel.isRoot(node)) {

      if (!validLogLikelihoods[nodeNumber]) { // recompute

        childTrait = treeModel.getMultivariateNodeTrait(node, traitName);
        double time = getRescaledBranchLengthForPrecision(node);
        if (parentTrait == null)
          parentTrait = treeModel.getMultivariateNodeTrait(treeModel.getParent(node), traitName);
        logL = diffusionModel.getLogLikelihood(parentTrait, childTrait, time);
        cachedLogLikelihoods[nodeNumber] = logL;
        validLogLikelihoods[nodeNumber] = true;
      } else logL = cachedLogLikelihoods[nodeNumber];
    }

    int childCount = treeModel.getChildCount(node);
    for (int i = 0; i < childCount; i++) {
      logL += traitCachedLogLikelihood(childTrait, treeModel.getChild(node, i));
    }

    return logL;
  }
Exemple #10
0
 private double getIBDWeight(Tree tree, NodeRef node) {
   if (!weightsKnown) {
     expectedIBD();
     weightsKnown = true;
   }
   if (tree.isExternal(node)) {
     int nodeNum = node.getNumber();
     return ibdweights[nodeNum] + 1;
   }
   return 0;
 }
  public final double getLogDataLikelihood() {
    double logLikelihood = 0;
    for (int i = 0; i < treeModel.getExternalNodeCount(); i++) {
      NodeRef tip =
          treeModel.getExternalNode(i); // TODO Do not include integrated tips; how to check???

      if (cacheBranches && validLogLikelihoods[tip.getNumber()])
        logLikelihood += cachedLogLikelihoods[tip.getNumber()];
      else {
        NodeRef parent = treeModel.getParent(tip);

        double[] tipTrait = treeModel.getMultivariateNodeTrait(tip, traitName);
        double[] parentTrait = treeModel.getMultivariateNodeTrait(parent, traitName);
        double time = getRescaledBranchLengthForPrecision(tip);

        logLikelihood += diffusionModel.getLogLikelihood(parentTrait, tipTrait, time);
      }
    }
    return logLikelihood;
  }
Exemple #12
0
  private void writeNode(Tree tree, NodeRef node, boolean attributes, Map<String, Integer> idMap) {
    if (tree.isExternal(node)) {
      int k = node.getNumber() + 1;
      if (idMap != null) k = idMap.get(tree.getTaxonId(k - 1));

      out.print(k);
    } else {
      out.print("(");
      writeNode(tree, tree.getChild(node, 0), attributes, idMap);
      for (int i = 1; i < tree.getChildCount(node); i++) {
        out.print(",");
        writeNode(tree, tree.getChild(node, i), attributes, idMap);
      }
      out.print(")");
    }

    if (writeAttributesAs == AttributeType.BRANCH_ATTRIBUTES && !tree.isRoot(node)) {
      out.print(":");
    }

    if (attributes) {
      Iterator<?> iter = tree.getNodeAttributeNames(node);
      if (iter != null) {
        boolean first = true;
        while (iter.hasNext()) {
          if (first) {
            out.print("[&");
            first = false;
          } else {
            out.print(",");
          }
          String name = (String) iter.next();
          out.print(name + "=");
          Object value = tree.getNodeAttribute(node, name);
          printValue(value);
        }
        out.print("]");
      }
    }

    if (writeAttributesAs == AttributeType.NODE_ATTRIBUTES && !tree.isRoot(node)) {
      out.print(":");
    }

    if (!tree.isRoot(node)) {
      double length = tree.getBranchLength(node);
      if (formatter != null) {
        out.print(formatter.format(length));
      } else {
        out.print(length);
      }
    }
  }
  private void setTipDataValuesForNode(NodeRef node) {
    // Set tip data values
    int index = node.getNumber();
    double[] traitValue = traitParameter.getParameter(index).getParameterValues();
    if (traitValue.length < dim) {
      throw new RuntimeException(
          "The trait parameter for the tip with index, " + index + ", is too short");
    }

    cacheHelper.setTipMeans(traitValue, dim, index, node);
    //        System.arraycopy(traitValue, 0, meanCache
    ////                cacheHelper.getMeanCache()
    //                , dim * index, dim);
  }
  private boolean eligibleForMove(NodeRef node, TreeModel tree, BranchMapModel branchMap) {
    // to be eligible for this move, the node's parent and grandparent, or parent and other child,
    // must be in the same partition (so removing the parent has no effect on the remaining links of
    // the TT),
    // and the node and its parent must be in different partitions (such that the move does not
    // disconnect anything)

    return ((tree.getParent(tree.getParent(node)) != null
                && branchMap.get(tree.getParent(node).getNumber())
                    == branchMap.get(tree.getParent(tree.getParent(node)).getNumber()))
            || branchMap.get(tree.getParent(node).getNumber())
                == branchMap.get(getOtherChild(tree, tree.getParent(node), node).getNumber()))
        && branchMap.get(tree.getParent(node).getNumber()) != branchMap.get(node.getNumber());
  }
  void traverse(NodeRef node, double[] parentSequence, double[][] latLongs) {
    for (int iChild = 0; iChild < m_tree.getChildCount(node); iChild++) {
      NodeRef child = m_tree.getChild(node, iChild);

      // find the branch length
      final double branchRate = m_branchRateModel.getBranchRate(m_tree, child);
      final double branchLength =
          branchRate * (m_tree.getNodeHeight(node) - m_tree.getNodeHeight(child));
      if (branchLength < 0.0) {
        throw new RuntimeException("Negative branch length: " + branchLength);
      }

      double childLat =
          MathUtils.nextGaussian() * Math.sqrt(branchLength) + parentSequence[LATITUDE_INDEX];
      double childLong =
          MathUtils.nextGaussian() * Math.sqrt(branchLength) + parentSequence[LONGITUDE_INDEX];
      int childNum = child.getNumber();

      latLongs[childNum][LATITUDE_INDEX] = childLat;
      latLongs[childNum][LONGITUDE_INDEX] = childLong;

      traverse(m_tree.getChild(node, iChild), latLongs[childNum], latLongs);
    }
  }
  public double[] getTraitForNode(Tree tree, NodeRef node, String traitName) {

    //        if (tree != treeModel) {
    //            throw new RuntimeException("Can only reconstruct states on treeModel given to
    // constructor");
    //        }

    getLogLikelihood();

    if (!areStatesRedrawn) redrawAncestralStates();

    int index = node.getNumber();

    double[] trait = new double[dim];
    System.arraycopy(drawnStates, index * dim, trait, 0, dim);
    return trait;
  }
 private void getDescendants(NodeRef n, List<Integer> descendants) {
   descendants.add(n.getNumber());
   if (tree.isExternal(n)) return;
   for (int i = 0; i < tree.getChildCount(n); ++i)
     getDescendants(tree.getChild(n, i), descendants);
 }
  /**
   * Traverse the tree calculating partial likelihoods.
   *
   * @param tree tree
   * @param node node
   * @param operatorNumber operatorNumber
   * @param flip flip
   * @return boolean
   */
  private boolean traverse(Tree tree, NodeRef node, int[] operatorNumber, boolean flip) {

    boolean update = false;

    int nodeNum = node.getNumber();

    NodeRef parent = tree.getParent(node);

    if (operatorNumber != null) {
      operatorNumber[0] = -1;
    }

    // First update the transition probability matrix(ices) for this branch
    if (parent != null && updateNode[nodeNum]) {

      final double branchRate = branchRateModel.getBranchRate(tree, node);

      final double parentHeight = tree.getNodeHeight(parent);
      final double nodeHeight = tree.getNodeHeight(node);

      // Get the operational time of the branch
      final double branchLength = branchRate * (parentHeight - nodeHeight);
      if (branchLength < 0.0) {
        throw new RuntimeException("Negative branch length: " + branchLength);
      }

      if (flip) {
        substitutionModelDelegate.flipMatrixBuffer(nodeNum);
      }
      branchUpdateIndices[branchUpdateCount] = nodeNum;
      branchLengths[branchUpdateCount] = branchLength;
      branchUpdateCount++;

      update = true;
    }

    // If the node is internal, update the partial likelihoods.
    if (!tree.isExternal(node)) {

      // Traverse down the two child nodes
      NodeRef child1 = tree.getChild(node, 0);
      final int[] op1 = {-1};
      final boolean update1 = traverse(tree, child1, op1, flip);

      NodeRef child2 = tree.getChild(node, 1);
      final int[] op2 = {-1};
      final boolean update2 = traverse(tree, child2, op2, flip);

      // If either child node was updated then update this node too
      if (update1 || update2) {

        int x = operationCount[operationListCount] * Beagle.OPERATION_TUPLE_SIZE;

        if (flip) {
          // first flip the partialBufferHelper
          partialBufferHelper.flipOffset(nodeNum);
        }

        final int[] operations = this.operations[operationListCount];

        operations[x] = partialBufferHelper.getOffsetIndex(nodeNum);

        if (useScaleFactors) {
          // get the index of this scaling buffer
          int n = nodeNum - tipCount;

          if (recomputeScaleFactors) {
            // flip the indicator: can take either n or (internalNodeCount + 1) - n
            scaleBufferHelper.flipOffset(n);

            // store the index
            scaleBufferIndices[n] = scaleBufferHelper.getOffsetIndex(n);

            operations[x + 1] = scaleBufferIndices[n]; // Write new scaleFactor
            operations[x + 2] = Beagle.NONE;

          } else {
            operations[x + 1] = Beagle.NONE;
            operations[x + 2] = scaleBufferIndices[n]; // Read existing scaleFactor
          }

        } else {

          if (useAutoScaling) {
            scaleBufferIndices[nodeNum - tipCount] = partialBufferHelper.getOffsetIndex(nodeNum);
          }
          operations[x + 1] = Beagle.NONE; // Not using scaleFactors
          operations[x + 2] = Beagle.NONE;
        }

        operations[x + 3] = partialBufferHelper.getOffsetIndex(child1.getNumber()); // source node 1
        operations[x + 4] =
            substitutionModelDelegate.getMatrixIndex(child1.getNumber()); // source matrix 1
        operations[x + 5] = partialBufferHelper.getOffsetIndex(child2.getNumber()); // source node 2
        operations[x + 6] =
            substitutionModelDelegate.getMatrixIndex(child2.getNumber()); // source matrix 2

        operationCount[operationListCount]++;

        update = true;

        if (hasRestrictedPartials) {
          // Test if this set of partials should be restricted
          if (updateRestrictedNodePartials) {
            // Recompute map
            computeNodeToRestrictionMap();
            updateRestrictedNodePartials = false;
          }
          if (partialsMap[nodeNum] != null) {}
        }
      }
    }

    return update;
  }
  /**
   * Calculate the log likelihood of the current state.
   *
   * @return the log likelihood.
   */
  protected double calculateLogLikelihood() {

    if (patternLogLikelihoods == null) {
      patternLogLikelihoods = new double[patternCount];
    }

    if (branchUpdateIndices == null) {
      branchUpdateIndices = new int[nodeCount];
      branchLengths = new double[nodeCount];
      scaleBufferIndices = new int[internalNodeCount];
      storedScaleBufferIndices = new int[internalNodeCount];
    }

    if (operations == null) {
      operations =
          new int[numRestrictedPartials + 1][internalNodeCount * Beagle.OPERATION_TUPLE_SIZE];
      operationCount = new int[numRestrictedPartials + 1];
    }

    recomputeScaleFactors = false;

    if (this.rescalingScheme == PartialsRescalingScheme.ALWAYS) {
      useScaleFactors = true;
      recomputeScaleFactors = true;
    } else if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC && everUnderflowed) {
      useScaleFactors = true;
      if (rescalingCountInner < RESCALE_TIMES) {
        recomputeScaleFactors = true;
        makeDirty();
        //                System.err.println("Recomputing scale factors");
      }

      rescalingCountInner++;
      rescalingCount++;
      if (rescalingCount > rescalingFrequency) {
        rescalingCount = 0;
        rescalingCountInner = 0;
      }
    } else if (this.rescalingScheme == PartialsRescalingScheme.DELAYED && everUnderflowed) {
      useScaleFactors = true;
      recomputeScaleFactors = true;
      rescalingCount++;
    }

    if (tipStatesModel != null) {
      int tipCount = treeModel.getExternalNodeCount();
      for (int index = 0; index < tipCount; index++) {
        if (updateNode[index]) {
          if (tipStatesModel.getModelType() == TipStatesModel.Type.PARTIALS) {
            tipStatesModel.getTipPartials(index, tipPartials);
            beagle.setTipPartials(index, tipPartials);
          } else {
            tipStatesModel.getTipStates(index, tipStates);
            beagle.setTipStates(index, tipStates);
          }
        }
      }
    }

    branchUpdateCount = 0;
    operationListCount = 0;

    if (hasRestrictedPartials) {
      for (int i = 0; i <= numRestrictedPartials; i++) {
        operationCount[i] = 0;
      }
    } else {
      operationCount[0] = 0;
    }

    final NodeRef root = treeModel.getRoot();
    traverse(treeModel, root, null, true);

    if (updateSubstitutionModel) { // TODO More efficient to update only the substitution model that
                                   // changed, instead of all
      substitutionModelDelegate.updateSubstitutionModels(beagle);

      // we are currently assuming a no-category model...
    }

    if (updateSiteModel) {
      double[] categoryRates = this.siteModel.getCategoryRates();
      beagle.setCategoryRates(categoryRates);
    }

    if (branchUpdateCount > 0) {
      substitutionModelDelegate.updateTransitionMatrices(
          beagle, branchUpdateIndices, branchLengths, branchUpdateCount);
    }

    if (COUNT_TOTAL_OPERATIONS) {
      totalMatrixUpdateCount += branchUpdateCount;

      for (int i = 0; i <= numRestrictedPartials; i++) {
        totalOperationCount += operationCount[i];
      }
    }

    double logL;
    boolean done;
    boolean firstRescaleAttempt = true;

    do {

      if (hasRestrictedPartials) {
        for (int i = 0; i <= numRestrictedPartials; i++) {
          beagle.updatePartials(operations[i], operationCount[i], Beagle.NONE);
          if (i < numRestrictedPartials) {
            //                        restrictNodePartials(restrictedIndices[i]);
          }
        }
      } else {
        beagle.updatePartials(operations[0], operationCount[0], Beagle.NONE);
      }

      int rootIndex = partialBufferHelper.getOffsetIndex(root.getNumber());

      double[] categoryWeights = this.siteModel.getCategoryProportions();

      // This should probably explicitly be the state frequencies for the root node...
      double[] frequencies = substitutionModelDelegate.getRootStateFrequencies();

      int cumulateScaleBufferIndex = Beagle.NONE;
      if (useScaleFactors) {

        if (recomputeScaleFactors) {
          scaleBufferHelper.flipOffset(internalNodeCount);
          cumulateScaleBufferIndex = scaleBufferHelper.getOffsetIndex(internalNodeCount);
          beagle.resetScaleFactors(cumulateScaleBufferIndex);
          beagle.accumulateScaleFactors(
              scaleBufferIndices, internalNodeCount, cumulateScaleBufferIndex);
        } else {
          cumulateScaleBufferIndex = scaleBufferHelper.getOffsetIndex(internalNodeCount);
        }
      } else if (useAutoScaling) {
        beagle.accumulateScaleFactors(scaleBufferIndices, internalNodeCount, Beagle.NONE);
      }

      // these could be set only when they change but store/restore would need to be considered
      beagle.setCategoryWeights(0, categoryWeights);
      beagle.setStateFrequencies(0, frequencies);

      double[] sumLogLikelihoods = new double[1];

      beagle.calculateRootLogLikelihoods(
          new int[] {rootIndex},
          new int[] {0},
          new int[] {0},
          new int[] {cumulateScaleBufferIndex},
          1,
          sumLogLikelihoods);

      logL = sumLogLikelihoods[0];

      if (ascertainedSitePatterns) {
        // Need to correct for ascertainedSitePatterns
        beagle.getSiteLogLikelihoods(patternLogLikelihoods);
        logL =
            getAscertainmentCorrectedLogLikelihood(
                (AscertainedSitePatterns) patternList, patternLogLikelihoods, patternWeights);
      }

      if (Double.isNaN(logL) || Double.isInfinite(logL)) {
        everUnderflowed = true;
        logL = Double.NEGATIVE_INFINITY;

        if (firstRescaleAttempt
            && (rescalingScheme == PartialsRescalingScheme.DYNAMIC
                || rescalingScheme == PartialsRescalingScheme.DELAYED)) {
          // we have had a potential under/over flow so attempt a rescaling
          if (rescalingScheme == PartialsRescalingScheme.DYNAMIC || (rescalingCount == 0)) {
            Logger.getLogger("dr.evomodel")
                .info("Underflow calculating likelihood. Attempting a rescaling...");
          }
          useScaleFactors = true;
          recomputeScaleFactors = true;

          branchUpdateCount = 0;

          if (hasRestrictedPartials) {
            for (int i = 0; i <= numRestrictedPartials; i++) {
              operationCount[i] = 0;
            }
          } else {
            operationCount[0] = 0;
          }

          // traverse again but without flipping partials indices as we
          // just want to overwrite the last attempt. We will flip the
          // scale buffer indices though as we are recomputing them.
          traverse(treeModel, root, null, false);

          done = false; // Run through do-while loop again
          firstRescaleAttempt = false; // Only try to rescale once
        } else {
          // we have already tried a rescale, not rescaling or always rescaling
          // so just return the likelihood...
          done = true;
        }
      } else {
        done = true; // No under-/over-flow, then done
      }

    } while (!done);

    // If these are needed...
    // beagle.getSiteLogLikelihoods(patternLogLikelihoods);

    // ********************************************************************
    // after traverse all nodes and patterns have been updated --
    // so change flags to reflect this.
    for (int i = 0; i < nodeCount; i++) {
      updateNode[i] = false;
    }

    updateSubstitutionModel = false;
    updateSiteModel = false;
    // ********************************************************************

    return logL;
  }
 @Override
 public Taxon getSlidableNodeTaxon(NodeRef node) {
   assert node == pionodes[node.getNumber()];
   return ((PopsIONode) node).getTaxon();
 }
 @Override
 public NodeRef getSlidableChild(NodeRef node, int j) {
   int n = node.getNumber();
   return j == 0 ? pionodes[pionodes[n].lft] : pionodes[pionodes[n].rgt];
 }
 @Override
 public boolean isExternalSlidable(NodeRef node) {
   return (pionodes[node.getNumber()].lft < 0);
 }
 @Override
 public void setSlidableNodeHeight(NodeRef node, double height) {
   assert node == pionodes[node.getNumber()];
   ((PopsIONode) node).height = height;
 }
 @Override
 public double getSlidableNodeHeight(NodeRef node) {
   assert node == pionodes[node.getNumber()];
   return ((PopsIONode) node).getHeight();
 }
  public void proposeTree() throws OperatorFailedException {
    TreeModel tree = c2cLikelihood.getTreeModel();
    BranchMapModel branchMap = c2cLikelihood.getBranchMap();
    NodeRef i;
    double oldMinAge, newMinAge, newRange, oldRange, newAge, q;
    // choose a random node avoiding root, and nodes that are ineligible for this move because they
    // have nowhere to
    // go
    final int nodeCount = tree.getNodeCount();
    do {
      i = tree.getNode(MathUtils.nextInt(nodeCount));
    } while (tree.getRoot() == i || !eligibleForMove(i, tree, branchMap));
    final NodeRef iP = tree.getParent(i);

    // this one can go anywhere

    NodeRef j = tree.getNode(MathUtils.nextInt(tree.getNodeCount()));
    NodeRef k = tree.getParent(j);

    while ((k != null && tree.getNodeHeight(k) <= tree.getNodeHeight(i)) || (i == j)) {
      j = tree.getNode(MathUtils.nextInt(tree.getNodeCount()));
      k = tree.getParent(j);
    }

    if (iP == tree.getRoot() || j == tree.getRoot()) {
      throw new OperatorFailedException("Root changes not allowed!");
    }

    if (k == iP || j == iP || k == i) throw new OperatorFailedException("move failed");

    final NodeRef CiP = getOtherChild(tree, iP, i);
    NodeRef PiP = tree.getParent(iP);

    newMinAge = Math.max(tree.getNodeHeight(i), tree.getNodeHeight(j));
    newRange = tree.getNodeHeight(k) - newMinAge;
    newAge = newMinAge + (MathUtils.nextDouble() * newRange);
    oldMinAge = Math.max(tree.getNodeHeight(i), tree.getNodeHeight(CiP));
    oldRange = tree.getNodeHeight(PiP) - oldMinAge;
    q = newRange / Math.abs(oldRange);

    // need to account for the random repainting of iP

    if (branchMap.get(PiP.getNumber()) != branchMap.get(CiP.getNumber())) {
      q *= 0.5;
    }

    if (branchMap.get(k.getNumber()) != branchMap.get(j.getNumber())) {
      q *= 2;
    }

    tree.beginTreeEdit();

    if (j == tree.getRoot()) {

      // 1. remove edges <iP, CiP>
      tree.removeChild(iP, CiP);
      tree.removeChild(PiP, iP);

      // 2. add edges <k, iP>, <iP, j>, <PiP, CiP>
      tree.addChild(iP, j);
      tree.addChild(PiP, CiP);

      // iP is the new root
      tree.setRoot(iP);

    } else if (iP == tree.getRoot()) {

      // 1. remove edges <k, j>, <iP, CiP>, <PiP, iP>
      tree.removeChild(k, j);
      tree.removeChild(iP, CiP);

      // 2. add edges <k, iP>, <iP, j>, <PiP, CiP>
      tree.addChild(iP, j);
      tree.addChild(k, iP);

      // CiP is the new root
      tree.setRoot(CiP);

    } else {
      // 1. remove edges <k, j>, <iP, CiP>, <PiP, iP>
      tree.removeChild(k, j);
      tree.removeChild(iP, CiP);
      tree.removeChild(PiP, iP);

      // 2. add edges <k, iP>, <iP, j>, <PiP, CiP>
      tree.addChild(iP, j);
      tree.addChild(k, iP);
      tree.addChild(PiP, CiP);
    }

    tree.setNodeHeight(iP, newAge);

    tree.endTreeEdit();

    //
    logq = Math.log(q);

    // repaint the parent to match either its new parent or its new child (50% chance of each).

    if (MathUtils.nextInt(2) == 0) {
      branchMap.set(iP.getNumber(), branchMap.get(k.getNumber()), true);
    } else {
      branchMap.set(iP.getNumber(), branchMap.get(j.getNumber()), true);
    }

    if (DEBUG) {
      c2cLikelihood.checkPartitions();
    }
  }
 public double getBranchRate(final Tree tree, final NodeRef node) {
   assert !tree.isRoot(node) : "root node doesn't have a rate!";
   // int rateCategory = (int) Math.round(rateCategories.getNodeValue(tree, node));
   return rates[node.getNumber()] * scaleFactor;
 }
  void postOrderTraverse(
      MultivariateTraitTree treeModel,
      NodeRef node,
      double[][] precisionMatrix,
      double logDetPrecisionMatrix,
      boolean cacheOuterProducts) {

    final int thisNumber = node.getNumber();

    if (treeModel.isExternal(node)) {

      // Fill in precision scalar, traitValues already filled in

      if (missingTraits.isCompletelyMissing(thisNumber)) {
        upperPrecisionCache[thisNumber] = 0;
        lowerPrecisionCache[thisNumber] = 0; // Needed in the pre-order traversal
      } else { // not missing tip trait
        upperPrecisionCache[thisNumber] =
            (1.0 / getRescaledBranchLengthForPrecision(node))
                * Math.pow(cacheHelper.getOUFactor(node), 2);
        lowerPrecisionCache[thisNumber] = Double.POSITIVE_INFINITY;
      }
      return;
    }

    final NodeRef childNode0 = treeModel.getChild(node, 0);
    final NodeRef childNode1 = treeModel.getChild(node, 1);

    postOrderTraverse(
        treeModel, childNode0, precisionMatrix, logDetPrecisionMatrix, cacheOuterProducts);
    postOrderTraverse(
        treeModel, childNode1, precisionMatrix, logDetPrecisionMatrix, cacheOuterProducts);

    final int childNumber0 = childNode0.getNumber();
    final int childNumber1 = childNode1.getNumber();
    final int meanOffset0 = dim * childNumber0;
    final int meanOffset1 = dim * childNumber1;
    final int meanThisOffset = dim * thisNumber;

    final double precision0 = upperPrecisionCache[childNumber0];
    final double precision1 = upperPrecisionCache[childNumber1];
    final double totalPrecision = precision0 + precision1;

    lowerPrecisionCache[thisNumber] = totalPrecision;

    // Multiply child0 and child1 densities

    // Delegate this!
    cacheHelper.computeMeanCaches(
        meanThisOffset,
        meanOffset0,
        meanOffset1,
        totalPrecision,
        precision0,
        precision1,
        missingTraits,
        node,
        childNode0,
        childNode1);
    //        if (totalPrecision == 0) {
    //            System.arraycopy(zeroDimVector, 0, meanCache, meanThisOffset, dim);
    //        } else {
    //            // Delegate in case either child is partially missing
    //            // computeCorrectedWeightedAverage
    //            missingTraits.computeWeightedAverage(meanCache,
    //                    meanOffset0, precision0,
    //                    meanOffset1, precision1,
    //                    meanThisOffset, dim);
    //        }
    // In this delegation, you can call
    // getShiftForBranchLength(node);

    if (!treeModel.isRoot(node)) {
      // Integrate out trait value at this node
      double thisPrecision = 1.0 / getRescaledBranchLengthForPrecision(node);
      if (Double.isInfinite(thisPrecision)) {
        upperPrecisionCache[thisNumber] = totalPrecision;
      } else {
        upperPrecisionCache[thisNumber] =
            (totalPrecision * thisPrecision / (totalPrecision + thisPrecision))
                * Math.pow(cacheHelper.getOUFactor(node), 2);
      }
    }

    // Compute logRemainderDensity

    logRemainderDensityCache[thisNumber] = 0;

    if (precision0 != 0 && precision1 != 0) {

      incrementRemainderDensities(
          precisionMatrix,
          logDetPrecisionMatrix,
          thisNumber,
          meanThisOffset,
          meanOffset0,
          meanOffset1,
          precision0,
          precision1,
          cacheHelper.getOUFactor(childNode0),
          cacheHelper.getOUFactor(childNode1),
          cacheOuterProducts);
    }
  }
 @Override
 public int getParent(int i) {
   NodeRef n = tree.getParent(tree.getNode(i));
   return n != null ? n.getNumber() : RootedTree.NULL;
 }
  private void preOrderTraverseSample(
      MultivariateTraitTree treeModel,
      NodeRef node,
      int parentIndex,
      double[][] treePrecision,
      double[][] treeVariance) {

    final int thisIndex = node.getNumber();

    if (treeModel.isRoot(node)) {
      // draw root

      double[] rootMean = new double[dimTrait];
      final int rootIndex = treeModel.getRoot().getNumber();
      double rootPrecision = lowerPrecisionCache[rootIndex];

      for (int datum = 0; datum < numData; datum++) {
        // System.arraycopy(meanCache, thisIndex * dim + datum * dimTrait, rootMean, 0, dimTrait);
        System.arraycopy(
            cacheHelper.getMeanCache(), thisIndex * dim + datum * dimTrait, rootMean, 0, dimTrait);

        double[][] variance =
            computeMarginalRootMeanAndVariance(
                rootMean, treePrecision, treeVariance, rootPrecision);

        double[] draw =
            MultivariateNormalDistribution.nextMultivariateNormalVariance(rootMean, variance);

        if (DEBUG_PREORDER) {
          Arrays.fill(draw, 1.0);
        }

        System.arraycopy(draw, 0, drawnStates, rootIndex * dim + datum * dimTrait, dimTrait);

        if (DEBUG) {
          System.err.println("Root mean: " + new Vector(rootMean));
          System.err.println("Root var : " + new Matrix(variance));
          System.err.println("Root draw: " + new Vector(draw));
        }
      }
    } else { // draw conditional on parentState

      if (!missingTraits.isCompletelyMissing(thisIndex)
          && !missingTraits.isPartiallyMissing(thisIndex)) {

        // System.arraycopy(meanCache, thisIndex * dim, drawnStates, thisIndex * dim, dim);
        System.arraycopy(
            cacheHelper.getMeanCache(), thisIndex * dim, drawnStates, thisIndex * dim, dim);

      } else {

        if (missingTraits.isPartiallyMissing(thisIndex)) {
          throw new RuntimeException("Partially missing values are not yet implemented");
        }
        // This code should work for sampling a missing tip trait as well, but needs testing

        // parent trait at drawnStates[parentOffset]
        double precisionToParent = 1.0 / getRescaledBranchLengthForPrecision(node);
        double precisionOfNode = lowerPrecisionCache[thisIndex];
        double totalPrecision = precisionOfNode + precisionToParent;

        double[] mean = Ay; // temporary storage
        double[][] var = tmpM; // temporary storage

        for (int datum = 0; datum < numData; datum++) {

          int parentOffset = parentIndex * dim + datum * dimTrait;
          int thisOffset = thisIndex * dim + datum * dimTrait;

          if (DEBUG) {
            double[] parentValue = new double[dimTrait];
            System.arraycopy(drawnStates, parentOffset, parentValue, 0, dimTrait);
            System.err.println("Parent draw: " + new Vector(parentValue));
            if (parentValue[0] != drawnStates[parentOffset]) {
              throw new RuntimeException("Error in setting indices");
            }
          }

          for (int i = 0; i < dimTrait; i++) {
            mean[i] =
                (drawnStates[parentOffset + i] * precisionToParent
                        //  + meanCache[thisOffset + i] * precisionOfNode) / totalPrecision;
                        + cacheHelper.getMeanCache()[thisOffset + i] * precisionOfNode)
                    / totalPrecision;
            for (int j = 0; j < dimTrait; j++) {
              var[i][j] = treeVariance[i][j] / totalPrecision;
            }
          }
          double[] draw = MultivariateNormalDistribution.nextMultivariateNormalVariance(mean, var);
          System.arraycopy(draw, 0, drawnStates, thisOffset, dimTrait);

          if (DEBUG) {
            System.err.println("Int prec: " + totalPrecision);
            System.err.println("Int mean: " + new Vector(mean));
            System.err.println("Int var : " + new Matrix(var));
            System.err.println("Int draw: " + new Vector(draw));
            System.err.println("");
          }
        }
      }
    }

    if (peel() && !treeModel.isExternal(node)) {
      preOrderTraverseSample(
          treeModel, treeModel.getChild(node, 0), thisIndex, treePrecision, treeVariance);
      preOrderTraverseSample(
          treeModel, treeModel.getChild(node, 1), thisIndex, treePrecision, treeVariance);
    }
  }
 @Override
 public void replaceSlidableRoot(NodeRef newroot) {
   rootn = newroot.getNumber();
   pionodes[rootn].anc = -1;
 }