Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
ea69430
update vi interface to match AdvancedVI@0.5
Red-Portal Oct 22, 2025
86ee6dd
revert unintended commit of `runtests.jl`
Red-Portal Oct 22, 2025
3e30e04
Merge branch 'breaking' of github.com:TuringLang/Turing.jl into bump_…
Red-Portal Oct 24, 2025
d870045
update docs for `vi`
Red-Portal Oct 24, 2025
2d928e0
add history entry for `AdvancedVI@0.5`
Red-Portal Oct 24, 2025
5211b37
remove export for removed symbol
Red-Portal Oct 24, 2025
f0d615d
fix formatting
Red-Portal Oct 24, 2025
1b2351f
fix formatting
Red-Portal Oct 24, 2025
2be31b4
tidy tests advi
Red-Portal Oct 24, 2025
e48ae42
fix rename file `advi.jl` to `vi.jl` to reflect naming changes
Red-Portal Oct 24, 2025
44f7762
fix docs
Red-Portal Oct 25, 2025
fd0e928
fix HISTORY.md
Red-Portal Oct 25, 2025
77276bd
fix HISTORY.md
Red-Portal Oct 25, 2025
cb1620c
Merge branch 'main' of github.com:TuringLang/Turing.jl into bump_adva…
Red-Portal Oct 25, 2025
e70ddb4
update history
Red-Portal Oct 25, 2025
115802d
Merge branch 'bump_advancedvi_0.5' of github.com:TuringLang/Turing.jl…
Red-Portal Oct 25, 2025
25b5087
Merge branch 'main' of github.com:TuringLang/Turing.jl into bump_adva…
Red-Portal Nov 19, 2025
4c02f7b
bump AdvancedVI version
Red-Portal Nov 19, 2025
6518b82
add exports new algorithms, modify `vi` to operate in unconstrained
Red-Portal Nov 19, 2025
5bd6978
Merge branch 'breaking' of github.com:TuringLang/Turing.jl into bump_…
Red-Portal Nov 19, 2025
874a0b2
add clarification on initializing unconstrained algorithms
Red-Portal Nov 19, 2025
e021eb7
update api
Red-Portal Nov 19, 2025
eec7ef2
run formatter
Red-Portal Nov 19, 2025
b6d8202
run formatter
Red-Portal Nov 19, 2025
b900ab4
run formatter
Red-Portal Nov 19, 2025
e71b07b
run formatter
Red-Portal Nov 19, 2025
c08de12
run formatter
Red-Portal Nov 19, 2025
ae80f1e
run formatter
Red-Portal Nov 19, 2025
73bd309
run formatter
Red-Portal Nov 19, 2025
eaac4c3
run formatter
Red-Portal Nov 19, 2025
757ebb4
revert changes to README
Red-Portal Nov 19, 2025
05ab711
fix wrong use of transformation in vi
Red-Portal Nov 20, 2025
91606b5
change inital value for scale matrices to 0.6*I and update docs
Red-Portal Nov 20, 2025
722153a
run formatter
Red-Portal Nov 20, 2025
65bfaa3
fix rename advi to vi
Red-Portal Nov 21, 2025
61e59a6
Merge branch 'bump_advancedvi_0.5' of github.com:TuringLang/Turing.jl…
Red-Portal Nov 21, 2025
4a039dd
add batch-and-match
Red-Portal Nov 21, 2025
f782f56
fix format api table
Red-Portal Nov 21, 2025
5251ee7
fix use fullrank Gaussian in tests since all algorithm support it
Red-Portal Nov 21, 2025
b665f96
fix tweak step sizes, remove unused kwargs
Red-Portal Nov 21, 2025
15af544
fix increase budgets for failing algorithms
Red-Portal Nov 21, 2025
7711e42
run formatter
Red-Portal Nov 21, 2025
a319962
rename main variational inference file to match module name
Red-Portal Nov 22, 2025
6e76afa
run formatter
Red-Portal Nov 27, 2025
2e78774
fix docstring
Red-Portal Nov 27, 2025
98d07f8
run formatter
Red-Portal Nov 27, 2025
cc2cac9
update docstring
Red-Portal Nov 27, 2025
2d88ddd
update docstring
Red-Portal Nov 27, 2025
025202f
fix missing docstring for `unconstrained` keyword argument
Red-Portal Nov 27, 2025
59fe8dc
fix relax assert transformed dist to an exception and add test
Red-Portal Nov 27, 2025
b1aa675
Update HISTORY.md
Red-Portal Nov 27, 2025
22895b0
run formatter
Red-Portal Nov 27, 2025
733f0be
run formatter
Red-Portal Nov 27, 2025
2530f05
run formatter
Red-Portal Nov 27, 2025
cb5a798
fix document the change in output of `vi` in `HISTORY`
Red-Portal Nov 27, 2025
8bd0ba7
run formatter
Red-Portal Nov 27, 2025
ed25920
run formatter
Red-Portal Nov 27, 2025
6be562d
run formatter
Red-Portal Nov 27, 2025
df1b832
run formatter
Red-Portal Nov 27, 2025
d83148b
add missing namespace
Red-Portal Nov 27, 2025
a686cae
fix add missing import for `@test_throws`
Red-Portal Nov 29, 2025
31256e3
Merge branch 'breaking' into bump_advancedvi_0.5
Red-Portal Nov 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Accessors = "0.1"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7, 0.8"
AdvancedMH = "0.8"
AdvancedPS = "0.7"
AdvancedVI = "0.4"
AdvancedVI = "0.5"
BangBang = "0.4.2"
Bijectors = "0.14, 0.15"
Compat = "4.15.0"
Expand Down
3 changes: 3 additions & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ export
q_locationscale,
q_meanfield_gaussian,
q_fullrank_gaussian,
KLMinRepGradProxDescent,
KLMinRepGradDescent,
KLMinScoreGradDescent,
# ADTypes
AutoForwardDiff,
AutoReverseDiff,
Expand Down
80 changes: 34 additions & 46 deletions src/variational/VariationalInference.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@

