Skip to content

Commit

Permalink
Revert "reverted merge with torfjelde/extract-realizations" (#590)
Browse files Browse the repository at this point in the history
This reverts commit 33a84c7.
  • Loading branch information
torfjelde authored Apr 20, 2024
1 parent 8c432b6 commit d11a33f
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 1 deletion.
6 changes: 6 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ Sometimes it can be useful to extract the priors of a model. This is the possibl
extract_priors
```

Safe extraction of values from a given [`AbstractVarInfo`](@ref) as they are seen in the model can be done using [`values_as_in_model`](@ref).

```@docs
values_as_in_model
```

```@docs
NamedDist
```
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ export AbstractVarInfo,
getargnames,
generated_quantities,
extract_priors,
values_as_in_model,
# Samplers
Sampler,
SampleFromPrior,
Expand Down Expand Up @@ -179,6 +180,7 @@ include("transforming.jl")
include("logdensityfunction.jl")
include("model_utils.jl")
include("extract_priors.jl")
include("values_as_in_model.jl")

include("debug_utils.jl")
using .DebugUtils
Expand Down
181 changes: 181 additions & 0 deletions src/values_as_in_model.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@

"""
ValuesAsInModelContext
A context that is used by [`values_as_in_model`](@ref) to obtain values
of the model parameters as they are in the model.
This is particularly useful when working in unconstrained space, but one
wants to extract the realization of a model in a constrained space.
# Fields
$(TYPEDFIELDS)
"""
struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext
"values that are extracted from the model"
values::T
"child context"
context::C
end

ValuesAsInModelContext(values) = ValuesAsInModelContext(values, DefaultContext())
function ValuesAsInModelContext(context::AbstractContext)
return ValuesAsInModelContext(OrderedDict(), context)
end

NodeTrait(::ValuesAsInModelContext) = IsParent()
childcontext(context::ValuesAsInModelContext) = context.context
function setchildcontext(context::ValuesAsInModelContext, child)
return ValuesAsInModelContext(context.values, child)
end

function Base.push!(context::ValuesAsInModelContext, vn::VarName, value)
return setindex!(context.values, copy(value), vn)
end

function broadcast_push!(context::ValuesAsInModelContext, vns, values)
return push!.((context,), vns, values)
end

# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`.
function broadcast_push!(
context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix
)
for (vn, col) in zip(vns, eachcol(values))
push!(context, vn, col)
end
end

# `tilde_asssume`
function tilde_assume(context::ValuesAsInModelContext, right, vn, vi)
value, logp, vi = tilde_assume(childcontext(context), right, vn, vi)
# Save the value.
push!(context, vn, value)
# Save the value.
# Pass on.
return value, logp, vi
end
function tilde_assume(
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi
)
value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
# Save the value.
push!(context, vn, value)
# Pass on.
return value, logp, vi
end

# `dot_tilde_assume`
function dot_tilde_assume(context::ValuesAsInModelContext, right, left, vn, vi)
value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi)

# Save the value.
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
broadcast_push!(context, _vns, value)

return value, logp, vi
end
function dot_tilde_assume(
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, left, vn, vi
)
value, logp, vi = dot_tilde_assume(
rng, childcontext(context), sampler, right, left, vn, vi
)
# Save the value.
_right, _left, _vns = unwrap_right_left_vns(right, left, vn)
broadcast_push!(context, _vns, value)

return value, logp, vi
end

"""
values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
Get the values of `varinfo` as they would be seen in the model.
If no `varinfo` is provided, then this is effectively the same as
[`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref).
More specifically, this method attempts to extract the realization _as seen in the model_.
For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible
with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained
space.
Hence this method is a "safe" way of obtaining realizations in constrained space at the cost
of additional model evaluations.
# Arguments
- `model::Model`: model to extract realizations from.
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.
# Examples
## When `VarInfo` fails
The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables.
```jldoctest
julia> using Distributions, StableRNGs
julia> rng = StableRNG(42);
julia> @model function model_changing_support()
x ~ Bernoulli(0.5)
y ~ x == 1 ? Uniform(0, 1) : Uniform(11, 12)
end;
julia> model = model_changing_support();
julia> # Construct initial type-stable `VarInfo`.
varinfo = VarInfo(rng, model);
julia> # Link it so it works in unconstrained space.
varinfo_linked = DynamicPPL.link(varinfo, model);
julia> # Perform computations in unconstrained space, e.g. changing the values of `θ`.
# Flip `x` so we hit the other support of `y`.
θ = [!varinfo[@varname(x)], rand(rng)];
julia> # Update the `VarInfo` with the new values.
varinfo_linked = DynamicPPL.unflatten(varinfo_linked, θ);
julia> # Determine the expected support of `y`.
lb, ub = θ[1] == 1 ? (0, 1) : (11, 12)
(0, 1)
julia> # Approach 1: Convert back to constrained space using `invlink` and extract.
varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model);
julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions
# used in the very first model evaluation, hence the support of `y`
# is not updated even though `x` has changed.
lb ≤ varinfo_invlinked[@varname(y)] ≤ ub
false
julia> # Approach 2: Extract realizations using `values_as_in_model`.
# (✓) `values_as_in_model` will re-run the model and extract
# the correct realization of `y` given the new values of `x`.
lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub
true
```
"""
function values_as_in_model(
model::Model,
varinfo::AbstractVarInfo=VarInfo(),
context::AbstractContext=DefaultContext(),
)
context = ValuesAsInModelContext(context)
evaluate!!(model, varinfo, context)
return context.values
end
function values_as_in_model(
rng::Random.AbstractRNG,
model::Model,
varinfo::AbstractVarInfo=VarInfo(),
context::AbstractContext=DefaultContext(),
)
return values_as_in_model(model, varinfo, SamplingContext(rng, context))
end
20 changes: 19 additions & 1 deletion test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
]
@testset "$(model.f)" for model in models_to_test
vns = DynamicPPL.TestUtils.varnames(model)
example_values = DynamicPPL.TestUtils.rand(model)
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
varinfos = filter(
is_typed_varinfo,
DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns),
Expand All @@ -375,4 +375,22 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
end
end
end

@testset "values_as_in_model" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
vns = DynamicPPL.TestUtils.varnames(model)
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
realizations = values_as_in_model(model, varinfo)
# Ensure that all variables are found.
vns_found = collect(keys(realizations))
@test vns vns_found == vns vns_found
# Ensure that the values are the same.
for vn in vns
@test realizations[vn] == varinfo[vn]
end
end
end
end
end

2 comments on commit d11a33f

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Version 0.25.1 already exists

Please sign in to comment.