{ rules = new XMLSyntaxRule[] { AttributeRule.newBooleanRule(NON_INFORMATIVE, true), AttributeRule.newDoubleRule(DF, true), new ElementRule( SCALE_MATRIX, new XMLSyntaxRule[] {new ElementRule(MatrixParameter.class)}, true), new ElementRule( DATA, new XMLSyntaxRule[] { new ElementRule(MatrixParameter.class, 1, Integer.MAX_VALUE) }) }; }
/** @author Marc Suchard */ public class MultivariateDistributionLikelihood extends AbstractDistributionLikelihood { public static final String MVN_PRIOR = "multivariateNormalPrior"; public static final String MVN_MEAN = "meanParameter"; public static final String MVN_PRECISION = "precisionParameter"; public static final String MVN_CV = "coefficientOfVariation"; public static final String WISHART_PRIOR = "multivariateWishartPrior"; public static final String INV_WISHART_PRIOR = "multivariateInverseWishartPrior"; public static final String DIRICHLET_PRIOR = "dirichletPrior"; public static final String DF = "df"; public static final String SCALE_MATRIX = "scaleMatrix"; public static final String MVGAMMA_PRIOR = "multivariateGammaPrior"; public static final String MVGAMMA_SHAPE = "shapeParameter"; public static final String MVGAMMA_SCALE = "scaleParameter"; public static final String COUNTS = "countsParameter"; public static final String NON_INFORMATIVE = "nonInformative"; public static final String MULTIVARIATE_LIKELIHOOD = "multivariateDistributionLikelihood"; public static final String DATA_AS_MATRIX = "dataAsMatrix"; // public static final String TREE_TRAIT = "treeTraitNormalDistribution"; public static final String TREE_TRAIT = "treeTraitNormalDistributionLikelihood"; public static final String TREE_TRAIT_NORMAL = "treeTraitNormalDistribution"; public static final String ROOT_VALUE = "rootValue"; public static final String CONDITION = "conditionOnRoot"; public static final String DATA = "data"; private final MultivariateDistribution distribution; private final Transform[] transforms; public MultivariateDistributionLikelihood( String name, ParametricMultivariateDistributionModel model) { this(name, model, null); } public MultivariateDistributionLikelihood( String name, ParametricMultivariateDistributionModel model, Transform[] transforms) { super(model); this.distribution = model; this.transforms = transforms; } public MultivariateDistributionLikelihood(String name, MultivariateDistribution distribution) { this(name, distribution, null); } public MultivariateDistributionLikelihood( String name, MultivariateDistribution distribution, Transform[] transforms) { super(new DefaultModel(name)); this.distribution = distribution; this.transforms = transforms; } public MultivariateDistributionLikelihood(MultivariateDistribution distribution) { this(distribution, null); } public MultivariateDistributionLikelihood( MultivariateDistribution distribution, Transform[] transforms) { this(distribution.getType(), distribution, transforms); } public String toString() { return getClass().getName() + "(" + getLogLikelihood() + ")"; } public double calculateLogLikelihood() { double logL = 0.0; for (Attribute<double[]> data : dataList) { double[] x = data.getAttributeValue(); if (transforms != null) { double[] y = new double[x.length]; for (int i = 0; i < x.length; ++i) { logL += transforms[i].getLogJacobian(x[i]); y[i] = transforms[i].transform(x[i]); } logL += distribution.logPdf(y); } else { logL += distribution.logPdf(x); } } return logL; } @Override public void addData(Attribute<double[]> data) { super.addData(data); if (data instanceof Variable && getModel() instanceof DefaultModel) { ((DefaultModel) getModel()).addVariable((Variable) data); } } public MultivariateDistribution getDistribution() { return distribution; } public static Transform[] parseListOfTransforms(XMLObject xo, int maxDim) throws XMLParseException { Transform[] transforms = null; boolean anyTransforms = false; for (int i = 0; i < xo.getChildCount(); ++i) { if (xo.getChild(i) instanceof Transform.ParsedTransform) { Transform.ParsedTransform t = (Transform.ParsedTransform) xo.getChild(i); if (transforms == null) { transforms = Transform.Util.getListOfNoTransforms(maxDim); } t.end = Math.max(t.end, maxDim); if (t.start < 0 || t.end < 0 || t.start > t.end) { throw new XMLParseException("Invalid bounds for transform in " + xo.getId()); } for (int j = t.start; j < t.end; j += t.every) { transforms[j] = t.transform; anyTransforms = true; } } } if (anyTransforms) { StringBuilder sb = new StringBuilder("Using distributional transforms in " + xo.getId() + "\n"); for (int i = 0; i < transforms.length; ++i) { if (transforms[i] != Transform.NONE) { sb.append("\t") .append(transforms[i].getTransformName()) .append(" on index ") .append(i + 1) .append("\n"); } } sb.append("Please cite:\n").append(Citable.Utils.getCitationString(Transform.LOG)); Logger.getLogger("dr.utils.Transform").info(sb.toString()); } return transforms; } public static XMLObjectParser DIRICHLET_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return DIRICHLET_PRIOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { XMLObject cxo = xo.getChild(COUNTS); Parameter counts = (Parameter) cxo.getChild(Parameter.class); DirichletDistribution dirichlet = new DirichletDistribution(counts.getParameterValues()); MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood(dirichlet); cxo = xo.getChild(DATA); for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof Parameter) { likelihood.addData((Parameter) cxo.getChild(j)); } else { throw new XMLParseException( "illegal element in " + xo.getName() + " element " + cxo.getName()); } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new ElementRule(COUNTS, new XMLSyntaxRule[] {new ElementRule(Parameter.class)}), new ElementRule( DATA, new XMLSyntaxRule[] {new ElementRule(Parameter.class)}, 1, Integer.MAX_VALUE), }; public String getParserDescription() { return "Calculates the likelihood of some data under a Dirichlet distribution."; } public Class getReturnType() { return Likelihood.class; } }; public static XMLObjectParser INV_WISHART_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return INV_WISHART_PRIOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { int df = xo.getIntegerAttribute(DF); XMLObject cxo = xo.getChild(SCALE_MATRIX); MatrixParameter scaleMatrix = (MatrixParameter) cxo.getChild(MatrixParameter.class); InverseWishartDistribution invWishart = new InverseWishartDistribution(df, scaleMatrix.getParameterAsMatrix()); MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood(invWishart); cxo = xo.getChild(DATA); for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof MatrixParameter) { likelihood.addData((MatrixParameter) cxo.getChild(j)); } else { throw new XMLParseException( "illegal element in " + xo.getName() + " element " + cxo.getName()); } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newDoubleRule(DF), new ElementRule( SCALE_MATRIX, new XMLSyntaxRule[] {new ElementRule(MatrixParameter.class)}), }; public String getParserDescription() { return "Calculates the likelihood of some data under an Inverse-Wishart distribution."; } public Class getReturnType() { return Likelihood.class; } }; public static XMLObjectParser WISHART_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return WISHART_PRIOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { MultivariateDistributionLikelihood likelihood; if (xo.hasAttribute(NON_INFORMATIVE) && xo.getBooleanAttribute(NON_INFORMATIVE)) { // Make non-informative settings XMLObject cxo = xo.getChild(DATA); int dim = ((MatrixParameter) cxo.getChild(0)).getColumnDimension(); likelihood = new MultivariateDistributionLikelihood(new WishartDistribution(dim)); } else { if (!xo.hasAttribute(DF) || !xo.hasChildNamed(SCALE_MATRIX)) { throw new XMLParseException("Must specify both a df and scaleMatrix"); } double df = xo.getDoubleAttribute(DF); XMLObject cxo = xo.getChild(SCALE_MATRIX); MatrixParameter scaleMatrix = (MatrixParameter) cxo.getChild(MatrixParameter.class); likelihood = new MultivariateDistributionLikelihood( new WishartDistribution(df, scaleMatrix.getParameterAsMatrix())); } XMLObject cxo = xo.getChild(DATA); for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof MatrixParameter) { likelihood.addData((MatrixParameter) cxo.getChild(j)); } else { throw new XMLParseException( "illegal element in " + xo.getName() + " element " + cxo.getName()); } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules; { rules = new XMLSyntaxRule[] { AttributeRule.newBooleanRule(NON_INFORMATIVE, true), AttributeRule.newDoubleRule(DF, true), new ElementRule( SCALE_MATRIX, new XMLSyntaxRule[] {new ElementRule(MatrixParameter.class)}, true), new ElementRule( DATA, new XMLSyntaxRule[] { new ElementRule(MatrixParameter.class, 1, Integer.MAX_VALUE) }) }; } public String getParserDescription() { return "Calculates the likelihood of some data under a Wishart distribution."; } public Class getReturnType() { return Likelihood.class; } }; public static XMLObjectParser MULTIVARIATE_LIKELIHOOD_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return MULTIVARIATE_LIKELIHOOD; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { XMLObject cxo = xo.getChild(DistributionLikelihoodParser.DISTRIBUTION); ParametricMultivariateDistributionModel distribution = (ParametricMultivariateDistributionModel) cxo.getChild(ParametricMultivariateDistributionModel.class); // Parse transforms here int maxDim = distribution.getMean().length; Transform[] transforms = parseListOfTransforms(xo, maxDim); MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood(xo.getId(), distribution, transforms); boolean dataAsMatrix = xo.getAttribute(DATA_AS_MATRIX, false); cxo = xo.getChild(DATA); if (cxo != null) { for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof Parameter) { Parameter data = (Parameter) cxo.getChild(j); if (data instanceof MatrixParameter) { MatrixParameter matrix = (MatrixParameter) data; if (dataAsMatrix) { likelihood.addData(matrix); } else { if (matrix.getParameter(0).getDimension() != distribution.getMean().length) throw new XMLParseException( "dim(" + data.getStatisticName() + ") = " + matrix.getParameter(0).getDimension() + " is not equal to dim(" + distribution.getType() + ") = " + distribution.getMean().length + " in " + xo.getName() + "element"); for (int i = 0; i < matrix.getParameterCount(); i++) { likelihood.addData(matrix.getParameter(i)); } } } else { if (data.getDimension() != distribution.getMean().length) throw new XMLParseException( "dim(" + data.getStatisticName() + ") = " + data.getDimension() + " is not equal to dim(" + distribution.getType() + ") = " + distribution.getMean().length + " in " + xo.getName() + "element"); likelihood.addData(data); } } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new ElementRule( DistributionLikelihoodParser.DISTRIBUTION, new XMLSyntaxRule[] {new ElementRule(ParametricMultivariateDistributionModel.class)}), AttributeRule.newBooleanRule(DATA_AS_MATRIX, true), new ElementRule(Transform.ParsedTransform.class, 0, Integer.MAX_VALUE), new ElementRule( DATA, new XMLSyntaxRule[] {new ElementRule(Parameter.class, 1, Integer.MAX_VALUE)}, true) }; public String getParserDescription() { return "Calculates the likelihood of some data under a given multivariate distribution."; } public Class getReturnType() { return MultivariateDistributionLikelihood.class; } }; public static XMLObjectParser MVN_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return MVN_PRIOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { XMLObject cxo = xo.getChild(MVN_MEAN); Parameter mean = (Parameter) cxo.getChild(Parameter.class); cxo = xo.getChild(MVN_PRECISION); MatrixParameter precision = (MatrixParameter) cxo.getChild(MatrixParameter.class); if (mean.getDimension() != precision.getRowDimension() || mean.getDimension() != precision.getColumnDimension()) throw new XMLParseException( "Mean and precision have wrong dimensions in " + xo.getName() + " element"); Transform[] transforms = parseListOfTransforms(xo, mean.getDimension()); MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood( new MultivariateNormalDistribution( mean.getParameterValues(), precision.getParameterAsMatrix()), transforms); cxo = xo.getChild(DATA); if (cxo != null) { for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof Parameter) { Parameter data = (Parameter) cxo.getChild(j); if (data instanceof MatrixParameter) { MatrixParameter matrix = (MatrixParameter) data; if (matrix.getParameter(0).getDimension() != mean.getDimension()) throw new XMLParseException( "dim(" + data.getStatisticName() + ") = " + matrix.getParameter(0).getDimension() + " is not equal to dim(" + mean.getStatisticName() + ") = " + mean.getDimension() + " in " + xo.getName() + "element"); for (int i = 0; i < matrix.getParameterCount(); i++) { likelihood.addData(matrix.getParameter(i)); } } else { if (data.getDimension() != mean.getDimension()) throw new XMLParseException( "dim(" + data.getStatisticName() + ") = " + data.getDimension() + " is not equal to dim(" + mean.getStatisticName() + ") = " + mean.getDimension() + " in " + xo.getName() + "element"); likelihood.addData(data); } } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new ElementRule(MVN_MEAN, new XMLSyntaxRule[] {new ElementRule(Parameter.class)}), new ElementRule( MVN_PRECISION, new XMLSyntaxRule[] {new ElementRule(MatrixParameter.class)}), new ElementRule(Transform.ParsedTransform.class, 0, Integer.MAX_VALUE), new ElementRule( DATA, new XMLSyntaxRule[] {new ElementRule(Parameter.class, 1, Integer.MAX_VALUE)}, true) }; public String getParserDescription() { return "Calculates the likelihood of some data under a given multivariate-normal distribution."; } public Class getReturnType() { return MultivariateDistributionLikelihood.class; } }; public static XMLObjectParser MVGAMMA_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return MVGAMMA_PRIOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { double[] shape; double[] scale; if (xo.hasChildNamed(MVGAMMA_SHAPE)) { XMLObject cxo = xo.getChild(MVGAMMA_SHAPE); shape = ((Parameter) cxo.getChild(Parameter.class)).getParameterValues(); cxo = xo.getChild(MVGAMMA_SCALE); scale = ((Parameter) cxo.getChild(Parameter.class)).getParameterValues(); if (shape.length != scale.length) throw new XMLParseException( "Shape and scale have wrong dimensions in " + xo.getName() + " element"); } else { XMLObject cxo = xo.getChild(MVN_MEAN); double[] mean = ((Parameter) cxo.getChild(Parameter.class)).getParameterValues(); cxo = xo.getChild(MVN_CV); double[] cv = ((Parameter) cxo.getChild(Parameter.class)).getParameterValues(); if (mean.length != cv.length) throw new XMLParseException( "Mean and CV have wrong dimensions in " + xo.getName() + " element"); final int dim = mean.length; shape = new double[dim]; scale = new double[dim]; for (int i = 0; i < dim; i++) { double c2 = cv[i] * cv[i]; shape[i] = 1.0 / c2; scale[i] = c2 * mean[i]; } } MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood( new MultivariateGammaDistribution(shape, scale)); XMLObject cxo = xo.getChild(DATA); for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof Parameter) { Parameter data = (Parameter) cxo.getChild(j); likelihood.addData(data); if (data.getDimension() != shape.length) throw new XMLParseException( "dim(" + data.getStatisticName() + ") != " + shape.length + " in " + xo.getName() + "element"); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new XORRule( new ElementRule( MVGAMMA_SHAPE, new XMLSyntaxRule[] {new ElementRule(Parameter.class)}), new ElementRule(MVN_MEAN, new XMLSyntaxRule[] {new ElementRule(Parameter.class)})), new XORRule( new ElementRule( MVGAMMA_SCALE, new XMLSyntaxRule[] {new ElementRule(Parameter.class)}), new ElementRule(MVN_CV, new XMLSyntaxRule[] {new ElementRule(Parameter.class)})), new ElementRule( DATA, new XMLSyntaxRule[] {new ElementRule(Parameter.class, 1, Integer.MAX_VALUE)}) }; public String getParserDescription() { return "Calculates the likelihood of some data under a given multivariate-gamma distribution."; } public Class getReturnType() { return MultivariateDistributionLikelihood.class; } }; public static XMLObjectParser TREE_TRAIT_MODEL = new AbstractXMLObjectParser() { public String getParserName() { return TREE_TRAIT_NORMAL; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { boolean conditionOnRoot = xo.getAttribute(CONDITION, false); FullyConjugateMultivariateTraitLikelihood traitModel = (FullyConjugateMultivariateTraitLikelihood) xo.getChild(FullyConjugateMultivariateTraitLikelihood.class); TreeTraitNormalDistributionModel treeTraitModel; if (xo.getChild(ROOT_VALUE) != null) { XMLObject cxo = xo.getChild(ROOT_VALUE); Parameter rootValue = (Parameter) cxo.getChild(Parameter.class); treeTraitModel = new TreeTraitNormalDistributionModel(traitModel, rootValue, conditionOnRoot); } else { treeTraitModel = new TreeTraitNormalDistributionModel(traitModel, conditionOnRoot); } return treeTraitModel; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newBooleanRule(CONDITION, true), new ElementRule(FullyConjugateMultivariateTraitLikelihood.class) }; public String getParserDescription() { return "Parses TreeTraitNormalDistributionModel"; } public Class getReturnType() { return TreeTraitNormalDistributionModel.class; } }; public static XMLObjectParser TREE_TRAIT_DISTRIBUTION = new AbstractXMLObjectParser() { public String getParserName() { return TREE_TRAIT; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { /* boolean conditionOnRoot = xo.getAttribute(CONDITION, false); FullyConjugateMultivariateTraitLikelihood traitModel = (FullyConjugateMultivariateTraitLikelihood) xo.getChild(FullyConjugateMultivariateTraitLikelihood.class); */ TreeTraitNormalDistributionModel treeTraitModel = (TreeTraitNormalDistributionModel) xo.getChild(TreeTraitNormalDistributionModel.class); MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood( // new TreeTraitNormalDistributionModel(traitModel, conditionOnRoot) treeTraitModel); XMLObject cxo = xo.getChild(DATA); for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof Parameter) { likelihood.addData((Parameter) cxo.getChild(j)); } else { throw new XMLParseException( "illegal element in " + xo.getName() + " element " + cxo.getName()); } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { // AttributeRule.newBooleanRule(CONDITION, true), // new ElementRule(FullyConjugateMultivariateTraitLikelihood.class), new ElementRule(TreeTraitNormalDistributionModel.class), new ElementRule( DATA, new XMLSyntaxRule[] {new ElementRule(Parameter.class, 1, Integer.MAX_VALUE)}) }; public String getParserDescription() { return "Calculates the likelihood of some data under a given multivariate-gamma distribution."; } public Class getReturnType() { return MultivariateDistributionLikelihood.class; } }; }