{
   rules =
       new XMLSyntaxRule[] {
         AttributeRule.newBooleanRule(NON_INFORMATIVE, true),
         AttributeRule.newDoubleRule(DF, true),
         new ElementRule(
             SCALE_MATRIX,
             new XMLSyntaxRule[] {new ElementRule(MatrixParameter.class)},
             true),
         new ElementRule(
             DATA,
             new XMLSyntaxRule[] {
               new ElementRule(MatrixParameter.class, 1, Integer.MAX_VALUE)
             })
       };
 }
/** @author Marc Suchard */
public class MultivariateDistributionLikelihood extends AbstractDistributionLikelihood {

  public static final String MVN_PRIOR = "multivariateNormalPrior";
  public static final String MVN_MEAN = "meanParameter";
  public static final String MVN_PRECISION = "precisionParameter";
  public static final String MVN_CV = "coefficientOfVariation";
  public static final String WISHART_PRIOR = "multivariateWishartPrior";
  public static final String INV_WISHART_PRIOR = "multivariateInverseWishartPrior";
  public static final String DIRICHLET_PRIOR = "dirichletPrior";
  public static final String DF = "df";
  public static final String SCALE_MATRIX = "scaleMatrix";
  public static final String MVGAMMA_PRIOR = "multivariateGammaPrior";
  public static final String MVGAMMA_SHAPE = "shapeParameter";
  public static final String MVGAMMA_SCALE = "scaleParameter";
  public static final String COUNTS = "countsParameter";
  public static final String NON_INFORMATIVE = "nonInformative";
  public static final String MULTIVARIATE_LIKELIHOOD = "multivariateDistributionLikelihood";
  public static final String DATA_AS_MATRIX = "dataAsMatrix";
  // public static final String TREE_TRAIT = "treeTraitNormalDistribution";
  public static final String TREE_TRAIT = "treeTraitNormalDistributionLikelihood";
  public static final String TREE_TRAIT_NORMAL = "treeTraitNormalDistribution";
  public static final String ROOT_VALUE = "rootValue";
  public static final String CONDITION = "conditionOnRoot";

  public static final String DATA = "data";

  private final MultivariateDistribution distribution;
  private final Transform[] transforms;

  public MultivariateDistributionLikelihood(
      String name, ParametricMultivariateDistributionModel model) {
    this(name, model, null);
  }

  public MultivariateDistributionLikelihood(
      String name, ParametricMultivariateDistributionModel model, Transform[] transforms) {
    super(model);
    this.distribution = model;
    this.transforms = transforms;
  }

  public MultivariateDistributionLikelihood(String name, MultivariateDistribution distribution) {
    this(name, distribution, null);
  }

  public MultivariateDistributionLikelihood(
      String name, MultivariateDistribution distribution, Transform[] transforms) {
    super(new DefaultModel(name));
    this.distribution = distribution;
    this.transforms = transforms;
  }

  public MultivariateDistributionLikelihood(MultivariateDistribution distribution) {
    this(distribution, null);
  }

  public MultivariateDistributionLikelihood(
      MultivariateDistribution distribution, Transform[] transforms) {
    this(distribution.getType(), distribution, transforms);
  }

  public String toString() {
    return getClass().getName() + "(" + getLogLikelihood() + ")";
  }

  public double calculateLogLikelihood() {
    double logL = 0.0;

    for (Attribute<double[]> data : dataList) {
      double[] x = data.getAttributeValue();
      if (transforms != null) {
        double[] y = new double[x.length];
        for (int i = 0; i < x.length; ++i) {
          logL += transforms[i].getLogJacobian(x[i]);
          y[i] = transforms[i].transform(x[i]);
        }
        logL += distribution.logPdf(y);
      } else {
        logL += distribution.logPdf(x);
      }
    }
    return logL;
  }

  @Override
  public void addData(Attribute<double[]> data) {
    super.addData(data);

    if (data instanceof Variable && getModel() instanceof DefaultModel) {
      ((DefaultModel) getModel()).addVariable((Variable) data);
    }
  }

  public MultivariateDistribution getDistribution() {
    return distribution;
  }

