From c0158ead74a35b620781ae691f73e1c98550e5c6 Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Thu, 7 Aug 2025 09:08:25 +0100 Subject: [PATCH 01/22] Add GibbsConditional sampler and corresponding tests --- src/mcmc/Inference.jl | 2 + src/mcmc/gibbs_conditional.jl | 245 ++++++++++++++++++++++++++++++++++ test/mcmc/gibbs.jl | 114 ++++++++++++++++ test_gibbs_conditional.jl | 78 +++++++++++ 4 files changed, 439 insertions(+) create mode 100644 src/mcmc/gibbs_conditional.jl create mode 100644 test_gibbs_conditional.jl diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 0951026aa..35cdc46b5 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -67,6 +67,7 @@ export InferenceAlgorithm, ESS, Emcee, Gibbs, # classic sampling + GibbsConditional, # conditional sampling HMC, SGLD, PolynomialStepsize, @@ -392,6 +393,7 @@ include("mh.jl") include("is.jl") include("particle_mcmc.jl") include("gibbs.jl") +include("gibbs_conditional.jl") include("sghmc.jl") include("emcee.jl") include("prior.jl") diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl new file mode 100644 index 000000000..e01e17aff --- /dev/null +++ b/src/mcmc/gibbs_conditional.jl @@ -0,0 +1,245 @@ +using DynamicPPL: VarName +using Random: Random +import AbstractMCMC + +# These functions are defined in gibbs.jl which is loaded before this file + +""" + GibbsConditional(sym::Symbol, conditional) + +A Gibbs sampler component that samples a variable according to a user-provided +analytical conditional distribution. + +The `conditional` function should take a `NamedTuple` of conditioned variables and return +a `Distribution` from which to sample the variable `sym`. + +# Examples + +```julia +# Define a model +@model function inverse_gdemo(x) + λ ~ Gamma(2, 3) + m ~ Normal(0, sqrt(1 / λ)) + for i in 1:length(x) + x[i] ~ Normal(m, sqrt(1 / λ)) + end +end + +# Define analytical conditionals +function cond_λ(c::NamedTuple) + a = 2.0 + b = 3.0 + m = c.m + x = c.x + n = length(x) + a_new = a + (n + 1) / 2 + b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 + return Gamma(a_new, 1 / b_new) +end + +function cond_m(c::NamedTuple) + λ = c.λ + x = c.x + n = length(x) + m_mean = sum(x) / (n + 1) + m_var = 1 / (λ * (n + 1)) + return Normal(m_mean, sqrt(m_var)) +end + +# Sample using GibbsConditional +model = inverse_gdemo([1.0, 2.0, 3.0]) +chain = sample(model, Gibbs( + :λ => GibbsConditional(:λ, cond_λ), + :m => GibbsConditional(:m, cond_m) +), 1000) +``` +""" +struct GibbsConditional{S,C} <: InferenceAlgorithm + conditional::C + + function GibbsConditional(sym::Symbol, conditional::C) where {C} + return new{sym,C}(conditional) + end +end + +# Mark GibbsConditional as a valid Gibbs component +isgibbscomponent(::GibbsConditional) = true + +""" + DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi) + +Initialize the GibbsConditional sampler. +""" +function DynamicPPL.initialstep( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:GibbsConditional}, + vi::DynamicPPL.AbstractVarInfo; + kwargs..., +) + # GibbsConditional doesn't need any special initialization + # Just return the initial state + return nothing, vi +end + +""" + AbstractMCMC.step(rng, model, sampler::GibbsConditional, state) + +Perform a step of GibbsConditional sampling. +""" +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:GibbsConditional{S}}, + state::DynamicPPL.AbstractVarInfo; + kwargs..., +) where {S} + alg = sampler.alg + + # For GibbsConditional within Gibbs, we need to get all variable values + # Check if we're in a Gibbs context + global_vi = if hasproperty(model, :context) && model.context isa GibbsContext + # We're in a Gibbs context, get the global varinfo + get_global_varinfo(model.context) + else + # We're not in a Gibbs context, use the current state + state + end + + # Extract conditioned values as a NamedTuple + # Include both random variables and observed data + condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple) + condvals_obs = NamedTuple{keys(model.args)}(model.args) + condvals = merge(condvals_vars, condvals_obs) + + # Get the conditional distribution + conddist = alg.conditional(condvals) + + # Sample from the conditional distribution + updated = rand(rng, conddist) + + # Update the variable in state + # We need to get the actual VarName for this variable + # The symbol S tells us which variable to update + vn = VarName{S}() + + # Check if the variable needs to be a vector + new_vi = if haskey(state, vn) + # Update the existing variable + DynamicPPL.setindex!!(state, updated, vn) + else + # Try to find the variable with indices + # This handles cases where the variable might have indices + local updated_vi = state + found = false + for key in keys(state) + if DynamicPPL.getsym(key) == S + updated_vi = DynamicPPL.setindex!!(state, updated, key) + found = true + break + end + end + if !found + error("Could not find variable $S in VarInfo") + end + updated_vi + end + + # Update log joint probability + new_vi = last(DynamicPPL.evaluate!!(model, new_vi, DynamicPPL.DefaultContext())) + + return nothing, new_vi +end + +""" + setparams_varinfo!!(model, sampler::GibbsConditional, state, params::AbstractVarInfo) + +Update the variable info with new parameters for GibbsConditional. +""" +function setparams_varinfo!!( + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:GibbsConditional}, + state, + params::DynamicPPL.AbstractVarInfo, +) + # For GibbsConditional, we just return the params as-is since + # the state is nothing and we don't need to update anything + return params +end + +""" + gibbs_initialstep_recursive( + rng, model, sampler::GibbsConditional, target_varnames, global_vi, prev_state + ) + +Initialize the GibbsConditional sampler. +""" +function gibbs_initialstep_recursive( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional}, + target_varnames::AbstractVector{<:VarName}, + global_vi::DynamicPPL.AbstractVarInfo, + prev_state, +) + # GibbsConditional doesn't need any special initialization + # Just perform one sampling step + return gibbs_step_recursive( + rng, model, sampler_wrapped, target_varnames, global_vi, nothing + ) +end + +""" + gibbs_step_recursive( + rng, model, sampler::GibbsConditional, target_varnames, global_vi, state + ) + +Perform a single step of GibbsConditional sampling. +""" +function gibbs_step_recursive( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional{S}}, + target_varnames::AbstractVector{<:VarName}, + global_vi::DynamicPPL.AbstractVarInfo, + state, +) where {S} + sampler = sampler_wrapped.alg + + # Extract conditioned values as a NamedTuple + # Include both random variables and observed data + condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple) + condvals_obs = NamedTuple{keys(model.args)}(model.args) + condvals = merge(condvals_vars, condvals_obs) + + # Get the conditional distribution + conddist = sampler.conditional(condvals) + + # Sample from the conditional distribution + updated = rand(rng, conddist) + + # Update the variable in global_vi + # We need to get the actual VarName for this variable + # The symbol S tells us which variable to update + vn = VarName{S}() + + # Check if the variable needs to be a vector + if haskey(global_vi, vn) + # Update the existing variable + global_vi = DynamicPPL.setindex!!(global_vi, updated, vn) + else + # Try to find the variable with indices + # This handles cases where the variable might have indices + for key in keys(global_vi) + if DynamicPPL.getsym(key) == S + global_vi = DynamicPPL.setindex!!(global_vi, updated, key) + break + end + end + end + + # Update log joint probability + global_vi = last(DynamicPPL.evaluate!!(model, global_vi, DynamicPPL.DefaultContext())) + + return nothing, global_vi +end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index f44a9fefc..a7884bb7e 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -882,6 +882,120 @@ end sampler = Gibbs(:w => HMC(0.05, 10)) @test (sample(model, sampler, 10); true) end + + @testset "GibbsConditional" begin + # Test with the inverse gamma example from the issue + @model function inverse_gdemo(x) + λ ~ Gamma(2, 3) + m ~ Normal(0, sqrt(1 / λ)) + for i in 1:length(x) + x[i] ~ Normal(m, sqrt(1 / λ)) + end + end + + # Define analytical conditionals + function cond_λ(c::NamedTuple) + a = 2.0 + b = 3.0 + m = c.m + x = c.x + n = length(x) + a_new = a + (n + 1) / 2 + b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 + return Gamma(a_new, 1 / b_new) + end + + function cond_m(c::NamedTuple) + λ = c.λ + x = c.x + n = length(x) + m_mean = sum(x) / (n + 1) + m_var = 1 / (λ * (n + 1)) + return Normal(m_mean, sqrt(m_var)) + end + + # Test basic functionality + @testset "basic sampling" begin + Random.seed!(42) + x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] + model = inverse_gdemo(x_obs) + + # Test that GibbsConditional works + sampler = Gibbs(GibbsConditional(:λ, cond_λ), GibbsConditional(:m, cond_m)) + chain = sample(model, sampler, 1000) + + # Check that we got the expected variables + @test :λ in names(chain) + @test :m in names(chain) + + # Check that the values are reasonable + λ_samples = vec(chain[:λ]) + m_samples = vec(chain[:m]) + + # Given the observed data, we expect certain behavior + @test mean(λ_samples) > 0 # λ should be positive + @test minimum(λ_samples) > 0 + @test std(m_samples) < 2.0 # m should be relatively well-constrained + end + + # Test mixing with other samplers + @testset "mixed samplers" begin + Random.seed!(42) + x_obs = [1.0, 2.0, 3.0] + model = inverse_gdemo(x_obs) + + # Mix GibbsConditional with standard samplers + sampler = Gibbs(GibbsConditional(:λ, cond_λ), :m => MH()) + chain = sample(model, sampler, 500) + + @test :λ in names(chain) + @test :m in names(chain) + @test size(chain, 1) == 500 + end + + # Test with a simpler model + @testset "simple normal model" begin + @model function simple_normal(x) + μ ~ Normal(0, 10) + σ ~ truncated(Normal(1, 1); lower=0.01) + for i in 1:length(x) + x[i] ~ Normal(μ, σ) + end + end + + # Conditional for μ given σ and x + function cond_μ(c::NamedTuple) + σ = c.σ + x = c.x + n = length(x) + # Prior: μ ~ Normal(0, 10) + # Likelihood: x[i] ~ Normal(μ, σ) + # Posterior: μ ~ Normal(μ_post, σ_post) + prior_var = 100.0 # 10^2 + likelihood_var = σ^2 / n + post_var = 1 / (1 / prior_var + n / σ^2) + post_mean = post_var * (0 / prior_var + sum(x) / σ^2) + return Normal(post_mean, sqrt(post_var)) + end + + Random.seed!(42) + x_obs = randn(10) .+ 2.0 # Data centered around 2 + model = simple_normal(x_obs) + + sampler = Gibbs(GibbsConditional(:μ, cond_μ), :σ => MH()) + + chain = sample(model, sampler, 1000) + + μ_samples = vec(chain[:μ]) + @test abs(mean(μ_samples) - 2.0) < 0.5 # Should be close to true mean + end + + # Test that GibbsConditional is marked as a valid component + @testset "isgibbscomponent" begin + gc = GibbsConditional(:x, c -> Normal(0, 1)) + @test Turing.Inference.isgibbscomponent(gc) + end + end end end diff --git a/test_gibbs_conditional.jl b/test_gibbs_conditional.jl new file mode 100644 index 000000000..d6466e537 --- /dev/null +++ b/test_gibbs_conditional.jl @@ -0,0 +1,78 @@ +using Turing +using Turing.Inference: GibbsConditional +using Distributions +using Random +using Statistics + +# Test with the inverse gamma example from the issue +@model function inverse_gdemo(x) + λ ~ Gamma(2, 3) + m ~ Normal(0, sqrt(1 / λ)) + for i in 1:length(x) + x[i] ~ Normal(m, sqrt(1 / λ)) + end +end + +# Define analytical conditionals +function cond_λ(c::NamedTuple) + a = 2.0 + b = 3.0 + m = c.m + x = c.x + n = length(x) + a_new = a + (n + 1) / 2 + b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 + return Gamma(a_new, 1 / b_new) +end + +function cond_m(c::NamedTuple) + λ = c.λ + x = c.x + n = length(x) + m_mean = sum(x) / (n + 1) + m_var = 1 / (λ * (n + 1)) + return Normal(m_mean, sqrt(m_var)) +end + +# Generate some observed data +Random.seed!(42) +x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] + +# Create the model +model = inverse_gdemo(x_obs) + +# Sample using GibbsConditional +println("Testing GibbsConditional sampler...") +sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) + +# Run a short chain to test +chain = sample(model, sampler, 100) + +println("Sampling completed successfully!") +println("\nChain summary:") +println(chain) + +# Extract samples +λ_samples = vec(chain[:λ]) +m_samples = vec(chain[:m]) + +println("\nλ statistics:") +println(" Mean: ", mean(λ_samples)) +println(" Std: ", std(λ_samples)) +println(" Min: ", minimum(λ_samples)) +println(" Max: ", maximum(λ_samples)) + +println("\nm statistics:") +println(" Mean: ", mean(m_samples)) +println(" Std: ", std(m_samples)) +println(" Min: ", minimum(m_samples)) +println(" Max: ", maximum(m_samples)) + +# Test mixing with other samplers +println("\n\nTesting mixed samplers...") +sampler2 = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => MH()) + +chain2 = sample(model, sampler2, 100) +println("Mixed sampling completed successfully!") +println("\nMixed chain summary:") +println(chain2) From a972b5a2b54216b2dfb5b5ee2fa807229be081bb Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Thu, 7 Aug 2025 09:13:15 +0100 Subject: [PATCH 02/22] clarified comment --- src/mcmc/gibbs_conditional.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index e01e17aff..2bf7a7bb5 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -2,7 +2,7 @@ using DynamicPPL: VarName using Random: Random import AbstractMCMC -# These functions are defined in gibbs.jl which is loaded before this file +# These functions provide specialized methods for GibbsConditional that extend the generic implementations in gibbs.jl """ GibbsConditional(sym::Symbol, conditional) From c3cc7739cbcae0675ed50bff85dcdd34ddd29c3a Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Fri, 8 Aug 2025 11:11:32 +0100 Subject: [PATCH 03/22] add MHs suggestions --- src/mcmc/gibbs_conditional.jl | 148 +++++++++------------------------- 1 file changed, 37 insertions(+), 111 deletions(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index 2bf7a7bb5..c2eba05ba 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -54,17 +54,45 @@ chain = sample(model, Gibbs( ), 1000) ``` """ -struct GibbsConditional{S,C} <: InferenceAlgorithm +struct GibbsConditional{C} <: InferenceAlgorithm conditional::C function GibbsConditional(sym::Symbol, conditional::C) where {C} - return new{sym,C}(conditional) + return new{C}(conditional) end end # Mark GibbsConditional as a valid Gibbs component isgibbscomponent(::GibbsConditional) = true +""" + find_global_varinfo(context, fallback_vi) + +Traverse the context stack to find global variable information from +GibbsContext, ConditionContext, FixedContext, etc. +""" +function find_global_varinfo(context, fallback_vi) + # Start with the given context and traverse down + current_context = context + + while current_context !== nothing + if current_context isa GibbsContext + # Found GibbsContext, return its global varinfo + return get_global_varinfo(current_context) + elseif hasproperty(current_context, :childcontext) && + isdefined(DynamicPPL, :childcontext) + # Move to child context if it exists + current_context = DynamicPPL.childcontext(current_context) + else + # No more child contexts + break + end + end + + # If no GibbsContext found, use the fallback + return fallback_vi +end + """ DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi) @@ -97,12 +125,10 @@ function AbstractMCMC.step( alg = sampler.alg # For GibbsConditional within Gibbs, we need to get all variable values - # Check if we're in a Gibbs context - global_vi = if hasproperty(model, :context) && model.context isa GibbsContext - # We're in a Gibbs context, get the global varinfo - get_global_varinfo(model.context) + # Traverse the context stack to find all conditioned/fixed/Gibbs variables + global_vi = if hasproperty(model, :context) + find_global_varinfo(model.context, state) else - # We're not in a Gibbs context, use the current state state end @@ -119,34 +145,10 @@ function AbstractMCMC.step( updated = rand(rng, conddist) # Update the variable in state - # We need to get the actual VarName for this variable - # The symbol S tells us which variable to update - vn = VarName{S}() - - # Check if the variable needs to be a vector - new_vi = if haskey(state, vn) - # Update the existing variable - DynamicPPL.setindex!!(state, updated, vn) - else - # Try to find the variable with indices - # This handles cases where the variable might have indices - local updated_vi = state - found = false - for key in keys(state) - if DynamicPPL.getsym(key) == S - updated_vi = DynamicPPL.setindex!!(state, updated, key) - found = true - break - end - end - if !found - error("Could not find variable $S in VarInfo") - end - updated_vi - end - - # Update log joint probability - new_vi = last(DynamicPPL.evaluate!!(model, new_vi, DynamicPPL.DefaultContext())) + # The Gibbs sampler ensures that state only contains one variable + # Get the variable name from the keys + varname = first(keys(state)) + new_vi = DynamicPPL.setindex!!(state, updated, varname) return nothing, new_vi end @@ -167,79 +169,3 @@ function setparams_varinfo!!( return params end -""" - gibbs_initialstep_recursive( - rng, model, sampler::GibbsConditional, target_varnames, global_vi, prev_state - ) - -Initialize the GibbsConditional sampler. -""" -function gibbs_initialstep_recursive( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional}, - target_varnames::AbstractVector{<:VarName}, - global_vi::DynamicPPL.AbstractVarInfo, - prev_state, -) - # GibbsConditional doesn't need any special initialization - # Just perform one sampling step - return gibbs_step_recursive( - rng, model, sampler_wrapped, target_varnames, global_vi, nothing - ) -end - -""" - gibbs_step_recursive( - rng, model, sampler::GibbsConditional, target_varnames, global_vi, state - ) - -Perform a single step of GibbsConditional sampling. -""" -function gibbs_step_recursive( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional{S}}, - target_varnames::AbstractVector{<:VarName}, - global_vi::DynamicPPL.AbstractVarInfo, - state, -) where {S} - sampler = sampler_wrapped.alg - - # Extract conditioned values as a NamedTuple - # Include both random variables and observed data - condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple) - condvals_obs = NamedTuple{keys(model.args)}(model.args) - condvals = merge(condvals_vars, condvals_obs) - - # Get the conditional distribution - conddist = sampler.conditional(condvals) - - # Sample from the conditional distribution - updated = rand(rng, conddist) - - # Update the variable in global_vi - # We need to get the actual VarName for this variable - # The symbol S tells us which variable to update - vn = VarName{S}() - - # Check if the variable needs to be a vector - if haskey(global_vi, vn) - # Update the existing variable - global_vi = DynamicPPL.setindex!!(global_vi, updated, vn) - else - # Try to find the variable with indices - # This handles cases where the variable might have indices - for key in keys(global_vi) - if DynamicPPL.getsym(key) == S - global_vi = DynamicPPL.setindex!!(global_vi, updated, key) - break - end - end - end - - # Update log joint probability - global_vi = last(DynamicPPL.evaluate!!(model, global_vi, DynamicPPL.DefaultContext())) - - return nothing, global_vi -end From 714c1e82979e5daa6cb1d005c113c967e9d4647a Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Fri, 8 Aug 2025 11:11:52 +0100 Subject: [PATCH 04/22] formatter --- src/mcmc/gibbs_conditional.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index c2eba05ba..fe04b048d 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -74,13 +74,13 @@ GibbsContext, ConditionContext, FixedContext, etc. function find_global_varinfo(context, fallback_vi) # Start with the given context and traverse down current_context = context - + while current_context !== nothing if current_context isa GibbsContext # Found GibbsContext, return its global varinfo return get_global_varinfo(current_context) - elseif hasproperty(current_context, :childcontext) && - isdefined(DynamicPPL, :childcontext) + elseif hasproperty(current_context, :childcontext) && + isdefined(DynamicPPL, :childcontext) # Move to child context if it exists current_context = DynamicPPL.childcontext(current_context) else @@ -88,7 +88,7 @@ function find_global_varinfo(context, fallback_vi) break end end - + # If no GibbsContext found, use the fallback return fallback_vi end @@ -168,4 +168,3 @@ function setparams_varinfo!!( # the state is nothing and we don't need to update anything return params end - From 94b723da263927edfef7c20d8e56e543d0d84fc3 Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Fri, 8 Aug 2025 14:37:27 +0100 Subject: [PATCH 05/22] fixed exporting thing --- src/mcmc/gibbs_conditional.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index fe04b048d..7415c5f3f 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -65,6 +65,9 @@ end # Mark GibbsConditional as a valid Gibbs component isgibbscomponent(::GibbsConditional) = true +# Required methods for Gibbs constructor +Base.length(::GibbsConditional) = 1 # Each GibbsConditional handles one variable + """ find_global_varinfo(context, fallback_vi) From 2058ae54e34111e17441c60ab001ba929284646c Mon Sep 17 00:00:00 2001 From: Aoife Date: Tue, 23 Sep 2025 13:09:12 +0100 Subject: [PATCH 06/22] Refactor Gibbs sampler to use inverse of parameters for Gamma distribution and improve context variable retrieval --- src/mcmc/gibbs_conditional.jl | 56 ++++++++++++--------- test/mcmc/gibbs.jl | 4 +- test_gibbs_conditional.jl | 93 ++++++++++++++++++++++------------- 3 files changed, 92 insertions(+), 61 deletions(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index 7415c5f3f..74c0686b1 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -18,7 +18,7 @@ a `Distribution` from which to sample the variable `sym`. ```julia # Define a model @model function inverse_gdemo(x) - λ ~ Gamma(2, 3) + λ ~ Gamma(2, inv(3)) m ~ Normal(0, sqrt(1 / λ)) for i in 1:length(x) x[i] ~ Normal(m, sqrt(1 / λ)) @@ -28,7 +28,7 @@ end # Define analytical conditionals function cond_λ(c::NamedTuple) a = 2.0 - b = 3.0 + b = inv(3) m = c.m x = c.x n = length(x) @@ -75,25 +75,39 @@ Traverse the context stack to find global variable information from GibbsContext, ConditionContext, FixedContext, etc. """ function find_global_varinfo(context, fallback_vi) - # Start with the given context and traverse down + # Traverse the entire context stack to find relevant contexts current_context = context + gibbs_context = nothing + condition_context = nothing + fixed_context = nothing while current_context !== nothing - if current_context isa GibbsContext - # Found GibbsContext, return its global varinfo - return get_global_varinfo(current_context) - elseif hasproperty(current_context, :childcontext) && - isdefined(DynamicPPL, :childcontext) - # Move to child context if it exists + # Use NodeTrait for robust context checking + if DynamicPPL.NodeTrait(current_context) isa DynamicPPL.IsParent + if current_context isa GibbsContext + gibbs_context = current_context + elseif current_context isa DynamicPPL.ConditionContext + condition_context = current_context + elseif current_context isa DynamicPPL.FixedContext + fixed_context = current_context + end + # Move to child context current_context = DynamicPPL.childcontext(current_context) else - # No more child contexts break end end - # If no GibbsContext found, use the fallback - return fallback_vi + # Return the most relevant context's varinfo + if gibbs_context !== nothing + return get_global_varinfo(gibbs_context) + elseif condition_context !== nothing + return DynamicPPL.getvarinfo(condition_context) + elseif fixed_context !== nothing + return DynamicPPL.getvarinfo(fixed_context) + else + return fallback_vi + end end """ @@ -121,19 +135,15 @@ Perform a step of GibbsConditional sampling. function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:GibbsConditional{S}}, + sampler::DynamicPPL.Sampler{<:GibbsConditional}, state::DynamicPPL.AbstractVarInfo; kwargs..., -) where {S} +) alg = sampler.alg # For GibbsConditional within Gibbs, we need to get all variable values - # Traverse the context stack to find all conditioned/fixed/Gibbs variables - global_vi = if hasproperty(model, :context) - find_global_varinfo(model.context, state) - else - state - end + # Model always has a context field, so we can traverse it directly + global_vi = find_global_varinfo(model.context, state) # Extract conditioned values as a NamedTuple # Include both random variables and observed data @@ -147,11 +157,9 @@ function AbstractMCMC.step( # Sample from the conditional distribution updated = rand(rng, conddist) - # Update the variable in state + # Update the variable in state using unflatten for simplicity # The Gibbs sampler ensures that state only contains one variable - # Get the variable name from the keys - varname = first(keys(state)) - new_vi = DynamicPPL.setindex!!(state, updated, varname) + new_vi = DynamicPPL.unflatten(state, [updated]) return nothing, new_vi end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index d7c41d70d..a825401a1 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -946,7 +946,7 @@ end @testset "GibbsConditional" begin # Test with the inverse gamma example from the issue @model function inverse_gdemo(x) - λ ~ Gamma(2, 3) + λ ~ Gamma(2, inv(3)) m ~ Normal(0, sqrt(1 / λ)) for i in 1:length(x) x[i] ~ Normal(m, sqrt(1 / λ)) @@ -956,7 +956,7 @@ end # Define analytical conditionals function cond_λ(c::NamedTuple) a = 2.0 - b = 3.0 + b = inv(3) m = c.m x = c.x n = length(x) diff --git a/test_gibbs_conditional.jl b/test_gibbs_conditional.jl index d6466e537..1a01fa9b2 100644 --- a/test_gibbs_conditional.jl +++ b/test_gibbs_conditional.jl @@ -3,10 +3,11 @@ using Turing.Inference: GibbsConditional using Distributions using Random using Statistics +using Test # Test with the inverse gamma example from the issue @model function inverse_gdemo(x) - λ ~ Gamma(2, 3) + λ ~ Gamma(2, inv(3)) m ~ Normal(0, sqrt(1 / λ)) for i in 1:length(x) x[i] ~ Normal(m, sqrt(1 / λ)) @@ -16,7 +17,7 @@ end # Define analytical conditionals function cond_λ(c::NamedTuple) a = 2.0 - b = 3.0 + b = inv(3) m = c.m x = c.x n = length(x) @@ -34,45 +35,67 @@ function cond_m(c::NamedTuple) return Normal(m_mean, sqrt(m_var)) end -# Generate some observed data -Random.seed!(42) -x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] +@testset "GibbsConditional Integration Tests" begin + # Generate some observed data + Random.seed!(42) + x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] -# Create the model -model = inverse_gdemo(x_obs) + # Create the model + model = inverse_gdemo(x_obs) -# Sample using GibbsConditional -println("Testing GibbsConditional sampler...") -sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) + @testset "Basic GibbsConditional sampling" begin + # Sample using GibbsConditional + sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) -# Run a short chain to test -chain = sample(model, sampler, 100) + # Run a short chain to test + chain = sample(model, sampler, 100) -println("Sampling completed successfully!") -println("\nChain summary:") -println(chain) + # Test that sampling completed successfully + @test chain isa MCMCChains.Chains + @test size(chain, 1) == 100 + @test :λ in names(chain) + @test :m in names(chain) + end + + @testset "Sample statistics" begin + # Generate samples for statistics testing + sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) + chain = sample(model, sampler, 100) -# Extract samples -λ_samples = vec(chain[:λ]) -m_samples = vec(chain[:m]) + # Extract samples + λ_samples = vec(chain[:λ]) + m_samples = vec(chain[:m]) -println("\nλ statistics:") -println(" Mean: ", mean(λ_samples)) -println(" Std: ", std(λ_samples)) -println(" Min: ", minimum(λ_samples)) -println(" Max: ", maximum(λ_samples)) + # Test λ statistics + @test mean(λ_samples) > 0 # λ should be positive + @test minimum(λ_samples) > 0 # All λ samples should be positive + @test std(λ_samples) > 0 # Should have some variability + @test isfinite(mean(λ_samples)) + @test isfinite(std(λ_samples)) + + # Test m statistics + @test isfinite(mean(m_samples)) + @test isfinite(std(m_samples)) + @test std(m_samples) > 0 # Should have some variability + end -println("\nm statistics:") -println(" Mean: ", mean(m_samples)) -println(" Std: ", std(m_samples)) -println(" Min: ", minimum(m_samples)) -println(" Max: ", maximum(m_samples)) + @testset "Mixed samplers" begin + # Test mixing with other samplers + sampler2 = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => MH()) -# Test mixing with other samplers -println("\n\nTesting mixed samplers...") -sampler2 = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => MH()) + chain2 = sample(model, sampler2, 100) -chain2 = sample(model, sampler2, 100) -println("Mixed sampling completed successfully!") -println("\nMixed chain summary:") -println(chain2) + # Test that mixed sampling completed successfully + @test chain2 isa MCMCChains.Chains + @test size(chain2, 1) == 100 + @test :λ in names(chain2) + @test :m in names(chain2) + + # Test that values are reasonable + λ_samples2 = vec(chain2[:λ]) + m_samples2 = vec(chain2[:m]) + @test all(λ_samples2 .> 0) # All λ should be positive + @test all(isfinite.(λ_samples2)) + @test all(isfinite.(m_samples2)) + end +end From b0812a3bfc3be6e3be11fdcdb06b34a324b4c26c Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Thu, 25 Sep 2025 09:31:31 +0100 Subject: [PATCH 07/22] removed file added by mistake --- test_gibbs_conditional.jl | 101 -------------------------------------- 1 file changed, 101 deletions(-) delete mode 100644 test_gibbs_conditional.jl diff --git a/test_gibbs_conditional.jl b/test_gibbs_conditional.jl deleted file mode 100644 index 1a01fa9b2..000000000 --- a/test_gibbs_conditional.jl +++ /dev/null @@ -1,101 +0,0 @@ -using Turing -using Turing.Inference: GibbsConditional -using Distributions -using Random -using Statistics -using Test - -# Test with the inverse gamma example from the issue -@model function inverse_gdemo(x) - λ ~ Gamma(2, inv(3)) - m ~ Normal(0, sqrt(1 / λ)) - for i in 1:length(x) - x[i] ~ Normal(m, sqrt(1 / λ)) - end -end - -# Define analytical conditionals -function cond_λ(c::NamedTuple) - a = 2.0 - b = inv(3) - m = c.m - x = c.x - n = length(x) - a_new = a + (n + 1) / 2 - b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 - return Gamma(a_new, 1 / b_new) -end - -function cond_m(c::NamedTuple) - λ = c.λ - x = c.x - n = length(x) - m_mean = sum(x) / (n + 1) - m_var = 1 / (λ * (n + 1)) - return Normal(m_mean, sqrt(m_var)) -end - -@testset "GibbsConditional Integration Tests" begin - # Generate some observed data - Random.seed!(42) - x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] - - # Create the model - model = inverse_gdemo(x_obs) - - @testset "Basic GibbsConditional sampling" begin - # Sample using GibbsConditional - sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) - - # Run a short chain to test - chain = sample(model, sampler, 100) - - # Test that sampling completed successfully - @test chain isa MCMCChains.Chains - @test size(chain, 1) == 100 - @test :λ in names(chain) - @test :m in names(chain) - end - - @testset "Sample statistics" begin - # Generate samples for statistics testing - sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) - chain = sample(model, sampler, 100) - - # Extract samples - λ_samples = vec(chain[:λ]) - m_samples = vec(chain[:m]) - - # Test λ statistics - @test mean(λ_samples) > 0 # λ should be positive - @test minimum(λ_samples) > 0 # All λ samples should be positive - @test std(λ_samples) > 0 # Should have some variability - @test isfinite(mean(λ_samples)) - @test isfinite(std(λ_samples)) - - # Test m statistics - @test isfinite(mean(m_samples)) - @test isfinite(std(m_samples)) - @test std(m_samples) > 0 # Should have some variability - end - - @testset "Mixed samplers" begin - # Test mixing with other samplers - sampler2 = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => MH()) - - chain2 = sample(model, sampler2, 100) - - # Test that mixed sampling completed successfully - @test chain2 isa MCMCChains.Chains - @test size(chain2, 1) == 100 - @test :λ in names(chain2) - @test :m in names(chain2) - - # Test that values are reasonable - λ_samples2 = vec(chain2[:λ]) - m_samples2 = vec(chain2[:m]) - @test all(λ_samples2 .> 0) # All λ should be positive - @test all(isfinite.(λ_samples2)) - @test all(isfinite.(m_samples2)) - end -end From d91031205ce308881f652f4f135734243d06eaa9 Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Mon, 29 Sep 2025 12:48:39 +0100 Subject: [PATCH 08/22] Add safety checks and error handling in find_global_varinfo and AbstractMCMC.step functions --- src/mcmc/gibbs_conditional.jl | 132 ++++++++++++++++++++++++---------- 1 file changed, 95 insertions(+), 37 deletions(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index 74c0686b1..8401d3405 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -81,33 +81,56 @@ function find_global_varinfo(context, fallback_vi) condition_context = nothing fixed_context = nothing - while current_context !== nothing - # Use NodeTrait for robust context checking - if DynamicPPL.NodeTrait(current_context) isa DynamicPPL.IsParent - if current_context isa GibbsContext - gibbs_context = current_context - elseif current_context isa DynamicPPL.ConditionContext - condition_context = current_context - elseif current_context isa DynamicPPL.FixedContext - fixed_context = current_context + # Safety check: avoid infinite loops with a maximum depth + max_depth = 20 + depth = 0 + + while current_context !== nothing && depth < max_depth + depth += 1 + + try + # Use NodeTrait for robust context checking + if DynamicPPL.NodeTrait(current_context) isa DynamicPPL.IsParent + if current_context isa GibbsContext + gibbs_context = current_context + elseif current_context isa DynamicPPL.ConditionContext + condition_context = current_context + elseif current_context isa DynamicPPL.FixedContext + fixed_context = current_context + end + # Move to child context + current_context = DynamicPPL.childcontext(current_context) + else + break end - # Move to child context - current_context = DynamicPPL.childcontext(current_context) - else + catch e + # If there's an error traversing contexts, break and use fallback + @debug "Error traversing context at depth $depth: $e" break end end - # Return the most relevant context's varinfo - if gibbs_context !== nothing - return get_global_varinfo(gibbs_context) - elseif condition_context !== nothing - return DynamicPPL.getvarinfo(condition_context) - elseif fixed_context !== nothing - return DynamicPPL.getvarinfo(fixed_context) - else - return fallback_vi + # Return the most relevant context's varinfo with error handling + try + if gibbs_context !== nothing + return get_global_varinfo(gibbs_context) + elseif condition_context !== nothing + # Check if getvarinfo method exists for ConditionContext + if hasmethod(DynamicPPL.getvarinfo, (typeof(condition_context),)) + return DynamicPPL.getvarinfo(condition_context) + end + elseif fixed_context !== nothing + # Check if getvarinfo method exists for FixedContext + if hasmethod(DynamicPPL.getvarinfo, (typeof(fixed_context),)) + return DynamicPPL.getvarinfo(fixed_context) + end + end + catch e + @debug "Error accessing varinfo from context: $e" end + + # Fall back to the provided fallback_vi + return fallback_vi end """ @@ -141,27 +164,62 @@ function AbstractMCMC.step( ) alg = sampler.alg - # For GibbsConditional within Gibbs, we need to get all variable values - # Model always has a context field, so we can traverse it directly - global_vi = find_global_varinfo(model.context, state) + try + # For GibbsConditional within Gibbs, we need to get all variable values + # Model always has a context field, so we can traverse it directly + global_vi = find_global_varinfo(model.context, state) + + # Extract conditioned values as a NamedTuple + # Include both random variables and observed data + # Use a safe approach for invlink to avoid linking conflicts + invlinked_global_vi = try + DynamicPPL.invlink(global_vi, model) + catch e + @debug "Failed to invlink global_vi, using as-is: $e" + global_vi + end - # Extract conditioned values as a NamedTuple - # Include both random variables and observed data - condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple) - condvals_obs = NamedTuple{keys(model.args)}(model.args) - condvals = merge(condvals_vars, condvals_obs) + condvals_vars = DynamicPPL.values_as(invlinked_global_vi, NamedTuple) + condvals_obs = NamedTuple{keys(model.args)}(model.args) + condvals = merge(condvals_vars, condvals_obs) - # Get the conditional distribution - conddist = alg.conditional(condvals) + # Get the conditional distribution + conddist = alg.conditional(condvals) - # Sample from the conditional distribution - updated = rand(rng, conddist) + # Sample from the conditional distribution + updated = rand(rng, conddist) - # Update the variable in state using unflatten for simplicity - # The Gibbs sampler ensures that state only contains one variable - new_vi = DynamicPPL.unflatten(state, [updated]) + # Update the variable in state, handling linking properly + # The Gibbs sampler ensures that state only contains one variable + state_is_linked = try + DynamicPPL.islinked(state, model) + catch e + @debug "Error checking if state is linked: $e" + false + end - return nothing, new_vi + if state_is_linked + # If state is linked, we need to unlink, update, then relink + try + unlinked_state = DynamicPPL.invlink(state, model) + updated_state = DynamicPPL.unflatten(unlinked_state, [updated]) + new_vi = DynamicPPL.link(updated_state, model) + catch e + @debug "Error in linked state update path: $e, falling back to direct update" + new_vi = DynamicPPL.unflatten(state, [updated]) + end + else + # State is not linked, we can update directly + new_vi = DynamicPPL.unflatten(state, [updated]) + end + + return nothing, new_vi + + catch e + # If there's any error in the step, log it and rethrow + @error "Error in GibbsConditional step: $e" + rethrow(e) + end end """ From 4b1dc2f03ef6b7957789364b6be4de37cc801ecf Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Thu, 9 Oct 2025 12:16:32 +0100 Subject: [PATCH 09/22] imports? --- test/mcmc/gibbs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index a825401a1..ca60f27b6 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -19,7 +19,7 @@ using StableRNGs: StableRNG using Test: @inferred, @test, @test_broken, @test_throws, @testset using Turing using Turing: Inference -using Turing.Inference: AdvancedHMC, AdvancedMH +using Turing.Inference: AdvancedHMC, AdvancedMH, GibbsConditional using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess function check_transition_varnames(transition::Turing.Inference.Transition, parent_varnames) From a33d8a9c4e7c570fbfeb17f086e57d75e4eb59e4 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 17 Nov 2025 16:06:08 +0000 Subject: [PATCH 10/22] Fixes and improvements for GibbsConditional --- src/Turing.jl | 1 + src/mcmc/gibbs_conditional.jl | 235 ++++++++++++---------------------- test/mcmc/gibbs.jl | 62 +++++---- 3 files changed, 114 insertions(+), 184 deletions(-) diff --git a/src/Turing.jl b/src/Turing.jl index 58a58eb2a..d808ea5df 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -102,6 +102,7 @@ export Emcee, ESS, Gibbs, + GibbsConditional, HMC, SGLD, SGHMC, diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index 8401d3405..ed518b45d 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -2,16 +2,22 @@ using DynamicPPL: VarName using Random: Random import AbstractMCMC -# These functions provide specialized methods for GibbsConditional that extend the generic implementations in gibbs.jl - """ - GibbsConditional(sym::Symbol, conditional) + GibbsConditional(conditional) + +A Gibbs component sampler that samples variables according to user-provided +analytical conditional distributions. -A Gibbs sampler component that samples a variable according to a user-provided -analytical conditional distribution. +`conditional` should be a function that takes a `Dict{<:VarName}` of conditioned variables +and their values, and returns one of the following: +- A single `Distribution`, if only one variable is being sampled. +- An `AbstractDict{<:VarName,<:Distribution}` that maps the variables being sampled to their + `Distribution`s. +- A `NamedTuple` of `Distribution`s, which is like the `AbstractDict` case but can be used + if all the variable names are single `Symbol`s, and may be more performant. -The `conditional` function should take a `NamedTuple` of conditioned variables and return -a `Distribution` from which to sample the variable `sym`. +If a Gibbs component is created with `(:var1, :var2) => GibbsConditional(conditional)`, then +`var1` and `var2` should be in the keys of the return value of `conditional`. # Examples @@ -26,20 +32,20 @@ a `Distribution` from which to sample the variable `sym`. end # Define analytical conditionals -function cond_λ(c::NamedTuple) +function cond_λ(c) a = 2.0 b = inv(3) - m = c.m - x = c.x + m = c[@varname(m)] + x = c[@varname(x)] n = length(x) a_new = a + (n + 1) / 2 b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 return Gamma(a_new, 1 / b_new) end -function cond_m(c::NamedTuple) - λ = c.λ - x = c.x +function cond_m(c) + λ = c[@varname(λ)] + x = c[@varname(x)] n = length(x) m_mean = sum(x) / (n + 1) m_var = 1 / (λ * (n + 1)) @@ -49,88 +55,52 @@ end # Sample using GibbsConditional model = inverse_gdemo([1.0, 2.0, 3.0]) chain = sample(model, Gibbs( - :λ => GibbsConditional(:λ, cond_λ), - :m => GibbsConditional(:m, cond_m) + :λ => GibbsConditional(cond_λ), + :m => GibbsConditional(cond_m) ), 1000) ``` """ -struct GibbsConditional{C} <: InferenceAlgorithm +struct GibbsConditional{C} <: AbstractSampler conditional::C - - function GibbsConditional(sym::Symbol, conditional::C) where {C} - return new{C}(conditional) - end end # Mark GibbsConditional as a valid Gibbs component isgibbscomponent(::GibbsConditional) = true -# Required methods for Gibbs constructor -Base.length(::GibbsConditional) = 1 # Each GibbsConditional handles one variable - """ - find_global_varinfo(context, fallback_vi) + build_variable_dict(model::DynamicPPL.Model) -Traverse the context stack to find global variable information from -GibbsContext, ConditionContext, FixedContext, etc. +Traverse the context stack of `model` and build a `Dict` of all the variable values that are +set in GibbsContext, ConditionContext, or FixedContext. """ -function find_global_varinfo(context, fallback_vi) - # Traverse the entire context stack to find relevant contexts - current_context = context - gibbs_context = nothing - condition_context = nothing - fixed_context = nothing - - # Safety check: avoid infinite loops with a maximum depth - max_depth = 20 - depth = 0 - - while current_context !== nothing && depth < max_depth - depth += 1 - - try - # Use NodeTrait for robust context checking - if DynamicPPL.NodeTrait(current_context) isa DynamicPPL.IsParent - if current_context isa GibbsContext - gibbs_context = current_context - elseif current_context isa DynamicPPL.ConditionContext - condition_context = current_context - elseif current_context isa DynamicPPL.FixedContext - fixed_context = current_context - end - # Move to child context - current_context = DynamicPPL.childcontext(current_context) - else - break - end - catch e - # If there's an error traversing contexts, break and use fallback - @debug "Error traversing context at depth $depth: $e" - break - end - end +function build_variable_dict(model::DynamicPPL.Model) + context = model.context + cond_nt = DynamicPPL.conditioned(context) + fixed_nt = DynamicPPL.fixed(context) + # TODO(mhauru) Can we avoid invlinking all the time? + global_vi = DynamicPPL.invlink(get_gibbs_global_varinfo(context), model) + # TODO(mhauru) Double-check that the ordered of precedence here is correct. Should we + # in fact error if there is any overlap in the keys? + return merge( + DynamicPPL.values_as(global_vi, Dict), + Dict( + (DynamicPPL.VarName{sym}() => val for (sym, val) in pairs(cond_nt))..., + (DynamicPPL.VarName{sym}() => val for (sym, val) in pairs(fixed_nt))..., + (DynamicPPL.VarName{sym}() => val for (sym, val) in pairs(model.args))..., + ), + ) +end - # Return the most relevant context's varinfo with error handling - try - if gibbs_context !== nothing - return get_global_varinfo(gibbs_context) - elseif condition_context !== nothing - # Check if getvarinfo method exists for ConditionContext - if hasmethod(DynamicPPL.getvarinfo, (typeof(condition_context),)) - return DynamicPPL.getvarinfo(condition_context) - end - elseif fixed_context !== nothing - # Check if getvarinfo method exists for FixedContext - if hasmethod(DynamicPPL.getvarinfo, (typeof(fixed_context),)) - return DynamicPPL.getvarinfo(fixed_context) - end - end - catch e - @debug "Error accessing varinfo from context: $e" +function get_gibbs_global_varinfo(context::DynamicPPL.AbstractContext) + return if context isa GibbsContext + get_global_varinfo(context) + elseif DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent + get_gibbs_global_varinfo(DynamicPPL.childcontext(context)) + else + throw(ArgumentError("""No GibbsContext found in context stack. \ + Are you trying to use GibbsConditional outside of Gibbs? + """)) end - - # Fall back to the provided fallback_vi - return fallback_vi end """ @@ -138,16 +108,17 @@ end Initialize the GibbsConditional sampler. """ -function DynamicPPL.initialstep( - rng::Random.AbstractRNG, +function initialstep( + ::Random.AbstractRNG, model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:GibbsConditional}, + ::GibbsConditional, vi::DynamicPPL.AbstractVarInfo; kwargs..., ) - # GibbsConditional doesn't need any special initialization - # Just return the initial state - return nothing, vi + state = DynamicPPL.is_transformed(vi) ? DynamicPPL.invlink(vi, model) : vi + # Since GibbsConditional is only used within Gibbs, it does not need to return a + # transition. + return nothing, state end """ @@ -158,82 +129,44 @@ Perform a step of GibbsConditional sampling. function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:GibbsConditional}, + sampler::GibbsConditional, state::DynamicPPL.AbstractVarInfo; kwargs..., ) - alg = sampler.alg - - try - # For GibbsConditional within Gibbs, we need to get all variable values - # Model always has a context field, so we can traverse it directly - global_vi = find_global_varinfo(model.context, state) - - # Extract conditioned values as a NamedTuple - # Include both random variables and observed data - # Use a safe approach for invlink to avoid linking conflicts - invlinked_global_vi = try - DynamicPPL.invlink(global_vi, model) - catch e - @debug "Failed to invlink global_vi, using as-is: $e" - global_vi + # Get all the conditioned variable values from the model context. This is assumed to + # include a GibbsContext as part of the context stack. + condvals = build_variable_dict(model) + + # Get the conditional distributions + conddists = sampler.conditional(condvals) + + # We support three different kinds of return values for `sample.conditional`, to make + # life easier for the user. + if conddists isa AbstractDict + for (vn, dist) in conddists + state = setindex!!(state, rand(rng, dist), vn) end - - condvals_vars = DynamicPPL.values_as(invlinked_global_vi, NamedTuple) - condvals_obs = NamedTuple{keys(model.args)}(model.args) - condvals = merge(condvals_vars, condvals_obs) - - # Get the conditional distribution - conddist = alg.conditional(condvals) - - # Sample from the conditional distribution - updated = rand(rng, conddist) - - # Update the variable in state, handling linking properly - # The Gibbs sampler ensures that state only contains one variable - state_is_linked = try - DynamicPPL.islinked(state, model) - catch e - @debug "Error checking if state is linked: $e" - false + elseif conddists isa NamedTuple + for (vn_sym => dist) in pairs(conddists) + vn = VarName{vn_sym}() + state = setindex!!(state, rand(rng, dist), vn) end - - if state_is_linked - # If state is linked, we need to unlink, update, then relink - try - unlinked_state = DynamicPPL.invlink(state, model) - updated_state = DynamicPPL.unflatten(unlinked_state, [updated]) - new_vi = DynamicPPL.link(updated_state, model) - catch e - @debug "Error in linked state update path: $e, falling back to direct update" - new_vi = DynamicPPL.unflatten(state, [updated]) - end - else - # State is not linked, we can update directly - new_vi = DynamicPPL.unflatten(state, [updated]) - end - - return nothing, new_vi - - catch e - # If there's any error in the step, log it and rethrow - @error "Error in GibbsConditional step: $e" - rethrow(e) + else + # Single variable case + vn = first(keys(state)) + state = setindex!!(state, rand(rng, conddists), vn) end -end -""" - setparams_varinfo!!(model, sampler::GibbsConditional, state, params::AbstractVarInfo) + # Since GibbsConditional is only used within Gibbs, it does not need to return a + # transition. + return nothing, state +end -Update the variable info with new parameters for GibbsConditional. -""" function setparams_varinfo!!( model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:GibbsConditional}, + sampler::GibbsConditional, state, params::DynamicPPL.AbstractVarInfo, ) - # For GibbsConditional, we just return the params as-is since - # the state is nothing and we don't need to update anything return params end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 1b6291bf0..8ea2234bf 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -496,7 +496,7 @@ end @testset "dynamic model with analytical posterior" begin # A dynamic model where b ~ Bernoulli determines the dimensionality - # When b=0: single parameter θ₁ + # When b=0: single parameter θ₁ # When b=1: two parameters θ₁, θ₂ where we observe their sum @model function dynamic_bernoulli_normal(y_obs=2.0) b ~ Bernoulli(0.3) @@ -575,7 +575,7 @@ end # end # end # sample(f(), Gibbs(:a => PG(10), :x => MH()), 1000) - # + # # because the number of observations in each particle depends on the value # of `a`. # @@ -925,20 +925,20 @@ end end # Define analytical conditionals - function cond_λ(c::NamedTuple) + function cond_λ(c) a = 2.0 b = inv(3) - m = c.m - x = c.x + m = c[@varname(m)] + x = c[@varname(x)] n = length(x) a_new = a + (n + 1) / 2 b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 return Gamma(a_new, 1 / b_new) end - function cond_m(c::NamedTuple) - λ = c.λ - x = c.x + function cond_m(c) + λ = c[@varname(λ)] + x = c[@varname(x)] n = length(x) m_mean = sum(x) / (n + 1) m_var = 1 / (λ * (n + 1)) @@ -952,7 +952,7 @@ end model = inverse_gdemo(x_obs) # Test that GibbsConditional works - sampler = Gibbs(GibbsConditional(:λ, cond_λ), GibbsConditional(:m, cond_m)) + sampler = Gibbs(:λ => GibbsConditional(cond_λ), :m => GibbsConditional(cond_m)) chain = sample(model, sampler, 1000) # Check that we got the expected variables @@ -971,12 +971,11 @@ end # Test mixing with other samplers @testset "mixed samplers" begin - Random.seed!(42) x_obs = [1.0, 2.0, 3.0] model = inverse_gdemo(x_obs) # Mix GibbsConditional with standard samplers - sampler = Gibbs(GibbsConditional(:λ, cond_λ), :m => MH()) + sampler = Gibbs(:λ => GibbsConditional(cond_λ), :m => MH()) chain = sample(model, sampler, 500) @test :λ in names(chain) @@ -986,44 +985,41 @@ end # Test with a simpler model @testset "simple normal model" begin - @model function simple_normal(x) + @model function simple_normal(dim) μ ~ Normal(0, 10) - σ ~ truncated(Normal(1, 1); lower=0.01) - for i in 1:length(x) - x[i] ~ Normal(μ, σ) - end + σ2 ~ truncated(Normal(1, 1); lower=0.01) + return x ~ MvNormal(fill(μ, dim), I * σ2) end # Conditional for μ given σ and x - function cond_μ(c::NamedTuple) - σ = c.σ - x = c.x + function cond_μ(c) + σ2 = c[@varname(σ2)] + x = c[@varname(x)] n = length(x) # Prior: μ ~ Normal(0, 10) # Likelihood: x[i] ~ Normal(μ, σ) # Posterior: μ ~ Normal(μ_post, σ_post) prior_var = 100.0 # 10^2 - likelihood_var = σ^2 / n - post_var = 1 / (1 / prior_var + n / σ^2) - post_mean = post_var * (0 / prior_var + sum(x) / σ^2) + post_var = 1 / (1 / prior_var + n / σ2) + post_mean = post_var * (0 / prior_var + sum(x) / σ2) return Normal(post_mean, sqrt(post_var)) end - Random.seed!(42) - x_obs = randn(10) .+ 2.0 # Data centered around 2 - model = simple_normal(x_obs) - - sampler = Gibbs(GibbsConditional(:μ, cond_μ), :σ => MH()) - - chain = sample(model, sampler, 1000) - - μ_samples = vec(chain[:μ]) - @test abs(mean(μ_samples) - 2.0) < 0.5 # Should be close to true mean + rng = StableRNG(23) + dim = 10_000 + true_mean = 2.0 + x_obs = randn(rng, dim) .+ true_mean + model = simple_normal(dim) | (; x=x_obs) + sampler = Gibbs(:μ => GibbsConditional(cond_μ), :σ2 => MH()) + chain = sample(rng, model, sampler, 1_000) + # The correct posterior mean isn't true_mean, but it is very close, because we + # have a lot of data. + @test mean(chain, :μ) ≈ true_mean atol = 0.02 end # Test that GibbsConditional is marked as a valid component @testset "isgibbscomponent" begin - gc = GibbsConditional(:x, c -> Normal(0, 1)) + gc = GibbsConditional(c -> Normal(0, 1)) @test Turing.Inference.isgibbscomponent(gc) end end From f41fc6e41c34574d7136c268e1c7b4a14154bcf6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 17 Nov 2025 16:35:21 +0000 Subject: [PATCH 11/22] Move GibbsConditional tests to their own file --- test/mcmc/gibbs.jl | 110 ------------------------------ test/mcmc/gibbs_conditional.jl | 120 +++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 121 insertions(+), 110 deletions(-) create mode 100644 test/mcmc/gibbs_conditional.jl diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 8ea2234bf..c2df731b0 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -913,116 +913,6 @@ end sampler = Gibbs(:w => HMC(0.05, 10)) @test (sample(model, sampler, 10); true) end - - @testset "GibbsConditional" begin - # Test with the inverse gamma example from the issue - @model function inverse_gdemo(x) - λ ~ Gamma(2, inv(3)) - m ~ Normal(0, sqrt(1 / λ)) - for i in 1:length(x) - x[i] ~ Normal(m, sqrt(1 / λ)) - end - end - - # Define analytical conditionals - function cond_λ(c) - a = 2.0 - b = inv(3) - m = c[@varname(m)] - x = c[@varname(x)] - n = length(x) - a_new = a + (n + 1) / 2 - b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 - return Gamma(a_new, 1 / b_new) - end - - function cond_m(c) - λ = c[@varname(λ)] - x = c[@varname(x)] - n = length(x) - m_mean = sum(x) / (n + 1) - m_var = 1 / (λ * (n + 1)) - return Normal(m_mean, sqrt(m_var)) - end - - # Test basic functionality - @testset "basic sampling" begin - Random.seed!(42) - x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] - model = inverse_gdemo(x_obs) - - # Test that GibbsConditional works - sampler = Gibbs(:λ => GibbsConditional(cond_λ), :m => GibbsConditional(cond_m)) - chain = sample(model, sampler, 1000) - - # Check that we got the expected variables - @test :λ in names(chain) - @test :m in names(chain) - - # Check that the values are reasonable - λ_samples = vec(chain[:λ]) - m_samples = vec(chain[:m]) - - # Given the observed data, we expect certain behavior - @test mean(λ_samples) > 0 # λ should be positive - @test minimum(λ_samples) > 0 - @test std(m_samples) < 2.0 # m should be relatively well-constrained - end - - # Test mixing with other samplers - @testset "mixed samplers" begin - x_obs = [1.0, 2.0, 3.0] - model = inverse_gdemo(x_obs) - - # Mix GibbsConditional with standard samplers - sampler = Gibbs(:λ => GibbsConditional(cond_λ), :m => MH()) - chain = sample(model, sampler, 500) - - @test :λ in names(chain) - @test :m in names(chain) - @test size(chain, 1) == 500 - end - - # Test with a simpler model - @testset "simple normal model" begin - @model function simple_normal(dim) - μ ~ Normal(0, 10) - σ2 ~ truncated(Normal(1, 1); lower=0.01) - return x ~ MvNormal(fill(μ, dim), I * σ2) - end - - # Conditional for μ given σ and x - function cond_μ(c) - σ2 = c[@varname(σ2)] - x = c[@varname(x)] - n = length(x) - # Prior: μ ~ Normal(0, 10) - # Likelihood: x[i] ~ Normal(μ, σ) - # Posterior: μ ~ Normal(μ_post, σ_post) - prior_var = 100.0 # 10^2 - post_var = 1 / (1 / prior_var + n / σ2) - post_mean = post_var * (0 / prior_var + sum(x) / σ2) - return Normal(post_mean, sqrt(post_var)) - end - - rng = StableRNG(23) - dim = 10_000 - true_mean = 2.0 - x_obs = randn(rng, dim) .+ true_mean - model = simple_normal(dim) | (; x=x_obs) - sampler = Gibbs(:μ => GibbsConditional(cond_μ), :σ2 => MH()) - chain = sample(rng, model, sampler, 1_000) - # The correct posterior mean isn't true_mean, but it is very close, because we - # have a lot of data. - @test mean(chain, :μ) ≈ true_mean atol = 0.02 - end - - # Test that GibbsConditional is marked as a valid component - @testset "isgibbscomponent" begin - gc = GibbsConditional(c -> Normal(0, 1)) - @test Turing.Inference.isgibbscomponent(gc) - end - end end end diff --git a/test/mcmc/gibbs_conditional.jl b/test/mcmc/gibbs_conditional.jl new file mode 100644 index 000000000..6868f4f2c --- /dev/null +++ b/test/mcmc/gibbs_conditional.jl @@ -0,0 +1,120 @@ +module GibbsConditionalTests + +using Distributions: InverseGamma, Normal +using Distributions: sample +using DynamicPPL: DynamicPPL +using Random: Random +using StableRNGs: StableRNG +using Test: @inferred, @test, @test_broken, @test_throws, @testset +using Turing + +@testset "GibbsConditional" begin + @model function inverse_gdemo(x) + λ ~ Gamma(2, inv(3)) + m ~ Normal(0, sqrt(1 / λ)) + for i in 1:length(x) + x[i] ~ Normal(m, sqrt(1 / λ)) + end + end + + # Define analytical conditionals + function cond_λ(c) + a = 2.0 + b = inv(3) + m = c[@varname(m)] + x = c[@varname(x)] + n = length(x) + a_new = a + (n + 1) / 2 + b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 + return Gamma(a_new, 1 / b_new) + end + + function cond_m(c) + λ = c[@varname(λ)] + x = c[@varname(x)] + n = length(x) + m_mean = sum(x) / (n + 1) + m_var = 1 / (λ * (n + 1)) + return Normal(m_mean, sqrt(m_var)) + end + + # Test basic functionality + @testset "basic sampling" begin + Random.seed!(42) + x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] + model = inverse_gdemo(x_obs) + + # Test that GibbsConditional works + sampler = Gibbs(:λ => GibbsConditional(cond_λ), :m => GibbsConditional(cond_m)) + chain = sample(model, sampler, 1000) + + # Check that we got the expected variables + @test :λ in names(chain) + @test :m in names(chain) + + # Check that the values are reasonable + λ_samples = vec(chain[:λ]) + m_samples = vec(chain[:m]) + + # Given the observed data, we expect certain behavior + @test mean(λ_samples) > 0 # λ should be positive + @test minimum(λ_samples) > 0 + @test std(m_samples) < 2.0 # m should be relatively well-constrained + end + + # Test mixing with other samplers + @testset "mixed samplers" begin + x_obs = [1.0, 2.0, 3.0] + model = inverse_gdemo(x_obs) + + # Mix GibbsConditional with standard samplers + sampler = Gibbs(:λ => GibbsConditional(cond_λ), :m => MH()) + chain = sample(model, sampler, 500) + + @test :λ in names(chain) + @test :m in names(chain) + @test size(chain, 1) == 500 + end + + # Test with a simpler model + @testset "simple normal model" begin + @model function simple_normal(dim) + μ ~ Normal(0, 10) + σ2 ~ truncated(Normal(1, 1); lower=0.01) + return x ~ MvNormal(fill(μ, dim), I * σ2) + end + + # Conditional for μ given σ and x + function cond_μ(c) + σ2 = c[@varname(σ2)] + x = c[@varname(x)] + n = length(x) + # Prior: μ ~ Normal(0, 10) + # Likelihood: x[i] ~ Normal(μ, σ) + # Posterior: μ ~ Normal(μ_post, σ_post) + prior_var = 100.0 # 10^2 + post_var = 1 / (1 / prior_var + n / σ2) + post_mean = post_var * (0 / prior_var + sum(x) / σ2) + return Normal(post_mean, sqrt(post_var)) + end + + rng = StableRNG(23) + dim = 10_000 + true_mean = 2.0 + x_obs = randn(rng, dim) .+ true_mean + model = simple_normal(dim) | (; x=x_obs) + sampler = Gibbs(:μ => GibbsConditional(cond_μ), :σ2 => MH()) + chain = sample(rng, model, sampler, 1_000) + # The correct posterior mean isn't true_mean, but it is very close, because we + # have a lot of data. + @test mean(chain, :μ) ≈ true_mean atol = 0.02 + end + + # Test that GibbsConditional is marked as a valid component + @testset "isgibbscomponent" begin + gc = GibbsConditional(c -> Normal(0, 1)) + @test Turing.Inference.isgibbscomponent(gc) + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 81b4bdde2..ce4c7d166 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,6 +53,7 @@ end @timeit TIMEROUTPUT "inference" begin @testset "inference with samplers" verbose = true begin @timeit_include("mcmc/gibbs.jl") + @timeit_include("mcmc/gibbs_conditional.jl") @timeit_include("mcmc/hmc.jl") @timeit_include("mcmc/Inference.jl") @timeit_include("mcmc/sghmc.jl") From bd5ff0b588305fbbd18c29f1d8eb0e4f81c1dd03 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 17 Nov 2025 17:34:52 +0000 Subject: [PATCH 12/22] More GibbsConditional tests --- src/mcmc/gibbs_conditional.jl | 5 +- test/mcmc/gibbs_conditional.jl | 288 ++++++++++++++++++++++++--------- 2 files changed, 219 insertions(+), 74 deletions(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index ed518b45d..cee1dd66a 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -77,7 +77,8 @@ function build_variable_dict(model::DynamicPPL.Model) context = model.context cond_nt = DynamicPPL.conditioned(context) fixed_nt = DynamicPPL.fixed(context) - # TODO(mhauru) Can we avoid invlinking all the time? + # TODO(mhauru) Can we avoid invlinking all the time? Note that this causes a model + # evaluation, which may be expensive. global_vi = DynamicPPL.invlink(get_gibbs_global_varinfo(context), model) # TODO(mhauru) Double-check that the ordered of precedence here is correct. Should we # in fact error if there is any overlap in the keys? @@ -147,7 +148,7 @@ function AbstractMCMC.step( state = setindex!!(state, rand(rng, dist), vn) end elseif conddists isa NamedTuple - for (vn_sym => dist) in pairs(conddists) + for (vn_sym, dist) in pairs(conddists) vn = VarName{vn_sym}() state = setindex!!(state, rand(rng, dist), vn) end diff --git a/test/mcmc/gibbs_conditional.jl b/test/mcmc/gibbs_conditional.jl index 6868f4f2c..2371d806e 100644 --- a/test/mcmc/gibbs_conditional.jl +++ b/test/mcmc/gibbs_conditional.jl @@ -8,93 +8,106 @@ using StableRNGs: StableRNG using Test: @inferred, @test, @test_broken, @test_throws, @testset using Turing -@testset "GibbsConditional" begin - @model function inverse_gdemo(x) - λ ~ Gamma(2, inv(3)) - m ~ Normal(0, sqrt(1 / λ)) - for i in 1:length(x) - x[i] ~ Normal(m, sqrt(1 / λ)) +@testset "GibbsConditional" verbose = true begin + @testset "Gamma model tests" begin + @model function inverse_gdemo(x) + precision ~ Gamma(2, inv(3)) + std = sqrt(1 / precision) + m ~ Normal(0, std) + for i in 1:length(x) + x[i] ~ Normal(m, std) + end end - end - # Define analytical conditionals - function cond_λ(c) - a = 2.0 - b = inv(3) - m = c[@varname(m)] - x = c[@varname(x)] - n = length(x) - a_new = a + (n + 1) / 2 - b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 - return Gamma(a_new, 1 / b_new) - end + # Define analytical conditionals + function cond_precision(c) + a = 2.0 + b = 3.0 + m = c[@varname(m)] + x = c[@varname(x)] + n = length(x) + a_new = a + (n + 1) / 2 + b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 + return Gamma(a_new, 1 / b_new) + end - function cond_m(c) - λ = c[@varname(λ)] - x = c[@varname(x)] - n = length(x) - m_mean = sum(x) / (n + 1) - m_var = 1 / (λ * (n + 1)) - return Normal(m_mean, sqrt(m_var)) - end + function cond_m(c) + precision = c[@varname(precision)] + x = c[@varname(x)] + n = length(x) + m_mean = sum(x) / (n + 1) + m_var = 1 / (precision * (n + 1)) + return Normal(m_mean, sqrt(m_var)) + end - # Test basic functionality - @testset "basic sampling" begin - Random.seed!(42) + rng = StableRNG(23) x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] model = inverse_gdemo(x_obs) - # Test that GibbsConditional works - sampler = Gibbs(:λ => GibbsConditional(cond_λ), :m => GibbsConditional(cond_m)) - chain = sample(model, sampler, 1000) + reference_sampler = NUTS() + reference_chain = sample(rng, model, reference_sampler, 10_000) - # Check that we got the expected variables - @test :λ in names(chain) - @test :m in names(chain) + # Use both conditionals + sampler = Gibbs( + :precision => GibbsConditional(cond_precision), :m => GibbsConditional(cond_m) + ) + chain = sample(rng, model, sampler, 10_000) - # Check that the values are reasonable - λ_samples = vec(chain[:λ]) - m_samples = vec(chain[:m]) + @test size(chain, 1) == 10_000 + @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.05 + @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.05 - # Given the observed data, we expect certain behavior - @test mean(λ_samples) > 0 # λ should be positive - @test minimum(λ_samples) > 0 - @test std(m_samples) < 2.0 # m should be relatively well-constrained - end + # Mix GibbsConditional with standard samplers + sampler = Gibbs(:precision => GibbsConditional(cond_precision), :m => MH()) + chain = sample(rng, model, sampler, 10_000) - # Test mixing with other samplers - @testset "mixed samplers" begin - x_obs = [1.0, 2.0, 3.0] - model = inverse_gdemo(x_obs) + @test size(chain, 1) == 10_000 + @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.05 + @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.05 - # Mix GibbsConditional with standard samplers - sampler = Gibbs(:λ => GibbsConditional(cond_λ), :m => MH()) - chain = sample(model, sampler, 500) + sampler = Gibbs(:m => GibbsConditional(cond_m), :precision => HMC(0.1, 10)) + chain = sample(rng, model, sampler, 10_000) + + @test size(chain, 1) == 10_000 + @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.05 + @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.05 - @test :λ in names(chain) - @test :m in names(chain) - @test size(chain, 1) == 500 + # Block sample, sampling the same variable with multiple component samplers. + sampler = Gibbs( + (:precision, :m) => HMC(0.1, 10), + :m => GibbsConditional(cond_m), + :precision => MH(), + :precision => GibbsConditional(cond_precision), + :precision => GibbsConditional(cond_precision), + :precision => HMC(0.1, 10), + :m => GibbsConditional(cond_m), + :m => PG(10), + ) + chain = sample(rng, model, sampler, 1_000) + + @test size(chain, 1) == 1_000 + @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.05 + @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.05 end - # Test with a simpler model - @testset "simple normal model" begin + @testset "Simple normal model" begin @model function simple_normal(dim) - μ ~ Normal(0, 10) - σ2 ~ truncated(Normal(1, 1); lower=0.01) - return x ~ MvNormal(fill(μ, dim), I * σ2) + mean ~ Normal(0, 10) + var ~ truncated(Normal(1, 1); lower=0.01) + return x ~ MvNormal(fill(mean, dim), I * var) end - # Conditional for μ given σ and x - function cond_μ(c) - σ2 = c[@varname(σ2)] + # Conditional for mean given var and x + function cond_mean(c) + var = c[@varname(var)] x = c[@varname(x)] n = length(x) - # Prior: μ ~ Normal(0, 10) - # Likelihood: x[i] ~ Normal(μ, σ) - # Posterior: μ ~ Normal(μ_post, σ_post) + # Prior: mean ~ Normal(0, 10) + # Likelihood: x[i] ~ Normal(mean, σ) + # Posterior: mean ~ Normal(μ_post, σ_post) prior_var = 100.0 # 10^2 - post_var = 1 / (1 / prior_var + n / σ2) - post_mean = post_var * (0 / prior_var + sum(x) / σ2) + post_var = 1 / (1 / prior_var + n / var) + post_mean = post_var * (0 / prior_var + sum(x) / var) return Normal(post_mean, sqrt(post_var)) end @@ -103,17 +116,148 @@ using Turing true_mean = 2.0 x_obs = randn(rng, dim) .+ true_mean model = simple_normal(dim) | (; x=x_obs) - sampler = Gibbs(:μ => GibbsConditional(cond_μ), :σ2 => MH()) + sampler = Gibbs(:mean => GibbsConditional(cond_mean), :var => MH()) chain = sample(rng, model, sampler, 1_000) # The correct posterior mean isn't true_mean, but it is very close, because we # have a lot of data. - @test mean(chain, :μ) ≈ true_mean atol = 0.02 + @test mean(chain, :mean) ≈ true_mean atol = 0.05 + end + + # Test that the different ways of returning values from the conditional function work. + @testset "Double simple normal" begin + # This is the same model as simple_normal above, but just doubled. + @model function double_simple_normal(dim1, dim2) + prior_std1 = 10.0 + mean1 ~ Normal(0, prior_std1) + var1 ~ truncated(Normal(1, 1); lower=0.01) + x1 ~ MvNormal(fill(mean1, dim1), I * var1) + + prior_std2 = 20.0 + mean2 ~ Normal(0, prior_std2) + var2 ~ truncated(Normal(1, 1); lower=0.01) + x2 ~ MvNormal(fill(mean2, dim2), I * var2) + return nothing + end + + function cond_mean(var, x, prior_std) + n = length(x) + # Prior: mean ~ Normal(0, prior_std) + # Likelihood: x[i] ~ Normal(mean, σ) + # Posterior: mean ~ Normal(μ_post, σ_post) + prior_var = prior_std^2 + post_var = 1 / (1 / prior_var + n / var) + post_mean = post_var * (0 / prior_var + sum(x) / var) + return Normal(post_mean, sqrt(post_var)) + end + + rng = StableRNG(23) + dim1 = 10_000 + true_mean1 = -10.0 + x1_obs = randn(rng, dim1) .+ true_mean1 + dim2 = 20_000 + true_mean2 = -20.0 + x2_obs = randn(rng, dim2) .+ true_mean2 + base_model = double_simple_normal(dim1, dim2) + + @testset "conditionals return types" begin + # Test using GibbsConditional for both separately. + cond_mean1(c) = cond_mean(c[@varname(var1)], c[@varname(x1)], 10.0) + cond_mean2(c) = cond_mean(c[@varname(var2)], c[@varname(x2)], 20.0) + model = base_model | (; x1=x1_obs, x2=x2_obs) + sampler = Gibbs( + :mean1 => GibbsConditional(cond_mean1), + :mean2 => GibbsConditional(cond_mean2), + (:var1, :var2) => HMC(0.1, 10), + ) + chain = sample(StableRNG(24), model, sampler, 1_000) + # The correct posterior mean isn't true_mean, but it is very close, because we + # have a lot of data. + @test mean(chain, :mean1) ≈ true_mean1 atol = 0.05 + @test mean(chain, :mean2) ≈ true_mean2 atol = 0.05 + + # Test using GibbsConditional for both in a block, returning a Dict. + function cond_mean_dict(c) + return Dict( + @varname(mean1) => cond_mean(c[@varname(var1)], c[@varname(x1)], 10.0), + @varname(mean2) => cond_mean(c[@varname(var2)], c[@varname(x2)], 20.0), + ) + end + sampler = Gibbs( + (:mean1, :mean2) => GibbsConditional(cond_mean_dict), + (:var1, :var2) => HMC(0.1, 10), + ) + chain = sample(StableRNG(24), model, sampler, 1_000) + @test mean(chain, :mean1) ≈ true_mean1 atol = 0.05 + @test mean(chain, :mean2) ≈ true_mean2 atol = 0.05 + + # The same but with a NamedTuple rather than a Dict. + function cond_mean_dict(c) + return (; + mean1=cond_mean(c[@varname(var1)], c[@varname(x1)], 10.0), + mean2=cond_mean(c[@varname(var2)], c[@varname(x2)], 20.0), + ) + end + sampler = Gibbs( + (:mean1, :mean2) => GibbsConditional(cond_mean_dict), + (:var1, :var2) => HMC(0.1, 10), + ) + chain = sample(StableRNG(24), model, sampler, 1_000) + @test mean(chain, :mean1) ≈ true_mean1 atol = 0.05 + @test mean(chain, :mean2) ≈ true_mean2 atol = 0.05 + end + + @testset "condition and fix" begin + # Note that fixed variables don't contribute to the likelihood, and hence the + # conditional posterior changes to be just the prior. + model_condition_fix = condition(fix(base_model; x1=x1_obs); x2=x2_obs) + cond_mean1(_) = Normal(0.0, 10.0) + cond_mean2(c) = cond_mean(c[@varname(var2)], c[@varname(x2)], 20.0) + sampler = Gibbs( + :mean1 => GibbsConditional(cond_mean1), + :mean2 => GibbsConditional(cond_mean2), + (:var1, :var2) => HMC(0.1, 10), + ) + chain = sample(StableRNG(24), model_condition_fix, sampler, 10_000) + @test mean(chain, :mean1) ≈ 0.0 atol = 0.05 + @test mean(chain, :mean2) ≈ true_mean2 atol = 0.05 + + # As above, but reverse the order of condition and fix. + model_fix_condition = fix(condition(base_model; x2=x2_obs); x1=x1_obs) + chain = sample(StableRNG(24), model_condition_fix, sampler, 10_000) + @test mean(chain, :mean1) ≈ 0.0 atol = 0.05 + @test mean(chain, :mean2) ≈ true_mean2 atol = 0.05 + end end - # Test that GibbsConditional is marked as a valid component - @testset "isgibbscomponent" begin - gc = GibbsConditional(c -> Normal(0, 1)) - @test Turing.Inference.isgibbscomponent(gc) + @testset "Indexed VarNames" begin + # Check that GibbsConditional works with VarNames with IndexLenses. + # This example is statistically nonsense, it only tests that the values returned by + # `conditionals` are passed through correctly. + @model function f() + a = Vector{Float64}(undef, 3) + # These priors will be completely ignored in the sampling. + a[1] ~ Normal() + a[2] ~ Normal() + a[3] ~ Normal() + return nothing + end + + m = f() + function conditionals(c) + d1 = Normal(0, 1) + d2 = Normal(c[@varname(a[1])] + 10, 1) + d3 = Normal(c[@varname(a[2])] + 10, 1) + return Dict(@varname(a[1]) => d1, @varname(a[2]) => d2, @varname(a[3]) => d3) + end + + sampler = Gibbs( + (@varname(a[1]), @varname(a[2]), @varname(a[3])) => + GibbsConditional(conditionals), + ) + chain = sample(StableRNG(23), m, sampler, 1_000) + @test mean(chain, Symbol("a[1]")) ≈ 0.0 atol = 0.1 + @test mean(chain, Symbol("a[2]")) ≈ 10.0 atol = 0.1 + @test mean(chain, Symbol("a[3]")) ≈ 20.0 atol = 0.1 end end From 34acad7dbc105cc2b94cf1058afdb8876fc0d9d2 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 17 Nov 2025 17:41:22 +0000 Subject: [PATCH 13/22] Bump patch version to 0.41.2, add HISTORY.md entry --- HISTORY.md | 8 ++++++++ Project.toml | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 039ac6bb9..d15d490be 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,11 @@ +# 0.41.2 + +Add `GibbsConditional`, a "sampler" that can be used to provide analytically known conditional posteriors in a Gibbs sampler. + +In Gibbs sampling, some variables are sampled with a component sampler, while holding other variables conditioned to their current values. Usually one e.g. takes turns sampling one variable with HMC and the other with a particle sampler. However, sometimes the posterior distribution of one variable is known analytically, given the conditioned values of other variables. `GibbsConditional` provides a way to implement these analytically known conditional posteriors and use them as component samplers for Gibbs. See the docstring of `GibbsConditional` for details. + +Note that `GibbsConditional` used to exist in Turing.jl until v0.36, at which it was removed when the whole Gibbs sampler was rewritten. This reintroduces the same functionality, though with a slightly different interface. + # 0.41.1 The `ModeResult` struct returned by `maximum_a_posteriori` and `maximum_likelihood` can now be wrapped in `InitFromParams()`. diff --git a/Project.toml b/Project.toml index cb7b1cb72..2faa78ce5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.41.1" +version = "0.41.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 4786a594015a079337fa9de3013eb4519fcaf8f8 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 17 Nov 2025 17:43:20 +0000 Subject: [PATCH 14/22] Remove spurious change --- test/mcmc/gibbs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index c2df731b0..d02f94982 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -19,7 +19,7 @@ using StableRNGs: StableRNG using Test: @inferred, @test, @test_broken, @test_throws, @testset using Turing using Turing: Inference -using Turing.Inference: AdvancedHMC, AdvancedMH, GibbsConditional +using Turing.Inference: AdvancedHMC, AdvancedMH using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess function check_transition_varnames(transition::Turing.Inference.Transition, parent_varnames) From 8951d981f766f6fdb3c8d4e4613063ee16e29ebd Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 18 Nov 2025 11:12:20 +0000 Subject: [PATCH 15/22] Code style and documentation --- docs/src/api.md | 1 + src/mcmc/gibbs_conditional.jl | 87 ++++++++++++++++------------------ test/mcmc/gibbs_conditional.jl | 33 ++++++------- 3 files changed, 55 insertions(+), 66 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 885d587ea..3e097da23 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -63,6 +63,7 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu | `Emcee` | [`Turing.Inference.Emcee`](@ref) | Affine-invariant ensemble sampler | | `ESS` | [`Turing.Inference.ESS`](@ref) | Elliptical slice sampling | | `Gibbs` | [`Turing.Inference.Gibbs`](@ref) | Gibbs sampling | +| `GibbsConditional` | [`Turing.Inference.GibbsConditional`](@ref) | Gibbs sampling with analytical conditional posterior distributions | | `HMC` | [`Turing.Inference.HMC`](@ref) | Hamiltonian Monte Carlo | | `SGLD` | [`Turing.Inference.SGLD`](@ref) | Stochastic gradient Langevin dynamics | | `SGHMC` | [`Turing.Inference.SGHMC`](@ref) | Stochastic gradient Hamiltonian Monte Carlo | diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index cee1dd66a..75c232a12 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -1,38 +1,48 @@ -using DynamicPPL: VarName -using Random: Random -import AbstractMCMC - """ - GibbsConditional(conditional) + GibbsConditional(get_cond_dists) + +A Gibbs component sampler that samples variables according to user-provided analytical +conditional posterior distributions. -A Gibbs component sampler that samples variables according to user-provided -analytical conditional distributions. +When using Gibbs sampling, sometimes one may know the analytical form of the posterior for +a given variable, given the conditioned values of the other variables. In such cases one can +use `GibbsConditional` as a component sampler to to sample from these known conditionals +directly, avoiding any MCMC methods. One does so with + +```julia +sampler = Gibbs( + (@varname(var1), @varname(var2)) => GibbsConditional(get_cond_dists), + other samplers go here... +) +``` -`conditional` should be a function that takes a `Dict{<:VarName}` of conditioned variables -and their values, and returns one of the following: +Here `get_cond_dists(c::Dict{<:VarName})` should be a function that takes a `Dict` mapping +the conditioned variables (anything other than `var1` and `var2`) to their values, and +returns the conditional posterior distributions for `var1` and `var2`. You may, of course, +have any number of variables being sampled as a block in this manner, we only use two as an +example. The return value of `get_cond_dists` should be one of the following: - A single `Distribution`, if only one variable is being sampled. - An `AbstractDict{<:VarName,<:Distribution}` that maps the variables being sampled to their - `Distribution`s. + conditional posteriors E.g. `Dict(@varname(var1) => dist1, @varname(var2) => dist2)`. - A `NamedTuple` of `Distribution`s, which is like the `AbstractDict` case but can be used - if all the variable names are single `Symbol`s, and may be more performant. - -If a Gibbs component is created with `(:var1, :var2) => GibbsConditional(conditional)`, then -`var1` and `var2` should be in the keys of the return value of `conditional`. + if all the variable names are single `Symbol`s, and may be more performant. E.g. + `(; var1=dist1, var2=dist2)`. # Examples ```julia # Define a model @model function inverse_gdemo(x) - λ ~ Gamma(2, inv(3)) - m ~ Normal(0, sqrt(1 / λ)) + precision ~ Gamma(2, inv(3)) + std = sqrt(1 / precision) + m ~ Normal(0, std) for i in 1:length(x) - x[i] ~ Normal(m, sqrt(1 / λ)) + x[i] ~ Normal(m, std) end end # Define analytical conditionals -function cond_λ(c) +function cond_precision(c) a = 2.0 b = inv(3) m = c[@varname(m)] @@ -44,27 +54,26 @@ function cond_λ(c) end function cond_m(c) - λ = c[@varname(λ)] + precision = c[@varname(precision)] x = c[@varname(x)] n = length(x) m_mean = sum(x) / (n + 1) - m_var = 1 / (λ * (n + 1)) + m_var = 1 / (precision * (n + 1)) return Normal(m_mean, sqrt(m_var)) end # Sample using GibbsConditional model = inverse_gdemo([1.0, 2.0, 3.0]) chain = sample(model, Gibbs( - :λ => GibbsConditional(cond_λ), + :precision => GibbsConditional(cond_precision), :m => GibbsConditional(cond_m) ), 1000) ``` """ struct GibbsConditional{C} <: AbstractSampler - conditional::C + get_cond_dists::C end -# Mark GibbsConditional as a valid Gibbs component isgibbscomponent(::GibbsConditional) = true """ @@ -80,8 +89,6 @@ function build_variable_dict(model::DynamicPPL.Model) # TODO(mhauru) Can we avoid invlinking all the time? Note that this causes a model # evaluation, which may be expensive. global_vi = DynamicPPL.invlink(get_gibbs_global_varinfo(context), model) - # TODO(mhauru) Double-check that the ordered of precedence here is correct. Should we - # in fact error if there is any overlap in the keys? return merge( DynamicPPL.values_as(global_vi, Dict), Dict( @@ -98,17 +105,13 @@ function get_gibbs_global_varinfo(context::DynamicPPL.AbstractContext) elseif DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent get_gibbs_global_varinfo(DynamicPPL.childcontext(context)) else - throw(ArgumentError("""No GibbsContext found in context stack. \ - Are you trying to use GibbsConditional outside of Gibbs? - """)) + msg = """No GibbsContext found in context stack. Are you trying to use \ + GibbsConditional outside of Gibbs? + """ + throw(ArgumentError(msg)) end end -""" - DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi) - -Initialize the GibbsConditional sampler. -""" function initialstep( ::Random.AbstractRNG, model::DynamicPPL.Model, @@ -122,11 +125,6 @@ function initialstep( return nothing, state end -""" - AbstractMCMC.step(rng, model, sampler::GibbsConditional, state) - -Perform a step of GibbsConditional sampling. -""" function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, @@ -137,11 +135,9 @@ function AbstractMCMC.step( # Get all the conditioned variable values from the model context. This is assumed to # include a GibbsContext as part of the context stack. condvals = build_variable_dict(model) + conddists = sampler.get_cond_dists(condvals) - # Get the conditional distributions - conddists = sampler.conditional(condvals) - - # We support three different kinds of return values for `sample.conditional`, to make + # We support three different kinds of return values for `sample.get_cond_dists`, to make # life easier for the user. if conddists isa AbstractDict for (vn, dist) in conddists @@ -154,7 +150,7 @@ function AbstractMCMC.step( end else # Single variable case - vn = first(keys(state)) + vn = only(keys(state)) state = setindex!!(state, rand(rng, conddists), vn) end @@ -164,10 +160,7 @@ function AbstractMCMC.step( end function setparams_varinfo!!( - model::DynamicPPL.Model, - sampler::GibbsConditional, - state, - params::DynamicPPL.AbstractVarInfo, + ::DynamicPPL.Model, ::GibbsConditional, ::Any, params::DynamicPPL.AbstractVarInfo ) return params end diff --git a/test/mcmc/gibbs_conditional.jl b/test/mcmc/gibbs_conditional.jl index 2371d806e..b09b89361 100644 --- a/test/mcmc/gibbs_conditional.jl +++ b/test/mcmc/gibbs_conditional.jl @@ -1,14 +1,12 @@ module GibbsConditionalTests -using Distributions: InverseGamma, Normal -using Distributions: sample using DynamicPPL: DynamicPPL using Random: Random using StableRNGs: StableRNG -using Test: @inferred, @test, @test_broken, @test_throws, @testset +using Test: @test, @testset using Turing -@testset "GibbsConditional" verbose = true begin +@testset "GibbsConditional" begin @testset "Gamma model tests" begin @model function inverse_gdemo(x) precision ~ Gamma(2, inv(3)) @@ -47,27 +45,24 @@ using Turing reference_sampler = NUTS() reference_chain = sample(rng, model, reference_sampler, 10_000) - # Use both conditionals + # Use both conditionals, check results against reference sampler. sampler = Gibbs( :precision => GibbsConditional(cond_precision), :m => GibbsConditional(cond_m) ) chain = sample(rng, model, sampler, 10_000) - @test size(chain, 1) == 10_000 @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.05 @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.05 - # Mix GibbsConditional with standard samplers + # Mix GibbsConditional with an MCMC sampler sampler = Gibbs(:precision => GibbsConditional(cond_precision), :m => MH()) chain = sample(rng, model, sampler, 10_000) - @test size(chain, 1) == 10_000 @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.05 @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.05 sampler = Gibbs(:m => GibbsConditional(cond_m), :precision => HMC(0.1, 10)) chain = sample(rng, model, sampler, 10_000) - @test size(chain, 1) == 10_000 @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.05 @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.05 @@ -84,7 +79,6 @@ using Turing :m => PG(10), ) chain = sample(rng, model, sampler, 1_000) - @test size(chain, 1) == 1_000 @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.05 @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.05 @@ -97,7 +91,7 @@ using Turing return x ~ MvNormal(fill(mean, dim), I * var) end - # Conditional for mean given var and x + # Conditional posterior for mean given var and x function cond_mean(c) var = c[@varname(var)] x = c[@varname(x)] @@ -123,7 +117,6 @@ using Turing @test mean(chain, :mean) ≈ true_mean atol = 0.05 end - # Test that the different ways of returning values from the conditional function work. @testset "Double simple normal" begin # This is the same model as simple_normal above, but just doubled. @model function double_simple_normal(dim1, dim2) @@ -159,6 +152,7 @@ using Turing x2_obs = randn(rng, dim2) .+ true_mean2 base_model = double_simple_normal(dim1, dim2) + # Test different ways of returning values from the conditional function. @testset "conditionals return types" begin # Test using GibbsConditional for both separately. cond_mean1(c) = cond_mean(c[@varname(var1)], c[@varname(x1)], 10.0) @@ -190,15 +184,15 @@ using Turing @test mean(chain, :mean1) ≈ true_mean1 atol = 0.05 @test mean(chain, :mean2) ≈ true_mean2 atol = 0.05 - # The same but with a NamedTuple rather than a Dict. - function cond_mean_dict(c) + # As above but with a NamedTuple rather than a Dict. + function cond_mean_nt(c) return (; mean1=cond_mean(c[@varname(var1)], c[@varname(x1)], 10.0), mean2=cond_mean(c[@varname(var2)], c[@varname(x2)], 20.0), ) end sampler = Gibbs( - (:mean1, :mean2) => GibbsConditional(cond_mean_dict), + (:mean1, :mean2) => GibbsConditional(cond_mean_nt), (:var1, :var2) => HMC(0.1, 10), ) chain = sample(StableRNG(24), model, sampler, 1_000) @@ -206,6 +200,7 @@ using Turing @test mean(chain, :mean2) ≈ true_mean2 atol = 0.05 end + # Test simultaneously conditioning and fixing variables. @testset "condition and fix" begin # Note that fixed variables don't contribute to the likelihood, and hence the # conditional posterior changes to be just the prior. @@ -229,8 +224,8 @@ using Turing end end + # Check that GibbsConditional works with VarNames with IndexLenses. @testset "Indexed VarNames" begin - # Check that GibbsConditional works with VarNames with IndexLenses. # This example is statistically nonsense, it only tests that the values returned by # `conditionals` are passed through correctly. @model function f() @@ -255,9 +250,9 @@ using Turing GibbsConditional(conditionals), ) chain = sample(StableRNG(23), m, sampler, 1_000) - @test mean(chain, Symbol("a[1]")) ≈ 0.0 atol = 0.1 - @test mean(chain, Symbol("a[2]")) ≈ 10.0 atol = 0.1 - @test mean(chain, Symbol("a[3]")) ≈ 20.0 atol = 0.1 + @test mean(chain, Symbol("a[1]")) ≈ 0.0 atol = 0.05 + @test mean(chain, Symbol("a[2]")) ≈ 10.0 atol = 0.05 + @test mean(chain, Symbol("a[3]")) ≈ 20.0 atol = 0.05 end end From 45ab5f851bb735edd5f3f23e44d998f2f9f795e1 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 18 Nov 2025 11:38:46 +0000 Subject: [PATCH 16/22] Add one test_throws, tweak test thresholds and dimensions --- test/mcmc/gibbs_conditional.jl | 106 +++++++++++++++++++-------------- 1 file changed, 62 insertions(+), 44 deletions(-) diff --git a/test/mcmc/gibbs_conditional.jl b/test/mcmc/gibbs_conditional.jl index b09b89361..1fddfb3ea 100644 --- a/test/mcmc/gibbs_conditional.jl +++ b/test/mcmc/gibbs_conditional.jl @@ -3,7 +3,7 @@ module GibbsConditionalTests using DynamicPPL: DynamicPPL using Random: Random using StableRNGs: StableRNG -using Test: @test, @testset +using Test: @test, @test_throws, @testset using Turing @testset "GibbsConditional" begin @@ -49,23 +49,23 @@ using Turing sampler = Gibbs( :precision => GibbsConditional(cond_precision), :m => GibbsConditional(cond_m) ) - chain = sample(rng, model, sampler, 10_000) - @test size(chain, 1) == 10_000 - @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.05 - @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.05 + chain = sample(rng, model, sampler, 1_000) + @test size(chain, 1) == 1_000 + @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.1 + @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.1 # Mix GibbsConditional with an MCMC sampler sampler = Gibbs(:precision => GibbsConditional(cond_precision), :m => MH()) - chain = sample(rng, model, sampler, 10_000) - @test size(chain, 1) == 10_000 - @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.05 - @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.05 + chain = sample(rng, model, sampler, 1_000) + @test size(chain, 1) == 1_000 + @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.1 + @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.1 sampler = Gibbs(:m => GibbsConditional(cond_m), :precision => HMC(0.1, 10)) - chain = sample(rng, model, sampler, 10_000) - @test size(chain, 1) == 10_000 - @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.05 - @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.05 + chain = sample(rng, model, sampler, 1_000) + @test size(chain, 1) == 1_000 + @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.1 + @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.1 # Block sample, sampling the same variable with multiple component samplers. sampler = Gibbs( @@ -80,8 +80,8 @@ using Turing ) chain = sample(rng, model, sampler, 1_000) @test size(chain, 1) == 1_000 - @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.05 - @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.05 + @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.1 + @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.1 end @testset "Simple normal model" begin @@ -106,7 +106,7 @@ using Turing end rng = StableRNG(23) - dim = 10_000 + dim = 1_000 true_mean = 2.0 x_obs = randn(rng, dim) .+ true_mean model = simple_normal(dim) | (; x=x_obs) @@ -119,13 +119,13 @@ using Turing @testset "Double simple normal" begin # This is the same model as simple_normal above, but just doubled. + prior_std1 = 10.0 + prior_std2 = 20.0 @model function double_simple_normal(dim1, dim2) - prior_std1 = 10.0 mean1 ~ Normal(0, prior_std1) var1 ~ truncated(Normal(1, 1); lower=0.01) x1 ~ MvNormal(fill(mean1, dim1), I * var1) - prior_std2 = 20.0 mean2 ~ Normal(0, prior_std2) var2 ~ truncated(Normal(1, 1); lower=0.01) x2 ~ MvNormal(fill(mean2, dim2), I * var2) @@ -144,10 +144,10 @@ using Turing end rng = StableRNG(23) - dim1 = 10_000 + dim1 = 1_000 true_mean1 = -10.0 x1_obs = randn(rng, dim1) .+ true_mean1 - dim2 = 20_000 + dim2 = 2_000 true_mean2 = -20.0 x2_obs = randn(rng, dim2) .+ true_mean2 base_model = double_simple_normal(dim1, dim2) @@ -155,49 +155,51 @@ using Turing # Test different ways of returning values from the conditional function. @testset "conditionals return types" begin # Test using GibbsConditional for both separately. - cond_mean1(c) = cond_mean(c[@varname(var1)], c[@varname(x1)], 10.0) - cond_mean2(c) = cond_mean(c[@varname(var2)], c[@varname(x2)], 20.0) + cond_mean1(c) = cond_mean(c[@varname(var1)], c[@varname(x1)], prior_std1) + cond_mean2(c) = cond_mean(c[@varname(var2)], c[@varname(x2)], prior_std2) model = base_model | (; x1=x1_obs, x2=x2_obs) sampler = Gibbs( :mean1 => GibbsConditional(cond_mean1), :mean2 => GibbsConditional(cond_mean2), (:var1, :var2) => HMC(0.1, 10), ) - chain = sample(StableRNG(24), model, sampler, 1_000) + chain = sample(StableRNG(23), model, sampler, 1_000) # The correct posterior mean isn't true_mean, but it is very close, because we # have a lot of data. - @test mean(chain, :mean1) ≈ true_mean1 atol = 0.05 - @test mean(chain, :mean2) ≈ true_mean2 atol = 0.05 + @test mean(chain, :mean1) ≈ true_mean1 atol = 0.1 + @test mean(chain, :mean2) ≈ true_mean2 atol = 0.1 # Test using GibbsConditional for both in a block, returning a Dict. function cond_mean_dict(c) return Dict( - @varname(mean1) => cond_mean(c[@varname(var1)], c[@varname(x1)], 10.0), - @varname(mean2) => cond_mean(c[@varname(var2)], c[@varname(x2)], 20.0), + @varname(mean1) => + cond_mean(c[@varname(var1)], c[@varname(x1)], prior_std1), + @varname(mean2) => + cond_mean(c[@varname(var2)], c[@varname(x2)], prior_std2), ) end sampler = Gibbs( (:mean1, :mean2) => GibbsConditional(cond_mean_dict), (:var1, :var2) => HMC(0.1, 10), ) - chain = sample(StableRNG(24), model, sampler, 1_000) - @test mean(chain, :mean1) ≈ true_mean1 atol = 0.05 - @test mean(chain, :mean2) ≈ true_mean2 atol = 0.05 + chain = sample(StableRNG(23), model, sampler, 1_000) + @test mean(chain, :mean1) ≈ true_mean1 atol = 0.1 + @test mean(chain, :mean2) ≈ true_mean2 atol = 0.1 # As above but with a NamedTuple rather than a Dict. function cond_mean_nt(c) return (; - mean1=cond_mean(c[@varname(var1)], c[@varname(x1)], 10.0), - mean2=cond_mean(c[@varname(var2)], c[@varname(x2)], 20.0), + mean1=cond_mean(c[@varname(var1)], c[@varname(x1)], prior_std1), + mean2=cond_mean(c[@varname(var2)], c[@varname(x2)], prior_std2), ) end sampler = Gibbs( (:mean1, :mean2) => GibbsConditional(cond_mean_nt), (:var1, :var2) => HMC(0.1, 10), ) - chain = sample(StableRNG(24), model, sampler, 1_000) - @test mean(chain, :mean1) ≈ true_mean1 atol = 0.05 - @test mean(chain, :mean2) ≈ true_mean2 atol = 0.05 + chain = sample(StableRNG(23), model, sampler, 1_000) + @test mean(chain, :mean1) ≈ true_mean1 atol = 0.1 + @test mean(chain, :mean2) ≈ true_mean2 atol = 0.1 end # Test simultaneously conditioning and fixing variables. @@ -205,22 +207,27 @@ using Turing # Note that fixed variables don't contribute to the likelihood, and hence the # conditional posterior changes to be just the prior. model_condition_fix = condition(fix(base_model; x1=x1_obs); x2=x2_obs) - cond_mean1(_) = Normal(0.0, 10.0) - cond_mean2(c) = cond_mean(c[@varname(var2)], c[@varname(x2)], 20.0) + function cond_mean1(c) + @assert @varname(var1) in keys(c) + @assert @varname(x1) in keys(c) + return Normal(0.0, prior_std1) + end + cond_mean2(c) = cond_mean(c[@varname(var2)], c[@varname(x2)], prior_std2) sampler = Gibbs( :mean1 => GibbsConditional(cond_mean1), :mean2 => GibbsConditional(cond_mean2), - (:var1, :var2) => HMC(0.1, 10), + :var1 => HMC(0.1, 10), + :var2 => HMC(0.1, 10), ) - chain = sample(StableRNG(24), model_condition_fix, sampler, 10_000) - @test mean(chain, :mean1) ≈ 0.0 atol = 0.05 - @test mean(chain, :mean2) ≈ true_mean2 atol = 0.05 + chain = sample(StableRNG(23), model_condition_fix, sampler, 10_000) + @test mean(chain, :mean1) ≈ 0.0 atol = 0.1 + @test mean(chain, :mean2) ≈ true_mean2 atol = 0.1 # As above, but reverse the order of condition and fix. model_fix_condition = fix(condition(base_model; x2=x2_obs); x1=x1_obs) - chain = sample(StableRNG(24), model_condition_fix, sampler, 10_000) - @test mean(chain, :mean1) ≈ 0.0 atol = 0.05 - @test mean(chain, :mean2) ≈ true_mean2 atol = 0.05 + chain = sample(StableRNG(23), model_condition_fix, sampler, 10_000) + @test mean(chain, :mean1) ≈ 0.0 atol = 0.1 + @test mean(chain, :mean2) ≈ true_mean2 atol = 0.1 end end @@ -254,6 +261,17 @@ using Turing @test mean(chain, Symbol("a[2]")) ≈ 10.0 atol = 0.05 @test mean(chain, Symbol("a[3]")) ≈ 20.0 atol = 0.05 end + + @testset "Helpful error outside Gibbs" begin + @model f() = x ~ Normal() + m = f() + cond_x(_) = Normal() + sampler = GibbsConditional(cond_x) + @test_throws( + "Are you trying to use GibbsConditional outside of Gibbs?", + sample(m, sampler, 3), + ) + end end end From 805bc6072fba9d6539dfec02521b3df90061d52c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 18 Nov 2025 16:54:25 +0000 Subject: [PATCH 17/22] Apply suggestions from code review Co-authored-by: Penelope Yong --- src/mcmc/gibbs_conditional.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index 75c232a12..56e203289 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -36,7 +36,7 @@ example. The return value of `get_cond_dists` should be one of the following: precision ~ Gamma(2, inv(3)) std = sqrt(1 / precision) m ~ Normal(0, std) - for i in 1:length(x) + for i in eachindex(x) x[i] ~ Normal(m, std) end end @@ -44,12 +44,12 @@ end # Define analytical conditionals function cond_precision(c) a = 2.0 - b = inv(3) + b = 3.0 m = c[@varname(m)] x = c[@varname(x)] n = length(x) a_new = a + (n + 1) / 2 - b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 + b_new = b + sum(abs2, x .- m) / 2 + m^2 / 2 return Gamma(a_new, 1 / b_new) end From d0c3cf4d6ba12e52c6f41305ca411bf27934c6e0 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 18 Nov 2025 16:59:36 +0000 Subject: [PATCH 18/22] Add links for where to get analytical posteriors --- src/mcmc/gibbs_conditional.jl | 3 ++- test/mcmc/gibbs_conditional.jl | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index 56e203289..5ff761a8a 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -41,7 +41,8 @@ example. The return value of `get_cond_dists` should be one of the following: end end -# Define analytical conditionals +# Define analytical conditionals. See +# https://en.wikipedia.org/wiki/Conjugate_prior#When_likelihood_function_is_a_continuous_distribution function cond_precision(c) a = 2.0 b = 3.0 diff --git a/test/mcmc/gibbs_conditional.jl b/test/mcmc/gibbs_conditional.jl index 1fddfb3ea..a69ac4128 100644 --- a/test/mcmc/gibbs_conditional.jl +++ b/test/mcmc/gibbs_conditional.jl @@ -17,7 +17,8 @@ using Turing end end - # Define analytical conditionals + # Define analytical conditionals. See + # https://en.wikipedia.org/wiki/Conjugate_prior#When_likelihood_function_is_a_continuous_distribution function cond_precision(c) a = 2.0 b = 3.0 @@ -91,7 +92,8 @@ using Turing return x ~ MvNormal(fill(mean, dim), I * var) end - # Conditional posterior for mean given var and x + # Conditional posterior for mean given var and x. See + # https://en.wikipedia.org/wiki/Conjugate_prior#When_likelihood_function_is_a_continuous_distribution function cond_mean(c) var = c[@varname(var)] x = c[@varname(x)] From 98f4213a72e6b333142197e640d84031758ebf39 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 18 Nov 2025 17:00:41 +0000 Subject: [PATCH 19/22] Update TODO note --- src/mcmc/gibbs_conditional.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index 5ff761a8a..0a5bcbe24 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -87,8 +87,7 @@ function build_variable_dict(model::DynamicPPL.Model) context = model.context cond_nt = DynamicPPL.conditioned(context) fixed_nt = DynamicPPL.fixed(context) - # TODO(mhauru) Can we avoid invlinking all the time? Note that this causes a model - # evaluation, which may be expensive. + # TODO(mhauru) Can we avoid invlinking all the time? global_vi = DynamicPPL.invlink(get_gibbs_global_varinfo(context), model) return merge( DynamicPPL.values_as(global_vi, Dict), From c74b0a001b24793b3e69768ed0bd69b08ac01beb Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 18 Nov 2025 17:21:43 +0000 Subject: [PATCH 20/22] Fix a GibbsConditional bug, add a test --- src/mcmc/gibbs_conditional.jl | 15 +++++------ test/mcmc/gibbs_conditional.jl | 46 +++++++++++++++++++++++----------- 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index 0a5bcbe24..2a901a733 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -85,17 +85,18 @@ set in GibbsContext, ConditionContext, or FixedContext. """ function build_variable_dict(model::DynamicPPL.Model) context = model.context - cond_nt = DynamicPPL.conditioned(context) - fixed_nt = DynamicPPL.fixed(context) + cond_vals = DynamicPPL.conditioned(context) + fixed_vals = DynamicPPL.fixed(context) # TODO(mhauru) Can we avoid invlinking all the time? global_vi = DynamicPPL.invlink(get_gibbs_global_varinfo(context), model) + # TODO(mhauru) This creates a lot of Dicts, which are then immediately merged into one. + # Also, DynamicPPL.to_varname_dict is known to be inefficient. Make a more efficient + # implementation. return merge( DynamicPPL.values_as(global_vi, Dict), - Dict( - (DynamicPPL.VarName{sym}() => val for (sym, val) in pairs(cond_nt))..., - (DynamicPPL.VarName{sym}() => val for (sym, val) in pairs(fixed_nt))..., - (DynamicPPL.VarName{sym}() => val for (sym, val) in pairs(model.args))..., - ), + DynamicPPL.to_varname_dict(cond_vals), + DynamicPPL.to_varname_dict(fixed_vals), + DynamicPPL.to_varname_dict(model.args), ) end diff --git a/test/mcmc/gibbs_conditional.jl b/test/mcmc/gibbs_conditional.jl index a69ac4128..a226d8475 100644 --- a/test/mcmc/gibbs_conditional.jl +++ b/test/mcmc/gibbs_conditional.jl @@ -239,29 +239,47 @@ using Turing # `conditionals` are passed through correctly. @model function f() a = Vector{Float64}(undef, 3) + a[1] ~ Normal(0.0) + a[2] ~ Normal(10.0) + a[3] ~ Normal(20.0) + b = Vector{Float64}(undef, 3) # These priors will be completely ignored in the sampling. - a[1] ~ Normal() - a[2] ~ Normal() - a[3] ~ Normal() + b[1] ~ Normal() + b[2] ~ Normal() + b[3] ~ Normal() return nothing end m = f() - function conditionals(c) - d1 = Normal(0, 1) - d2 = Normal(c[@varname(a[1])] + 10, 1) - d3 = Normal(c[@varname(a[2])] + 10, 1) - return Dict(@varname(a[1]) => d1, @varname(a[2]) => d2, @varname(a[3]) => d3) + function conditionals_b(c) + d1 = Normal(c[@varname(a[1])], 1) + d2 = Normal(c[@varname(a[2])], 1) + d3 = Normal(c[@varname(a[3])], 1) + return Dict(@varname(b[1]) => d1, @varname(b[2]) => d2, @varname(b[3]) => d3) end sampler = Gibbs( - (@varname(a[1]), @varname(a[2]), @varname(a[3])) => - GibbsConditional(conditionals), + (@varname(b[1]), @varname(b[2]), @varname(b[3])) => + GibbsConditional(conditionals_b), + (@varname(a[1]), @varname(a[2]), @varname(a[3])) => ESS(), ) - chain = sample(StableRNG(23), m, sampler, 1_000) - @test mean(chain, Symbol("a[1]")) ≈ 0.0 atol = 0.05 - @test mean(chain, Symbol("a[2]")) ≈ 10.0 atol = 0.05 - @test mean(chain, Symbol("a[3]")) ≈ 20.0 atol = 0.05 + chain = sample(StableRNG(23), m, sampler, 10_000) + @test mean(chain, Symbol("b[1]")) ≈ 0.0 atol = 0.05 + @test mean(chain, Symbol("b[2]")) ≈ 10.0 atol = 0.05 + @test mean(chain, Symbol("b[3]")) ≈ 20.0 atol = 0.05 + + m_condfix = fix( + condition(m, Dict(@varname(a[1]) => 100.0)), Dict(@varname(a[2]) => 200.0) + ) + sampler = Gibbs( + (@varname(b[1]), @varname(b[2]), @varname(b[3])) => + GibbsConditional(conditionals_b), + @varname(a[3]) => ESS(), + ) + chain = sample(StableRNG(23), m_condfix, sampler, 10_000) + @test mean(chain, Symbol("b[1]")) ≈ 100.0 atol = 0.05 + @test mean(chain, Symbol("b[2]")) ≈ 200.0 atol = 0.05 + @test mean(chain, Symbol("b[3]")) ≈ 20.0 atol = 0.05 end @testset "Helpful error outside Gibbs" begin From 4a7d08cf1efa6524bcd638771ce37774f333fb82 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 18 Nov 2025 17:25:06 +0000 Subject: [PATCH 21/22] Set seeds better --- test/mcmc/gibbs_conditional.jl | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/test/mcmc/gibbs_conditional.jl b/test/mcmc/gibbs_conditional.jl index a226d8475..07d676df1 100644 --- a/test/mcmc/gibbs_conditional.jl +++ b/test/mcmc/gibbs_conditional.jl @@ -39,31 +39,30 @@ using Turing return Normal(m_mean, sqrt(m_var)) end - rng = StableRNG(23) x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] model = inverse_gdemo(x_obs) reference_sampler = NUTS() - reference_chain = sample(rng, model, reference_sampler, 10_000) + reference_chain = sample(StableRNG(23), model, reference_sampler, 10_000) # Use both conditionals, check results against reference sampler. sampler = Gibbs( :precision => GibbsConditional(cond_precision), :m => GibbsConditional(cond_m) ) - chain = sample(rng, model, sampler, 1_000) + chain = sample(StableRNG(23), model, sampler, 1_000) @test size(chain, 1) == 1_000 @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.1 @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.1 # Mix GibbsConditional with an MCMC sampler sampler = Gibbs(:precision => GibbsConditional(cond_precision), :m => MH()) - chain = sample(rng, model, sampler, 1_000) + chain = sample(StableRNG(23), model, sampler, 1_000) @test size(chain, 1) == 1_000 @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.1 @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.1 sampler = Gibbs(:m => GibbsConditional(cond_m), :precision => HMC(0.1, 10)) - chain = sample(rng, model, sampler, 1_000) + chain = sample(StableRNG(23), model, sampler, 1_000) @test size(chain, 1) == 1_000 @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.1 @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.1 @@ -79,7 +78,7 @@ using Turing :m => GibbsConditional(cond_m), :m => PG(10), ) - chain = sample(rng, model, sampler, 1_000) + chain = sample(StableRNG(23), model, sampler, 1_000) @test size(chain, 1) == 1_000 @test mean(chain, :precision) ≈ mean(reference_chain, :precision) atol = 0.1 @test mean(chain, :m) ≈ mean(reference_chain, :m) atol = 0.1 @@ -107,13 +106,12 @@ using Turing return Normal(post_mean, sqrt(post_var)) end - rng = StableRNG(23) dim = 1_000 true_mean = 2.0 - x_obs = randn(rng, dim) .+ true_mean + x_obs = randn(StableRNG(23), dim) .+ true_mean model = simple_normal(dim) | (; x=x_obs) sampler = Gibbs(:mean => GibbsConditional(cond_mean), :var => MH()) - chain = sample(rng, model, sampler, 1_000) + chain = sample(StableRNG(23), model, sampler, 1_000) # The correct posterior mean isn't true_mean, but it is very close, because we # have a lot of data. @test mean(chain, :mean) ≈ true_mean atol = 0.05 @@ -145,13 +143,12 @@ using Turing return Normal(post_mean, sqrt(post_var)) end - rng = StableRNG(23) dim1 = 1_000 true_mean1 = -10.0 - x1_obs = randn(rng, dim1) .+ true_mean1 + x1_obs = randn(StableRNG(23), dim1) .+ true_mean1 dim2 = 2_000 true_mean2 = -20.0 - x2_obs = randn(rng, dim2) .+ true_mean2 + x2_obs = randn(StableRNG(24), dim2) .+ true_mean2 base_model = double_simple_normal(dim1, dim2) # Test different ways of returning values from the conditional function. From 744d2548aeba4e2eeb68637657c6cddd9b9687c7 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 19 Nov 2025 16:26:27 +0000 Subject: [PATCH 22/22] Use getvalue in docstring --- src/mcmc/gibbs_conditional.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index 2a901a733..8586a002d 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -46,8 +46,12 @@ end function cond_precision(c) a = 2.0 b = 3.0 - m = c[@varname(m)] - x = c[@varname(x)] + # We use AbstractPPL.getvalue instead of indexing into `c` directly to guard against + # issues where e.g. you try to get `c[@varname(x[1])]` but only `@varname(x)` is present + # in `c`. `getvalue` handles that gracefully, `getindex` doesn't. In this case + # `getindex` would suffice, but `getvalue` is good practice. + m = AbstractPPL.getvalue(c, @varname(m)) + x = AbstractPPL.getvalue(c, @varname(x)) n = length(x) a_new = a + (n + 1) / 2 b_new = b + sum(abs2, x .- m) / 2 + m^2 / 2 @@ -55,8 +59,8 @@ function cond_precision(c) end function cond_m(c) - precision = c[@varname(precision)] - x = c[@varname(x)] + precision = AbstractPPL.getvalue(c, @varname(precision)) + x = AbstractPPL.getvalue(c, @varname(x)) n = length(x) m_mean = sum(x) / (n + 1) m_var = 1 / (precision * (n + 1))