diff --git a/Project.toml b/Project.toml index de15d32e2..785ad7a14 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.31.3" +version = "0.31.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 7e429eefa..311632a3b 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -317,11 +317,15 @@ Return a named tuple of parameters. """ getparams(model, t) = t.θ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo) - # Want the end-user to receive parameters in constrained space, so we `link`. - vi = DynamicPPL.invlink(vi, model) - - # Extract parameter values in a simple form from the `VarInfo`. - vals = DynamicPPL.values_as(vi, OrderedDict) + # NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used. + # Unfortunately, using `invlink` can cause issues in scenarios where the constraints + # of the parameters change depending on the realizations. Hence we have to use + # `values_as_in_model`, which re-runs the model and extracts the parameters + # as they are seen in the model, i.e. in the constrained space. Moreover, + # this means that the code below will work both of linked and invlinked `vi`. + # Ref: https://github.com/TuringLang/Turing.jl/issues/2195 + # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`. + vals = DynamicPPL.values_as_in_model(model, deepcopy(vi)) # Obtain an iterator over the flattened parameter names and values. iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) diff --git a/test/Project.toml b/test/Project.toml index a02b14a94..f743f3bbc 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,6 +10,7 @@ DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" @@ -43,6 +44,7 @@ DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.25.1" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" +HypothesisTests = "0.11" LogDensityProblems = "2" LogDensityProblemsAD = "1.4" MCMCChains = "5, 6" diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 7206a977f..755fa4b45 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -261,4 +261,46 @@ @test mean(Array(chain)) ≈ 0.2 end end + + @turing_testset "issue: #2195" begin + @model function buggy_model() + lb ~ Uniform(0, 1) + ub ~ Uniform(1.5, 2) + + # HACK: Necessary to avoid NUTS failing during adaptation. + try + x ~ transformed(Normal(0, 1), inverse(Bijectors.Logit(lb, ub))) + catch e + if e isa DomainError + Turing.@addlogprob! -Inf + return nothing + else + rethrow() + end + end + end + + model = buggy_model(); + num_samples = 1_000; + + chain = sample( + model, + NUTS(), + num_samples; + initial_params=[0.5, 1.75, 1.0] + ) + chain_prior = sample(model, Prior(), num_samples) + + # Extract the `x` like this because running `generated_quantities` was how + # the issue was discovered, hence we also want to make sure that it works. + results = generated_quantities(model, chain) + results_prior = generated_quantities(model, chain_prior) + + # Make sure none of the samples in the chains resulted in errors. + @test all(!isnothing, results) + + # The discrepancies in the chains are in the tails, so we can't just compare the mean, etc. + # KS will compare the empirical CDFs, which seems like a reasonable thing to do here. + @test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.01 + end end diff --git a/test/runtests.jl b/test/runtests.jl index 1e82a3f57..8f9db1d1d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,6 +17,7 @@ using ReverseDiff using SpecialFunctions using StatsBase using StatsFuns +using HypothesisTests using Tracker using Turing using Turing.Inference