  public static Transform[] parseListOfTransforms(XMLObject xo, int maxDim)
      throws XMLParseException {
    Transform[] transforms = null;

    boolean anyTransforms = false;
    for (int i = 0; i < xo.getChildCount(); ++i) {
      if (xo.getChild(i) instanceof Transform.ParsedTransform) {
        Transform.ParsedTransform t = (Transform.ParsedTransform) xo.getChild(i);
        if (transforms == null) {
          transforms = Transform.Util.getListOfNoTransforms(maxDim);
        }

        t.end = Math.max(t.end, maxDim);
        if (t.start < 0 || t.end < 0 || t.start > t.end) {
          throw new XMLParseException("Invalid bounds for transform in " + xo.getId());
        }
        for (int j = t.start; j < t.end; j += t.every) {
          transforms[j] = t.transform;
          anyTransforms = true;
        }
      }
    }
    if (anyTransforms) {
      StringBuilder sb =
          new StringBuilder("Using distributional transforms in " + xo.getId() + "\n");
      for (int i = 0; i < transforms.length; ++i) {
        if (transforms[i] != Transform.NONE) {
          sb.append("\t")
              .append(transforms[i].getTransformName())
              .append(" on index ")
              .append(i + 1)
              .append("\n");
        }
      }
      sb.append("Please cite:\n").append(Citable.Utils.getCitationString(Transform.LOG));
      Logger.getLogger("dr.utils.Transform").info(sb.toString());
    }
    return transforms;
  }

  public static XMLObjectParser DIRICHLET_PRIOR_PARSER =
      new AbstractXMLObjectParser() {

        public String getParserName() {
          return DIRICHLET_PRIOR;
        }

        public Object parseXMLObject(XMLObject xo) throws XMLParseException {

          XMLObject cxo = xo.getChild(COUNTS);
          Parameter counts = (Parameter) cxo.getChild(Parameter.class);

          DirichletDistribution dirichlet = new DirichletDistribution(counts.getParameterValues());

          MultivariateDistributionLikelihood likelihood =
              new MultivariateDistributionLikelihood(dirichlet);

          cxo = xo.getChild(DATA);
          for (int j = 0; j < cxo.getChildCount(); j++) {
            if (cxo.getChild(j) instanceof Parameter) {
              likelihood.addData((Parameter) cxo.getChild(j));
            } else {
              throw new XMLParseException(
                  "illegal element in " + xo.getName() + " element " + cxo.getName());
            }
          }

          return likelihood;
        }

        public XMLSyntaxRule[] getSyntaxRules() {
          return rules;
        }

        private final XMLSyntaxRule[] rules = {
          new ElementRule(COUNTS, new XMLSyntaxRule[] {new ElementRule(Parameter.class)}),
          new ElementRule(
              DATA, new XMLSyntaxRule[] {new ElementRule(Parameter.class)}, 1, Integer.MAX_VALUE),
        };

        public String getParserDescription() {
          return "Calculates the likelihood of some data under a Dirichlet distribution.";
        }

        public Class getReturnType() {
          return Likelihood.class;
        }
      };

  public static XMLObjectParser INV_WISHART_PRIOR_PARSER =
      new AbstractXMLObjectParser() {

        public String getParserName() {
          return INV_WISHART_PRIOR;
        }

        public Object parseXMLObject(XMLObject xo) throws XMLParseException {

          int df = xo.getIntegerAttribute(DF);

          XMLObject cxo = xo.getChild(SCALE_MATRIX);
          MatrixParameter scaleMatrix = (MatrixParameter) cxo.getChild(MatrixParameter.class);
          InverseWishartDistribution invWishart =
              new InverseWishartDistribution(df, scaleMatrix.getParameterAsMatrix());

          MultivariateDistributionLikelihood likelihood =
              new MultivariateDistributionLikelihood(invWishart);

          cxo = xo.getChild(DATA);
          for (int j = 0; j < cxo.getChildCount(); j++) {
            if (cxo.getChild(j) instanceof MatrixParameter) {
              likelihood.addData((MatrixParameter) cxo.getChild(j));
            } else {
              throw new XMLParseException(
                  "illegal element in " + xo.getName() + " element " + cxo.getName());
            }
          }

          return likelihood;
        }

        public XMLSyntaxRule[] getSyntaxRules() {
          return rules;
        }

        private final XMLSyntaxRule[] rules = {
          AttributeRule.newDoubleRule(DF),
          new ElementRule(
              SCALE_MATRIX, new XMLSyntaxRule[] {new ElementRule(MatrixParameter.class)}),
        };

        public String getParserDescription() {
          return "Calculates the likelihood of some data under an Inverse-Wishart distribution.";
        }

        public Class getReturnType() {
          return Likelihood.class;
        }
      };

