Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Turing v0.33 #189

Merged
merged 20 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Pathfinder"
uuid = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
authors = ["Seth Axen <[email protected]> and contributors"]
version = "0.9.0-DEV"
version = "0.9.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -41,8 +41,8 @@ ADTypes = "0.2, 1"
Accessors = "0.1.12"
Distributions = "0.25.87"
DynamicHMC = "3.4.0"
DynamicPPL = "0.24.7, 0.25, 0.27"
Folds = "0.2.2"
DynamicPPL = "0.25.2, 0.27"
Folds = "0.2.9"
ForwardDiff = "0.10.19"
IrrationalConstants = "0.1.1, 0.2"
LinearAlgebra = "1.6"
Expand All @@ -61,8 +61,8 @@ ReverseDiff = "1.4.5"
SciMLBase = "1.95.0, 2"
Statistics = "1.6"
StatsBase = "0.33.7, 0.34"
Transducers = "0.4.66"
Turing = "0.30.5, 0.31, 0.32"
Transducers = "0.4.81"
Turing = "0.31.4, 0.32, 0.33"
UnPack = "1"
julia = "1.6"

Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ StatsFuns = "1"
StatsPlots = "0.14.21, 0.15"
TransformVariables = "0.6.2, 0.7, 0.8"
TransformedLogDensities = "1.0.2"
Turing = "0.30.5, 0.31, 0.32"
Turing = "0.31.4, 0.32, 0.33"
173 changes: 55 additions & 118 deletions ext/PathfinderTuringExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,149 +2,86 @@ module PathfinderTuringExt

if isdefined(Base, :get_extension)
using Accessors: Accessors
using ADTypes: ADTypes
using DynamicPPL: DynamicPPL
using MCMCChains: MCMCChains
using Pathfinder: Pathfinder
using Random: Random
using Turing: Turing
import Pathfinder: flattened_varnames_list
else # using Requires
using ..Accessors: Accessors
using ..ADTypes: ADTypes
using ..DynamicPPL: DynamicPPL
using ..MCMCChains: MCMCChains
using ..Pathfinder: Pathfinder
using ..Random: Random
using ..Turing: Turing
import ..Pathfinder: flattened_varnames_list
end

# utilities for working with Turing model parameter names using only the DynamicPPL API

function Pathfinder.flattened_varnames_list(model::DynamicPPL.Model)
varnames_ranges = varnames_to_ranges(model)
nsyms = maximum(maximum, values(varnames_ranges))
syms = Vector{Symbol}(undef, nsyms)
for (var_name, range) in varnames_to_ranges(model)
sym = Symbol(var_name)
if length(range) == 1
syms[range[begin]] = sym
continue
end
for i in eachindex(range)
syms[range[i]] = Symbol("$sym[$i]")
end
end
return syms
end

# code snippet shared by @torfjelde
"""
varnames_to_ranges(model::DynamicPPL.Model)
varnames_to_ranges(model::DynamicPPL.VarInfo)
varnames_to_ranges(model::DynamicPPL.Metadata)

Get `Dict` mapping variable names in model to their ranges in a corresponding parameter vector.
create_log_density_problem(model::DynamicPPL.Model)

# Examples
Create a log density problem from a `model`.

```julia
julia> @model function demo()
s ~ Dirac(1)
x = Matrix{Float64}(undef, 2, 4)
x[1, 1] ~ Dirac(2)
x[2, 1] ~ Dirac(3)
x[3] ~ Dirac(4)
y ~ Dirac(5)
x[4] ~ Dirac(6)
x[:, 3] ~ arraydist([Dirac(7), Dirac(8)])
x[[2, 1], 4] ~ arraydist([Dirac(9), Dirac(10)])
return s, x, y
end
demo (generic function with 2 methods)

julia> demo()()
(1, Any[2.0 4.0 7 10; 3.0 6.0 8 9], 5)
The return value is an object implementing the LogDensityProblems API whose log-density is
that of the `model` transformed to unconstrained space with the appropriate log-density
adjustment due to change of variables.
"""
function create_log_density_problem(model::DynamicPPL.Model)
# create an unconstrained VarInfo
varinfo = DynamicPPL.link(DynamicPPL.VarInfo(model), model)
# DefaultContext ensures that the log-density adjustment is computed
prob = DynamicPPL.LogDensityFunction(varinfo, model, DynamicPPL.DefaultContext())
return prob
end

