Skip to content

Commit

Permalink
make TaskLocal the default RNG
Browse files Browse the repository at this point in the history
  • Loading branch information
JeffBezanson committed Apr 20, 2021
1 parent 66051fd commit db62af9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 45 deletions.
32 changes: 10 additions & 22 deletions stdlib/Random/src/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -554,45 +554,33 @@ function seed!(r::MersenneTwister, seed::Vector{UInt32})
return r
end

seed!(r::MersenneTwister=default_rng()) = seed!(r, make_seed())
seed!() = seed!(default_rng(), make_seed())
seed!(r::MersenneTwister) = seed!(r, make_seed())
seed!(r::MersenneTwister, n::Integer) = seed!(r, make_seed(n))
seed!(seed::Union{Integer,Vector{UInt32}}) = seed!(default_rng(), seed)
seed!(seed::Union{Integer,Vector{UInt32},Vector{UInt64}}) = seed!(default_rng(), seed)


### Global RNG

const THREAD_RNGs = MersenneTwister[]
@inline default_rng() = default_rng(Threads.threadid())
@noinline function default_rng(tid::Int)
0 < tid <= length(THREAD_RNGs) || _rng_length_assert()
if @inbounds isassigned(THREAD_RNGs, tid)
@inbounds MT = THREAD_RNGs[tid]
else
MT = MersenneTwister()
@inbounds THREAD_RNGs[tid] = MT
end
return MT
end
@noinline _rng_length_assert() = @assert false "0 < tid <= length(THREAD_RNGs)"
@inline default_rng() = TaskLocal()
@inline default_rng(tid::Int) = TaskLocal()

function __init__()
resize!(empty!(THREAD_RNGs), Threads.nthreads()) # ensures that we didn't save a bad object
seed!(TaskLocal())
end


struct _GLOBAL_RNG <: AbstractRNG
global const GLOBAL_RNG = _GLOBAL_RNG.instance
end

# GLOBAL_RNG currently represents a MersenneTwister
typeof_rng(::_GLOBAL_RNG) = MersenneTwister
# GLOBAL_RNG currently uses TaskLocal
typeof_rng(::_GLOBAL_RNG) = TaskLocal

copy!(dst::MersenneTwister, ::_GLOBAL_RNG) = copy!(dst, default_rng())
copy!(::_GLOBAL_RNG, src::MersenneTwister) = copy!(default_rng(), src)
copy!(dst::Xoshiro, ::_GLOBAL_RNG) = copy!(dst, default_rng())
copy!(::_GLOBAL_RNG, src::Xoshiro) = copy!(default_rng(), src)
copy(::_GLOBAL_RNG) = copy(default_rng())

seed!(::_GLOBAL_RNG, seed::Vector{UInt32}) = seed!(default_rng(), seed)
seed!(::_GLOBAL_RNG, seed::Union{Vector{UInt32}, Vector{UInt64}}) = seed!(default_rng(), seed)
seed!(::_GLOBAL_RNG, n::Integer) = seed!(default_rng(), n)
seed!(::_GLOBAL_RNG, ::Nothing) = seed!(default_rng(), nothing)
seed!(::_GLOBAL_RNG) = seed!(default_rng(), nothing)
Expand Down
37 changes: 18 additions & 19 deletions stdlib/Random/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ guardseed() do
m = MersenneTwister(0)
@test Random.seed!() === g
@test Random.seed!(rand(UInt)) === g
@test Random.seed!(rand(UInt32, rand(1:10))) === g
@test Random.seed!(rand(UInt32, rand(1:8))) === g
@test Random.seed!(m) === m
@test Random.seed!(m, rand(UInt)) === m
@test Random.seed!(m, rand(UInt32, rand(1:10))) === m
Expand Down Expand Up @@ -751,28 +751,27 @@ end
@test Random.seed!(GLOBAL_RNG, 0) === LOCAL_RNG
@test Random.seed!(GLOBAL_RNG) === LOCAL_RNG

