コード例 #1
0
  public void addWeightsFrom(DiscreteMessage other) {
    assertSameSize(other.size());

    for (int i = _message.length; --i >= 0; ) {
      setWeight(i, getWeight(i) + other.getWeight(i));
    }
  }
コード例 #2
0
  /**
   * Sets values from another message of the same size.
   *
   * @param other is another message with the same {@link #size()} as this one but not necessarily
   *     the same representation.
   * @since 0.08
   * @throws IllegalArgumentException if {@code other} does not have the same size.
   */
  public void setFrom(DiscreteMessage other) {
    final double[] otherRep = other.representation();

    if (other.storesWeights()) {
      setWeights(otherRep);
    } else {
      setEnergies(otherRep);
    }

    _normalizationEnergy = other._normalizationEnergy;
  }
コード例 #3
0
  /**
   * {@inheritDoc}
   *
   * <p>Discrete messages compute KL using:
   *
   * <blockquote>
   *
   * <big>&Sigma;</big> ln(P<sub>i</sub> / Q<sub>i</sub>) P<sub>i</sub>
   *
   * </blockquote>
   */
  @Override
  public double computeKLDivergence(IParameterizedMessage that) {
    if (that instanceof DiscreteMessage) {
      // KL(P|Q) == sum(log(Pi/Qi) * Pi)
      //
      // To normalize you need to divide Pi by sum(Pi) and Qi by sum(Qi), denote these
      // by Ps and Qs:
      //
      //  ==> sum(log((Pi/Ps)/(Qi/Qs)) * Pi/Ps)
      //
      //  ==> 1/Ps * sum(log(Pi/Qi) * Pi + log(Qs/Ps) * Pi)
      //
      //  ==> sum(Pi*(log(Pi) - log(Qi)))/Ps + log(Qs/Ps)
      //
      // This formulation allows you to perform the computation using a single loop.

      final DiscreteMessage P = this;
      final DiscreteMessage Q = (DiscreteMessage) that;

      final int size = P.size();

      if (size != Q.size()) {
        throw new IllegalArgumentException(
            String.format("Mismatched domain sizes '%d' and '%d'", P.size(), Q.size()));
      }

      double Ps = 0.0, Qs = 0.0, unnormalizedKL = 0.0;

      for (int i = 0; i < size; ++i) {
        final double pw = P.getWeight(i);
        if (pw == 0.0) continue;

        final double qw = Q.getWeight(i);

        Ps += pw;
        Qs += qw;

        final double pe = P.getEnergy(i);
        final double qe = Q.getEnergy(i);

        unnormalizedKL += pw * (qe - pe);
      }

      return unnormalizedKL / Ps + Math.log(Qs / Ps);
    }

    throw new IllegalArgumentException(
        String.format("Expected '%s' but got '%s'", getClass(), that.getClass()));
  }