  public static XMLObjectParser WISHART_PRIOR_PARSER =
      new AbstractXMLObjectParser() {

        public String getParserName() {
          return WISHART_PRIOR;
        }

        public Object parseXMLObject(XMLObject xo) throws XMLParseException {

          MultivariateDistributionLikelihood likelihood;

          if (xo.hasAttribute(NON_INFORMATIVE) && xo.getBooleanAttribute(NON_INFORMATIVE)) {
            // Make non-informative settings
            XMLObject cxo = xo.getChild(DATA);
            int dim = ((MatrixParameter) cxo.getChild(0)).getColumnDimension();
            likelihood = new MultivariateDistributionLikelihood(new WishartDistribution(dim));
          } else {
            if (!xo.hasAttribute(DF) || !xo.hasChildNamed(SCALE_MATRIX)) {
              throw new XMLParseException("Must specify both a df and scaleMatrix");
            }

            double df = xo.getDoubleAttribute(DF);

            XMLObject cxo = xo.getChild(SCALE_MATRIX);
            MatrixParameter scaleMatrix = (MatrixParameter) cxo.getChild(MatrixParameter.class);

            likelihood =
                new MultivariateDistributionLikelihood(
                    new WishartDistribution(df, scaleMatrix.getParameterAsMatrix()));
          }

          XMLObject cxo = xo.getChild(DATA);
          for (int j = 0; j < cxo.getChildCount(); j++) {
            if (cxo.getChild(j) instanceof MatrixParameter) {
              likelihood.addData((MatrixParameter) cxo.getChild(j));
            } else {
              throw new XMLParseException(
                  "illegal element in " + xo.getName() + " element " + cxo.getName());
            }
          }

          return likelihood;
        }

        public XMLSyntaxRule[] getSyntaxRules() {
          return rules;
        }

        private final XMLSyntaxRule[] rules;

        {
          rules =
              new XMLSyntaxRule[] {
                AttributeRule.newBooleanRule(NON_INFORMATIVE, true),
                AttributeRule.newDoubleRule(DF, true),
                new ElementRule(
                    SCALE_MATRIX,
                    new XMLSyntaxRule[] {new ElementRule(MatrixParameter.class)},
                    true),
                new ElementRule(
                    DATA,
                    new XMLSyntaxRule[] {
                      new ElementRule(MatrixParameter.class, 1, Integer.MAX_VALUE)
                    })
              };
        }

        public String getParserDescription() {
          return "Calculates the likelihood of some data under a Wishart distribution.";
        }

        public Class getReturnType() {
          return Likelihood.class;
        }
      };