julia> varnames_to_ranges(demo())
Dict{AbstractPPL.VarName, UnitRange{Int64}} with 8 entries:
s => 1:1
x[4] => 5:5
x[:,3] => 6:7
x[1,1] => 2:2
x[2,1] => 3:3
x[[2, 1],4] => 8:9
x[3] => 4:4
y => 10:10
```
"""
function varnames_to_ranges end
draws_to_chains(model::DynamicPPL.Model, draws) -> MCMCChains.Chains

varnames_to_ranges(model::DynamicPPL.Model) = varnames_to_ranges(DynamicPPL.VarInfo(model))
function varnames_to_ranges(varinfo::DynamicPPL.UntypedVarInfo)
return varnames_to_ranges(varinfo.metadata)
end
function varnames_to_ranges(varinfo::DynamicPPL.TypedVarInfo)
offset = 0
dicts = map(varinfo.metadata) do md
vns2ranges = varnames_to_ranges(md)
vals = collect(values(vns2ranges))
vals_offset = map(r -> offset .+ r, vals)
offset += reduce((curr, r) -> max(curr, r[end]), vals; init=0)
Dict(zip(keys(vns2ranges), vals_offset))
Convert a `(nparams, ndraws)` matrix of unconstrained `draws` to an `MCMCChains.Chains`
object with corresponding constrained draws and names according to `model`.
"""
function draws_to_chains(model::DynamicPPL.Model, draws::AbstractMatrix)
varinfo = DynamicPPL.link(DynamicPPL.VarInfo(model), model)
draw_con_varinfos = map(eachcol(draws)) do draw
# this re-evaluates the model but allows supporting dynamic bijectors
# https://github.com/TuringLang/Turing.jl/issues/2195
return Turing.Inference.getparams(model, DynamicPPL.unflatten(varinfo, draw))
end

return reduce(merge, dicts)
end
function varnames_to_ranges(metadata::DynamicPPL.Metadata)
idcs = map(Base.Fix1(getindex, metadata.idcs), metadata.vns)
ranges = metadata.ranges[idcs]
return Dict(zip(metadata.vns, ranges))
param_con_names = map(first, first(draw_con_varinfos))
draws_con = reduce(
vcat, Iterators.map(transpose ∘ Base.Fix1(map, last), draw_con_varinfos)
)
chns = MCMCChains.Chains(draws_con, param_con_names)
return chns
end

function Pathfinder.pathfinder(
model::DynamicPPL.Model;
rng=Random.GLOBAL_RNG,
init_scale=2,
init_sampler=Pathfinder.UniformSampler(init_scale),
init=nothing,
adtype::ADTypes.AbstractADType=Pathfinder.default_ad(),
kwargs...,
)
var_names = flattened_varnames_list(model)
prob = Turing.optim_problem(
model, Turing.MAP(); constrained=false, init_theta=init, adtype
)
init_sampler(rng, prob.prob.u0)
result = Pathfinder.pathfinder(prob.prob; rng, input=model, kwargs...)
draws = reduce(vcat, transpose.(prob.transform.(eachcol(result.draws))))
chns = MCMCChains.Chains(draws, var_names; info=(; pathfinder_result=result))
result_new = Accessors.@set result.draws_transformed = chns
function Pathfinder.pathfinder(model::DynamicPPL.Model; kwargs...)
log_density_problem = create_log_density_problem(model)
result = Pathfinder.pathfinder(log_density_problem; input=model, kwargs...)

# add transformed draws as Chains
chains_info = (; pathfinder_result=result)
chains = Accessors.@set draws_to_chains(model, result.draws).info = chains_info
result_new = Accessors.@set result.draws_transformed = chains
return result_new
end

function Pathfinder.multipathfinder(
model::DynamicPPL.Model,
ndraws::Int;
rng=Random.GLOBAL_RNG,
init_scale=2,
init_sampler=Pathfinder.UniformSampler(init_scale),
nruns::Int,
adtype=Pathfinder.default_ad(),
kwargs...,
)
var_names = flattened_varnames_list(model)
fun = Turing.optim_function(model, Turing.MAP(); constrained=false, adtype)
init1 = fun.init()
init = [init_sampler(rng, init1)]
for _ in 2:nruns
push!(init, init_sampler(rng, deepcopy(init1)))
function Pathfinder.multipathfinder(model::DynamicPPL.Model, ndraws::Int; kwargs...)
log_density_problem = create_log_density_problem(model)
result = Pathfinder.multipathfinder(log_density_problem, ndraws; input=model, kwargs...)

# add transformed draws as Chains
chains_info = (; pathfinder_result=result)
chains = Accessors.@set draws_to_chains(model, result.draws).info = chains_info

# add transformed draws as Chains for each individual path
single_path_results_new = map(result.pathfinder_results) do r
single_chains_info = (; pathfinder_result=r)
single_chains = Accessors.@set draws_to_chains(model, r.draws).info =
single_chains_info
r_new = Accessors.@set r.draws_transformed = single_chains
return r_new
end
result = Pathfinder.multipathfinder(fun.func, ndraws; rng, input=model, init, kwargs...)
draws = reduce(vcat, transpose.(fun.transform.(eachcol(result.draws))))
chns = MCMCChains.Chains(draws, var_names; info=(; pathfinder_result=result))
result_new = Accessors.@set result.draws_transformed = chns

result_new = Accessors.@set (Accessors.@set result.draws_transformed =
chains).pathfinder_results = single_path_results_new
return result_new
end

Expand Down
2 changes: 0 additions & 2 deletions src/Pathfinder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ include("resample.jl")
include("singlepath.jl")
include("multipath.jl")

include("integration/turing.jl")

function __init__()
Requires.@require AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" begin
include("integration/advancedhmc.jl")
Expand Down
41 changes: 0 additions & 41 deletions src/integration/turing.jl

This file was deleted.

2 changes: 1 addition & 1 deletion test/integration/AdvancedHMC/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
TransformedLogDensities = "f9bc47f6-f3f8-4f3b-ab21-f8bc73906f26"

[compat]
AdvancedHMC = "0.4, 0.5.2, 0.6"
AdvancedHMC = "0.6"
Distributions = "0.25.87"
ForwardDiff = "0.10.19"
LogDensityProblems = "2.1.0"
Expand Down
5 changes: 4 additions & 1 deletion test/integration/Turing/Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
LogDensityProblems = "2.1.0"
Pathfinder = "0.9"
Turing = "0.30.5, 0.31, 0.32"
Turing = "0.31.4, 0.32, 0.33"
julia = "1.6"
Loading
Loading