Ejemplo n.º 1
0
  public void forwardIBD() {
    int numNodes = treeModel.getNodeCount();
    int stateCount = substitutionModel.getStateCount();
    getDiagonalRates(diag);
    for (int nodeId = 0; nodeId < numNodes; ++nodeId) {
      NodeRef node = treeModel.getNode(nodeId);
      NodeRef parent = treeModel.getParent(node);
      if (parent == null) { // handle the root

      } else if (treeModel.isExternal(node)) { // Handle the tip
        double branchTime =
            branchRateModel.getBranchRate(treeModel, node)
                * (treeModel.getNodeHeight(parent) - treeModel.getNodeHeight(node));

        for (int state = 0; state < stateCount; ++state) {
          ibdForward[nodeId][state] = Math.exp(-diag[state] * branchTime);
        }
      } else { // Handle internal node
        double branchTime =
            branchRateModel.getBranchRate(treeModel, node)
                * (treeModel.getNodeHeight(parent) - treeModel.getNodeHeight(node));

        int childCount = treeModel.getChildCount(node);
        for (int state = 0; state < stateCount; ++state) {
          ibdForward[nodeId][state] = 0;
          for (int child = 0; child < childCount; ++child) {
            int childNodeId = treeModel.getChild(node, child).getNumber();
            ibdForward[nodeId][state] += ibdForward[childNodeId][state];
          }
          ibdForward[nodeId][state] *= Math.exp(-diag[state] * branchTime);
        }
      }
    }
  }
Ejemplo n.º 2
0
 public void backwardIBD(NodeRef node) {
   int stateCount = substitutionModel.getStateCount();
   if (node == null) {
     node = treeModel.getRoot();
     int nodeId = node.getNumber();
     for (int state = 0; state < stateCount; ++state) {
       ibdBackward[nodeId][state] = 0;
     }
   }
   getDiagonalRates(diag);
   int childCount = treeModel.getChildCount(node);
   int nodeId = node.getNumber();
   for (int child = 0; child < childCount; ++child) {
     NodeRef childNode = treeModel.getChild(node, child);
     int childNodeId = childNode.getNumber();
     double branchTime =
         branchRateModel.getBranchRate(treeModel, childNode)
             * (treeModel.getNodeHeight(node) - treeModel.getNodeHeight(childNode));
     for (int state = 0; state < stateCount; ++state) {
       ibdBackward[childNodeId][state] = ibdBackward[nodeId][state];
       for (int sibling = 0; sibling < childCount; ++sibling) {
         if (sibling != child) {
           int siblingId = treeModel.getChild(node, sibling).getNumber();
           ibdBackward[childNodeId][state] += ibdForward[siblingId][state];
         }
       }
       ibdBackward[childNodeId][state] *= Math.exp(-diag[state] * branchTime);
     }
   }
   for (int child = 0; child < childCount; ++child) {
     NodeRef childNode = treeModel.getChild(node, child);
     backwardIBD(childNode);
   }
 }
