Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix type stability of sampling from Chisq, TDist, Gamma #1885

Merged
merged 12 commits into from
Aug 23, 2024
Merged
2 changes: 1 addition & 1 deletion src/samplers/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,6 @@ end

function rand(rng::AbstractRNG, s::GammaIPSampler)
x = rand(rng, s.s)
e = randexp(rng)
e = randexp(rng, typeof(x))
x*exp(s.nia*e)
end
2 changes: 1 addition & 1 deletion src/univariate/continuous/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ cf(d::Exponential, t::Real) = 1/(1 - t * im * scale(d))


#### Sampling
rand(rng::AbstractRNG, d::Exponential) = xval(d, randexp(rng))
rand(rng::AbstractRNG, d::Exponential{T}) where {T} = xval(d, randexp(rng, float(T)))


#### Fit model
Expand Down
2 changes: 1 addition & 1 deletion src/univariate/continuous/tdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
function rand(rng::AbstractRNG, d::TDist)
ν = d.ν
z = sqrt(rand(rng, Chisq{typeof(ν)}(ν)) / ν)
return randn(rng) / (isinf(ν) ? one(z) : z)
return randn(rng, typeof(z)) / (isinf(ν) ? one(z) : z)
end

function cf(d::TDist{T}, t::Real) where T <: Real
Expand Down
11 changes: 9 additions & 2 deletions test/univariate/continuous/chisq.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
test_cgf(Chisq(1), (0.49, -1, -100, -1f6))
test_cgf(Chisq(3), (0.49, -1, -100, -1f6))

@testset "Chisq" begin
test_cgf(Chisq(1), (0.49, -1, -100, -1.0f6))
test_cgf(Chisq(3), (0.49, -1, -100, -1.0f6))

for T in (Float32, Float64)
@test @inferred(rand(Chisq(T(1)))) isa T
end
end
12 changes: 9 additions & 3 deletions test/univariate/continuous/exponential.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@

test_cgf(Exponential(1), (0.9, -1, -100f0, -1e6))
test_cgf(Exponential(0.91), (0.9, -1, -100f0, -1e6))
test_cgf(Exponential(10 ), (0.08, -1, -100f0, -1e6))
@testset "Exponential" begin
test_cgf(Exponential(1), (0.9, -1, -100f0, -1e6))
test_cgf(Exponential(0.91), (0.9, -1, -100f0, -1e6))
test_cgf(Exponential(10 ), (0.08, -1, -100f0, -1e6))

for T in (Float32, Float64)
@test @inferred(rand(Exponential(T(1)))) isa T
end
end
38 changes: 23 additions & 15 deletions test/univariate/continuous/gamma.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
using Test, Distributions, OffsetArrays

test_cgf(Gamma(1 ,1 ), (0.9, -1, -100f0, -1e6))
test_cgf(Gamma(10 ,1 ), (0.9, -1, -100f0, -1e6))
test_cgf(Gamma(0.2, 10), (0.08, -1, -100f0, -1e6))
@testset "Gamma" begin
test_cgf(Gamma(1, 1), (0.9, -1, -100.0f0, -1e6))
test_cgf(Gamma(10, 1), (0.9, -1, -100.0f0, -1e6))
test_cgf(Gamma(0.2, 10), (0.08, -1, -100.0f0, -1e6))

@testset "Gamma suffstats and OffsetArrays" begin
a = rand(Gamma(), 11)
wa = 1.0:11.0
@testset "Gamma suffstats and OffsetArrays" begin
a = rand(Gamma(), 11)
wa = 1.0:11.0

resulta = @inferred(suffstats(Gamma, a))
resulta = @inferred(suffstats(Gamma, a))

resultwa = @inferred(suffstats(Gamma, a, wa))
resultwa = @inferred(suffstats(Gamma, a, wa))

b = OffsetArray(a, -5:5)
wb = OffsetArray(wa, -5:5)
b = OffsetArray(a, -5:5)
wb = OffsetArray(wa, -5:5)

resultb = @inferred(suffstats(Gamma, b))
@test resulta == resultb
resultb = @inferred(suffstats(Gamma, b))
@test resulta == resultb

resultwb = @inferred(suffstats(Gamma, b, wb))
@test resultwa == resultwb
resultwb = @inferred(suffstats(Gamma, b, wb))
@test resultwa == resultwb

@test_throws DimensionMismatch suffstats(Gamma, a, wb)
@test_throws DimensionMismatch suffstats(Gamma, a, wb)
end

for T in (Float32, Float64)
@test @inferred(rand(Gamma(T(1), T(1)))) isa T
@test @inferred(rand(Gamma(1/T(2), T(1)))) isa T
@test @inferred(rand(Gamma(T(2), T(1)))) isa T
end
end
17 changes: 12 additions & 5 deletions test/univariate/continuous/tdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@ using ForwardDiff

using Test

@testset "Type stability of `rand` (#1614)" begin
if VERSION >= v"1.9.0-DEV.348"
# randn(::BigFloat) was only added in https://github.com/JuliaLang/julia/pull/44714
@inferred(rand(TDist(big"1.0")))
@testset "TDist" begin
@testset "Type stability of `rand` (#1614)" begin
if VERSION >= v"1.9.0-DEV.348"
# randn(::BigFloat) was only added in https://github.com/JuliaLang/julia/pull/44714
@inferred(rand(TDist(big"1.0")))
end
@inferred(rand(TDist(ForwardDiff.Dual(1.0))))

end

for T in (Float32, Float64)
@test @inferred(rand(TDist(T(1)))) isa T
end
@inferred(rand(TDist(ForwardDiff.Dual(1.0))))
end
Loading