public void backwardIBD(NodeRef node) { int stateCount = substitutionModel.getStateCount(); if (node == null) { node = treeModel.getRoot(); int nodeId = node.getNumber(); for (int state = 0; state < stateCount; ++state) { ibdBackward[nodeId][state] = 0; } } getDiagonalRates(diag); int childCount = treeModel.getChildCount(node); int nodeId = node.getNumber(); for (int child = 0; child < childCount; ++child) { NodeRef childNode = treeModel.getChild(node, child); int childNodeId = childNode.getNumber(); double branchTime = branchRateModel.getBranchRate(treeModel, childNode) * (treeModel.getNodeHeight(node) - treeModel.getNodeHeight(childNode)); for (int state = 0; state < stateCount; ++state) { ibdBackward[childNodeId][state] = ibdBackward[nodeId][state]; for (int sibling = 0; sibling < childCount; ++sibling) { if (sibling != child) { int siblingId = treeModel.getChild(node, sibling).getNumber(); ibdBackward[childNodeId][state] += ibdForward[siblingId][state]; } } ibdBackward[childNodeId][state] *= Math.exp(-diag[state] * branchTime); } } for (int child = 0; child < childCount; ++child) { NodeRef childNode = treeModel.getChild(node, child); backwardIBD(childNode); } }
private void computeNodeToRestrictionMap() { Arrays.fill(partialsMap, null); for (Set<String> taxonNames : partialsRestrictions.keySet()) { NodeRef node = Tree.Utils.getCommonAncestorNode(treeModel, taxonNames); partialsMap[node.getNumber()] = partialsRestrictions.get(taxonNames); } }
@Override public List<Integer> getAncestors(int i, boolean b) { List<Integer> ancestors = new ArrayList<Integer>(); if (b) ancestors.add(i); for (NodeRef n = tree.getParent(tree.getNode(i)); n != null; n = tree.getParent(n)) ancestors.add(n.getNumber()); return ancestors; }
@Override public List<Integer> getDescendantLeaves(int i, boolean b) { List<Integer> descendants = new ArrayList<Integer>(); if (!b) descendants.add(i); for (NodeRef n : Tree.Utils.getExternalNodes(tree, tree.getNode(i))) descendants.add(n.getNumber()); return descendants; }
@Override public int getSibling(int i) { NodeRef n = tree.getNode(i); if (tree.isRoot(n)) return RootedTree.NULL; NodeRef p = tree.getParent(n); int c1 = tree.getChild(p, 0).getNumber(); int c2 = tree.getChild(p, 1).getNumber(); return n.getNumber() == c2 ? c1 : c2; }
@Override public void replaceSlidableChildren(NodeRef node, NodeRef lft, NodeRef rgt) { int nn = node.getNumber(); int lftn = lft.getNumber(); int rgtn = rgt.getNumber(); assert pionodes[nn].lft >= 0; pionodes[nn].lft = lftn; pionodes[nn].rgt = rgtn; pionodes[lftn].anc = pionodes[nn].nodeNumber; pionodes[rgtn].anc = pionodes[nn].nodeNumber; }
public double[][] simulateLocations() { NodeRef root = m_tree.getRoot(); // assume uniform double[][] latLongs = new double[m_tree.getNodeCount()][2]; double rootLat = MathUtils.nextDouble() * (maxLat - minLat) + minLat; double rootLong = MathUtils.nextDouble() * (maxLong - minLong) + minLong; int rootNum = root.getNumber(); latLongs[rootNum][LATITUDE_INDEX] = rootLat; latLongs[rootNum][LONGITUDE_INDEX] = rootLong; traverse(root, latLongs[rootNum], latLongs); return latLongs; }
/** Convert an alignment to a pattern */ public ArrayList simulateGeoAttr() { double[][] locations = simulateLocations(); ArrayList<Parameter> locationList = new ArrayList<Parameter>(); for (int i = 0; i < m_tree.getExternalNodeCount(); i++) { NodeRef node = m_tree.getNode(i); String taxaName = m_tree.getTaxon(node.getNumber()).getId(); Parameter location = new Parameter.Default(locations[i]); System.out.println( "taxon: " + taxaName + ", lat: " + locations[i][0] + ", long: " + locations[i][1]); locationList.add(location); } return locationList; }
private double traitCachedLogLikelihood(double[] parentTrait, NodeRef node) { double logL = 0.0; double[] childTrait = null; final int nodeNumber = node.getNumber(); if (!treeModel.isRoot(node)) { if (!validLogLikelihoods[nodeNumber]) { // recompute childTrait = treeModel.getMultivariateNodeTrait(node, traitName); double time = getRescaledBranchLengthForPrecision(node); if (parentTrait == null) parentTrait = treeModel.getMultivariateNodeTrait(treeModel.getParent(node), traitName); logL = diffusionModel.getLogLikelihood(parentTrait, childTrait, time); cachedLogLikelihoods[nodeNumber] = logL; validLogLikelihoods[nodeNumber] = true; } else logL = cachedLogLikelihoods[nodeNumber]; } int childCount = treeModel.getChildCount(node); for (int i = 0; i < childCount; i++) { logL += traitCachedLogLikelihood(childTrait, treeModel.getChild(node, i)); } return logL; }
private double getIBDWeight(Tree tree, NodeRef node) { if (!weightsKnown) { expectedIBD(); weightsKnown = true; } if (tree.isExternal(node)) { int nodeNum = node.getNumber(); return ibdweights[nodeNum] + 1; } return 0; }
public final double getLogDataLikelihood() { double logLikelihood = 0; for (int i = 0; i < treeModel.getExternalNodeCount(); i++) { NodeRef tip = treeModel.getExternalNode(i); // TODO Do not include integrated tips; how to check??? if (cacheBranches && validLogLikelihoods[tip.getNumber()]) logLikelihood += cachedLogLikelihoods[tip.getNumber()]; else { NodeRef parent = treeModel.getParent(tip); double[] tipTrait = treeModel.getMultivariateNodeTrait(tip, traitName); double[] parentTrait = treeModel.getMultivariateNodeTrait(parent, traitName); double time = getRescaledBranchLengthForPrecision(tip); logLikelihood += diffusionModel.getLogLikelihood(parentTrait, tipTrait, time); } } return logLikelihood; }
private void writeNode(Tree tree, NodeRef node, boolean attributes, Map<String, Integer> idMap) { if (tree.isExternal(node)) { int k = node.getNumber() + 1; if (idMap != null) k = idMap.get(tree.getTaxonId(k - 1)); out.print(k); } else { out.print("("); writeNode(tree, tree.getChild(node, 0), attributes, idMap); for (int i = 1; i < tree.getChildCount(node); i++) { out.print(","); writeNode(tree, tree.getChild(node, i), attributes, idMap); } out.print(")"); } if (writeAttributesAs == AttributeType.BRANCH_ATTRIBUTES && !tree.isRoot(node)) { out.print(":"); } if (attributes) { Iterator<?> iter = tree.getNodeAttributeNames(node); if (iter != null) { boolean first = true; while (iter.hasNext()) { if (first) { out.print("[&"); first = false; } else { out.print(","); } String name = (String) iter.next(); out.print(name + "="); Object value = tree.getNodeAttribute(node, name); printValue(value); } out.print("]"); } } if (writeAttributesAs == AttributeType.NODE_ATTRIBUTES && !tree.isRoot(node)) { out.print(":"); } if (!tree.isRoot(node)) { double length = tree.getBranchLength(node); if (formatter != null) { out.print(formatter.format(length)); } else { out.print(length); } } }
private void setTipDataValuesForNode(NodeRef node) { // Set tip data values int index = node.getNumber(); double[] traitValue = traitParameter.getParameter(index).getParameterValues(); if (traitValue.length < dim) { throw new RuntimeException( "The trait parameter for the tip with index, " + index + ", is too short"); } cacheHelper.setTipMeans(traitValue, dim, index, node); // System.arraycopy(traitValue, 0, meanCache //// cacheHelper.getMeanCache() // , dim * index, dim); }
private boolean eligibleForMove(NodeRef node, TreeModel tree, BranchMapModel branchMap) { // to be eligible for this move, the node's parent and grandparent, or parent and other child, // must be in the same partition (so removing the parent has no effect on the remaining links of // the TT), // and the node and its parent must be in different partitions (such that the move does not // disconnect anything) return ((tree.getParent(tree.getParent(node)) != null && branchMap.get(tree.getParent(node).getNumber()) == branchMap.get(tree.getParent(tree.getParent(node)).getNumber())) || branchMap.get(tree.getParent(node).getNumber()) == branchMap.get(getOtherChild(tree, tree.getParent(node), node).getNumber())) && branchMap.get(tree.getParent(node).getNumber()) != branchMap.get(node.getNumber()); }
void traverse(NodeRef node, double[] parentSequence, double[][] latLongs) { for (int iChild = 0; iChild < m_tree.getChildCount(node); iChild++) { NodeRef child = m_tree.getChild(node, iChild); // find the branch length final double branchRate = m_branchRateModel.getBranchRate(m_tree, child); final double branchLength = branchRate * (m_tree.getNodeHeight(node) - m_tree.getNodeHeight(child)); if (branchLength < 0.0) { throw new RuntimeException("Negative branch length: " + branchLength); } double childLat = MathUtils.nextGaussian() * Math.sqrt(branchLength) + parentSequence[LATITUDE_INDEX]; double childLong = MathUtils.nextGaussian() * Math.sqrt(branchLength) + parentSequence[LONGITUDE_INDEX]; int childNum = child.getNumber(); latLongs[childNum][LATITUDE_INDEX] = childLat; latLongs[childNum][LONGITUDE_INDEX] = childLong; traverse(m_tree.getChild(node, iChild), latLongs[childNum], latLongs); } }
public double[] getTraitForNode(Tree tree, NodeRef node, String traitName) { // if (tree != treeModel) { // throw new RuntimeException("Can only reconstruct states on treeModel given to // constructor"); // } getLogLikelihood(); if (!areStatesRedrawn) redrawAncestralStates(); int index = node.getNumber(); double[] trait = new double[dim]; System.arraycopy(drawnStates, index * dim, trait, 0, dim); return trait; }
private void getDescendants(NodeRef n, List<Integer> descendants) { descendants.add(n.getNumber()); if (tree.isExternal(n)) return; for (int i = 0; i < tree.getChildCount(n); ++i) getDescendants(tree.getChild(n, i), descendants); }
/** * Traverse the tree calculating partial likelihoods. * * @param tree tree * @param node node * @param operatorNumber operatorNumber * @param flip flip * @return boolean */ private boolean traverse(Tree tree, NodeRef node, int[] operatorNumber, boolean flip) { boolean update = false; int nodeNum = node.getNumber(); NodeRef parent = tree.getParent(node); if (operatorNumber != null) { operatorNumber[0] = -1; } // First update the transition probability matrix(ices) for this branch if (parent != null && updateNode[nodeNum]) { final double branchRate = branchRateModel.getBranchRate(tree, node); final double parentHeight = tree.getNodeHeight(parent); final double nodeHeight = tree.getNodeHeight(node); // Get the operational time of the branch final double branchLength = branchRate * (parentHeight - nodeHeight); if (branchLength < 0.0) { throw new RuntimeException("Negative branch length: " + branchLength); } if (flip) { substitutionModelDelegate.flipMatrixBuffer(nodeNum); } branchUpdateIndices[branchUpdateCount] = nodeNum; branchLengths[branchUpdateCount] = branchLength; branchUpdateCount++; update = true; } // If the node is internal, update the partial likelihoods. if (!tree.isExternal(node)) { // Traverse down the two child nodes NodeRef child1 = tree.getChild(node, 0); final int[] op1 = {-1}; final boolean update1 = traverse(tree, child1, op1, flip); NodeRef child2 = tree.getChild(node, 1); final int[] op2 = {-1}; final boolean update2 = traverse(tree, child2, op2, flip); // If either child node was updated then update this node too if (update1 || update2) { int x = operationCount[operationListCount] * Beagle.OPERATION_TUPLE_SIZE; if (flip) { // first flip the partialBufferHelper partialBufferHelper.flipOffset(nodeNum); } final int[] operations = this.operations[operationListCount]; operations[x] = partialBufferHelper.getOffsetIndex(nodeNum); if (useScaleFactors) { // get the index of this scaling buffer int n = nodeNum - tipCount; if (recomputeScaleFactors) { // flip the indicator: can take either n or (internalNodeCount + 1) - n scaleBufferHelper.flipOffset(n); // store the index scaleBufferIndices[n] = scaleBufferHelper.getOffsetIndex(n); operations[x + 1] = scaleBufferIndices[n]; // Write new scaleFactor operations[x + 2] = Beagle.NONE; } else { operations[x + 1] = Beagle.NONE; operations[x + 2] = scaleBufferIndices[n]; // Read existing scaleFactor } } else { if (useAutoScaling) { scaleBufferIndices[nodeNum - tipCount] = partialBufferHelper.getOffsetIndex(nodeNum); } operations[x + 1] = Beagle.NONE; // Not using scaleFactors operations[x + 2] = Beagle.NONE; } operations[x + 3] = partialBufferHelper.getOffsetIndex(child1.getNumber()); // source node 1 operations[x + 4] = substitutionModelDelegate.getMatrixIndex(child1.getNumber()); // source matrix 1 operations[x + 5] = partialBufferHelper.getOffsetIndex(child2.getNumber()); // source node 2 operations[x + 6] = substitutionModelDelegate.getMatrixIndex(child2.getNumber()); // source matrix 2 operationCount[operationListCount]++; update = true; if (hasRestrictedPartials) { // Test if this set of partials should be restricted if (updateRestrictedNodePartials) { // Recompute map computeNodeToRestrictionMap(); updateRestrictedNodePartials = false; } if (partialsMap[nodeNum] != null) {} } } } return update; }
/** * Calculate the log likelihood of the current state. * * @return the log likelihood. */ protected double calculateLogLikelihood() { if (patternLogLikelihoods == null) { patternLogLikelihoods = new double[patternCount]; } if (branchUpdateIndices == null) { branchUpdateIndices = new int[nodeCount]; branchLengths = new double[nodeCount]; scaleBufferIndices = new int[internalNodeCount]; storedScaleBufferIndices = new int[internalNodeCount]; } if (operations == null) { operations = new int[numRestrictedPartials + 1][internalNodeCount * Beagle.OPERATION_TUPLE_SIZE]; operationCount = new int[numRestrictedPartials + 1]; } recomputeScaleFactors = false; if (this.rescalingScheme == PartialsRescalingScheme.ALWAYS) { useScaleFactors = true; recomputeScaleFactors = true; } else if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC && everUnderflowed) { useScaleFactors = true; if (rescalingCountInner < RESCALE_TIMES) { recomputeScaleFactors = true; makeDirty(); // System.err.println("Recomputing scale factors"); } rescalingCountInner++; rescalingCount++; if (rescalingCount > rescalingFrequency) { rescalingCount = 0; rescalingCountInner = 0; } } else if (this.rescalingScheme == PartialsRescalingScheme.DELAYED && everUnderflowed) { useScaleFactors = true; recomputeScaleFactors = true; rescalingCount++; } if (tipStatesModel != null) { int tipCount = treeModel.getExternalNodeCount(); for (int index = 0; index < tipCount; index++) { if (updateNode[index]) { if (tipStatesModel.getModelType() == TipStatesModel.Type.PARTIALS) { tipStatesModel.getTipPartials(index, tipPartials); beagle.setTipPartials(index, tipPartials); } else { tipStatesModel.getTipStates(index, tipStates); beagle.setTipStates(index, tipStates); } } } } branchUpdateCount = 0; operationListCount = 0; if (hasRestrictedPartials) { for (int i = 0; i <= numRestrictedPartials; i++) { operationCount[i] = 0; } } else { operationCount[0] = 0; } final NodeRef root = treeModel.getRoot(); traverse(treeModel, root, null, true); if (updateSubstitutionModel) { // TODO More efficient to update only the substitution model that // changed, instead of all substitutionModelDelegate.updateSubstitutionModels(beagle); // we are currently assuming a no-category model... } if (updateSiteModel) { double[] categoryRates = this.siteModel.getCategoryRates(); beagle.setCategoryRates(categoryRates); } if (branchUpdateCount > 0) { substitutionModelDelegate.updateTransitionMatrices( beagle, branchUpdateIndices, branchLengths, branchUpdateCount); } if (COUNT_TOTAL_OPERATIONS) { totalMatrixUpdateCount += branchUpdateCount; for (int i = 0; i <= numRestrictedPartials; i++) { totalOperationCount += operationCount[i]; } } double logL; boolean done; boolean firstRescaleAttempt = true; do { if (hasRestrictedPartials) { for (int i = 0; i <= numRestrictedPartials; i++) { beagle.updatePartials(operations[i], operationCount[i], Beagle.NONE); if (i < numRestrictedPartials) { // restrictNodePartials(restrictedIndices[i]); } } } else { beagle.updatePartials(operations[0], operationCount[0], Beagle.NONE); } int rootIndex = partialBufferHelper.getOffsetIndex(root.getNumber()); double[] categoryWeights = this.siteModel.getCategoryProportions(); // This should probably explicitly be the state frequencies for the root node... double[] frequencies = substitutionModelDelegate.getRootStateFrequencies(); int cumulateScaleBufferIndex = Beagle.NONE; if (useScaleFactors) { if (recomputeScaleFactors) { scaleBufferHelper.flipOffset(internalNodeCount); cumulateScaleBufferIndex = scaleBufferHelper.getOffsetIndex(internalNodeCount); beagle.resetScaleFactors(cumulateScaleBufferIndex); beagle.accumulateScaleFactors( scaleBufferIndices, internalNodeCount, cumulateScaleBufferIndex); } else { cumulateScaleBufferIndex = scaleBufferHelper.getOffsetIndex(internalNodeCount); } } else if (useAutoScaling) { beagle.accumulateScaleFactors(scaleBufferIndices, internalNodeCount, Beagle.NONE); } // these could be set only when they change but store/restore would need to be considered beagle.setCategoryWeights(0, categoryWeights); beagle.setStateFrequencies(0, frequencies); double[] sumLogLikelihoods = new double[1]; beagle.calculateRootLogLikelihoods( new int[] {rootIndex}, new int[] {0}, new int[] {0}, new int[] {cumulateScaleBufferIndex}, 1, sumLogLikelihoods); logL = sumLogLikelihoods[0]; if (ascertainedSitePatterns) { // Need to correct for ascertainedSitePatterns beagle.getSiteLogLikelihoods(patternLogLikelihoods); logL = getAscertainmentCorrectedLogLikelihood( (AscertainedSitePatterns) patternList, patternLogLikelihoods, patternWeights); } if (Double.isNaN(logL) || Double.isInfinite(logL)) { everUnderflowed = true; logL = Double.NEGATIVE_INFINITY; if (firstRescaleAttempt && (rescalingScheme == PartialsRescalingScheme.DYNAMIC || rescalingScheme == PartialsRescalingScheme.DELAYED)) { // we have had a potential under/over flow so attempt a rescaling if (rescalingScheme == PartialsRescalingScheme.DYNAMIC || (rescalingCount == 0)) { Logger.getLogger("dr.evomodel") .info("Underflow calculating likelihood. Attempting a rescaling..."); } useScaleFactors = true; recomputeScaleFactors = true; branchUpdateCount = 0; if (hasRestrictedPartials) { for (int i = 0; i <= numRestrictedPartials; i++) { operationCount[i] = 0; } } else { operationCount[0] = 0; } // traverse again but without flipping partials indices as we // just want to overwrite the last attempt. We will flip the // scale buffer indices though as we are recomputing them. traverse(treeModel, root, null, false); done = false; // Run through do-while loop again firstRescaleAttempt = false; // Only try to rescale once } else { // we have already tried a rescale, not rescaling or always rescaling // so just return the likelihood... done = true; } } else { done = true; // No under-/over-flow, then done } } while (!done); // If these are needed... // beagle.getSiteLogLikelihoods(patternLogLikelihoods); // ******************************************************************** // after traverse all nodes and patterns have been updated -- // so change flags to reflect this. for (int i = 0; i < nodeCount; i++) { updateNode[i] = false; } updateSubstitutionModel = false; updateSiteModel = false; // ******************************************************************** return logL; }
@Override public Taxon getSlidableNodeTaxon(NodeRef node) { assert node == pionodes[node.getNumber()]; return ((PopsIONode) node).getTaxon(); }
@Override public NodeRef getSlidableChild(NodeRef node, int j) { int n = node.getNumber(); return j == 0 ? pionodes[pionodes[n].lft] : pionodes[pionodes[n].rgt]; }
@Override public boolean isExternalSlidable(NodeRef node) { return (pionodes[node.getNumber()].lft < 0); }
@Override public void setSlidableNodeHeight(NodeRef node, double height) { assert node == pionodes[node.getNumber()]; ((PopsIONode) node).height = height; }
@Override public double getSlidableNodeHeight(NodeRef node) { assert node == pionodes[node.getNumber()]; return ((PopsIONode) node).getHeight(); }
public void proposeTree() throws OperatorFailedException { TreeModel tree = c2cLikelihood.getTreeModel(); BranchMapModel branchMap = c2cLikelihood.getBranchMap(); NodeRef i; double oldMinAge, newMinAge, newRange, oldRange, newAge, q; // choose a random node avoiding root, and nodes that are ineligible for this move because they // have nowhere to // go final int nodeCount = tree.getNodeCount(); do { i = tree.getNode(MathUtils.nextInt(nodeCount)); } while (tree.getRoot() == i || !eligibleForMove(i, tree, branchMap)); final NodeRef iP = tree.getParent(i); // this one can go anywhere NodeRef j = tree.getNode(MathUtils.nextInt(tree.getNodeCount())); NodeRef k = tree.getParent(j); while ((k != null && tree.getNodeHeight(k) <= tree.getNodeHeight(i)) || (i == j)) { j = tree.getNode(MathUtils.nextInt(tree.getNodeCount())); k = tree.getParent(j); } if (iP == tree.getRoot() || j == tree.getRoot()) { throw new OperatorFailedException("Root changes not allowed!"); } if (k == iP || j == iP || k == i) throw new OperatorFailedException("move failed"); final NodeRef CiP = getOtherChild(tree, iP, i); NodeRef PiP = tree.getParent(iP); newMinAge = Math.max(tree.getNodeHeight(i), tree.getNodeHeight(j)); newRange = tree.getNodeHeight(k) - newMinAge; newAge = newMinAge + (MathUtils.nextDouble() * newRange); oldMinAge = Math.max(tree.getNodeHeight(i), tree.getNodeHeight(CiP)); oldRange = tree.getNodeHeight(PiP) - oldMinAge; q = newRange / Math.abs(oldRange); // need to account for the random repainting of iP if (branchMap.get(PiP.getNumber()) != branchMap.get(CiP.getNumber())) { q *= 0.5; } if (branchMap.get(k.getNumber()) != branchMap.get(j.getNumber())) { q *= 2; } tree.beginTreeEdit(); if (j == tree.getRoot()) { // 1. remove edges <iP, CiP> tree.removeChild(iP, CiP); tree.removeChild(PiP, iP); // 2. add edges <k, iP>, <iP, j>, <PiP, CiP> tree.addChild(iP, j); tree.addChild(PiP, CiP); // iP is the new root tree.setRoot(iP); } else if (iP == tree.getRoot()) { // 1. remove edges <k, j>, <iP, CiP>, <PiP, iP> tree.removeChild(k, j); tree.removeChild(iP, CiP); // 2. add edges <k, iP>, <iP, j>, <PiP, CiP> tree.addChild(iP, j); tree.addChild(k, iP); // CiP is the new root tree.setRoot(CiP); } else { // 1. remove edges <k, j>, <iP, CiP>, <PiP, iP> tree.removeChild(k, j); tree.removeChild(iP, CiP); tree.removeChild(PiP, iP); // 2. add edges <k, iP>, <iP, j>, <PiP, CiP> tree.addChild(iP, j); tree.addChild(k, iP); tree.addChild(PiP, CiP); } tree.setNodeHeight(iP, newAge); tree.endTreeEdit(); // logq = Math.log(q); // repaint the parent to match either its new parent or its new child (50% chance of each). if (MathUtils.nextInt(2) == 0) { branchMap.set(iP.getNumber(), branchMap.get(k.getNumber()), true); } else { branchMap.set(iP.getNumber(), branchMap.get(j.getNumber()), true); } if (DEBUG) { c2cLikelihood.checkPartitions(); } }
public double getBranchRate(final Tree tree, final NodeRef node) { assert !tree.isRoot(node) : "root node doesn't have a rate!"; // int rateCategory = (int) Math.round(rateCategories.getNodeValue(tree, node)); return rates[node.getNumber()] * scaleFactor; }
void postOrderTraverse( MultivariateTraitTree treeModel, NodeRef node, double[][] precisionMatrix, double logDetPrecisionMatrix, boolean cacheOuterProducts) { final int thisNumber = node.getNumber(); if (treeModel.isExternal(node)) { // Fill in precision scalar, traitValues already filled in if (missingTraits.isCompletelyMissing(thisNumber)) { upperPrecisionCache[thisNumber] = 0; lowerPrecisionCache[thisNumber] = 0; // Needed in the pre-order traversal } else { // not missing tip trait upperPrecisionCache[thisNumber] = (1.0 / getRescaledBranchLengthForPrecision(node)) * Math.pow(cacheHelper.getOUFactor(node), 2); lowerPrecisionCache[thisNumber] = Double.POSITIVE_INFINITY; } return; } final NodeRef childNode0 = treeModel.getChild(node, 0); final NodeRef childNode1 = treeModel.getChild(node, 1); postOrderTraverse( treeModel, childNode0, precisionMatrix, logDetPrecisionMatrix, cacheOuterProducts); postOrderTraverse( treeModel, childNode1, precisionMatrix, logDetPrecisionMatrix, cacheOuterProducts); final int childNumber0 = childNode0.getNumber(); final int childNumber1 = childNode1.getNumber(); final int meanOffset0 = dim * childNumber0; final int meanOffset1 = dim * childNumber1; final int meanThisOffset = dim * thisNumber; final double precision0 = upperPrecisionCache[childNumber0]; final double precision1 = upperPrecisionCache[childNumber1]; final double totalPrecision = precision0 + precision1; lowerPrecisionCache[thisNumber] = totalPrecision; // Multiply child0 and child1 densities // Delegate this! cacheHelper.computeMeanCaches( meanThisOffset, meanOffset0, meanOffset1, totalPrecision, precision0, precision1, missingTraits, node, childNode0, childNode1); // if (totalPrecision == 0) { // System.arraycopy(zeroDimVector, 0, meanCache, meanThisOffset, dim); // } else { // // Delegate in case either child is partially missing // // computeCorrectedWeightedAverage // missingTraits.computeWeightedAverage(meanCache, // meanOffset0, precision0, // meanOffset1, precision1, // meanThisOffset, dim); // } // In this delegation, you can call // getShiftForBranchLength(node); if (!treeModel.isRoot(node)) { // Integrate out trait value at this node double thisPrecision = 1.0 / getRescaledBranchLengthForPrecision(node); if (Double.isInfinite(thisPrecision)) { upperPrecisionCache[thisNumber] = totalPrecision; } else { upperPrecisionCache[thisNumber] = (totalPrecision * thisPrecision / (totalPrecision + thisPrecision)) * Math.pow(cacheHelper.getOUFactor(node), 2); } } // Compute logRemainderDensity logRemainderDensityCache[thisNumber] = 0; if (precision0 != 0 && precision1 != 0) { incrementRemainderDensities( precisionMatrix, logDetPrecisionMatrix, thisNumber, meanThisOffset, meanOffset0, meanOffset1, precision0, precision1, cacheHelper.getOUFactor(childNode0), cacheHelper.getOUFactor(childNode1), cacheOuterProducts); } }
@Override public int getParent(int i) { NodeRef n = tree.getParent(tree.getNode(i)); return n != null ? n.getNumber() : RootedTree.NULL; }
private void preOrderTraverseSample( MultivariateTraitTree treeModel, NodeRef node, int parentIndex, double[][] treePrecision, double[][] treeVariance) { final int thisIndex = node.getNumber(); if (treeModel.isRoot(node)) { // draw root double[] rootMean = new double[dimTrait]; final int rootIndex = treeModel.getRoot().getNumber(); double rootPrecision = lowerPrecisionCache[rootIndex]; for (int datum = 0; datum < numData; datum++) { // System.arraycopy(meanCache, thisIndex * dim + datum * dimTrait, rootMean, 0, dimTrait); System.arraycopy( cacheHelper.getMeanCache(), thisIndex * dim + datum * dimTrait, rootMean, 0, dimTrait); double[][] variance = computeMarginalRootMeanAndVariance( rootMean, treePrecision, treeVariance, rootPrecision); double[] draw = MultivariateNormalDistribution.nextMultivariateNormalVariance(rootMean, variance); if (DEBUG_PREORDER) { Arrays.fill(draw, 1.0); } System.arraycopy(draw, 0, drawnStates, rootIndex * dim + datum * dimTrait, dimTrait); if (DEBUG) { System.err.println("Root mean: " + new Vector(rootMean)); System.err.println("Root var : " + new Matrix(variance)); System.err.println("Root draw: " + new Vector(draw)); } } } else { // draw conditional on parentState if (!missingTraits.isCompletelyMissing(thisIndex) && !missingTraits.isPartiallyMissing(thisIndex)) { // System.arraycopy(meanCache, thisIndex * dim, drawnStates, thisIndex * dim, dim); System.arraycopy( cacheHelper.getMeanCache(), thisIndex * dim, drawnStates, thisIndex * dim, dim); } else { if (missingTraits.isPartiallyMissing(thisIndex)) { throw new RuntimeException("Partially missing values are not yet implemented"); } // This code should work for sampling a missing tip trait as well, but needs testing // parent trait at drawnStates[parentOffset] double precisionToParent = 1.0 / getRescaledBranchLengthForPrecision(node); double precisionOfNode = lowerPrecisionCache[thisIndex]; double totalPrecision = precisionOfNode + precisionToParent; double[] mean = Ay; // temporary storage double[][] var = tmpM; // temporary storage for (int datum = 0; datum < numData; datum++) { int parentOffset = parentIndex * dim + datum * dimTrait; int thisOffset = thisIndex * dim + datum * dimTrait; if (DEBUG) { double[] parentValue = new double[dimTrait]; System.arraycopy(drawnStates, parentOffset, parentValue, 0, dimTrait); System.err.println("Parent draw: " + new Vector(parentValue)); if (parentValue[0] != drawnStates[parentOffset]) { throw new RuntimeException("Error in setting indices"); } } for (int i = 0; i < dimTrait; i++) { mean[i] = (drawnStates[parentOffset + i] * precisionToParent // + meanCache[thisOffset + i] * precisionOfNode) / totalPrecision; + cacheHelper.getMeanCache()[thisOffset + i] * precisionOfNode) / totalPrecision; for (int j = 0; j < dimTrait; j++) { var[i][j] = treeVariance[i][j] / totalPrecision; } } double[] draw = MultivariateNormalDistribution.nextMultivariateNormalVariance(mean, var); System.arraycopy(draw, 0, drawnStates, thisOffset, dimTrait); if (DEBUG) { System.err.println("Int prec: " + totalPrecision); System.err.println("Int mean: " + new Vector(mean)); System.err.println("Int var : " + new Matrix(var)); System.err.println("Int draw: " + new Vector(draw)); System.err.println(""); } } } } if (peel() && !treeModel.isExternal(node)) { preOrderTraverseSample( treeModel, treeModel.getChild(node, 0), thisIndex, treePrecision, treeVariance); preOrderTraverseSample( treeModel, treeModel.getChild(node, 1), thisIndex, treePrecision, treeVariance); } }
@Override public void replaceSlidableRoot(NodeRef newroot) { rootn = newroot.getNumber(); pionodes[rootn].anc = -1; }