module Variational

using DynamicPPL
using AdvancedVI:
AdvancedVI, KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent
using ADTypes
using Bijectors: Bijectors
using Distributions
using DynamicPPL
using LinearAlgebra
using LogDensityProblems
using Random
using ..Turing: DEFAULT_ADTYPE, PROGRESS

import ..Turing: DEFAULT_ADTYPE, PROGRESS

import AdvancedVI
import Bijectors

export vi, q_locationscale, q_meanfield_gaussian, q_fullrank_gaussian

include("deprecated.jl")
export vi,
q_locationscale,
q_meanfield_gaussian,
q_fullrank_gaussian,
KLMinRepGradProxDescent,
KLMinRepGradDescent,
KLMinScoreGradDescent

"""
q_initialize_scale(
Expand Down Expand Up @@ -248,76 +251,61 @@ end
"""
vi(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model;
model::DynamicPPL.Model,
q,
n_iterations::Int;
objective::AdvancedVI.AbstractVariationalObjective = AdvancedVI.RepGradELBO(
10; entropy = AdvancedVI.ClosedFormEntropyZeroGradient()
),
max_iter::Int;
algorithm::AdvancedVI.AbstractVariationalAlgorithm = KLMinRepGradProxDescent(DEFAULT_ADTYPE; n_samples=10),
show_progress::Bool = Turing.PROGRESS[],
optimizer::Optimisers.AbstractRule = AdvancedVI.DoWG(),
averager::AdvancedVI.AbstractAverager = AdvancedVI.PolynomialAveraging(),
operator::AdvancedVI.AbstractOperator = AdvancedVI.ProximalLocationScaleEntropy(),
adtype::ADTypes.AbstractADType = Turing.DEFAULT_ADTYPE,
kwargs...
)

Approximating the target `model` via variational inference by optimizing `objective` with the initialization `q`.
Approximate the target `model` via the variational inference algorithm `algorithm` by starting from the initial variational approximation `q`.
This is a thin wrapper around `AdvancedVI.optimize`.
The default `algorithm` assumes `q` uses `AdvancedVI.MvLocationScale`, which can be constructed by invoking `q_fullrank_gaussian` or `q_meanfield_gaussian`.
For other variational families, refer to `AdvancedVI` to determine the best algorithm and options.

