Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize multi-path Pathfinder with multi-threading #11

Closed
wants to merge 12 commits into from
7 changes: 6 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on:
- pull_request
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ matrix.num_threads }} threads - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
Expand All @@ -15,6 +15,9 @@ jobs:
- ubuntu-latest
arch:
- x64
num_threads:
- 1
- 2
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand All @@ -33,6 +36,8 @@ jobs:
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
JULIA_NUM_THREADS: ${{ matrix.num_threads }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v1
with:
Expand Down
38 changes: 26 additions & 12 deletions src/multipath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,42 @@ function multipathfinder(
end

# run pathfinder independently from each starting point
# TODO: allow to be parallelized
res = map(θ₀s) do θ₀
return pathfinder(logp, ∇logp, θ₀, ndraws_per_run; rng=rng, kwargs...)
nruns = length(θ₀s)
q₁, ϕ₁, logqϕ₁ = pathfinder(logp, ∇logp, θ₀s[1], ndraws_per_run; rng=rng, kwargs...)
qs = Vector{typeof(q₁)}(undef, nruns)
ϕs = Vector{typeof(ϕ₁)}(undef, nruns)
logqϕs = Vector{typeof(logqϕ₁)}(undef, nruns)
qs[1], ϕs[1], logqϕs[1] = q₁, ϕ₁, logqϕ₁

thread_range = 1:min(Threads.nthreads(), nruns)
rngs = [deepcopy(rng) for _ in thread_range]
logps = [deepcopy(logp) for _ in thread_range]
∇logps = [deepcopy(∇logp) for _ in thread_range]
seeds = rand(rng, UInt, nruns - 1)

Threads.@threads for i in 2:nruns
id = Threads.threadid()
rngᵢ = rngs[id]
Random.seed!(rngᵢ, seeds[i - 1])
qs[i], ϕs[i], logqϕs[i] = pathfinder(
logps[id], ∇logps[id], θ₀s[i], ndraws_per_run; rng=rngᵢ, kwargs...
)
end
qs = reduce(vcat, first.(res))
ϕs = reduce(hcat, getindex.(res, 2))
ϕ = reduce(hcat, ϕs)

# draw samples from augmented mixture model
inds = axes(ϕs, 2)
inds = axes(ϕ, 2)
sample_inds = if importance
logqϕs = reduce(vcat, last.(res))
log_ratios = map(((ϕ, logqϕ),) -> logp(ϕ) - logqϕ, zip(eachcol(ϕs), logqϕs))
logqϕ = reduce(vcat, logqϕs)
log_ratios = logp.(eachcol(ϕ)) .- logqϕ
resample(rng, inds, log_ratios, ndraws)
else
resample(rng, inds, ndraws)
end

q = Distributions.MixtureModel(qs)
ϕ = ϕs[:, sample_inds]

qmix = Distributions.MixtureModel(qs)
# get component ids (k) of draws in ϕ
component_ids = cld.(sample_inds, ndraws_per_run)

return q, ϕ, component_ids
return qmix, ϕ[:, sample_inds], component_ids
end
6 changes: 4 additions & 2 deletions test/multipath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ using ForwardDiff
using Pathfinder
using Test

include("test_utils.jl")

@testset "multi path pathfinder" begin
@testset "MvNormal" begin
n = 10
nruns = 20
ndraws = 1000_000
ndraws = 1_000_000
ndraws_per_run = ndraws ÷ nruns
Σ = rand_pd_mat(Float64, n)
μ = randn(n)
Expand All @@ -30,7 +32,7 @@ using Test
μ_hat = mean(ϕ; dims=2)
Σ_hat = cov(ϕ .- μ_hat; dims=2, corrected=false)
# adapted from the MvNormal tests
# allow for 10x disagreement in atol, since this method is approximate
# allow for 10x more disagreement in atol, since this method is approximate
multiplier = 10
for i in 1:n
atol = sqrt(Σ[i, i] / ndraws) * 8 * multiplier
Expand Down