Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some clean up of contexts #711

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 7 additions & 106 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,49 +77,11 @@ function tilde_assume(
return tilde_assume(rng, childcontext(context), args...)
end

function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(PriorContext(), right, vn, vi)
end
function tilde_assume(
rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, PriorContext(), sampler, right, vn, vi)
end

function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(LikelihoodContext(), right, vn, vi)
end
function tilde_assume(
rng::Random.AbstractRNG,
context::LikelihoodContext{<:NamedTuple},
sampler,
right,
vn,
vi,
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi)
end
function tilde_assume(::LikelihoodContext, right, vn, vi)
return assume(NoDist(right), vn, vi)
return assume(nodist(right), vn, vi)
end
function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi)
return assume(rng, sampler, NoDist(right), vn, vi)
return assume(rng, sampler, nodist(right), vn, vi)
end

function tilde_assume(context::PrefixContext, right, vn, vi)
Expand Down Expand Up @@ -328,37 +290,6 @@ function dot_tilde_assume(
end

# `LikelihoodContext`
function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(LikelihoodContext(), right, left, vn, vi)
end
end
function dot_tilde_assume(
rng::Random.AbstractRNG,
context::LikelihoodContext{<:NamedTuple},
sampler,
right,
left,
vn,
vi,
)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi)
end
end

function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi)
return dot_assume(nodist(right), left, vn, vi)
end
Expand All @@ -368,46 +299,16 @@ function dot_tilde_assume(
return dot_assume(rng, sampler, nodist(right), vn, left, vi)
end

# `PriorContext`
function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, vi)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(PriorContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(PriorContext(), right, left, vn, vi)
end
end
function dot_tilde_assume(
rng::Random.AbstractRNG,
context::PriorContext{<:NamedTuple},
sampler,
right,
left,
vn,
vi,
)
return if haskey(context.vars, getsym(vn))
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi)
end
end

# `PrefixContext`
function dot_tilde_assume(context::PrefixContext, right, left, vn, vi)
return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), vi)
return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi)
end

function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi)
function dot_tilde_assume(
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi
)
return dot_tilde_assume(
rng, context.context, sampler, right, prefix.(Ref(context), vn), vi
rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi
)
end

Expand Down
29 changes: 8 additions & 21 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ DefaultContext()
julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior

