Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use values_as_in_model to extract the parameters from a Transition #2202

Merged
merged 10 commits into from
May 7, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.31.3"
version = "0.31.4"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
14 changes: 9 additions & 5 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,15 @@ Return a named tuple of parameters.
"""
getparams(model, t) = t.θ
function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
# Want the end-user to receive parameters in constrained space, so we `link`.
vi = DynamicPPL.invlink(vi, model)

# Extract parameter values in a simple form from the `VarInfo`.
vals = DynamicPPL.values_as(vi, OrderedDict)
# NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
# Unfortunately, using `invlink` can cause issues in scenarios where the constraints
# of the parameters change depending on the realizations. Hence we have to use
# `values_as_in_model`, which re-runs the model and extracts the parameters
# as they are seen in the model, i.e. in the constrained space. Moreover,
# this means that the code below will work both of linked and invlinked `vi`.
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
vals = DynamicPPL.values_as_in_model(model, deepcopy(vi))

# Obtain an iterator over the flattened parameter names and values.
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Expand Down Expand Up @@ -43,6 +44,7 @@ DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.25.1"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
HypothesisTests = "0.11"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.4"
MCMCChains = "5, 6"
Expand Down
22 changes: 22 additions & 0 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,4 +261,26 @@
@test mean(Array(chain)) ≈ 0.2
end
end

@turing_testset "issue: #2195" begin
@model function buggy_model()
lb ~ Uniform(0, 0.1)
ub ~ Uniform(0.11, 0.2)
x ~ transformed(Normal(0, 1), inverse(Bijectors.Logit(lb, ub)))
end

model = buggy_model();

chain = sample(model, NUTS(), 1000);
chain_prior = sample(model, Prior(), 1000);

# Extract the `x` like this because running `generated_quantities` was how
# the issue was discovered, hence we also want to make sure that it works.
results = generated_quantities(model, chain);
results_prior = generated_quantities(model, chain_prior);

# The discrepancies in the chains are in the tails, so we can't just compare the mean, etc.
# KS will compare the empirical CDFs, which seems like a reasonable thing to do here.
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.05
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using ReverseDiff
using SpecialFunctions
using StatsBase
using StatsFuns
using HypothesisTests
using Tracker
using Turing
using Turing.Inference
Expand Down
Loading