# Arguments
- `model`: The target `DynamicPPL.Model`.
- `q`: The initial variational approximation.
- `n_iterations`: Number of optimization steps.
- `max_iter`: Maximum number of steps.

# Keyword Arguments
- `objective`: Variational objective to be optimized.
- `algorithm`: Variational inference algorithm.
- `show_progress`: Whether to show the progress bar.
- `optimizer`: Optimization algorithm.
- `averager`: Parameter averaging strategy.
- `operator`: Operator applied after each optimization step.
- `adtype`: Automatic differentiation backend.
- `adtype`: Automatic differentiation backend to be applied to the log-density. The default value for `algorithm` also uses this backend for differentiation the variational objective.

See the docs of `AdvancedVI.optimize` for additional keyword arguments.

# Returns
- `q`: Variational distribution formed by the last iterate of the optimization run.
- `q_avg`: Variational distribution formed by the averaged iterates according to `averager`.
- `state`: Collection of states used for optimization. This can be used to resume from a past call to `vi`.
- `info`: Information generated during the optimization run.
- `q`: Output variational distribution of `algorithm`.
- `state`: Collection of states used by `algorithm`. This can be used to resume from a past call to `vi`.
- `info`: Information generated while executing `algorithm`.
"""
function vi(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
q,
n_iterations::Int;
objective=AdvancedVI.RepGradELBO(
10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()
),
show_progress::Bool=PROGRESS[],
optimizer=AdvancedVI.DoWG(),
averager=AdvancedVI.PolynomialAveraging(),
operator=AdvancedVI.ProximalLocationScaleEntropy(),
max_iter::Int,
args...;
adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE,
algorithm=KLMinRepGradProxDescent(adtype; n_samples=10),
show_progress::Bool=PROGRESS[],
kwargs...,
)
return AdvancedVI.optimize(
rng,
LogDensityFunction(model),
objective,
algorithm,
max_iter,
LogDensityFunction(model; adtype),
q,
n_iterations;
args...;
show_progress=show_progress,
adtype,
optimizer,
averager,
operator,
kwargs...,
)
end

function vi(model::DynamicPPL.Model, q, n_iterations::Int; kwargs...)
return vi(Random.default_rng(), model, q, n_iterations; kwargs...)
function vi(model::DynamicPPL.Model, q, max_iter::Int; kwargs...)
return vi(Random.default_rng(), model, q, max_iter; kwargs...)
end

end
61 changes: 0 additions & 61 deletions src/variational/deprecated.jl

This file was deleted.

2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ AbstractMCMC = "5"
AbstractPPL = "0.11, 0.12, 0.13"
AdvancedMH = "0.6, 0.7, 0.8"
AdvancedPS = "0.7"
AdvancedVI = "0.4"
AdvancedVI = "0.5"
Aqua = "0.8"
BangBang = "0.4"
Bijectors = "0.14, 0.15"
Expand Down
72 changes: 26 additions & 46 deletions test/variational/advi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@ using Distributions: Dirichlet, Normal
using LinearAlgebra
using MCMCChains: Chains
using Random
using ReverseDiff
using StableRNGs: StableRNG
using Test: @test, @testset
using Turing
using Turing.Variational

@testset "ADVI" begin
adtype = AutoReverseDiff()
operator = AdvancedVI.ClipScale()

@testset "q initialization" begin
m = gdemo_default
d = length(Turing.DynamicPPL.VarInfo(m)[:])
Expand All @@ -41,86 +45,62 @@ using Turing.Variational

@testset "default interface" begin
for q0 in [q_meanfield_gaussian(gdemo_default), q_fullrank_gaussian(gdemo_default)]
_, q, _, _ = vi(gdemo_default, q0, 100; show_progress=Turing.PROGRESS[])
q, _, _ = vi(gdemo_default, q0, 100; show_progress=Turing.PROGRESS[], adtype)
c1 = rand(q, 10)
end
end

@testset "custom interface $name" for (name, objective, operator, optimizer) in [
(
"ADVI with closed-form entropy",
AdvancedVI.RepGradELBO(10),
AdvancedVI.ProximalLocationScaleEntropy(),
AdvancedVI.DoG(),
),
@testset "custom algorithm $name" for (name, algorithm) in [
(
"ADVI with proximal entropy",
AdvancedVI.RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()),
AdvancedVI.ClipScale(),
AdvancedVI.DoG(),
"KLMinRepGradProxDescent",
KLMinRepGradProxDescent(AutoReverseDiff(); n_samples=10),
),
(
"ADVI with STL entropy",
AdvancedVI.RepGradELBO(10; entropy=AdvancedVI.StickingTheLandingEntropy()),
AdvancedVI.ClipScale(),
AdvancedVI.DoG(),
"KLMinRepGradDescent",
KLMinRepGradDescent(AutoReverseDiff(); operator, n_samples=10),
),
]
T = 1000
q, q_avg, _, _ = vi(
q, _, _ = vi(
gdemo_default,
q_meanfield_gaussian(gdemo_default),
T;
objective,
optimizer,
operator,
algorithm,
adtype,
show_progress=Turing.PROGRESS[],
)

N = 1000
c1 = rand(q_avg, N)
c2 = rand(q, N)
end

@testset "inference $name" for (name, objective, operator, optimizer) in [
@testset "inference $name" for (name, algorithm) in [
(
"ADVI with closed-form entropy",
AdvancedVI.RepGradELBO(10),
AdvancedVI.ProximalLocationScaleEntropy(),
AdvancedVI.DoG(),
"KLMinRepGradProxDescent",
KLMinRepGradProxDescent(AutoReverseDiff(); n_samples=10),
),
(
"ADVI with proximal entropy",
RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()),
AdvancedVI.ClipScale(),
AdvancedVI.DoG(),
),
(
"ADVI with STL entropy",
AdvancedVI.RepGradELBO(10; entropy=AdvancedVI.StickingTheLandingEntropy()),
AdvancedVI.ClipScale(),
AdvancedVI.DoG(),
"KLMinRepGradDescent",
KLMinRepGradDescent(AutoReverseDiff(); operator, n_samples=10),
),
]
rng = StableRNG(0x517e1d9bf89bf94f)

T = 1000
q, q_avg, _, _ = vi(
q, _, _ = vi(
rng,
gdemo_default,
q_meanfield_gaussian(gdemo_default),
T;
optimizer,
algorithm,
adtype,
show_progress=Turing.PROGRESS[],
)

N = 1000
for q_out in [q_avg, q]
samples = transpose(rand(rng, q_out, N))
chn = Chains(reshape(samples, size(samples)..., 1), ["s", "m"])
samples = transpose(rand(rng, q, N))
chn = Chains(reshape(samples, size(samples)..., 1), ["s", "m"])

check_gdemo(chn; atol=0.5)
end
check_gdemo(chn; atol=0.5)
end

# regression test for:
Expand All @@ -143,7 +123,7 @@ using Turing.Variational
@test all(x0 .≈ x0_inv)

# And regression for https://github.com/TuringLang/Turing.jl/issues/2160.
_, q, _, _ = vi(rng, m, q_meanfield_gaussian(m), 1000)
q, _, _ = vi(rng, m, q_meanfield_gaussian(m), 1000; adtype)
x = rand(rng, q, 1000)
@test mean(eachcol(x)) ≈ [0.5, 0.5] atol = 0.1
end
Expand All @@ -158,7 +138,7 @@ using Turing.Variational
end

model = demo_issue2205() | (y=1.0,)
_, q, _, _ = vi(rng, model, q_meanfield_gaussian(model), 1000)
q, _, _ = vi(rng, model, q_meanfield_gaussian(model), 1000; adtype)
# True mean.
mean_true = 1 / 2
var_true = 1 / 2
Expand Down
Loading