Skip to content

Commit

Permalink
Make DynamicHMC and Turing support into extensions (#114)
Browse files Browse the repository at this point in the history
* Move integrations to ext folder

* Make extensions modules

* Use Requires for older Julia versions

* Always load Requires

* Specify weakdeps and extensions

* Increment patch number

* Run CI on nightly as well

* Only run with 2 threads on latest version

* Revert "Only run with 2 threads on latest version"

This reverts commit 1b84007.

* Revert "Run CI on nightly as well"

This reverts commit 1646855.

* Simplify job matrix

* Run DynamicHMC and Turing integration tests on nightly

* Load modules using syntax required by Requires

* Add Julia version to name

* Require Turing dependencies used

* Use double-dot syntax only for conditionally loaded deps

* Increment version number
  • Loading branch information
sethaxen authored Feb 1, 2023
1 parent b5d1917 commit 86b8285
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 29 deletions.
13 changes: 8 additions & 5 deletions .github/workflows/IntegrationTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,27 @@ on:
pull_request:
jobs:
test:
name: ${{ matrix.package }}
runs-on: ${{ matrix.os }}
name: ${{ matrix.package }} - Julia ${{ matrix.version }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
version: [1]
os: [ubuntu-latest]
arch: [x64]
package:
- DynamicHMC
- AdvancedHMC
- Turing
include:
- package: DynamicHMC
version: 'nightly'
- package: Turing
version: 'nightly'
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
arch: x64
- uses: julia-actions/julia-buildpkg@v1
- run: |
julia --code-coverage=user -e '
Expand Down
19 changes: 18 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Pathfinder"
uuid = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
authors = ["Seth Axen <[email protected]> and contributors"]
version = "0.7.0"
version = "0.7.1"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand All @@ -25,13 +25,25 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[extensions]
DynamicHMCExt = "DynamicHMC"
TuringExt = ["DynamicPPL", "MCMCChains", "Turing"]

[compat]
Accessors = "0.1"
Distributions = "0.25"
DynamicPPL = "0.20, 0.21"
Folds = "0.2"
ForwardDiff = "0.10"
IrrationalConstants = "0.1.1"
LogDensityProblems = "2"
MCMCChains = "5"
Optim = "1.4"
Optimization = "3"
OptimizationOptimJL = "0.1"
Expand All @@ -42,12 +54,17 @@ Requires = "1"
SciMLBase = "1.8.1"
StatsBase = "0.33"
Transducers = "0.4.5"
Turing = "0.21, 0.22, 0.23"
UnPack = "1"
julia = "1.6"

[extras]
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
OptimizationNLopt = "4e6fcdb7-1186-4e1f-a706-475e75c168bb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[targets]
test = ["OptimizationNLopt", "Test"]
22 changes: 22 additions & 0 deletions ext/DynamicHMCExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module DynamicHMCExt

using PDMats: PDMats
if isdefined(Base, :get_extension)
using Pathfinder: Pathfinder
using DynamicHMC: DynamicHMC
else # using Requires
using ..Pathfinder: Pathfinder
using ..DynamicHMC: DynamicHMC
end

function DynamicHMC.GaussianKineticEnergy(M⁻¹::Pathfinder.WoodburyPDMat)
return DynamicHMC.GaussianKineticEnergy(M⁻¹, inv(Pathfinder.pdfactorize(M⁻¹).R))
end

function DynamicHMC.kinetic_energy(
κ::DynamicHMC.GaussianKineticEnergy{<:Pathfinder.WoodburyPDMat}, p, q=nothing
)
return PDMats.quad.M⁻¹, p) / 2
end

end # module
29 changes: 22 additions & 7 deletions src/integration/turing.jl → ext/TuringExt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
using .Turing: Turing, DynamicPPL, MCMCChains
module TuringExt

using Accessors: Accessors
using Random: Random
if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL
using MCMCChains: MCMCChains
using Pathfinder: Pathfinder
using Turing: Turing
else # using Requires
using ..DynamicPPL: DynamicPPL
using ..MCMCChains: MCMCChains
using ..Pathfinder: Pathfinder
using ..Turing: Turing
end

# utilities for working with Turing model parameter names using only the DynamicPPL API

Expand Down Expand Up @@ -119,30 +132,30 @@ function varnames_to_ranges(metadata::DynamicPPL.Metadata)
return Dict(zip(metadata.vns, ranges))
end

function pathfinder(
function Pathfinder.pathfinder(
model::DynamicPPL.Model;
rng=Random.GLOBAL_RNG,
init_scale=2,
init_sampler=UniformSampler(init_scale),
init_sampler=Pathfinder.UniformSampler(init_scale),
init=nothing,
kwargs...,
)
var_names = flattened_varnames_list(model)
prob = Turing.optim_problem(model, Turing.MAP(); constrained=false, init_theta=init)
init_sampler(rng, prob.prob.u0)
result = pathfinder(prob.prob; rng, input=model, kwargs...)
result = Pathfinder.pathfinder(prob.prob; rng, input=model, kwargs...)
draws = reduce(vcat, transpose.(prob.transform.(eachcol(result.draws))))
chns = MCMCChains.Chains(draws, var_names; info=(; pathfinder_result=result))
result_new = Accessors.@set result.draws_transformed = chns
return result_new
end

function multipathfinder(
function Pathfinder.multipathfinder(
model::DynamicPPL.Model,
ndraws::Int;
rng=Random.GLOBAL_RNG,
init_scale=2,
init_sampler=UniformSampler(init_scale),
init_sampler=Pathfinder.UniformSampler(init_scale),
nruns::Int,
kwargs...,
)
Expand All @@ -153,9 +166,11 @@ function multipathfinder(
for _ in 2:nruns
push!(init, init_sampler(rng, deepcopy(init1)))
end
result = multipathfinder(fun.func, ndraws; rng, input=model, init, kwargs...)
result = Pathfinder.multipathfinder(fun.func, ndraws; rng, input=model, init, kwargs...)
draws = reduce(vcat, transpose.(fun.transform.(eachcol(result.draws))))
chns = MCMCChains.Chains(draws, var_names; info=(; pathfinder_result=result))
result_new = Accessors.@set result.draws_transformed = chns
return result_new
end

end # module
16 changes: 11 additions & 5 deletions src/Pathfinder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,20 @@ include("singlepath.jl")
include("multipath.jl")

function __init__()
Requires.@require DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" begin
include("integration/dynamichmc.jl")
end
Requires.@require AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" begin
include("integration/advancedhmc.jl")
end
Requires.@require Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" begin
include("integration/turing.jl")
@static if !isdefined(Base, :get_extension)
Requires.@require DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" begin
include("../ext/DynamicHMCExt.jl")
end
Requires.@require Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" begin
Requires.@require DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" begin
Requires.@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" begin
include("../ext/TuringExt.jl")
end
end
end
end
end

Expand Down
11 changes: 0 additions & 11 deletions src/integration/dynamichmc.jl

This file was deleted.

2 comments on commit 86b8285

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/76806

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.1 -m "<description of version>" 86b8285ac2ae048f4dd7a6deb48c9b843721b31b
git push origin v0.7.1

Please sign in to comment.