  public static XMLObjectParser MULTIVARIATE_LIKELIHOOD_PARSER =
      new AbstractXMLObjectParser() {

        public String getParserName() {
          return MULTIVARIATE_LIKELIHOOD;
        }

        public Object parseXMLObject(XMLObject xo) throws XMLParseException {

          XMLObject cxo = xo.getChild(DistributionLikelihoodParser.DISTRIBUTION);
          ParametricMultivariateDistributionModel distribution =
              (ParametricMultivariateDistributionModel)
                  cxo.getChild(ParametricMultivariateDistributionModel.class);

          // Parse transforms here
          int maxDim = distribution.getMean().length;
          Transform[] transforms = parseListOfTransforms(xo, maxDim);

          MultivariateDistributionLikelihood likelihood =
              new MultivariateDistributionLikelihood(xo.getId(), distribution, transforms);

          boolean dataAsMatrix = xo.getAttribute(DATA_AS_MATRIX, false);

          cxo = xo.getChild(DATA);
          if (cxo != null) {
            for (int j = 0; j < cxo.getChildCount(); j++) {
              if (cxo.getChild(j) instanceof Parameter) {
                Parameter data = (Parameter) cxo.getChild(j);
                if (data instanceof MatrixParameter) {
                  MatrixParameter matrix = (MatrixParameter) data;
                  if (dataAsMatrix) {
                    likelihood.addData(matrix);
                  } else {
                    if (matrix.getParameter(0).getDimension() != distribution.getMean().length)
                      throw new XMLParseException(
                          "dim("
                              + data.getStatisticName()
                              + ") = "
                              + matrix.getParameter(0).getDimension()
                              + " is not equal to dim("
                              + distribution.getType()
                              + ") = "
                              + distribution.getMean().length
                              + " in "
                              + xo.getName()
                              + "element");

                    for (int i = 0; i < matrix.getParameterCount(); i++) {
                      likelihood.addData(matrix.getParameter(i));
                    }
                  }
                } else {
                  if (data.getDimension() != distribution.getMean().length)
                    throw new XMLParseException(
                        "dim("
                            + data.getStatisticName()
                            + ") = "
                            + data.getDimension()
                            + " is not equal to dim("
                            + distribution.getType()
                            + ") = "
                            + distribution.getMean().length
                            + " in "
                            + xo.getName()
                            + "element");
                  likelihood.addData(data);
                }
              } else {
                throw new XMLParseException("illegal element in " + xo.getName() + " element");
              }
            }
          }

          return likelihood;
        }

        public XMLSyntaxRule[] getSyntaxRules() {
          return rules;
        }

        private final XMLSyntaxRule[] rules = {
          new ElementRule(
              DistributionLikelihoodParser.DISTRIBUTION,
              new XMLSyntaxRule[] {new ElementRule(ParametricMultivariateDistributionModel.class)}),
          AttributeRule.newBooleanRule(DATA_AS_MATRIX, true),
          new ElementRule(Transform.ParsedTransform.class, 0, Integer.MAX_VALUE),
          new ElementRule(
              DATA,
              new XMLSyntaxRule[] {new ElementRule(Parameter.class, 1, Integer.MAX_VALUE)},
              true)
        };

        public String getParserDescription() {
          return "Calculates the likelihood of some data under a given multivariate distribution.";
        }

        public Class getReturnType() {
          return MultivariateDistributionLikelihood.class;
        }
      };

  public static XMLObjectParser MVN_PRIOR_PARSER =
      new AbstractXMLObjectParser() {

        public String getParserName() {
          return MVN_PRIOR;
        }

        public Object parseXMLObject(XMLObject xo) throws XMLParseException {

          XMLObject cxo = xo.getChild(MVN_MEAN);
          Parameter mean = (Parameter) cxo.getChild(Parameter.class);

          cxo = xo.getChild(MVN_PRECISION);
          MatrixParameter precision = (MatrixParameter) cxo.getChild(MatrixParameter.class);

          if (mean.getDimension() != precision.getRowDimension()
              || mean.getDimension() != precision.getColumnDimension())
            throw new XMLParseException(
                "Mean and precision have wrong dimensions in " + xo.getName() + " element");

          Transform[] transforms = parseListOfTransforms(xo, mean.getDimension());

          MultivariateDistributionLikelihood likelihood =
              new MultivariateDistributionLikelihood(
                  new MultivariateNormalDistribution(
                      mean.getParameterValues(), precision.getParameterAsMatrix()),
                  transforms);
          cxo = xo.getChild(DATA);
          if (cxo != null) {
            for (int j = 0; j < cxo.getChildCount(); j++) {
              if (cxo.getChild(j) instanceof Parameter) {
                Parameter data = (Parameter) cxo.getChild(j);
                if (data instanceof MatrixParameter) {
                  MatrixParameter matrix = (MatrixParameter) data;
                  if (matrix.getParameter(0).getDimension() != mean.getDimension())
                    throw new XMLParseException(
                        "dim("
                            + data.getStatisticName()
                            + ") = "
                            + matrix.getParameter(0).getDimension()
                            + " is not equal to dim("
                            + mean.getStatisticName()
                            + ") = "
                            + mean.getDimension()
                            + " in "
                            + xo.getName()
                            + "element");

                  for (int i = 0; i < matrix.getParameterCount(); i++) {
                    likelihood.addData(matrix.getParameter(i));
                  }
                } else {
                  if (data.getDimension() != mean.getDimension())
                    throw new XMLParseException(
                        "dim("
                            + data.getStatisticName()
                            + ") = "
                            + data.getDimension()
                            + " is not equal to dim("
                            + mean.getStatisticName()
                            + ") = "
                            + mean.getDimension()
                            + " in "
                            + xo.getName()
                            + "element");
                  likelihood.addData(data);
                }
              } else {
                throw new XMLParseException("illegal element in " + xo.getName() + " element");
              }
            }
          }

          return likelihood;
        }

        public XMLSyntaxRule[] getSyntaxRules() {
          return rules;
        }

        private final XMLSyntaxRule[] rules = {
          new ElementRule(MVN_MEAN, new XMLSyntaxRule[] {new ElementRule(Parameter.class)}),
          new ElementRule(
              MVN_PRECISION, new XMLSyntaxRule[] {new ElementRule(MatrixParameter.class)}),
          new ElementRule(Transform.ParsedTransform.class, 0, Integer.MAX_VALUE),
          new ElementRule(
              DATA,
              new XMLSyntaxRule[] {new ElementRule(Parameter.class, 1, Integer.MAX_VALUE)},
              true)
        };

        public String getParserDescription() {
          return "Calculates the likelihood of some data under a given multivariate-normal distribution.";
        }

        public Class getReturnType() {
          return MultivariateDistributionLikelihood.class;
        }
      };