julia> DynamicPPL.childcontext(ctx_prior)
PriorContext{Nothing}(nothing)
PriorContext()
```
"""
setchildcontext
Expand Down Expand Up @@ -97,7 +97,7 @@ ParentContext(ParentContext(DefaultContext()))

julia> # Replace the leaf context with another leaf.
leafcontext(setleafcontext(ctx, PriorContext()))
PriorContext{Nothing}(nothing)
PriorContext()

julia> # Append another parent context.
setleafcontext(ctx, ParentContext(DefaultContext()))
Expand Down Expand Up @@ -195,32 +195,19 @@ struct DefaultContext <: AbstractContext end
NodeTrait(context::DefaultContext) = IsLeaf()

"""
struct PriorContext{Tvars} <: AbstractContext
vars::Tvars
end
PriorContext <: AbstractContext

The `PriorContext` enables the computation of the log prior of the parameters `vars` when
running the model.
A leaf context resulting in the exclusion of likelihood terms when running the model.
"""
struct PriorContext{Tvars} <: AbstractContext
vars::Tvars
end
PriorContext() = PriorContext(nothing)
struct PriorContext <: AbstractContext end
NodeTrait(context::PriorContext) = IsLeaf()

"""
struct LikelihoodContext{Tvars} <: AbstractContext
vars::Tvars
end
LikelihoodContext <: AbstractContext

The `LikelihoodContext` enables the computation of the log likelihood of the parameters when
running the model. `vars` can be used to evaluate the log likelihood for specific values
of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default.
A leaf context resulting in the exclusion of prior terms when running the model.
"""
struct LikelihoodContext{Tvars} <: AbstractContext
vars::Tvars
end
LikelihoodContext() = LikelihoodContext(nothing)
struct LikelihoodContext <: AbstractContext end
NodeTrait(context::LikelihoodContext) = IsLeaf()

"""
Expand Down
8 changes: 8 additions & 0 deletions src/distribution_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ Base.length(dist::NoDist) = Base.length(dist.dist)
Base.size(dist::NoDist) = Base.size(dist.dist)

Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist)
# NOTE(torfjelde): Need this to avoid stack overflow.
function Distributions.rand!(
rng::Random.AbstractRNG,
d::NoDist{Distributions.ArrayLikeVariate{N}},
x::AbstractArray{<:Real,N},
) where {N}
return Distributions.rand!(rng, d.dist, x)
end
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
Expand Down
87 changes: 69 additions & 18 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1039,24 +1039,7 @@ function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...;
return test_sampler_on_demo_models(sampler, args...; kwargs...)
end

"""
test_context_interface(context)

Test that `context` implements the `AbstractContext` interface.
"""
function test_context_interface(context)
# Is a subtype of `AbstractContext`.
@test context isa DynamicPPL.AbstractContext
# Should implement `NodeTrait.`
@test DynamicPPL.NodeTrait(context) isa Union{DynamicPPL.IsParent,DynamicPPL.IsLeaf}
# If it's a parent.
if DynamicPPL.NodeTrait(context) == DynamicPPL.IsParent
# Should implement `childcontext` and `setchildcontext`
@test DynamicPPL.setchildcontext(context, DynamicPPL.childcontext(context)) ==
context
end
end

# Testing for contexts.
"""
Context that multiplies each log-prior by mod
used to test whether varwise_logpriors respects child-context.
Expand Down Expand Up @@ -1097,4 +1080,72 @@ function DynamicPPL.dot_tilde_observe(
return logp * context.mod, vi
end

# Dummy context to test nested behaviors.
struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext
context::C
end
TestParentContext() = TestParentContext(DefaultContext())
DynamicPPL.NodeTrait(::TestParentContext) = DynamicPPL.IsParent()
DynamicPPL.childcontext(context::TestParentContext) = context.context
DynamicPPL.setchildcontext(::TestParentContext, child) = TestParentContext(child)
function Base.show(io::IO, c::TestParentContext)
return print(io, "TestParentContext(", DynamicPPL.childcontext(c), ")")
end

"""
test_context(context::AbstractContext, model::Model)

Test that `context` correctly implements the `AbstractContext` interface for `model`.

This method ensures that `context`
- Correctly implements the `AbstractContext` interface.
- Correctly implements the tilde-pipeline.
"""
function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model)
# `NodeTrait`.
node_trait = DynamicPPL.NodeTrait(context)
# Throw error immediately if it it's missing a `NodeTrait` implementation.
node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} ||
throw(ValueError("Invalid NodeTrait: $node_trait"))

# The interface methods.
if node_trait isa DynamicPPL.IsParent
# `childcontext` and `setchildcontext`
# With new child context
childcontext_new = TestParentContext()
@test DynamicPPL.childcontext(
DynamicPPL.setchildcontext(context, childcontext_new)
) == childcontext_new
end

# To see change, let's make sure we're using a different leaf context than the current.
leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext
PriorContext()
else
DefaultContext()
end
@test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) ==
leafcontext_new

# Setting the child context to a leaf should now change the leafcontext accordingly.
context_with_new_leaf = DynamicPPL.setleafcontext(context, leafcontext_new)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
@test childcontext(context_with_new_leaf) ===
leafcontext(context_with_new_leaf) ===
leafcontext_new

# Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded).
# The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it.
# NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the
# context might alter which variables are present, their names, etc., e.g. `PrefixContext`.
# TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos.
# Untyped varinfo.
varinfo_untyped = DynamicPPL.VarInfo()
@test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true)
@test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true)
# Typed varinfo.
varinfo_typed = DynamicPPL.TypedVarInfo(varinfo_untyped)
@test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true)
@test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true)
end

end
Loading
Loading