Ejemplo n.º 3
0
  public NewBeagleTreeLikelihood(
      PatternList patternList,
      TreeModel treeModel,
      BranchModel branchModel,
      SiteModel siteModel,
      BranchRateModel branchRateModel,
      TipStatesModel tipStatesModel,
      boolean useAmbiguities,
      PartialsRescalingScheme rescalingScheme,
      Map<Set<String>, Parameter> partialsRestrictions) {

    super(BeagleTreeLikelihoodParser.TREE_LIKELIHOOD, patternList, treeModel);

    try {
      final Logger logger = Logger.getLogger("dr.evomodel");

      logger.info("Using BEAGLE TreeLikelihood");

      this.siteModel = siteModel;
      addModel(this.siteModel);

      this.branchModel = branchModel;
      addModel(this.branchModel);

      if (branchRateModel != null) {
        this.branchRateModel = branchRateModel;
        logger.info("  Branch rate model used: " + branchRateModel.getModelName());
      } else {
        this.branchRateModel = new DefaultBranchRateModel();
      }
      addModel(this.branchRateModel);

      this.tipStatesModel = tipStatesModel;

      this.categoryCount = this.siteModel.getCategoryCount();

      this.tipCount = treeModel.getExternalNodeCount();

      internalNodeCount = nodeCount - tipCount;

      int compactPartialsCount = tipCount;
      if (useAmbiguities) {
        // if we are using ambiguities then we don't use tip partials
        compactPartialsCount = 0;
      }

      // one partials buffer for each tip and two for each internal node (for store restore)
      partialBufferHelper = new BufferIndexHelper(nodeCount, tipCount);

      // one scaling buffer for each internal node plus an extra for the accumulation, then doubled
      // for store/restore
      scaleBufferHelper = new BufferIndexHelper(getScaleBufferCount(), 0);

      // Attempt to get the resource order from the System Property
      if (resourceOrder == null) {
        resourceOrder = parseSystemPropertyIntegerArray(RESOURCE_ORDER_PROPERTY);
      }
      if (preferredOrder == null) {
        preferredOrder = parseSystemPropertyIntegerArray(PREFERRED_FLAGS_PROPERTY);
      }
      if (requiredOrder == null) {
        requiredOrder = parseSystemPropertyIntegerArray(REQUIRED_FLAGS_PROPERTY);
      }
      if (scalingOrder == null) {
        scalingOrder = parseSystemPropertyStringArray(SCALING_PROPERTY);
      }
      if (extraBufferOrder == null) {
        extraBufferOrder = parseSystemPropertyIntegerArray(EXTRA_BUFFER_COUNT_PROPERTY);
      }

      int extraBufferCount = -1; // default
      if (extraBufferOrder.size() > 0) {
        extraBufferCount = extraBufferOrder.get(instanceCount % extraBufferOrder.size());
      }
      substitutionModelDelegate =
          new SubstitutionModelDelegate(treeModel, branchModel, extraBufferCount);

      // first set the rescaling scheme to use from the parser
      this.rescalingScheme = rescalingScheme;
      int[] resourceList = null;
      long preferenceFlags = 0;
      long requirementFlags = 0;

      if (scalingOrder.size() > 0) {
        this.rescalingScheme =
            PartialsRescalingScheme.parseFromString(
                scalingOrder.get(instanceCount % scalingOrder.size()));
      }

      if (resourceOrder.size() > 0) {
        // added the zero on the end so that a CPU is selected if requested resource fails
        resourceList = new int[] {resourceOrder.get(instanceCount % resourceOrder.size()), 0};
        if (resourceList[0] > 0) {
          preferenceFlags |=
              BeagleFlag.PROCESSOR_GPU.getMask(); // Add preference weight against CPU
        }
      }

      if (preferredOrder.size() > 0) {
        preferenceFlags = preferredOrder.get(instanceCount % preferredOrder.size());
      }

      if (requiredOrder.size() > 0) {
        requirementFlags = requiredOrder.get(instanceCount % requiredOrder.size());
      }

      // Define default behaviour here
      if (this.rescalingScheme == PartialsRescalingScheme.DEFAULT) {
        // if GPU: the default is dynamic scaling in BEAST
        if (resourceList != null && resourceList[0] > 1) {
          this.rescalingScheme = DEFAULT_RESCALING_SCHEME;
        } else { // if CPU: just run as fast as possible
          //                    this.rescalingScheme = PartialsRescalingScheme.NONE;
          // Dynamic should run as fast as none until first underflow
          this.rescalingScheme = DEFAULT_RESCALING_SCHEME;
        }
      }

      if (this.rescalingScheme == PartialsRescalingScheme.AUTO) {
        preferenceFlags |= BeagleFlag.SCALING_AUTO.getMask();
        useAutoScaling = true;
      } else {
        //                preferenceFlags |= BeagleFlag.SCALING_MANUAL.getMask();
      }
      String r = System.getProperty(RESCALE_FREQUENCY_PROPERTY);
      if (r != null) {
        rescalingFrequency = Integer.parseInt(r);
        if (rescalingFrequency < 1) {
          rescalingFrequency = RESCALE_FREQUENCY;
        }
      }

      if (preferenceFlags == 0 && resourceList == null) { // else determine dataset characteristics
        if (stateCount == 4 && patternList.getPatternCount() < 10000) // TODO determine good cut-off
        preferenceFlags |= BeagleFlag.PROCESSOR_CPU.getMask();
      }

      if (BeagleFlag.VECTOR_SSE.isSet(preferenceFlags) && stateCount != 4) {
        // @todo SSE doesn't seem to work for larger state spaces so for now we override the
        // SSE option.
        preferenceFlags &= ~BeagleFlag.VECTOR_SSE.getMask();
        preferenceFlags |= BeagleFlag.VECTOR_NONE.getMask();

        if (stateCount > 4 && this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) {
          this.rescalingScheme = PartialsRescalingScheme.DELAYED;
        }
      }

      if (!BeagleFlag.PRECISION_SINGLE.isSet(preferenceFlags)) {
        // if single precision not explicitly set then prefer double
        preferenceFlags |= BeagleFlag.PRECISION_DOUBLE.getMask();
      }

      if (substitutionModelDelegate.canReturnComplexDiagonalization()) {
        requirementFlags |= BeagleFlag.EIGEN_COMPLEX.getMask();
      }

      instanceCount++;

      beagle =
          BeagleFactory.loadBeagleInstance(
              tipCount,
              partialBufferHelper.getBufferCount(),
              compactPartialsCount,
              stateCount,
              patternCount,
              substitutionModelDelegate.getEigenBufferCount(),
              substitutionModelDelegate.getMatrixBufferCount(),
              categoryCount,
              scaleBufferHelper.getBufferCount(), // Always allocate; they may become necessary
              resourceList,
              preferenceFlags,
              requirementFlags);

      InstanceDetails instanceDetails = beagle.getDetails();
      ResourceDetails resourceDetails = null;

      if (instanceDetails != null) {
        resourceDetails = BeagleFactory.getResourceDetails(instanceDetails.getResourceNumber());
        if (resourceDetails != null) {
          StringBuilder sb = new StringBuilder("  Using BEAGLE resource ");
          sb.append(resourceDetails.getNumber()).append(": ");
          sb.append(resourceDetails.getName()).append("\n");
          if (resourceDetails.getDescription() != null) {
            String[] description = resourceDetails.getDescription().split("\\|");
            for (String desc : description) {
              if (desc.trim().length() > 0) {
                sb.append("    ").append(desc.trim()).append("\n");
              }
            }
          }
          sb.append("    with instance flags: ").append(instanceDetails.toString());
          logger.info(sb.toString());
        } else {
          logger.info(
              "  Error retrieving BEAGLE resource for instance: " + instanceDetails.toString());
        }
      } else {
        logger.info(
            "  No external BEAGLE resources available, or resource list/requirements not met, using Java implementation");
      }
      logger.info(
          "  " + (useAmbiguities ? "Using" : "Ignoring") + " ambiguities in tree likelihood.");
      logger.info("  With " + patternList.getPatternCount() + " unique site patterns.");

      if (tipStatesModel != null) {
        tipStatesModel.setTree(treeModel);

        if (tipStatesModel.getModelType() == TipStatesModel.Type.PARTIALS) {
          tipPartials = new double[patternCount * stateCount];
        } else {
          tipStates = new int[patternCount];
        }

        addModel(tipStatesModel);
      }

      for (int i = 0; i < tipCount; i++) {
        // Find the id of tip i in the patternList
        String id = treeModel.getTaxonId(i);
        int index = patternList.getTaxonIndex(id);

        if (index == -1) {
          throw new TaxonList.MissingTaxonException(
              "Taxon, "
                  + id
                  + ", in tree, "
                  + treeModel.getId()
                  + ", is not found in patternList, "
                  + patternList.getId());
        } else {
          if (tipStatesModel != null) {
            // using a tipPartials model.
            // First set the observed states:
            tipStatesModel.setStates(patternList, index, i, id);

            if (tipStatesModel.getModelType() == TipStatesModel.Type.PARTIALS) {
              // Then set the tip partials as determined by the model:
              setPartials(beagle, tipStatesModel, i);
            } else {
              // or the tip states:
              tipStatesModel.getTipStates(i, tipStates);
              beagle.setTipStates(i, tipStates);
            }

          } else {
            if (useAmbiguities) {
              setPartials(beagle, patternList, index, i);
            } else {
              setStates(beagle, patternList, index, i);
            }
          }
        }
      }

      if (patternList instanceof AscertainedSitePatterns) {
        ascertainedSitePatterns = true;
      }

      this.partialsRestrictions = partialsRestrictions;
      //            hasRestrictedPartials = (partialsRestrictions != null);
      if (hasRestrictedPartials) {
        numRestrictedPartials = partialsRestrictions.size();
        updateRestrictedNodePartials = true;
        partialsMap = new Parameter[treeModel.getNodeCount()];
        partials = new double[stateCount * patternCount * categoryCount];
      } else {
        numRestrictedPartials = 0;
        updateRestrictedNodePartials = false;
      }

      beagle.setPatternWeights(patternWeights);

      String rescaleMessage = "  Using rescaling scheme : " + this.rescalingScheme.getText();
      if (this.rescalingScheme == PartialsRescalingScheme.AUTO
          && resourceDetails != null
          && (resourceDetails.getFlags() & BeagleFlag.SCALING_AUTO.getMask()) == 0) {
        // If auto scaling in BEAGLE is not supported then do it here
        this.rescalingScheme = PartialsRescalingScheme.DYNAMIC;
        rescaleMessage =
            "  Auto rescaling not supported in BEAGLE, using : " + this.rescalingScheme.getText();
      }
      if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) {
        rescaleMessage += " (rescaling every " + rescalingFrequency + " evaluations)";
      }
      logger.info(rescaleMessage);

      if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) {
        everUnderflowed = false; // If false, BEAST does not rescale until first under-/over-flow.
      }

      updateSubstitutionModel = true;
      updateSiteModel = true;

    } catch (TaxonList.MissingTaxonException mte) {
      throw new RuntimeException(mte.toString());
    }
    this.useAmbiguities = useAmbiguities;
    hasInitialized = true;
  }