  public static XMLObjectParser MVGAMMA_PRIOR_PARSER =
      new AbstractXMLObjectParser() {

        public String getParserName() {
          return MVGAMMA_PRIOR;
        }

        public Object parseXMLObject(XMLObject xo) throws XMLParseException {

          double[] shape;
          double[] scale;

          if (xo.hasChildNamed(MVGAMMA_SHAPE)) {

            XMLObject cxo = xo.getChild(MVGAMMA_SHAPE);
            shape = ((Parameter) cxo.getChild(Parameter.class)).getParameterValues();

            cxo = xo.getChild(MVGAMMA_SCALE);
            scale = ((Parameter) cxo.getChild(Parameter.class)).getParameterValues();

            if (shape.length != scale.length)
              throw new XMLParseException(
                  "Shape and scale have wrong dimensions in " + xo.getName() + " element");

          } else {

            XMLObject cxo = xo.getChild(MVN_MEAN);
            double[] mean = ((Parameter) cxo.getChild(Parameter.class)).getParameterValues();

            cxo = xo.getChild(MVN_CV);
            double[] cv = ((Parameter) cxo.getChild(Parameter.class)).getParameterValues();

            if (mean.length != cv.length)
              throw new XMLParseException(
                  "Mean and CV have wrong dimensions in " + xo.getName() + " element");

            final int dim = mean.length;
            shape = new double[dim];
            scale = new double[dim];

            for (int i = 0; i < dim; i++) {
              double c2 = cv[i] * cv[i];
              shape[i] = 1.0 / c2;
              scale[i] = c2 * mean[i];
            }
          }

          MultivariateDistributionLikelihood likelihood =
              new MultivariateDistributionLikelihood(
                  new MultivariateGammaDistribution(shape, scale));
          XMLObject cxo = xo.getChild(DATA);
          for (int j = 0; j < cxo.getChildCount(); j++) {
            if (cxo.getChild(j) instanceof Parameter) {
              Parameter data = (Parameter) cxo.getChild(j);
              likelihood.addData(data);
              if (data.getDimension() != shape.length)
                throw new XMLParseException(
                    "dim("
                        + data.getStatisticName()
                        + ") != "
                        + shape.length
                        + " in "
                        + xo.getName()
                        + "element");
            } else {
              throw new XMLParseException("illegal element in " + xo.getName() + " element");
            }
          }
          return likelihood;
        }

        public XMLSyntaxRule[] getSyntaxRules() {
          return rules;
        }

        private final XMLSyntaxRule[] rules = {
          new XORRule(
              new ElementRule(
                  MVGAMMA_SHAPE, new XMLSyntaxRule[] {new ElementRule(Parameter.class)}),
              new ElementRule(MVN_MEAN, new XMLSyntaxRule[] {new ElementRule(Parameter.class)})),
          new XORRule(
              new ElementRule(
                  MVGAMMA_SCALE, new XMLSyntaxRule[] {new ElementRule(Parameter.class)}),
              new ElementRule(MVN_CV, new XMLSyntaxRule[] {new ElementRule(Parameter.class)})),
          new ElementRule(
              DATA, new XMLSyntaxRule[] {new ElementRule(Parameter.class, 1, Integer.MAX_VALUE)})
        };

        public String getParserDescription() {
          return "Calculates the likelihood of some data under a given multivariate-gamma distribution.";
        }

        public Class getReturnType() {
          return MultivariateDistributionLikelihood.class;
        }
      };

