From e62fba1e8f2e6620bb33eb0f0ac2fd9953db4d61 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 18 Apr 2024 20:19:46 +0100 Subject: [PATCH 01/11] added `extract_realizations` providing a safe way to obtain the values corresponding to a varname from a varinfo as seen in the model --- docs/src/api.md | 6 ++ src/DynamicPPL.jl | 1 + src/contexts.jl | 173 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 180 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index 9b98f9dc6..dfd609ac1 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -143,6 +143,12 @@ Sometimes it can be useful to extract the priors of a model. This is the possibl extract_priors ``` +Safe extraction of realizations from a given [`AbstractVarInfo`](@ref) can be done using [`extract_realizations`](@ref). + +```@docs +extract_realizations +``` + ```@docs NamedDist ``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index ce6605250..055a1581a 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -93,6 +93,7 @@ export AbstractVarInfo, getargnames, generated_quantities, extract_priors, + extract_realizations, # Samplers Sampler, SampleFromPrior, diff --git a/src/contexts.jl b/src/contexts.jl index 83da5d929..0f10904b3 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -664,3 +664,176 @@ function fixed(context::FixedContext) # precedence over decendants of `context`. return merge(context.values, fixed(childcontext(context))) end + +""" + RealizationExtractorContext + +A context that is used to extract realizations from a model. + +This is particularly useful when working in unconstrained space, but one +wants to extract the realization of a model in a constrained space. + +# Fields +$(TYPEDFIELDS) +""" +struct RealizationExtractorContext{T,C<:AbstractContext} <: AbstractContext + "values that are extracted from the model" + values::T + "child context" + context::C +end + +RealizationExtractorContext(values) = RealizationExtractorContext(values, DefaultContext()) +function RealizationExtractorContext(context::AbstractContext) + return RealizationExtractorContext(OrderedDict(), context) +end + +NodeTrait(::RealizationExtractorContext) = IsParent() +childcontext(context::RealizationExtractorContext) = context.context +function setchildcontext(context::RealizationExtractorContext, child) + return RealizationExtractorContext(context.values, child) +end + +function Base.push!(context::RealizationExtractorContext, vn::VarName, value) + return setindex!(context.values, value, vn) +end + +function broadcast_push!(context::RealizationExtractorContext, vns, dists, values) + return push!.((context,), vns, values) +end + +# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`. +function broadcast_push!( + context::RealizationExtractorContext, vns::AbstractVector, values::AbstractMatrix +) + for (vn, col) in zip(vns, eachcol(values)) + push!(context, vn, col) + end +end + +# `tilde_asssume` +function tilde_assume(context::RealizationExtractorContext, right, vn, vi) + value, logp, vi = tilde_assume(childcontext(context), right, vn, vi) + # Save the value. + push!(context, vn, value) + # Pass on. + return value, logp, vi +end +function tilde_assume( + rng::Random.AbstractRNG, context::RealizationExtractorContext, sampler, right, vn, vi +) + value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) + # Save the value. + push!(context, vn, value) + # Pass on. + return value, logp, vi +end + +# `dot_tilde_assume` +function dot_tilde_assume(context::RealizationExtractorContext, right, left, vn, vi) + value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi) + + # Save the value. + # FIXME: This is not going to work for arbitrary broadcasting. + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + broadcast_push!(context, _vns, value) + + return value, logp, vi +end +function dot_tilde_assume( + rng::Random.AbstractRNG, + context::RealizationExtractorContext, + sampler, + right, + left, + vn, + vi, +) + value, logp, vi = dot_tilde_assume( + rng, childcontext(context), sampler, right, left, vn, vi + ) + # Save the value. + # FIXME: This is not going to work for arbitrary broadcasting. + _right, _left, _vns = unwrap_right_left_vns(right, left, vn) + broadcast_push!(context, _vns, value) + + return value, logp, vi +end + +""" + extract_realizations([rng::Random.AbstractRNG, ]model::Model[, varinfo::AbstractVarInfo]) + +Extract realizations from the `model` for a given `varinfo` through a evaluation of the model. + +If no `varinfo` is provided, then this is effectively the same as +[`Base.rand(rng::Random.AbstractRNG, model::Model)`]. + +More specifically, this method attempts to extract the realization _as seen in the model_. +For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible +with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained +space. + +Hence this method is a "safe" way of obtaining realizations in constrained space at the cost +of additional model evaluations. + +# Examples + +## When `VarInfo` fails + +The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables. + +```jldoctest +julia> using Distributions, StableRNGs + +julia> rng = StableRNG(42); + +julia> @model function model_changing_support() + x ~ Bernoulli(0.5) + y ~ x == 1 ? Uniform(0, 1) : Uniform(11, 12) + end; + +julia> model = model_changing_support(); + +julia> # Construct initial type-stable `VarInfo`. + varinfo = VarInfo(rng, model); + +julia> # Link it so it works in unconstrained space. + varinfo_linked = DynamicPPL.link(varinfo, model); + +julia> # Perform computations in unconstrained space, e.g. changing the values of `θ`. + # Flip `x` so we hit the other support of `y`. + θ = [!varinfo[@varname(x)], rand(rng)]; + +julia> # Update the `VarInfo` with the new values. + varinfo_linked = DynamicPPL.unflatten(varinfo_linked, θ); + +julia> # Determine the expected support of `y`. + lb, ub = θ[1] == 1 ? (0, 1) : (11, 12) +(0, 1) + +julia> # Approach 1: Convert back to constrained space using `invlink` and extract. + varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model); + +julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions + # used in the very first model evaluation, hence the support of `y` + # is not updated even though `x` has changed. + lb ≤ varinfo_invlinked[@varname(y)] ≤ ub +false + +julia> # Approach 2: Extract realizations using `extract_realizations`. + # (✓) `extract_realizations` will re-run the model and extract + # the correct realization of `y` given the new values of `x`. + lb ≤ extract_realizations(model, varinfo_linked)[@varname(y)] ≤ ub +true +``` +""" +function extract_realizations(model::Model, varinfo::AbstractVarInfo=VarInfo()) + return extract_realizations(Random.default_rng(), model, varinfo) +end +function extract_realizations( + rng::Random.AbstractRNG, model::Model, varinfo::AbstractVarInfo=VarInfo() +) + context = RealizationExtractorContext(DefaultContext()) + evaluate!!(model, varinfo, context) + return context.values +end From 7bc27fa599b2fc9ac6ac94fc154c7aefff25b70c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 18 Apr 2024 20:21:33 +0100 Subject: [PATCH 02/11] use `SamplingContext` instead of `DefaultContext` for `extract_realizations` so we sample variables not present in the varinfo --- src/contexts.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 0f10904b3..a12d6c6e4 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -833,7 +833,7 @@ end function extract_realizations( rng::Random.AbstractRNG, model::Model, varinfo::AbstractVarInfo=VarInfo() ) - context = RealizationExtractorContext(DefaultContext()) + context = RealizationExtractorContext(SamplingContext(rng)) evaluate!!(model, varinfo, context) return context.values end From 52df230aaefdc3847810e99ce623e2af8a7e6867 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 18 Apr 2024 20:56:28 +0100 Subject: [PATCH 03/11] added isstatic to model which indicates whether a model can be considered to be static --- src/model.jl | 50 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/src/model.jl b/src/model.jl index c0cc2f26f..1ba1a412e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -31,12 +31,21 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: - AbstractProbabilisticProgram +struct Model{ + F, + argnames, + defaultnames, + missings, + Targs, + Tdefaults, + Ctx<:AbstractContext, + IsStatic<:Union{Val{false},Val{true}}, +} <: AbstractProbabilisticProgram f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} context::Ctx + isstatic::IsStatic @doc """ Model{missings}(f, args::NamedTuple, defaults::NamedTuple) @@ -49,9 +58,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, context::Ctx=DefaultContext(), + isstatic::Union{Val{false},Val{true}}=Val{false}(), ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}( - f, args, defaults, context + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,typeof(isstatic)}( + f, args, defaults, context, isstatic ) end end @@ -78,6 +88,38 @@ function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); k return Model(f, args, NamedTuple(kwargs), context) end +""" + is_static(model::Model) + +Return `true` if `model` has static support. +""" +is_static(model::Model) = model.isstatic isa Val{true} + +""" + set_static(model::Model, isstatic::Val{true},Val{false}) + +Set `model` to have static support if `isstatic` is `true`, otherwise not. +""" +function set_static(model::Model, isstatic::Union{Val{true},Val{false}}) + return Model{getmissings(model)}( + model.f, model.args, model.defaults, model.context, isstatic + ) +end + +""" + mark_as_static(model::Model) + +Mark `model` as having static support. +""" +mark_as_static(model::Model) = set_static(model, Val{true}()) + +""" + mark_as_dynamic(model::Model) + +Mark `model` as not having static support. +""" +mark_as_dynamic(model::Model) = set_static(model, Val{false}()) + function contextualize(model::Model, context::AbstractContext) return Model(model.f, model.args, model.defaults, context) end From d4a6446e2bb9e8cb06f0c0c8d19fec228fc7b832 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 18 Apr 2024 22:38:26 +0100 Subject: [PATCH 04/11] added tests for `extract_realizations` in addition to allowing specifying whether we should sample or just evaluate using a context arg --- src/contexts.jl | 33 +++++++++++++++++++++++---------- test/model.jl | 20 +++++++++++++++++++- 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index a12d6c6e4..7a6e9ff06 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -695,10 +695,10 @@ function setchildcontext(context::RealizationExtractorContext, child) end function Base.push!(context::RealizationExtractorContext, vn::VarName, value) - return setindex!(context.values, value, vn) + return setindex!(context.values, copy(value), vn) end -function broadcast_push!(context::RealizationExtractorContext, vns, dists, values) +function broadcast_push!(context::RealizationExtractorContext, vns, values) return push!.((context,), vns, values) end @@ -716,6 +716,7 @@ function tilde_assume(context::RealizationExtractorContext, right, vn, vi) value, logp, vi = tilde_assume(childcontext(context), right, vn, vi) # Save the value. push!(context, vn, value) + # Save the value. # Pass on. return value, logp, vi end @@ -734,7 +735,6 @@ function dot_tilde_assume(context::RealizationExtractorContext, right, left, vn, value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi) # Save the value. - # FIXME: This is not going to work for arbitrary broadcasting. _right, _left, _vns = unwrap_right_left_vns(right, var, vn) broadcast_push!(context, _vns, value) @@ -753,7 +753,6 @@ function dot_tilde_assume( rng, childcontext(context), sampler, right, left, vn, vi ) # Save the value. - # FIXME: This is not going to work for arbitrary broadcasting. _right, _left, _vns = unwrap_right_left_vns(right, left, vn) broadcast_push!(context, _vns, value) @@ -761,7 +760,8 @@ function dot_tilde_assume( end """ - extract_realizations([rng::Random.AbstractRNG, ]model::Model[, varinfo::AbstractVarInfo]) + extract_realizations(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext]) + extract_realizations(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext]) Extract realizations from the `model` for a given `varinfo` through a evaluation of the model. @@ -776,6 +776,12 @@ space. Hence this method is a "safe" way of obtaining realizations in constrained space at the cost of additional model evaluations. +# Arguments +- `model::Model`: model to extract realizations from. +- `varinfo::AbstractVarInfo`: variable information to use for the extraction. +- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context` + will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`. + # Examples ## When `VarInfo` fails @@ -827,13 +833,20 @@ julia> # Approach 2: Extract realizations using `extract_realizations`. true ``` """ -function extract_realizations(model::Model, varinfo::AbstractVarInfo=VarInfo()) - return extract_realizations(Random.default_rng(), model, varinfo) -end function extract_realizations( - rng::Random.AbstractRNG, model::Model, varinfo::AbstractVarInfo=VarInfo() + model::Model, + varinfo::AbstractVarInfo=VarInfo(), + context::AbstractContext=DefaultContext(), ) - context = RealizationExtractorContext(SamplingContext(rng)) + context = RealizationExtractorContext(context) evaluate!!(model, varinfo, context) return context.values end +function extract_realizations( + rng::Random.AbstractRNG, + model::Model, + varinfo::AbstractVarInfo=VarInfo(), + context::AbstractContext=DefaultContext(), +) + return extract_realizations(model, varinfo, SamplingContext(rng, context)) +end diff --git a/test/model.jl b/test/model.jl index f8303e260..7f4caff24 100644 --- a/test/model.jl +++ b/test/model.jl @@ -356,7 +356,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true ] @testset "$(model.f)" for model in models_to_test vns = DynamicPPL.TestUtils.varnames(model) - example_values = DynamicPPL.TestUtils.rand(model) + example_values = DynamicPPL.TestUtils.rand_prior_true(model) varinfos = filter( is_typed_varinfo, DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), @@ -375,4 +375,22 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end end end + + @testset "extract_realizations" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + vns = DynamicPPL.TestUtils.varnames(model) + example_values = DynamicPPL.TestUtils.rand_prior_true(model) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + realizations = extract_realizations(model, varinfo) + # Ensure that all variables are found. + vns_found = collect(keys(realizations)) + @test vns ∩ vns_found == vns ∪ vns_found + # Ensure that the values are the same. + for vn in vns + @test realizations[vn] == varinfo[vn] + end + end + end + end end From ecc702442d31514dc2793a7126599f356cd57960 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 19 Apr 2024 14:29:08 +0100 Subject: [PATCH 05/11] fix doc reference --- src/contexts.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 672d5f21a..89aa120d5 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -766,7 +766,7 @@ end Extract realizations from the `model` for a given `varinfo` through a evaluation of the model. If no `varinfo` is provided, then this is effectively the same as -[`Base.rand(rng::Random.AbstractRNG, model::Model)`]. +[`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref). More specifically, this method attempts to extract the realization _as seen in the model_. For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible From a7f0656b5882b144b1fe5c7810169b893f032725 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 19 Apr 2024 16:02:49 +0100 Subject: [PATCH 06/11] renamed `isstatic` model fiel to `has_static_support` which is more accurate --- src/model.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/model.jl b/src/model.jl index 247529b7f..b02309ab4 100644 --- a/src/model.jl +++ b/src/model.jl @@ -45,7 +45,7 @@ struct Model{ args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} context::Ctx - isstatic::IsStatic + has_static_support::IsStatic @doc """ Model{missings}(f, args::NamedTuple, defaults::NamedTuple) @@ -58,10 +58,10 @@ struct Model{ args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, context::Ctx=DefaultContext(), - isstatic::Union{Val{false},Val{true}}=Val{false}(), + has_static_support::Union{Val{false},Val{true}}=Val{false}(), ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,typeof(isstatic)}( - f, args, defaults, context, isstatic + f, args, defaults, context, has_static_support ) end end @@ -89,36 +89,36 @@ function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); k end """ - is_static(model::Model) + has_static_support(model::Model) Return `true` if `model` has static support. """ -is_static(model::Model) = model.isstatic isa Val{true} +has_static_support(model::Model) = model.has_static_support isa Val{true} """ - set_static(model::Model, isstatic::Val{true},Val{false}) + set_static_support(model::Model, isstatic::Val{true},Val{false}) Set `model` to have static support if `isstatic` is `true`, otherwise not. """ -function set_static(model::Model, isstatic::Union{Val{true},Val{false}}) +function set_static_support(model::Model, isstatic::Union{Val{true},Val{false}}) return Model{getmissings(model)}( model.f, model.args, model.defaults, model.context, isstatic ) end """ - mark_as_static(model::Model) + mark_as_static_support(model::Model) Mark `model` as having static support. """ -mark_as_static(model::Model) = set_static(model, Val{true}()) +mark_as_static_support(model::Model) = set_static_support(model, Val{true}()) """ - mark_as_dynamic(model::Model) + mark_as_dynamic_support(model::Model) Mark `model` as not having static support. """ -mark_as_dynamic(model::Model) = set_static(model, Val{false}()) +mark_as_dynamic_support(model::Model) = set_static_support(model, Val{false}()) function contextualize(model::Model, context::AbstractContext) return Model(model.f, model.args, model.defaults, context) From 7da8390bd34441024872931c72b423c42285af01 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 19 Apr 2024 16:03:51 +0100 Subject: [PATCH 07/11] Update src/model.jl --- src/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index b02309ab4..5b94ec5dc 100644 --- a/src/model.jl +++ b/src/model.jl @@ -60,7 +60,7 @@ struct Model{ context::Ctx=DefaultContext(), has_static_support::Union{Val{false},Val{true}}=Val{false}(), ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,typeof(isstatic)}( + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,typeof(has_static_support)}( f, args, defaults, context, has_static_support ) end From 1cd3673c80b71533c0b2277e790100fb2748cb03 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 19 Apr 2024 16:43:37 +0100 Subject: [PATCH 08/11] Update src/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/model.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index 5b94ec5dc..863a32682 100644 --- a/src/model.jl +++ b/src/model.jl @@ -60,7 +60,9 @@ struct Model{ context::Ctx=DefaultContext(), has_static_support::Union{Val{false},Val{true}}=Val{false}(), ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,typeof(has_static_support)}( + return new{ + F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,typeof(has_static_support) + }( f, args, defaults, context, has_static_support ) end From 671620b5a8c93b47138dbfce430046ae55a26f95 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 19 Apr 2024 19:11:36 +0100 Subject: [PATCH 09/11] removed `has_static_support` and everything related --- src/model.jl | 50 ++++---------------------------------------------- 1 file changed, 4 insertions(+), 46 deletions(-) diff --git a/src/model.jl b/src/model.jl index b02309ab4..8c10ed36e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -31,21 +31,12 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{ - F, - argnames, - defaultnames, - missings, - Targs, - Tdefaults, - Ctx<:AbstractContext, - IsStatic<:Union{Val{false},Val{true}}, -} <: AbstractProbabilisticProgram +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: + AbstractProbabilisticProgram f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} context::Ctx - has_static_support::IsStatic @doc """ Model{missings}(f, args::NamedTuple, defaults::NamedTuple) @@ -58,10 +49,9 @@ struct Model{ args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, context::Ctx=DefaultContext(), - has_static_support::Union{Val{false},Val{true}}=Val{false}(), ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,typeof(isstatic)}( - f, args, defaults, context, has_static_support + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}( + f, args, defaults, context ) end end @@ -88,38 +78,6 @@ function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); k return Model(f, args, NamedTuple(kwargs), context) end -""" - has_static_support(model::Model) - -Return `true` if `model` has static support. -""" -has_static_support(model::Model) = model.has_static_support isa Val{true} - -""" - set_static_support(model::Model, isstatic::Val{true},Val{false}) - -Set `model` to have static support if `isstatic` is `true`, otherwise not. -""" -function set_static_support(model::Model, isstatic::Union{Val{true},Val{false}}) - return Model{getmissings(model)}( - model.f, model.args, model.defaults, model.context, isstatic - ) -end - -""" - mark_as_static_support(model::Model) - -Mark `model` as having static support. -""" -mark_as_static_support(model::Model) = set_static_support(model, Val{true}()) - -""" - mark_as_dynamic_support(model::Model) - -Mark `model` as not having static support. -""" -mark_as_dynamic_support(model::Model) = set_static_support(model, Val{false}()) - function contextualize(model::Model, context::AbstractContext) return Model(model.f, model.args, model.defaults, context) end From 7f0ff386c136cddb382d20c01a540d72995d8122 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 19 Apr 2024 19:25:58 +0100 Subject: [PATCH 10/11] renamed `extract_realizations` to `values_as_in_model` to be a bit more descriptive (and similarly for the corresponding context) --- docs/src/api.md | 4 ++-- src/DynamicPPL.jl | 2 +- src/contexts.jl | 61 ++++++++++++++++++++++------------------------- test/model.jl | 4 ++-- 4 files changed, 33 insertions(+), 38 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index dfd609ac1..773d00b9d 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -143,10 +143,10 @@ Sometimes it can be useful to extract the priors of a model. This is the possibl extract_priors ``` -Safe extraction of realizations from a given [`AbstractVarInfo`](@ref) can be done using [`extract_realizations`](@ref). +Safe extraction of values from a given [`AbstractVarInfo`](@ref) as they are seen in the model can be done using [`values_as_in_model`](@ref). ```@docs -extract_realizations +values_as_in_model ``` ```@docs diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index d499bf594..1615ba229 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -93,7 +93,7 @@ export AbstractVarInfo, getargnames, generated_quantities, extract_priors, - extract_realizations, + values_as_in_model, # Samplers Sampler, SampleFromPrior, diff --git a/src/contexts.jl b/src/contexts.jl index 89aa120d5..0d835d4ce 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -666,9 +666,10 @@ function fixed(context::FixedContext) end """ - RealizationExtractorContext + ValuesAsInModelContext -A context that is used to extract realizations from a model. +A context that is used by [`values_as_in_model`](@ref) to obtain values +of the model parameters as they are in the model. This is particularly useful when working in unconstrained space, but one wants to extract the realization of a model in a constrained space. @@ -676,35 +677,35 @@ wants to extract the realization of a model in a constrained space. # Fields $(TYPEDFIELDS) """ -struct RealizationExtractorContext{T,C<:AbstractContext} <: AbstractContext +struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext "values that are extracted from the model" values::T "child context" context::C end -RealizationExtractorContext(values) = RealizationExtractorContext(values, DefaultContext()) -function RealizationExtractorContext(context::AbstractContext) - return RealizationExtractorContext(OrderedDict(), context) +ValuesAsInModelContext(values) = ValuesAsInModelContext(values, DefaultContext()) +function ValuesAsInModelContext(context::AbstractContext) + return ValuesAsInModelContext(OrderedDict(), context) end -NodeTrait(::RealizationExtractorContext) = IsParent() -childcontext(context::RealizationExtractorContext) = context.context -function setchildcontext(context::RealizationExtractorContext, child) - return RealizationExtractorContext(context.values, child) +NodeTrait(::ValuesAsInModelContext) = IsParent() +childcontext(context::ValuesAsInModelContext) = context.context +function setchildcontext(context::ValuesAsInModelContext, child) + return ValuesAsInModelContext(context.values, child) end -function Base.push!(context::RealizationExtractorContext, vn::VarName, value) +function Base.push!(context::ValuesAsInModelContext, vn::VarName, value) return setindex!(context.values, copy(value), vn) end -function broadcast_push!(context::RealizationExtractorContext, vns, values) +function broadcast_push!(context::ValuesAsInModelContext, vns, values) return push!.((context,), vns, values) end # This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`. function broadcast_push!( - context::RealizationExtractorContext, vns::AbstractVector, values::AbstractMatrix + context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix ) for (vn, col) in zip(vns, eachcol(values)) push!(context, vn, col) @@ -712,7 +713,7 @@ function broadcast_push!( end # `tilde_asssume` -function tilde_assume(context::RealizationExtractorContext, right, vn, vi) +function tilde_assume(context::ValuesAsInModelContext, right, vn, vi) value, logp, vi = tilde_assume(childcontext(context), right, vn, vi) # Save the value. push!(context, vn, value) @@ -721,7 +722,7 @@ function tilde_assume(context::RealizationExtractorContext, right, vn, vi) return value, logp, vi end function tilde_assume( - rng::Random.AbstractRNG, context::RealizationExtractorContext, sampler, right, vn, vi + rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi ) value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) # Save the value. @@ -731,7 +732,7 @@ function tilde_assume( end # `dot_tilde_assume` -function dot_tilde_assume(context::RealizationExtractorContext, right, left, vn, vi) +function dot_tilde_assume(context::ValuesAsInModelContext, right, left, vn, vi) value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi) # Save the value. @@ -741,13 +742,7 @@ function dot_tilde_assume(context::RealizationExtractorContext, right, left, vn, return value, logp, vi end function dot_tilde_assume( - rng::Random.AbstractRNG, - context::RealizationExtractorContext, - sampler, - right, - left, - vn, - vi, + rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, left, vn, vi ) value, logp, vi = dot_tilde_assume( rng, childcontext(context), sampler, right, left, vn, vi @@ -760,10 +755,10 @@ function dot_tilde_assume( end """ - extract_realizations(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext]) - extract_realizations(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext]) + values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext]) + values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext]) -Extract realizations from the `model` for a given `varinfo` through a evaluation of the model. +Get the values of `varinfo` as they would be seen in the model. If no `varinfo` is provided, then this is effectively the same as [`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref). @@ -826,27 +821,27 @@ julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions lb ≤ varinfo_invlinked[@varname(y)] ≤ ub false -julia> # Approach 2: Extract realizations using `extract_realizations`. - # (✓) `extract_realizations` will re-run the model and extract +julia> # Approach 2: Extract realizations using `values_as_in_model`. + # (✓) `values_as_in_model` will re-run the model and extract # the correct realization of `y` given the new values of `x`. - lb ≤ extract_realizations(model, varinfo_linked)[@varname(y)] ≤ ub + lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub true ``` """ -function extract_realizations( +function values_as_in_model( model::Model, varinfo::AbstractVarInfo=VarInfo(), context::AbstractContext=DefaultContext(), ) - context = RealizationExtractorContext(context) + context = ValuesAsInModelContext(context) evaluate!!(model, varinfo, context) return context.values end -function extract_realizations( +function values_as_in_model( rng::Random.AbstractRNG, model::Model, varinfo::AbstractVarInfo=VarInfo(), context::AbstractContext=DefaultContext(), ) - return extract_realizations(model, varinfo, SamplingContext(rng, context)) + return values_as_in_model(model, varinfo, SamplingContext(rng, context)) end diff --git a/test/model.jl b/test/model.jl index 7f4caff24..b9e62827d 100644 --- a/test/model.jl +++ b/test/model.jl @@ -376,13 +376,13 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end end - @testset "extract_realizations" begin + @testset "values_as_in_model" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - realizations = extract_realizations(model, varinfo) + realizations = values_as_in_model(model, varinfo) # Ensure that all variables are found. vns_found = collect(keys(realizations)) @test vns ∩ vns_found == vns ∪ vns_found From af643560e104d85e12ae165c9c71af756d459aa3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 19 Apr 2024 19:30:20 +0100 Subject: [PATCH 11/11] moved impl of `values_as_in_model` to separate file due to size of impl --- src/DynamicPPL.jl | 1 + src/contexts.jl | 181 -------------------------------------- src/values_as_in_model.jl | 181 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 182 insertions(+), 181 deletions(-) create mode 100644 src/values_as_in_model.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 1615ba229..03e3ee308 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -179,6 +179,7 @@ include("transforming.jl") include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") +include("values_as_in_model.jl") if !isdefined(Base, :get_extension) using Requires diff --git a/src/contexts.jl b/src/contexts.jl index 0d835d4ce..2018b9155 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -664,184 +664,3 @@ function fixed(context::FixedContext) # precedence over decendants of `context`. return merge(context.values, fixed(childcontext(context))) end - -""" - ValuesAsInModelContext - -A context that is used by [`values_as_in_model`](@ref) to obtain values -of the model parameters as they are in the model. - -This is particularly useful when working in unconstrained space, but one -wants to extract the realization of a model in a constrained space. - -# Fields -$(TYPEDFIELDS) -""" -struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext - "values that are extracted from the model" - values::T - "child context" - context::C -end - -ValuesAsInModelContext(values) = ValuesAsInModelContext(values, DefaultContext()) -function ValuesAsInModelContext(context::AbstractContext) - return ValuesAsInModelContext(OrderedDict(), context) -end - -NodeTrait(::ValuesAsInModelContext) = IsParent() -childcontext(context::ValuesAsInModelContext) = context.context -function setchildcontext(context::ValuesAsInModelContext, child) - return ValuesAsInModelContext(context.values, child) -end - -function Base.push!(context::ValuesAsInModelContext, vn::VarName, value) - return setindex!(context.values, copy(value), vn) -end - -function broadcast_push!(context::ValuesAsInModelContext, vns, values) - return push!.((context,), vns, values) -end - -# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`. -function broadcast_push!( - context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix -) - for (vn, col) in zip(vns, eachcol(values)) - push!(context, vn, col) - end -end - -# `tilde_asssume` -function tilde_assume(context::ValuesAsInModelContext, right, vn, vi) - value, logp, vi = tilde_assume(childcontext(context), right, vn, vi) - # Save the value. - push!(context, vn, value) - # Save the value. - # Pass on. - return value, logp, vi -end -function tilde_assume( - rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi -) - value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) - # Save the value. - push!(context, vn, value) - # Pass on. - return value, logp, vi -end - -# `dot_tilde_assume` -function dot_tilde_assume(context::ValuesAsInModelContext, right, left, vn, vi) - value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi) - - # Save the value. - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - broadcast_push!(context, _vns, value) - - return value, logp, vi -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, left, vn, vi -) - value, logp, vi = dot_tilde_assume( - rng, childcontext(context), sampler, right, left, vn, vi - ) - # Save the value. - _right, _left, _vns = unwrap_right_left_vns(right, left, vn) - broadcast_push!(context, _vns, value) - - return value, logp, vi -end - -""" - values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext]) - values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext]) - -Get the values of `varinfo` as they would be seen in the model. - -If no `varinfo` is provided, then this is effectively the same as -[`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref). - -More specifically, this method attempts to extract the realization _as seen in the model_. -For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible -with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained -space. - -Hence this method is a "safe" way of obtaining realizations in constrained space at the cost -of additional model evaluations. - -# Arguments -- `model::Model`: model to extract realizations from. -- `varinfo::AbstractVarInfo`: variable information to use for the extraction. -- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context` - will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`. - -# Examples - -## When `VarInfo` fails - -The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables. - -```jldoctest -julia> using Distributions, StableRNGs - -julia> rng = StableRNG(42); - -julia> @model function model_changing_support() - x ~ Bernoulli(0.5) - y ~ x == 1 ? Uniform(0, 1) : Uniform(11, 12) - end; - -julia> model = model_changing_support(); - -julia> # Construct initial type-stable `VarInfo`. - varinfo = VarInfo(rng, model); - -julia> # Link it so it works in unconstrained space. - varinfo_linked = DynamicPPL.link(varinfo, model); - -julia> # Perform computations in unconstrained space, e.g. changing the values of `θ`. - # Flip `x` so we hit the other support of `y`. - θ = [!varinfo[@varname(x)], rand(rng)]; - -julia> # Update the `VarInfo` with the new values. - varinfo_linked = DynamicPPL.unflatten(varinfo_linked, θ); - -julia> # Determine the expected support of `y`. - lb, ub = θ[1] == 1 ? (0, 1) : (11, 12) -(0, 1) - -julia> # Approach 1: Convert back to constrained space using `invlink` and extract. - varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model); - -julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions - # used in the very first model evaluation, hence the support of `y` - # is not updated even though `x` has changed. - lb ≤ varinfo_invlinked[@varname(y)] ≤ ub -false - -julia> # Approach 2: Extract realizations using `values_as_in_model`. - # (✓) `values_as_in_model` will re-run the model and extract - # the correct realization of `y` given the new values of `x`. - lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub -true -``` -""" -function values_as_in_model( - model::Model, - varinfo::AbstractVarInfo=VarInfo(), - context::AbstractContext=DefaultContext(), -) - context = ValuesAsInModelContext(context) - evaluate!!(model, varinfo, context) - return context.values -end -function values_as_in_model( - rng::Random.AbstractRNG, - model::Model, - varinfo::AbstractVarInfo=VarInfo(), - context::AbstractContext=DefaultContext(), -) - return values_as_in_model(model, varinfo, SamplingContext(rng, context)) -end diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl new file mode 100644 index 000000000..dcf68c15c --- /dev/null +++ b/src/values_as_in_model.jl @@ -0,0 +1,181 @@ + +""" + ValuesAsInModelContext + +A context that is used by [`values_as_in_model`](@ref) to obtain values +of the model parameters as they are in the model. + +This is particularly useful when working in unconstrained space, but one +wants to extract the realization of a model in a constrained space. + +# Fields +$(TYPEDFIELDS) +""" +struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext + "values that are extracted from the model" + values::T + "child context" + context::C +end + +ValuesAsInModelContext(values) = ValuesAsInModelContext(values, DefaultContext()) +function ValuesAsInModelContext(context::AbstractContext) + return ValuesAsInModelContext(OrderedDict(), context) +end + +NodeTrait(::ValuesAsInModelContext) = IsParent() +childcontext(context::ValuesAsInModelContext) = context.context +function setchildcontext(context::ValuesAsInModelContext, child) + return ValuesAsInModelContext(context.values, child) +end + +function Base.push!(context::ValuesAsInModelContext, vn::VarName, value) + return setindex!(context.values, copy(value), vn) +end + +function broadcast_push!(context::ValuesAsInModelContext, vns, values) + return push!.((context,), vns, values) +end + +# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`. +function broadcast_push!( + context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix +) + for (vn, col) in zip(vns, eachcol(values)) + push!(context, vn, col) + end +end + +# `tilde_asssume` +function tilde_assume(context::ValuesAsInModelContext, right, vn, vi) + value, logp, vi = tilde_assume(childcontext(context), right, vn, vi) + # Save the value. + push!(context, vn, value) + # Save the value. + # Pass on. + return value, logp, vi +end +function tilde_assume( + rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi +) + value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) + # Save the value. + push!(context, vn, value) + # Pass on. + return value, logp, vi +end + +# `dot_tilde_assume` +function dot_tilde_assume(context::ValuesAsInModelContext, right, left, vn, vi) + value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi) + + # Save the value. + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + broadcast_push!(context, _vns, value) + + return value, logp, vi +end +function dot_tilde_assume( + rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, left, vn, vi +) + value, logp, vi = dot_tilde_assume( + rng, childcontext(context), sampler, right, left, vn, vi + ) + # Save the value. + _right, _left, _vns = unwrap_right_left_vns(right, left, vn) + broadcast_push!(context, _vns, value) + + return value, logp, vi +end + +""" + values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext]) + values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext]) + +Get the values of `varinfo` as they would be seen in the model. + +If no `varinfo` is provided, then this is effectively the same as +[`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref). + +More specifically, this method attempts to extract the realization _as seen in the model_. +For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible +with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained +space. + +Hence this method is a "safe" way of obtaining realizations in constrained space at the cost +of additional model evaluations. + +# Arguments +- `model::Model`: model to extract realizations from. +- `varinfo::AbstractVarInfo`: variable information to use for the extraction. +- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context` + will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`. + +# Examples + +## When `VarInfo` fails + +The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables. + +```jldoctest +julia> using Distributions, StableRNGs + +julia> rng = StableRNG(42); + +julia> @model function model_changing_support() + x ~ Bernoulli(0.5) + y ~ x == 1 ? Uniform(0, 1) : Uniform(11, 12) + end; + +julia> model = model_changing_support(); + +julia> # Construct initial type-stable `VarInfo`. + varinfo = VarInfo(rng, model); + +julia> # Link it so it works in unconstrained space. + varinfo_linked = DynamicPPL.link(varinfo, model); + +julia> # Perform computations in unconstrained space, e.g. changing the values of `θ`. + # Flip `x` so we hit the other support of `y`. + θ = [!varinfo[@varname(x)], rand(rng)]; + +julia> # Update the `VarInfo` with the new values. + varinfo_linked = DynamicPPL.unflatten(varinfo_linked, θ); + +julia> # Determine the expected support of `y`. + lb, ub = θ[1] == 1 ? (0, 1) : (11, 12) +(0, 1) + +julia> # Approach 1: Convert back to constrained space using `invlink` and extract. + varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model); + +julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions + # used in the very first model evaluation, hence the support of `y` + # is not updated even though `x` has changed. + lb ≤ varinfo_invlinked[@varname(y)] ≤ ub +false + +julia> # Approach 2: Extract realizations using `values_as_in_model`. + # (✓) `values_as_in_model` will re-run the model and extract + # the correct realization of `y` given the new values of `x`. + lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub +true +``` +""" +function values_as_in_model( + model::Model, + varinfo::AbstractVarInfo=VarInfo(), + context::AbstractContext=DefaultContext(), +) + context = ValuesAsInModelContext(context) + evaluate!!(model, varinfo, context) + return context.values +end +function values_as_in_model( + rng::Random.AbstractRNG, + model::Model, + varinfo::AbstractVarInfo=VarInfo(), + context::AbstractContext=DefaultContext(), +) + return values_as_in_model(model, varinfo, SamplingContext(rng, context)) +end