Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds @is_post_processing macro #589

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function DynamicPPL.generated_quantities(

# TODO: Some of the variables can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to `model`.
model(deepcopy(varinfo))
model(deepcopy(varinfo), DynamicPPL.PostProcessingContext())
end
end

Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ export AbstractVarInfo,
unfix,
# Convenience macros
@addlogprob!,
@is_post_processing,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@is_post_processing,
@is_generated_quantities,

@submodel,
value_iterator_from_chain

Expand Down
22 changes: 22 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,27 @@
const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)

"""
check_if_in_model_block_expr(name)

Return an expression that can be evaluated to check if we're inside a model block.

# Arguments
- `name`: The name of the variable or method that can only be used inside a model block.
Error message will include this name.
"""
function check_if_in_model_block_expr(name)
return Expr(
:||,
Expr(
:&&,
Expr(:isdefined, esc(:__model__)),
Expr(:call, :isa, esc(:__model__), Model),
),
# Otherwise, throw error.
:(error($(string(name)) * " can only be used inside a model block")),
)
end

"""
need_concretize(expr)

Expand Down
63 changes: 63 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -664,3 +664,66 @@ function fixed(context::FixedContext)
# precedence over decendants of `context`.
return merge(context.values, fixed(childcontext(context)))
end

""""
Copy link
Member

@yebai yebai Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to move these new codes and the existing generated_quantities functions into a new file generated_quantities.jl, so it is self-contained.

PostProcessingContext

Simple context used to indicate that the model is being evaluated with the aim
of post-processing the inference results, e.g. making predictions or computing
generated quantities.
"""
struct PostProcessingContext{Ctx} <: AbstractContext
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
struct PostProcessingContext{Ctx} <: AbstractContext
struct GeneratedQuantitiesContext{Ctx} <: AbstractContext

context::AbstractContext
end

function PostProcessingContext(context::AbstractContext)
return PostProcessingContext{typeof(context)}(context)
end
PostProcessingContext() = PostProcessingContext(DefaultContext())

NodeTrait(::PostProcessingContext) = IsParent()
childcontext(context::PostProcessingContext) = context.context
function setchildcontext(context::PostProcessingContext, child)
return PostProcessingContext(child)
end

function is_post_processing(context::AbstractContext)
return is_post_processing(NodeTrait(is_post_processing, context), context)
end
is_post_processing(::IsLeaf, context) = false
is_post_processing(::IsParent, context) = is_post_processing(childcontext(context))
is_post_processing(context::PostProcessingContext) = true

"""
@is_post_processing

Return `true` if the model is being evaluated with the aim of post-processing
inference results, e.g. making predictions or computing generated quantities.

# Examples

```jldoctest; setup = :(using Distributions)
julia> @model function demo()
x ~ Normal(0, 1)
return if @is_post_processing
x
else
nothing
end
end
demo (generic function with 2 methods)

julia> model = demo();

julia> model() # (✓) Returns nothing

julia> generated_quantities(model, (x = 1,)) # (✓) Returns 1.0
1.0
```
"""
macro is_post_processing()
return quote
$(check_if_in_model_block_expr("@is_post_processing"))
$(is_post_processing)($(esc(:__context__)))
end
end
6 changes: 3 additions & 3 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1253,7 +1253,7 @@ function generated_quantities(model::Model, chain::AbstractChains)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
model(varinfo)
model(varinfo, PostProcessingContext())
end
end

Expand Down Expand Up @@ -1295,11 +1295,11 @@ julia> generated_quantities(model, values(parameters), keys(parameters))
function generated_quantities(model::Model, parameters::NamedTuple)
varinfo = VarInfo(model)
setval_and_resample!(varinfo, values(parameters), keys(parameters))
return model(varinfo)
return model(varinfo, PostProcessingContext())
end

function generated_quantities(model::Model, values, keys)
varinfo = VarInfo(model)
setval_and_resample!(varinfo, values, keys)
return model(varinfo)
return model(varinfo, PostProcessingContext())
end
1 change: 1 addition & 0 deletions src/submodel_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ function submodel(prefix_expr, expr, ctx=esc(:__context__))
)
end
quote
$(check_if_in_model_block_expr("@submodel"))
$retval, $(esc(:__varinfo__)) = $(_evaluate!!)(
$(esc(R)), $(esc(:__varinfo__)), $(ctx)
)
Expand Down
1 change: 1 addition & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ true
"""
macro addlogprob!(ex)
return quote
$(check_if_in_model_block_expr("@addlogprob!"))
$(esc(:(__varinfo__))) = acclogp!!(
$(esc(:(__context__))), $(esc(:(__varinfo__))), $(esc(ex))
)
Expand Down
Loading