@Override public EF_LearningBayesianNetwork initModel( EF_LearningBayesianNetwork bayesianNetwork, PlateuStructure plateuStructure) { for (Variable paramVariable : bayesianNetwork.getParametersVariables().getListOfParamaterVariables()) { if (!paramVariable.isNormalParameter()) continue; // if (paramVariable.getName().contains("_Beta0_")) // continue; EF_Normal prior = bayesianNetwork.getDistribution(paramVariable); double varPrior = 1; double precisionPrior = 1 / varPrior; double meanPrior = 0; prior.getNaturalParameters().set(0, precisionPrior * meanPrior); prior.getNaturalParameters().set(1, -0.5 * precisionPrior); prior.fixNumericalInstability(); prior.updateMomentFromNaturalParameters(); } for (Variable localVar : this.localHiddenVars) { EF_NormalGamma normal = bayesianNetwork.getDistribution(localVar); Variable gammaVar = normal.getGammaParameterVariable(); EF_Gamma gamma = bayesianNetwork.getDistribution(gammaVar); int initVariance = 1; double alpha = 1000; double beta = alpha * initVariance; gamma.getNaturalParameters().set(0, alpha - 1); gamma.getNaturalParameters().set(1, -beta); gamma.fixNumericalInstability(); gamma.updateMomentFromNaturalParameters(); Variable meanVar = normal.getMeanParameterVariable(); EF_Normal meanDist = bayesianNetwork.getDistribution(meanVar); double mean = meanStart; double var = initVariance; meanDist.getNaturalParameters().set(0, mean / (var)); meanDist.getNaturalParameters().set(1, -1 / (2 * var)); meanDist.fixNumericalInstability(); meanDist.updateMomentFromNaturalParameters(); } return bayesianNetwork; }
@Override public EF_LearningBayesianNetwork transitionModel( EF_LearningBayesianNetwork bayesianNetwork, PlateuStructure plateuStructure) { for (Variable localVar : this.localHiddenVars) { Normal normalGlobalHiddenPreviousTimeStep = plateuStructure.getEFVariablePosterior(localVar, 0).toUnivariateDistribution(); EF_NormalGamma normal = bayesianNetwork.getDistribution(localVar); Variable gammaVar = normal.getGammaParameterVariable(); EF_Gamma gamma = bayesianNetwork.getDistribution(gammaVar); double variance = normalGlobalHiddenPreviousTimeStep.getVariance() + this.transtionVariance; double alpha = 1000; double beta = alpha * variance; gamma.getNaturalParameters().set(0, alpha - 1); gamma.getNaturalParameters().set(1, -beta); gamma.fixNumericalInstability(); gamma.updateMomentFromNaturalParameters(); Variable meanVar = normal.getMeanParameterVariable(); EF_Normal meanDist = bayesianNetwork.getDistribution(meanVar); double mean = normalGlobalHiddenPreviousTimeStep.getMean(); meanDist.getNaturalParameters().set(0, mean / (variance)); meanDist.getNaturalParameters().set(1, -1 / (2 * variance)); meanDist.fixNumericalInstability(); meanDist.updateMomentFromNaturalParameters(); } /** *** FADING *** */ if (fading < 1.0) { bayesianNetwork .getParametersVariables() .getListOfParamaterVariables() .stream() .forEach( var -> { EF_BaseDistribution_MultinomialParents dist = bayesianNetwork.getDistribution(var); EF_UnivariateDistribution prior = dist.getBaseEFUnivariateDistribution(0); NaturalParameters naturalParameters = prior.getNaturalParameters(); naturalParameters.multiplyBy(fading); prior.setNaturalParameters(naturalParameters); dist.setBaseEFDistribution(0, prior); }); } return bayesianNetwork; }