From d11a33f6ed4f968da6f0a1ff9bd4b334fcb9ae63 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 20 Apr 2024 15:18:40 +0100 Subject: [PATCH] Revert "reverted merge with torfjelde/extract-realizations" (#590) This reverts commit 33a84c7ec215672b6300e975f5005e86dd5fb5b9. --- docs/src/api.md | 6 ++ src/DynamicPPL.jl | 2 + src/values_as_in_model.jl | 181 ++++++++++++++++++++++++++++++++++++++ test/model.jl | 20 ++++- 4 files changed, 208 insertions(+), 1 deletion(-) create mode 100644 src/values_as_in_model.jl diff --git a/docs/src/api.md b/docs/src/api.md index 8576a2e0e..2e8081a88 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -143,6 +143,12 @@ Sometimes it can be useful to extract the priors of a model. This is the possibl extract_priors ``` +Safe extraction of values from a given [`AbstractVarInfo`](@ref) as they are seen in the model can be done using [`values_as_in_model`](@ref). + +```@docs +values_as_in_model +``` + ```@docs NamedDist ``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index be7da6f8f..0d934a09a 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -92,6 +92,7 @@ export AbstractVarInfo, getargnames, generated_quantities, extract_priors, + values_as_in_model, # Samplers Sampler, SampleFromPrior, @@ -179,6 +180,7 @@ include("transforming.jl") include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") +include("values_as_in_model.jl") include("debug_utils.jl") using .DebugUtils diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl new file mode 100644 index 000000000..dcf68c15c --- /dev/null +++ b/src/values_as_in_model.jl @@ -0,0 +1,181 @@ + +""" + ValuesAsInModelContext + +A context that is used by [`values_as_in_model`](@ref) to obtain values +of the model parameters as they are in the model. + +This is particularly useful when working in unconstrained space, but one +wants to extract the realization of a model in a constrained space. + +# Fields +$(TYPEDFIELDS) +""" +struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext + "values that are extracted from the model" + values::T + "child context" + context::C +end + +ValuesAsInModelContext(values) = ValuesAsInModelContext(values, DefaultContext()) +function ValuesAsInModelContext(context::AbstractContext) + return ValuesAsInModelContext(OrderedDict(), context) +end + +NodeTrait(::ValuesAsInModelContext) = IsParent() +childcontext(context::ValuesAsInModelContext) = context.context +function setchildcontext(context::ValuesAsInModelContext, child) + return ValuesAsInModelContext(context.values, child) +end + +function Base.push!(context::ValuesAsInModelContext, vn::VarName, value) + return setindex!(context.values, copy(value), vn) +end + +function broadcast_push!(context::ValuesAsInModelContext, vns, values) + return push!.((context,), vns, values) +end + +# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`. +function broadcast_push!( + context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix +) + for (vn, col) in zip(vns, eachcol(values)) + push!(context, vn, col) + end +end + +# `tilde_asssume` +function tilde_assume(context::ValuesAsInModelContext, right, vn, vi) + value, logp, vi = tilde_assume(childcontext(context), right, vn, vi) + # Save the value. + push!(context, vn, value) + # Save the value. + # Pass on. + return value, logp, vi +end +function tilde_assume( + rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi +) + value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) + # Save the value. + push!(context, vn, value) + # Pass on. + return value, logp, vi +end + +# `dot_tilde_assume` +function dot_tilde_assume(context::ValuesAsInModelContext, right, left, vn, vi) + value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi) + + # Save the value. + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + broadcast_push!(context, _vns, value) + + return value, logp, vi +end +function dot_tilde_assume( + rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, left, vn, vi +) + value, logp, vi = dot_tilde_assume( + rng, childcontext(context), sampler, right, left, vn, vi + ) + # Save the value. + _right, _left, _vns = unwrap_right_left_vns(right, left, vn) + broadcast_push!(context, _vns, value) + + return value, logp, vi +end + +""" + values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext]) + values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext]) + +Get the values of `varinfo` as they would be seen in the model. + +If no `varinfo` is provided, then this is effectively the same as +[`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref). + +More specifically, this method attempts to extract the realization _as seen in the model_. +For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible +with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained +space. + +Hence this method is a "safe" way of obtaining realizations in constrained space at the cost +of additional model evaluations. + +# Arguments +- `model::Model`: model to extract realizations from. +- `varinfo::AbstractVarInfo`: variable information to use for the extraction. +- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context` + will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`. + +# Examples + +## When `VarInfo` fails + +The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables. + +```jldoctest +julia> using Distributions, StableRNGs + +julia> rng = StableRNG(42); + +julia> @model function model_changing_support() + x ~ Bernoulli(0.5) + y ~ x == 1 ? Uniform(0, 1) : Uniform(11, 12) + end; + +julia> model = model_changing_support(); + +julia> # Construct initial type-stable `VarInfo`. + varinfo = VarInfo(rng, model); + +julia> # Link it so it works in unconstrained space. + varinfo_linked = DynamicPPL.link(varinfo, model); + +julia> # Perform computations in unconstrained space, e.g. changing the values of `θ`. + # Flip `x` so we hit the other support of `y`. + θ = [!varinfo[@varname(x)], rand(rng)]; + +julia> # Update the `VarInfo` with the new values. + varinfo_linked = DynamicPPL.unflatten(varinfo_linked, θ); + +julia> # Determine the expected support of `y`. + lb, ub = θ[1] == 1 ? (0, 1) : (11, 12) +(0, 1) + +julia> # Approach 1: Convert back to constrained space using `invlink` and extract. + varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model); + +julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions + # used in the very first model evaluation, hence the support of `y` + # is not updated even though `x` has changed. + lb ≤ varinfo_invlinked[@varname(y)] ≤ ub +false + +julia> # Approach 2: Extract realizations using `values_as_in_model`. + # (✓) `values_as_in_model` will re-run the model and extract + # the correct realization of `y` given the new values of `x`. + lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub +true +``` +""" +function values_as_in_model( + model::Model, + varinfo::AbstractVarInfo=VarInfo(), + context::AbstractContext=DefaultContext(), +) + context = ValuesAsInModelContext(context) + evaluate!!(model, varinfo, context) + return context.values +end +function values_as_in_model( + rng::Random.AbstractRNG, + model::Model, + varinfo::AbstractVarInfo=VarInfo(), + context::AbstractContext=DefaultContext(), +) + return values_as_in_model(model, varinfo, SamplingContext(rng, context)) +end diff --git a/test/model.jl b/test/model.jl index f8303e260..b9e62827d 100644 --- a/test/model.jl +++ b/test/model.jl @@ -356,7 +356,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true ] @testset "$(model.f)" for model in models_to_test vns = DynamicPPL.TestUtils.varnames(model) - example_values = DynamicPPL.TestUtils.rand(model) + example_values = DynamicPPL.TestUtils.rand_prior_true(model) varinfos = filter( is_typed_varinfo, DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), @@ -375,4 +375,22 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end end end + + @testset "values_as_in_model" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + vns = DynamicPPL.TestUtils.varnames(model) + example_values = DynamicPPL.TestUtils.rand_prior_true(model) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + realizations = values_as_in_model(model, varinfo) + # Ensure that all variables are found. + vns_found = collect(keys(realizations)) + @test vns ∩ vns_found == vns ∪ vns_found + # Ensure that the values are the same. + for vn in vns + @test realizations[vn] == varinfo[vn] + end + end + end + end end