  public static XMLObjectParser TREE_TRAIT_MODEL =
      new AbstractXMLObjectParser() {

        public String getParserName() {
          return TREE_TRAIT_NORMAL;
        }

        public Object parseXMLObject(XMLObject xo) throws XMLParseException {

          boolean conditionOnRoot = xo.getAttribute(CONDITION, false);

          FullyConjugateMultivariateTraitLikelihood traitModel =
              (FullyConjugateMultivariateTraitLikelihood)
                  xo.getChild(FullyConjugateMultivariateTraitLikelihood.class);

          TreeTraitNormalDistributionModel treeTraitModel;

          if (xo.getChild(ROOT_VALUE) != null) {
            XMLObject cxo = xo.getChild(ROOT_VALUE);
            Parameter rootValue = (Parameter) cxo.getChild(Parameter.class);
            treeTraitModel =
                new TreeTraitNormalDistributionModel(traitModel, rootValue, conditionOnRoot);
          } else {
            treeTraitModel = new TreeTraitNormalDistributionModel(traitModel, conditionOnRoot);
          }
          return treeTraitModel;
        }

        public XMLSyntaxRule[] getSyntaxRules() {
          return rules;
        }

        private final XMLSyntaxRule[] rules = {
          AttributeRule.newBooleanRule(CONDITION, true),
          new ElementRule(FullyConjugateMultivariateTraitLikelihood.class)
        };

        public String getParserDescription() {
          return "Parses TreeTraitNormalDistributionModel";
        }

        public Class getReturnType() {
          return TreeTraitNormalDistributionModel.class;
        }
      };

  public static XMLObjectParser TREE_TRAIT_DISTRIBUTION =
      new AbstractXMLObjectParser() {

        public String getParserName() {
          return TREE_TRAIT;
        }

        public Object parseXMLObject(XMLObject xo) throws XMLParseException {
          /*
          boolean conditionOnRoot = xo.getAttribute(CONDITION, false);

          FullyConjugateMultivariateTraitLikelihood traitModel = (FullyConjugateMultivariateTraitLikelihood)
                  xo.getChild(FullyConjugateMultivariateTraitLikelihood.class);
          */

          TreeTraitNormalDistributionModel treeTraitModel =
              (TreeTraitNormalDistributionModel)
                  xo.getChild(TreeTraitNormalDistributionModel.class);

          MultivariateDistributionLikelihood likelihood =
              new MultivariateDistributionLikelihood(
                  //  new TreeTraitNormalDistributionModel(traitModel, conditionOnRoot)
                  treeTraitModel);

          XMLObject cxo = xo.getChild(DATA);
          for (int j = 0; j < cxo.getChildCount(); j++) {
            if (cxo.getChild(j) instanceof Parameter) {
              likelihood.addData((Parameter) cxo.getChild(j));
            } else {
              throw new XMLParseException(
                  "illegal element in " + xo.getName() + " element " + cxo.getName());
            }
          }
          return likelihood;
        }

        public XMLSyntaxRule[] getSyntaxRules() {
          return rules;
        }

        private final XMLSyntaxRule[] rules = {
          //  AttributeRule.newBooleanRule(CONDITION, true),
          //   new ElementRule(FullyConjugateMultivariateTraitLikelihood.class),
          new ElementRule(TreeTraitNormalDistributionModel.class),
          new ElementRule(
              DATA, new XMLSyntaxRule[] {new ElementRule(Parameter.class, 1, Integer.MAX_VALUE)})
        };

        public String getParserDescription() {
          return "Calculates the likelihood of some data under a given multivariate-gamma distribution.";
        }

        public Class getReturnType() {
          return MultivariateDistributionLikelihood.class;
        }
      };
}