public void addWeightsFrom(DiscreteMessage other) { assertSameSize(other.size()); for (int i = _message.length; --i >= 0; ) { setWeight(i, getWeight(i) + other.getWeight(i)); } }
/** * Sets values from another message of the same size. * * @param other is another message with the same {@link #size()} as this one but not necessarily * the same representation. * @since 0.08 * @throws IllegalArgumentException if {@code other} does not have the same size. */ public void setFrom(DiscreteMessage other) { final double[] otherRep = other.representation(); if (other.storesWeights()) { setWeights(otherRep); } else { setEnergies(otherRep); } _normalizationEnergy = other._normalizationEnergy; }
/** * {@inheritDoc} * * <p>Discrete messages compute KL using: * * <blockquote> * * <big>Σ</big> ln(P<sub>i</sub> / Q<sub>i</sub>) P<sub>i</sub> * * </blockquote> */ @Override public double computeKLDivergence(IParameterizedMessage that) { if (that instanceof DiscreteMessage) { // KL(P|Q) == sum(log(Pi/Qi) * Pi) // // To normalize you need to divide Pi by sum(Pi) and Qi by sum(Qi), denote these // by Ps and Qs: // // ==> sum(log((Pi/Ps)/(Qi/Qs)) * Pi/Ps) // // ==> 1/Ps * sum(log(Pi/Qi) * Pi + log(Qs/Ps) * Pi) // // ==> sum(Pi*(log(Pi) - log(Qi)))/Ps + log(Qs/Ps) // // This formulation allows you to perform the computation using a single loop. final DiscreteMessage P = this; final DiscreteMessage Q = (DiscreteMessage) that; final int size = P.size(); if (size != Q.size()) { throw new IllegalArgumentException( String.format("Mismatched domain sizes '%d' and '%d'", P.size(), Q.size())); } double Ps = 0.0, Qs = 0.0, unnormalizedKL = 0.0; for (int i = 0; i < size; ++i) { final double pw = P.getWeight(i); if (pw == 0.0) continue; final double qw = Q.getWeight(i); Ps += pw; Qs += qw; final double pe = P.getEnergy(i); final double qe = Q.getEnergy(i); unnormalizedKL += pw * (qe - pe); } return unnormalizedKL / Ps + Math.log(Qs / Ps); } throw new IllegalArgumentException( String.format("Expected '%s' but got '%s'", getClass(), that.getClass())); }