Ejemplo n.º 1
0
  public NewBeagleTreeLikelihood(
      PatternList patternList,
      TreeModel treeModel,
      BranchModel branchModel,
      SiteModel siteModel,
      BranchRateModel branchRateModel,
      TipStatesModel tipStatesModel,
      boolean useAmbiguities,
      PartialsRescalingScheme rescalingScheme,
      Map<Set<String>, Parameter> partialsRestrictions) {

    super(BeagleTreeLikelihoodParser.TREE_LIKELIHOOD, patternList, treeModel);

    try {
      final Logger logger = Logger.getLogger("dr.evomodel");

      logger.info("Using BEAGLE TreeLikelihood");

      this.siteModel = siteModel;
      addModel(this.siteModel);

      this.branchModel = branchModel;
      addModel(this.branchModel);

      if (branchRateModel != null) {
        this.branchRateModel = branchRateModel;
        logger.info("  Branch rate model used: " + branchRateModel.getModelName());
      } else {
        this.branchRateModel = new DefaultBranchRateModel();
      }
      addModel(this.branchRateModel);

      this.tipStatesModel = tipStatesModel;

      this.categoryCount = this.siteModel.getCategoryCount();

      this.tipCount = treeModel.getExternalNodeCount();

      internalNodeCount = nodeCount - tipCount;

      int compactPartialsCount = tipCount;
      if (useAmbiguities) {
        // if we are using ambiguities then we don't use tip partials
        compactPartialsCount = 0;
      }

      // one partials buffer for each tip and two for each internal node (for store restore)
      partialBufferHelper = new BufferIndexHelper(nodeCount, tipCount);

      // one scaling buffer for each internal node plus an extra for the accumulation, then doubled
      // for store/restore
      scaleBufferHelper = new BufferIndexHelper(getScaleBufferCount(), 0);

      // Attempt to get the resource order from the System Property
      if (resourceOrder == null) {
        resourceOrder = parseSystemPropertyIntegerArray(RESOURCE_ORDER_PROPERTY);
      }
      if (preferredOrder == null) {
        preferredOrder = parseSystemPropertyIntegerArray(PREFERRED_FLAGS_PROPERTY);
      }
      if (requiredOrder == null) {
        requiredOrder = parseSystemPropertyIntegerArray(REQUIRED_FLAGS_PROPERTY);
      }
      if (scalingOrder == null) {
        scalingOrder = parseSystemPropertyStringArray(SCALING_PROPERTY);
      }
      if (extraBufferOrder == null) {
        extraBufferOrder = parseSystemPropertyIntegerArray(EXTRA_BUFFER_COUNT_PROPERTY);
      }

      int extraBufferCount = -1; // default
      if (extraBufferOrder.size() > 0) {
        extraBufferCount = extraBufferOrder.get(instanceCount % extraBufferOrder.size());
      }
      substitutionModelDelegate =
          new SubstitutionModelDelegate(treeModel, branchModel, extraBufferCount);

      // first set the rescaling scheme to use from the parser
      this.rescalingScheme = rescalingScheme;
      int[] resourceList = null;
      long preferenceFlags = 0;
      long requirementFlags = 0;

      if (scalingOrder.size() > 0) {
        this.rescalingScheme =
            PartialsRescalingScheme.parseFromString(
                scalingOrder.get(instanceCount % scalingOrder.size()));
      }

      if (resourceOrder.size() > 0) {
        // added the zero on the end so that a CPU is selected if requested resource fails
        resourceList = new int[] {resourceOrder.get(instanceCount % resourceOrder.size()), 0};
        if (resourceList[0] > 0) {
          preferenceFlags |=
              BeagleFlag.PROCESSOR_GPU.getMask(); // Add preference weight against CPU
        }
      }

      if (preferredOrder.size() > 0) {
        preferenceFlags = preferredOrder.get(instanceCount % preferredOrder.size());
      }

      if (requiredOrder.size() > 0) {
        requirementFlags = requiredOrder.get(instanceCount % requiredOrder.size());
      }

      // Define default behaviour here
      if (this.rescalingScheme == PartialsRescalingScheme.DEFAULT) {
        // if GPU: the default is dynamic scaling in BEAST
        if (resourceList != null && resourceList[0] > 1) {
          this.rescalingScheme = DEFAULT_RESCALING_SCHEME;
        } else { // if CPU: just run as fast as possible
          //                    this.rescalingScheme = PartialsRescalingScheme.NONE;
          // Dynamic should run as fast as none until first underflow
          this.rescalingScheme = DEFAULT_RESCALING_SCHEME;
        }
      }

      if (this.rescalingScheme == PartialsRescalingScheme.AUTO) {
        preferenceFlags |= BeagleFlag.SCALING_AUTO.getMask();
        useAutoScaling = true;
      } else {
        //                preferenceFlags |= BeagleFlag.SCALING_MANUAL.getMask();
      }
      String r = System.getProperty(RESCALE_FREQUENCY_PROPERTY);
      if (r != null) {
        rescalingFrequency = Integer.parseInt(r);
        if (rescalingFrequency < 1) {
          rescalingFrequency = RESCALE_FREQUENCY;
        }
      }

      if (preferenceFlags == 0 && resourceList == null) { // else determine dataset characteristics
        if (stateCount == 4 && patternList.getPatternCount() < 10000) // TODO determine good cut-off
        preferenceFlags |= BeagleFlag.PROCESSOR_CPU.getMask();
      }

      if (BeagleFlag.VECTOR_SSE.isSet(preferenceFlags) && stateCount != 4) {
        // @todo SSE doesn't seem to work for larger state spaces so for now we override the
        // SSE option.
        preferenceFlags &= ~BeagleFlag.VECTOR_SSE.getMask();
        preferenceFlags |= BeagleFlag.VECTOR_NONE.getMask();

        if (stateCount > 4 && this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) {
          this.rescalingScheme = PartialsRescalingScheme.DELAYED;
        }
      }

      if (!BeagleFlag.PRECISION_SINGLE.isSet(preferenceFlags)) {
        // if single precision not explicitly set then prefer double
        preferenceFlags |= BeagleFlag.PRECISION_DOUBLE.getMask();
      }

      if (substitutionModelDelegate.canReturnComplexDiagonalization()) {
        requirementFlags |= BeagleFlag.EIGEN_COMPLEX.getMask();
      }

      instanceCount++;

      beagle =
          BeagleFactory.loadBeagleInstance(
              tipCount,
              partialBufferHelper.getBufferCount(),
              compactPartialsCount,
              stateCount,
              patternCount,
              substitutionModelDelegate.getEigenBufferCount(),
              substitutionModelDelegate.getMatrixBufferCount(),
              categoryCount,
              scaleBufferHelper.getBufferCount(), // Always allocate; they may become necessary
              resourceList,
              preferenceFlags,
              requirementFlags);

      InstanceDetails instanceDetails = beagle.getDetails();
      ResourceDetails resourceDetails = null;

      if (instanceDetails != null) {
        resourceDetails = BeagleFactory.getResourceDetails(instanceDetails.getResourceNumber());
        if (resourceDetails != null) {
          StringBuilder sb = new StringBuilder("  Using BEAGLE resource ");
          sb.append(resourceDetails.getNumber()).append(": ");
          sb.append(resourceDetails.getName()).append("\n");
          if (resourceDetails.getDescription() != null) {
            String[] description = resourceDetails.getDescription().split("\\|");
            for (String desc : description) {
              if (desc.trim().length() > 0) {
                sb.append("    ").append(desc.trim()).append("\n");
              }
            }
          }
          sb.append("    with instance flags: ").append(instanceDetails.toString());
          logger.info(sb.toString());
        } else {
          logger.info(
              "  Error retrieving BEAGLE resource for instance: " + instanceDetails.toString());
        }
      } else {
        logger.info(
            "  No external BEAGLE resources available, or resource list/requirements not met, using Java implementation");
      }
      logger.info(
          "  " + (useAmbiguities ? "Using" : "Ignoring") + " ambiguities in tree likelihood.");
      logger.info("  With " + patternList.getPatternCount() + " unique site patterns.");

      if (tipStatesModel != null) {
        tipStatesModel.setTree(treeModel);

        if (tipStatesModel.getModelType() == TipStatesModel.Type.PARTIALS) {
          tipPartials = new double[patternCount * stateCount];
        } else {
          tipStates = new int[patternCount];
        }

        addModel(tipStatesModel);
      }

      for (int i = 0; i < tipCount; i++) {
        // Find the id of tip i in the patternList
        String id = treeModel.getTaxonId(i);
        int index = patternList.getTaxonIndex(id);

        if (index == -1) {
          throw new TaxonList.MissingTaxonException(
              "Taxon, "
                  + id
                  + ", in tree, "
                  + treeModel.getId()
                  + ", is not found in patternList, "
                  + patternList.getId());
        } else {
          if (tipStatesModel != null) {
            // using a tipPartials model.
            // First set the observed states:
            tipStatesModel.setStates(patternList, index, i, id);

            if (tipStatesModel.getModelType() == TipStatesModel.Type.PARTIALS) {
              // Then set the tip partials as determined by the model:
              setPartials(beagle, tipStatesModel, i);
            } else {
              // or the tip states:
              tipStatesModel.getTipStates(i, tipStates);
              beagle.setTipStates(i, tipStates);
            }

          } else {
            if (useAmbiguities) {
              setPartials(beagle, patternList, index, i);
            } else {
              setStates(beagle, patternList, index, i);
            }
          }
        }
      }

      if (patternList instanceof AscertainedSitePatterns) {
        ascertainedSitePatterns = true;
      }

      this.partialsRestrictions = partialsRestrictions;
      //            hasRestrictedPartials = (partialsRestrictions != null);
      if (hasRestrictedPartials) {
        numRestrictedPartials = partialsRestrictions.size();
        updateRestrictedNodePartials = true;
        partialsMap = new Parameter[treeModel.getNodeCount()];
        partials = new double[stateCount * patternCount * categoryCount];
      } else {
        numRestrictedPartials = 0;
        updateRestrictedNodePartials = false;
      }

      beagle.setPatternWeights(patternWeights);

      String rescaleMessage = "  Using rescaling scheme : " + this.rescalingScheme.getText();
      if (this.rescalingScheme == PartialsRescalingScheme.AUTO
          && resourceDetails != null
          && (resourceDetails.getFlags() & BeagleFlag.SCALING_AUTO.getMask()) == 0) {
        // If auto scaling in BEAGLE is not supported then do it here
        this.rescalingScheme = PartialsRescalingScheme.DYNAMIC;
        rescaleMessage =
            "  Auto rescaling not supported in BEAGLE, using : " + this.rescalingScheme.getText();
      }
      if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) {
        rescaleMessage += " (rescaling every " + rescalingFrequency + " evaluations)";
      }
      logger.info(rescaleMessage);

      if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) {
        everUnderflowed = false; // If false, BEAST does not rescale until first under-/over-flow.
      }

      updateSubstitutionModel = true;
      updateSiteModel = true;

    } catch (TaxonList.MissingTaxonException mte) {
      throw new RuntimeException(mte.toString());
    }
    this.useAmbiguities = useAmbiguities;
    hasInitialized = true;
  }
  /** Constructor. */
  public CenancestorTreeLikelihood(
      PatternList patternList,
      TreeModel treeModel,
      SiteModel siteModel,
      CenancestorBranchRateModel branchRateModel,
      TipStatesModel tipStatesModel,
      Parameter cenancestorHeight,
      Parameter cenancestorBranch,
      // Parameter asStatistic,
      boolean useAmbiguities,
      boolean allowMissingTaxa,
      boolean storePartials,
      boolean forceJavaCore,
      boolean forceRescaling) {

    super(CenancestorTreeLikelihoodParser.TREE_LIKELIHOOD, patternList, treeModel);

    this.storePartials = storePartials;
    nodeCount = treeModel.getNodeCount() + 1;
    updateNode = new boolean[nodeCount];
    for (int i = 0; i < nodeCount; i++) {
      updateNode[i] = true;
    }

    try {
      this.siteModel = siteModel;
      addModel(siteModel);

      this.frequencyModel = siteModel.getFrequencyModel();
      addModel(frequencyModel);

      this.tipStatesModel = tipStatesModel;

      integrateAcrossCategories = siteModel.integrateAcrossCategories();

      this.categoryCount = siteModel.getCategoryCount();

      this.cenancestorHeight = cenancestorHeight;
      addVariable(cenancestorHeight);
      cenancestorHeight.addBounds(
          new Parameter.DefaultBounds(
              Double.POSITIVE_INFINITY,
              0.0,
              1)); // TODO The lower bound should be the maximum leaf height

      this.cenancestorBranch = cenancestorBranch;
      cenancestorBranch.addBounds(
          new Parameter.DefaultBounds(
              Double.POSITIVE_INFINITY,
              0.0,
              1)); // TODO The upper bound should be the maximum leaf height
      addVariable(cenancestorBranch);

      // if (asStatistic == cenancestorHeight){
      //	this.branchRules=true;
      // }

      //	if (branchRules==true){
      updateCenancestorHeight(); // Trying to avoid improper initial values
      //	}
      // 	else{
      //		updateCenancestorBranch();
      //	}

      final Logger logger = Logger.getLogger("dr.evomodel");
      String coreName = "Java general";

      /** TODO: Check if is worthy to implement other datatypes. */
      final DataType dataType = patternList.getDataType();

      if (dataType instanceof dr.evolution.datatype.TwoStates) {
        coreName = "Java cenancestor binary";
        cenancestorlikelihoodCore =
            new GeneralCenancestorLikelihoodCore(patternList.getStateCount());
      } else if (dataType instanceof dr.evolution.datatype.GeneralDataType) {
        coreName = "Java cenancestor CNV";
        cenancestorlikelihoodCore =
            new GeneralCenancestorLikelihoodCore(patternList.getStateCount());
      }

      /*            if (integrateAcrossCategories) {

          final DataType dataType = patternList.getDataType();

          if (dataType instanceof dr.evolution.datatype.Nucleotides) {

              if (!forceJavaCore && NativeNucleotideLikelihoodCore.isAvailable()) {
                  coreName = "native nucleotide";
                  likelihoodCore = new NativeNucleotideLikelihoodCore();
              } else {
                  coreName = "Java nucleotide";
                  likelihoodCore = new NucleotideLikelihoodCore();
              }

          } else if (dataType instanceof dr.evolution.datatype.AminoAcids) {
              if (!forceJavaCore && NativeAminoAcidLikelihoodCore.isAvailable()) {
                  coreName = "native amino acid";
                  likelihoodCore = new NativeAminoAcidLikelihoodCore();
              } else {
                  coreName = "Java amino acid";
                  likelihoodCore = new AminoAcidLikelihoodCore();
              }

              // The codon core was out of date and did nothing more than the general core...
          } else if (dataType instanceof dr.evolution.datatype.Codons) {
              if (!forceJavaCore && NativeGeneralLikelihoodCore.isAvailable()) {
                  coreName = "native general";
                  likelihoodCore = new NativeGeneralLikelihoodCore(patternList.getStateCount());
              } else {
                  coreName = "Java general";
                  likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount());
              }
              useAmbiguities = true;
          } else {
              if (!forceJavaCore && NativeGeneralLikelihoodCore.isAvailable()) {
                  coreName = "native general";
                  likelihoodCore = new NativeGeneralLikelihoodCore(patternList.getStateCount());
              } else {
                  	coreName = "Java general";
                  	likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount());
              }
          }
      } else {
          likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount());
      }*/
      {
        final String id = getId();
        logger.info(
            "TreeLikelihood("
                + ((id != null) ? id : treeModel.getId())
                + ") using "
                + coreName
                + " likelihood core");

        logger.info(
            "  " + (useAmbiguities ? "Using" : "Ignoring") + " ambiguities in tree likelihood.");
        logger.info("  With " + patternList.getPatternCount() + " unique site patterns.");
      }

      if (branchRateModel != null) {
        this.branchRateModel = branchRateModel;
        logger.info("Branch rate model used: " + branchRateModel.getModelName());
      } else {
        this.branchRateModel = new DefaultCenancestorBranchRateModel();
      }
      addModel(this.branchRateModel);

      probabilities = new double[stateCount * stateCount];

      cenancestorlikelihoodCore.initialize(
          nodeCount, patternCount, categoryCount, integrateAcrossCategories);

      int extNodeCount = treeModel.getExternalNodeCount();
      int intNodeCount = treeModel.getInternalNodeCount();

      if (tipStatesModel != null) {
        tipStatesModel.setTree(treeModel);

        tipPartials = new double[patternCount * stateCount];

        for (int i = 0; i < extNodeCount; i++) {
          // Find the id of tip i in the patternList
          String id = treeModel.getTaxonId(i);
          int index = patternList.getTaxonIndex(id);

          if (index == -1) {
            throw new TaxonList.MissingTaxonException(
                "Taxon, "
                    + id
                    + ", in tree, "
                    + treeModel.getId()
                    + ", is not found in patternList, "
                    + patternList.getId());
          }

          tipStatesModel.setStates(patternList, index, i, id);
          cenancestorlikelihoodCore.createNodePartials(i);
        }

        addModel(tipStatesModel);
      } else {
        for (int i = 0; i < extNodeCount; i++) {
          // Find the id of tip i in the patternList
          String id = treeModel.getTaxonId(i);
          int index = patternList.getTaxonIndex(id);

          if (index == -1) {
            if (!allowMissingTaxa) {
              throw new TaxonList.MissingTaxonException(
                  "Taxon, "
                      + id
                      + ", in tree, "
                      + treeModel.getId()
                      + ", is not found in patternList, "
                      + patternList.getId());
            }
            if (useAmbiguities) {
              setMissingPartials((LikelihoodCore) cenancestorlikelihoodCore, i);
            } else {
              setMissingStates((LikelihoodCore) cenancestorlikelihoodCore, i);
            }
          } else {
            if (useAmbiguities) {
              setPartials(
                  (LikelihoodCore) cenancestorlikelihoodCore, patternList, categoryCount, index, i);
            } else {
              setStates((LikelihoodCore) cenancestorlikelihoodCore, patternList, index, i);
            }
          }
        }
      }
      for (int i = 0; i <= intNodeCount; i++) { // Added one step for the cenancestor
        cenancestorlikelihoodCore.createNodePartials(extNodeCount + i);
      }

      if (forceRescaling) {
        cenancestorlikelihoodCore.setUseScaling(true);
        logger.info("  Forcing use of partials rescaling.");
      }

    } catch (TaxonList.MissingTaxonException mte) {
      throw new RuntimeException(mte.toString());
    }

    addStatistic(new SiteLikelihoodsStatistic());
  }