Skip to content

Commit

Permalink
thread safety (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevengj authored Aug 30, 2024
1 parent 6003c00 commit 07e6c23
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
julia = "1.6"
Combinatorics = "1.0"
DataStructures = "0.15, 0.16, 0.17, 0.18"
QuadGK = "2"
QuadGK = "2.1"
StaticArrays = "1.6.4"
LinearAlgebra = "<0.0.1, 1"
Test = "<0.0.1, 1"
Expand Down
9 changes: 2 additions & 7 deletions src/gauss-kronrod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,9 @@ struct GaussKronrod{T<:Real}
wg::Vector{T}
end

# cache the Gauss-Kronrod rules so that we don't
# call QuadGK.kronrod every time.
const gkcache = Dict{Type, GaussKronrod}()

function GaussKronrod(::Type{T}) where {T<:Real}
haskey(gkcache, T) && return gkcache[T]::GaussKronrod{T}
gkcache[T] = g = GaussKronrod{T}(QuadGK.kronrod(T,7)...)
return g
# use QuadGK's internal rule cache
return GaussKronrod{T}(QuadGK.cachedrule(T, 7)...)
end

# further speed up the common case of double precision (25% faster for a trivial integrand)
Expand Down
41 changes: 29 additions & 12 deletions src/genz-malik.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,11 @@ end
# cache the Genz-Malik rules so that we don't reconstruct them every time;
# this mainly matters for simple integrands (low-degree polynomials) that
# don't require refinement.
const gmcache = Dict{Tuple{Int,Type}, GenzMalik}()

"""
GenzMalik(Val{n}(), T=Float64)
Construct an n-dimensional Genz-Malik rule for coordinates of type `T`.
"""
function GenzMalik(v::Val{n}, ::Type{T}=Float64) where {n, T<:Real}
haskey(gmcache, (n,T)) && return gmcache[n,T]::GenzMalik{n,T}
const gmcache = Dict{Tuple{Int,Type,Int}, GenzMalik}()
const gmcache_lock = ReentrantLock() # thread-safety

# internal code to construct n-dimensional Genz-Malik rule for coordinates of type `T`.
function _GenzMalik(v::Val{n}, ::Type{T}) where {n, T<:Real}
n < 2 && throw(ArgumentError("invalid dimension $n: GenzMalik rule requires dimension > 2"))

λ₄ = sqrt(9/T(10))
Expand All @@ -98,11 +93,33 @@ function GenzMalik(v::Val{n}, ::Type{T}=Float64) where {n, T<:Real}
p₄ = signcombos(2, λ₄, v)
p₅ = signcombos(n, λ₅, v)

g = GenzMalik{n,T}((p₂,p₃,p₄,p₅), (w₁,w₂,w₃,w₄,w₅), (w₁′,w₂′,w₃′,w₄′))
gmcache[n,T] = g
return g
return GenzMalik{n,T}((p₂,p₃,p₄,p₅), (w₁,w₂,w₃,w₄,w₅), (w₁′,w₂′,w₃′,w₄′))
end

"""
GenzMalik(Val{n}(), T=Float64)
Construct an n-dimensional Genz-Malik rule for coordinates of type `T`.
"""
function GenzMalik(v::Val{n}, ::Type{T}=Float64) where {n, T<:Real}
lock(gmcache_lock)
try
p = precision(T)
haskey(gmcache, (n,T,p)) && return gmcache[n,T,p]::GenzMalik{n,T}
return gmcache[n,T,p] = _GenzMalik(v, T)
finally
unlock(gmcache_lock)
end
end

# speed up common low-dimensional Float64 case:
for n in 2:4
gm = Symbol(:_gm, n)
@eval const $gm = _GenzMalik(Val($n), Float64)
@eval GenzMalik(::Val{$n}, ::Type{Float64}) = $gm
end


countevals(g::GenzMalik{n}) where {n} = 1 + 4n + 2*n*(n-1) + (1<<n)

"""
Expand Down
7 changes: 6 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@ using Test

@testset "simple" begin
@test hcubature(x -> cos(x[1])*cos(x[2]), [0,0], [1,1])[1] sin(1)^2
@inferred(hcubature(x -> cos(x[1])*cos(x[2]), (0,0), (1,1)))[1]
@inferred(hcubature(x -> cos(x[1])*cos(x[2]), (0,0), (1,1)))[1]
@inferred(hcubature(x -> cos(x[1])*cos(x[2]), (0.0f0,0.0f0), (1.0f0,1.0f0)))[1]
@test @inferred(hcubature(x -> cos(x[1]), (0,), (1,)))[1] sin(1)
@inferred(hquadrature(cos, 0, 1))[1]
@test @inferred(hcubature(x -> cos(x[1]), (0.0f0,), (1.0f0,)))[1] sin(1.0f0)
@test @inferred(hcubature(x -> 1.7, SVector{0,Float64}(), SVector{0,Float64}()))[1] == 1.7
@test @inferred(hcubature(x -> 2, (0,0), (2pi, pi))[1]) 4pi^2
@test @inferred(hcubature(x -> 2, (0.0f0,0.0f0), (2.0f0*pi, 1.0f0*pi))[1]) 4pi^2
@test_throws DimensionMismatch hcubature(x -> 2, [0,0,0], [2,0])
for d in 1:5
@test hcubature(x -> 1, fill(0,d), fill(1,d))[1] 1 rtol=1e-13
end
end

# function wrapper for counting evaluations
Expand Down

0 comments on commit 07e6c23

Please sign in to comment.