diff --git a/HISTORY.md b/HISTORY.md index 039ac6bb9..d15d490be 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,11 @@ +# 0.41.2 + +Add `GibbsConditional`, a "sampler" that can be used to provide analytically known conditional posteriors in a Gibbs sampler. + +In Gibbs sampling, some variables are sampled with a component sampler, while holding other variables conditioned to their current values. Usually one e.g. takes turns sampling one variable with HMC and the other with a particle sampler. However, sometimes the posterior distribution of one variable is known analytically, given the conditioned values of other variables. `GibbsConditional` provides a way to implement these analytically known conditional posteriors and use them as component samplers for Gibbs. See the docstring of `GibbsConditional` for details. + +Note that `GibbsConditional` used to exist in Turing.jl until v0.36, at which it was removed when the whole Gibbs sampler was rewritten. This reintroduces the same functionality, though with a slightly different interface. + # 0.41.1 The `ModeResult` struct returned by `maximum_a_posteriori` and `maximum_likelihood` can now be wrapped in `InitFromParams()`. diff --git a/Project.toml b/Project.toml index cb7b1cb72..2faa78ce5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.41.1" +version = "0.41.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/api.md b/docs/src/api.md index 885d587ea..3e097da23 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -63,6 +63,7 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu | `Emcee` | [`Turing.Inference.Emcee`](@ref) | Affine-invariant ensemble sampler | | `ESS` | [`Turing.Inference.ESS`](@ref) | Elliptical slice sampling | | `Gibbs` | [`Turing.Inference.Gibbs`](@ref) | Gibbs sampling | +| `GibbsConditional` | [`Turing.Inference.GibbsConditional`](@ref) | Gibbs sampling with analytical conditional posterior distributions | | `HMC` | [`Turing.Inference.HMC`](@ref) | Hamiltonian Monte Carlo | | `SGLD` | [`Turing.Inference.SGLD`](@ref) | Stochastic gradient Langevin dynamics | | `SGHMC` | [`Turing.Inference.SGHMC`](@ref) | Stochastic gradient Hamiltonian Monte Carlo | diff --git a/src/Turing.jl b/src/Turing.jl index 58a58eb2a..d808ea5df 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -102,6 +102,7 @@ export Emcee, ESS, Gibbs, + GibbsConditional, HMC, SGLD, SGHMC, diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 7d25ecd7e..0147bca08 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -56,6 +56,7 @@ export Hamiltonian, ESS, Emcee, Gibbs, # classic sampling + GibbsConditional, # conditional sampling HMC, SGLD, PolynomialStepsize, @@ -430,6 +431,7 @@ include("mh.jl") include("is.jl") include("particle_mcmc.jl") include("gibbs.jl") +include("gibbs_conditional.jl") include("sghmc.jl") include("emcee.jl") include("prior.jl") diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl new file mode 100644 index 000000000..8586a002d --- /dev/null +++ b/src/mcmc/gibbs_conditional.jl @@ -0,0 +1,171 @@ +""" + GibbsConditional(get_cond_dists) + +A Gibbs component sampler that samples variables according to user-provided analytical +conditional posterior distributions. + +When using Gibbs sampling, sometimes one may know the analytical form of the posterior for +a given variable, given the conditioned values of the other variables. In such cases one can +use `GibbsConditional` as a component sampler to to sample from these known conditionals +directly, avoiding any MCMC methods. One does so with + +```julia +sampler = Gibbs( + (@varname(var1), @varname(var2)) => GibbsConditional(get_cond_dists), + other samplers go here... +) +``` + +Here `get_cond_dists(c::Dict{<:VarName})` should be a function that takes a `Dict` mapping +the conditioned variables (anything other than `var1` and `var2`) to their values, and +returns the conditional posterior distributions for `var1` and `var2`. You may, of course, +have any number of variables being sampled as a block in this manner, we only use two as an +example. The return value of `get_cond_dists` should be one of the following: +- A single `Distribution`, if only one variable is being sampled. +- An `AbstractDict{<:VarName,<:Distribution}` that maps the variables being sampled to their + conditional posteriors E.g. `Dict(@varname(var1) => dist1, @varname(var2) => dist2)`. +- A `NamedTuple` of `Distribution`s, which is like the `AbstractDict` case but can be used + if all the variable names are single `Symbol`s, and may be more performant. E.g. + `(; var1=dist1, var2=dist2)`. + +# Examples + +```julia +# Define a model +@model function inverse_gdemo(x) + precision ~ Gamma(2, inv(3)) + std = sqrt(1 / precision) + m ~ Normal(0, std) + for i in eachindex(x) + x[i] ~ Normal(m, std) + end +end + +# Define analytical conditionals. See +# https://en.wikipedia.org/wiki/Conjugate_prior#When_likelihood_function_is_a_continuous_distribution +function cond_precision(c) + a = 2.0 + b = 3.0 + # We use AbstractPPL.getvalue instead of indexing into `c` directly to guard against + # issues where e.g. you try to get `c[@varname(x[1])]` but only `@varname(x)` is present + # in `c`. `getvalue` handles that gracefully, `getindex` doesn't. In this case + # `getindex` would suffice, but `getvalue` is good practice. + m = AbstractPPL.getvalue(c, @varname(m)) + x = AbstractPPL.getvalue(c, @varname(x)) + n = length(x) + a_new = a + (n + 1) / 2 + b_new = b + sum(abs2, x .- m) / 2 + m^2 / 2 + return Gamma(a_new, 1 / b_new) +end + +function cond_m(c) + precision = AbstractPPL.getvalue(c, @varname(precision)) + x = AbstractPPL.getvalue(c, @varname(x)) + n = length(x) + m_mean = sum(x) / (n + 1) + m_var = 1 / (precision * (n + 1)) + return Normal(m_mean, sqrt(m_var)) +end + +# Sample using GibbsConditional +model = inverse_gdemo([1.0, 2.0, 3.0]) +chain = sample(model, Gibbs( + :precision => GibbsConditional(cond_precision), + :m => GibbsConditional(cond_m) +), 1000) +``` +""" +struct GibbsConditional{C} <: AbstractSampler + get_cond_dists::C +end + +isgibbscomponent(::GibbsConditional) = true + +""" + build_variable_dict(model::DynamicPPL.Model) + +Traverse the context stack of `model` and build a `Dict` of all the variable values that are +set in GibbsContext, ConditionContext, or FixedContext. +""" +function build_variable_dict(model::DynamicPPL.Model) + context = model.context + cond_vals = DynamicPPL.conditioned(context) + fixed_vals = DynamicPPL.fixed(context) + # TODO(mhauru) Can we avoid invlinking all the time? + global_vi = DynamicPPL.invlink(get_gibbs_global_varinfo(context), model) + # TODO(mhauru) This creates a lot of Dicts, which are then immediately merged into one. + # Also, DynamicPPL.to_varname_dict is known to be inefficient. Make a more efficient + # implementation. + return merge( + DynamicPPL.values_as(global_vi, Dict), + DynamicPPL.to_varname_dict(cond_vals), + DynamicPPL.to_varname_dict(fixed_vals), + DynamicPPL.to_varname_dict(model.args), + ) +end + +function get_gibbs_global_varinfo(context::DynamicPPL.AbstractContext) + return if context isa GibbsContext + get_global_varinfo(context) + elseif DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent + get_gibbs_global_varinfo(DynamicPPL.childcontext(context)) + else + msg = """No GibbsContext found in context stack. Are you trying to use \ + GibbsConditional outside of Gibbs? + """ + throw(ArgumentError(msg)) + end +end + +function initialstep( + ::Random.AbstractRNG, + model::DynamicPPL.Model, + ::GibbsConditional, + vi::DynamicPPL.AbstractVarInfo; + kwargs..., +) + state = DynamicPPL.is_transformed(vi) ? DynamicPPL.invlink(vi, model) : vi + # Since GibbsConditional is only used within Gibbs, it does not need to return a + # transition. + return nothing, state +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler::GibbsConditional, + state::DynamicPPL.AbstractVarInfo; + kwargs..., +) + # Get all the conditioned variable values from the model context. This is assumed to + # include a GibbsContext as part of the context stack. + condvals = build_variable_dict(model) + conddists = sampler.get_cond_dists(condvals) + + # We support three different kinds of return values for `sample.get_cond_dists`, to make + # life easier for the user. + if conddists isa AbstractDict + for (vn, dist) in conddists + state = setindex!!(state, rand(rng, dist), vn) + end + elseif conddists isa NamedTuple + for (vn_sym, dist) in pairs(conddists) + vn = VarName{vn_sym}() + state = setindex!!(state, rand(rng, dist), vn) + end + else + # Single variable case + vn = only(keys(state)) + state = setindex!!(state, rand(rng, conddists), vn) + end + + # Since GibbsConditional is only used within Gibbs, it does not need to return a + # transition. + return nothing, state +end + +function setparams_varinfo!!( + ::DynamicPPL.Model, ::GibbsConditional, ::Any, params::DynamicPPL.AbstractVarInfo +) + return params +end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 1e3d5856c..d02f94982 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -496,7 +496,7 @@ end @testset "dynamic model with analytical posterior" begin # A dynamic model where b ~ Bernoulli determines the dimensionality - # When b=0: single parameter θ₁ + # When b=0: single parameter θ₁ # When b=1: two parameters θ₁, θ₂ where we observe their sum @model function dynamic_bernoulli_normal(y_obs=2.0) b ~ Bernoulli(0.3) @@ -575,7 +575,7 @@ end # end # end # sample(f(), Gibbs(:a => PG(10), :x => MH()), 1000) - # + # # because the number of observations in each particle depends on the value # of `a`. # diff --git a/test/mcmc/gibbs_conditional.jl b/test/mcmc/gibbs_conditional.jl new file mode 100644 index 000000000..07d676df1 --- /dev/null +++ b/test/mcmc/gibbs_conditional.jl @@ -0,0 +1,294 @@ +module GibbsConditionalTests + +using DynamicPPL: DynamicPPL +using Random: Random +using StableRNGs: StableRNG +using Test: @test, @test_throws, @testset +using Turing + +@testset "GibbsConditional" begin + @testset "Gamma model tests" begin + @model function inverse_gdemo(x) + precision ~ Gamma(2, inv(3)) + std = sqrt(1 / precision) + m ~ Normal(0, std) + for i in 1:length(x) + x[i] ~ Normal(m, std) + end + end + + # Define analytical conditionals. See + # https://en.wikipedia.org/wiki/Conjugate_prior#When_likelihood_function_is_a_continuous_distribution + function cond_precision(c) + a = 2.0 + b = 3.0 + m = c[@varname(m)] + x = c[@varname(x)] + n = length(x) + a_new = a + (n + 1) / 2 + b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 + return Gamma(a_new, 1 / b_new) + end + + function cond_m(c) + precision = c[@varname(precision)] + x = c[@varname(x)] + n = length(x) + m_mean = sum(x) / (n + 1) + m_var = 1 / (precision * (n + 1)) + return Normal(m_mean, sqrt(m_var)) + end + + x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] + model = inverse_gdemo(x_obs) + + reference_sampler = NUTS() + reference_chain = sample(StableRNG(23), model, reference_sampler, 10_000) + + # Use both conditionals, check results against reference sampler. + sampler = Gibbs( + :precision => GibbsConditional(cond_precision), :m => GibbsConditional(cond_m) + ) + chain = sample(StableRNG(23), model, sampler, 1_000) + @test size(chain, 1) == 1_000 + @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.1 + @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.1 + + # Mix GibbsConditional with an MCMC sampler + sampler = Gibbs(:precision => GibbsConditional(cond_precision), :m => MH()) + chain = sample(StableRNG(23), model, sampler, 1_000) + @test size(chain, 1) == 1_000 + @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.1 + @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.1 + + sampler = Gibbs(:m => GibbsConditional(cond_m), :precision => HMC(0.1, 10)) + chain = sample(StableRNG(23), model, sampler, 1_000) + @test size(chain, 1) == 1_000 + @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.1 + @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.1 + + # Block sample, sampling the same variable with multiple component samplers. + sampler = Gibbs( + (:precision, :m) => HMC(0.1, 10), + :m => GibbsConditional(cond_m), + :precision => MH(), + :precision => GibbsConditional(cond_precision), + :precision => GibbsConditional(cond_precision), + :precision => HMC(0.1, 10), + :m => GibbsConditional(cond_m), + :m => PG(10), + ) + chain = sample(StableRNG(23), model, sampler, 1_000) + @test size(chain, 1) == 1_000 + @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.1 + @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.1 + end + + @testset "Simple normal model" begin + @model function simple_normal(dim) + mean ~ Normal(0, 10) + var ~ truncated(Normal(1, 1); lower=0.01) + return x ~ MvNormal(fill(mean, dim), I * var) + end + + # Conditional posterior for mean given var and x. See + # https://en.wikipedia.org/wiki/Conjugate_prior#When_likelihood_function_is_a_continuous_distribution + function cond_mean(c) + var = c[@varname(var)] + x = c[@varname(x)] + n = length(x) + # Prior: mean ~ Normal(0, 10) + # Likelihood: x[i] ~ Normal(mean, σ) + # Posterior: mean ~ Normal(μ_post, σ_post) + prior_var = 100.0 # 10^2 + post_var = 1 / (1 / prior_var + n / var) + post_mean = post_var * (0 / prior_var + sum(x) / var) + return Normal(post_mean, sqrt(post_var)) + end + + dim = 1_000 + true_mean = 2.0 + x_obs = randn(StableRNG(23), dim) .+ true_mean + model = simple_normal(dim) | (; x=x_obs) + sampler = Gibbs(:mean => GibbsConditional(cond_mean), :var => MH()) + chain = sample(StableRNG(23), model, sampler, 1_000) + # The correct posterior mean isn't true_mean, but it is very close, because we + # have a lot of data. + @test mean(chain, :mean) ≈ true_mean atol = 0.05 + end + + @testset "Double simple normal" begin + # This is the same model as simple_normal above, but just doubled. + prior_std1 = 10.0 + prior_std2 = 20.0 + @model function double_simple_normal(dim1, dim2) + mean1 ~ Normal(0, prior_std1) + var1 ~ truncated(Normal(1, 1); lower=0.01) + x1 ~ MvNormal(fill(mean1, dim1), I * var1) + + mean2 ~ Normal(0, prior_std2) + var2 ~ truncated(Normal(1, 1); lower=0.01) + x2 ~ MvNormal(fill(mean2, dim2), I * var2) + return nothing + end + + function cond_mean(var, x, prior_std) + n = length(x) + # Prior: mean ~ Normal(0, prior_std) + # Likelihood: x[i] ~ Normal(mean, σ) + # Posterior: mean ~ Normal(μ_post, σ_post) + prior_var = prior_std^2 + post_var = 1 / (1 / prior_var + n / var) + post_mean = post_var * (0 / prior_var + sum(x) / var) + return Normal(post_mean, sqrt(post_var)) + end + + dim1 = 1_000 + true_mean1 = -10.0 + x1_obs = randn(StableRNG(23), dim1) .+ true_mean1 + dim2 = 2_000 + true_mean2 = -20.0 + x2_obs = randn(StableRNG(24), dim2) .+ true_mean2 + base_model = double_simple_normal(dim1, dim2) + + # Test different ways of returning values from the conditional function. + @testset "conditionals return types" begin + # Test using GibbsConditional for both separately. + cond_mean1(c) = cond_mean(c[@varname(var1)], c[@varname(x1)], prior_std1) + cond_mean2(c) = cond_mean(c[@varname(var2)], c[@varname(x2)], prior_std2) + model = base_model | (; x1=x1_obs, x2=x2_obs) + sampler = Gibbs( + :mean1 => GibbsConditional(cond_mean1), + :mean2 => GibbsConditional(cond_mean2), + (:var1, :var2) => HMC(0.1, 10), + ) + chain = sample(StableRNG(23), model, sampler, 1_000) + # The correct posterior mean isn't true_mean, but it is very close, because we + # have a lot of data. + @test mean(chain, :mean1) ≈ true_mean1 atol = 0.1 + @test mean(chain, :mean2) ≈ true_mean2 atol = 0.1 + + # Test using GibbsConditional for both in a block, returning a Dict. + function cond_mean_dict(c) + return Dict( + @varname(mean1) => + cond_mean(c[@varname(var1)], c[@varname(x1)], prior_std1), + @varname(mean2) => + cond_mean(c[@varname(var2)], c[@varname(x2)], prior_std2), + ) + end + sampler = Gibbs( + (:mean1, :mean2) => GibbsConditional(cond_mean_dict), + (:var1, :var2) => HMC(0.1, 10), + ) + chain = sample(StableRNG(23), model, sampler, 1_000) + @test mean(chain, :mean1) ≈ true_mean1 atol = 0.1 + @test mean(chain, :mean2) ≈ true_mean2 atol = 0.1 + + # As above but with a NamedTuple rather than a Dict. + function cond_mean_nt(c) + return (; + mean1=cond_mean(c[@varname(var1)], c[@varname(x1)], prior_std1), + mean2=cond_mean(c[@varname(var2)], c[@varname(x2)], prior_std2), + ) + end + sampler = Gibbs( + (:mean1, :mean2) => GibbsConditional(cond_mean_nt), + (:var1, :var2) => HMC(0.1, 10), + ) + chain = sample(StableRNG(23), model, sampler, 1_000) + @test mean(chain, :mean1) ≈ true_mean1 atol = 0.1 + @test mean(chain, :mean2) ≈ true_mean2 atol = 0.1 + end + + # Test simultaneously conditioning and fixing variables. + @testset "condition and fix" begin + # Note that fixed variables don't contribute to the likelihood, and hence the + # conditional posterior changes to be just the prior. + model_condition_fix = condition(fix(base_model; x1=x1_obs); x2=x2_obs) + function cond_mean1(c) + @assert @varname(var1) in keys(c) + @assert @varname(x1) in keys(c) + return Normal(0.0, prior_std1) + end + cond_mean2(c) = cond_mean(c[@varname(var2)], c[@varname(x2)], prior_std2) + sampler = Gibbs( + :mean1 => GibbsConditional(cond_mean1), + :mean2 => GibbsConditional(cond_mean2), + :var1 => HMC(0.1, 10), + :var2 => HMC(0.1, 10), + ) + chain = sample(StableRNG(23), model_condition_fix, sampler, 10_000) + @test mean(chain, :mean1) ≈ 0.0 atol = 0.1 + @test mean(chain, :mean2) ≈ true_mean2 atol = 0.1 + + # As above, but reverse the order of condition and fix. + model_fix_condition = fix(condition(base_model; x2=x2_obs); x1=x1_obs) + chain = sample(StableRNG(23), model_condition_fix, sampler, 10_000) + @test mean(chain, :mean1) ≈ 0.0 atol = 0.1 + @test mean(chain, :mean2) ≈ true_mean2 atol = 0.1 + end + end + + # Check that GibbsConditional works with VarNames with IndexLenses. + @testset "Indexed VarNames" begin + # This example is statistically nonsense, it only tests that the values returned by + # `conditionals` are passed through correctly. + @model function f() + a = Vector{Float64}(undef, 3) + a[1] ~ Normal(0.0) + a[2] ~ Normal(10.0) + a[3] ~ Normal(20.0) + b = Vector{Float64}(undef, 3) + # These priors will be completely ignored in the sampling. + b[1] ~ Normal() + b[2] ~ Normal() + b[3] ~ Normal() + return nothing + end + + m = f() + function conditionals_b(c) + d1 = Normal(c[@varname(a[1])], 1) + d2 = Normal(c[@varname(a[2])], 1) + d3 = Normal(c[@varname(a[3])], 1) + return Dict(@varname(b[1]) => d1, @varname(b[2]) => d2, @varname(b[3]) => d3) + end + + sampler = Gibbs( + (@varname(b[1]), @varname(b[2]), @varname(b[3])) => + GibbsConditional(conditionals_b), + (@varname(a[1]), @varname(a[2]), @varname(a[3])) => ESS(), + ) + chain = sample(StableRNG(23), m, sampler, 10_000) + @test mean(chain, Symbol("b[1]")) ≈ 0.0 atol = 0.05 + @test mean(chain, Symbol("b[2]")) ≈ 10.0 atol = 0.05 + @test mean(chain, Symbol("b[3]")) ≈ 20.0 atol = 0.05 + + m_condfix = fix( + condition(m, Dict(@varname(a[1]) => 100.0)), Dict(@varname(a[2]) => 200.0) + ) + sampler = Gibbs( + (@varname(b[1]), @varname(b[2]), @varname(b[3])) => + GibbsConditional(conditionals_b), + @varname(a[3]) => ESS(), + ) + chain = sample(StableRNG(23), m_condfix, sampler, 10_000) + @test mean(chain, Symbol("b[1]")) ≈ 100.0 atol = 0.05 + @test mean(chain, Symbol("b[2]")) ≈ 200.0 atol = 0.05 + @test mean(chain, Symbol("b[3]")) ≈ 20.0 atol = 0.05 + end + + @testset "Helpful error outside Gibbs" begin + @model f() = x ~ Normal() + m = f() + cond_x(_) = Normal() + sampler = GibbsConditional(cond_x) + @test_throws( + "Are you trying to use GibbsConditional outside of Gibbs?", + sample(m, sampler, 3), + ) + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 81b4bdde2..ce4c7d166 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,6 +53,7 @@ end @timeit TIMEROUTPUT "inference" begin @testset "inference with samplers" verbose = true begin @timeit_include("mcmc/gibbs.jl") + @timeit_include("mcmc/gibbs_conditional.jl") @timeit_include("mcmc/hmc.jl") @timeit_include("mcmc/Inference.jl") @timeit_include("mcmc/sghmc.jl")