mt = MersenneTwister(1)
@test copy!(mt, GLOBAL_RNG) === mt
@test mt == LOCAL_RNG
Random.seed!(mt, 2)
@test mt != LOCAL_RNG
@test copy!(GLOBAL_RNG, mt) === LOCAL_RNG
@test mt == LOCAL_RNG
mt2 = copy(GLOBAL_RNG)
@test mt2 isa typeof(LOCAL_RNG)
@test mt2 !== LOCAL_RNG
@test mt2 == LOCAL_RNG
xo = Xoshiro()
@test copy!(xo, GLOBAL_RNG) === xo
@test xo == LOCAL_RNG
Random.seed!(xo, 2)
@test xo != LOCAL_RNG
@test copy!(GLOBAL_RNG, xo) === LOCAL_RNG
@test xo == LOCAL_RNG
xo2 = copy(GLOBAL_RNG)
@test xo2 !== LOCAL_RNG
@test xo2 == LOCAL_RNG

for T in (Random.UInt52Raw{UInt64},
Random.UInt2x52Raw{UInt128},
Random.UInt104Raw{UInt128},
Random.CloseOpen12_64)
x = Random.SamplerTrivial(T())
@test rand(GLOBAL_RNG, x) === rand(mt, x)
@test rand(GLOBAL_RNG, x) === rand(xo, x)
end
for T in (Int64, UInt64, Int128, UInt128, Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32)
x = Random.SamplerType{T}()
@test rand(GLOBAL_RNG, x) === rand(mt, x)
@test rand(GLOBAL_RNG, x) === rand(xo, x)
end

A = fill(0.0, 100, 100)
Expand All @@ -781,21 +780,21 @@ end
vB = view(B, :, :)
I1 = Random.SamplerTrivial(Random.CloseOpen01{Float64}())
I2 = Random.SamplerTrivial(Random.CloseOpen12{Float64}())
@test rand!(GLOBAL_RNG, A, I1) === A == rand!(mt, B, I1) === B
@test rand!(GLOBAL_RNG, A, I1) === A == rand!(xo, B, I1) === B
B = fill!(B, 1.0)
@test rand!(GLOBAL_RNG, vA, I1) === vA
rand!(mt, vB, I1)
rand!(xo, vB, I1)
@test A == B
for T in (Float16, Float32)
B = fill!(B, 1.0)
@test rand!(GLOBAL_RNG, A, I2) === A == rand!(mt, B, I2) === B
@test rand!(GLOBAL_RNG, A, I2) === A == rand!(xo, B, I2) === B
B = fill!(B, 1.0)
@test rand!(GLOBAL_RNG, A, I1) === A == rand!(mt, B, I1) === B
@test rand!(GLOBAL_RNG, A, I1) === A == rand!(xo, B, I1) === B
end
for T in Base.BitInteger_types
x = Random.SamplerType{T}()
B = fill!(B, 1.0)
@test rand!(GLOBAL_RNG, A, x) === A == rand!(mt, B, x) === B
@test rand!(GLOBAL_RNG, A, x) === A == rand!(xo, B, x) === B
end
# issue #33170
@test Sampler(GLOBAL_RNG, 2:4, Val(1)) isa SamplerRangeNDL
Expand Down
5 changes: 1 addition & 4 deletions stdlib/Test/src/Test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1227,8 +1227,6 @@ function testset_beginend(args, tests, source)
local RNG = default_rng()
local oldrng = copy(RNG)
try
# RNG is re-seeded with its own seed to ease reproduce a failed test
Random.seed!(RNG.seed)
let
$(esc(tests))
end
Expand Down Expand Up @@ -1319,7 +1317,6 @@ function testset_forloop(args, testloop, source)
local ts
local RNG = default_rng()
local oldrng = copy(RNG)
Random.seed!(RNG.seed)
local tmprng = copy(RNG)
try
let
Expand Down Expand Up @@ -1790,7 +1787,7 @@ end

"`guardseed(f, seed)` is equivalent to running `Random.seed!(seed); f()` and
then restoring the state of the global RNG as it was before."
guardseed(f::Function, seed::Union{Vector{UInt32},Integer}) = guardseed() do
guardseed(f::Function, seed::Union{Vector{UInt64},Vector{UInt32},Integer}) = guardseed() do
Random.seed!(seed)
f()
end
Expand Down

0 comments on commit db62af9

Please sign in to comment.