From 9ab01111aea6cd90cb4883ac25288356c6dee8fb Mon Sep 17 00:00:00 2001 From: Rafael Fourquet Date: Sun, 6 Jun 2021 12:12:10 +0200 Subject: [PATCH] add Xoshiro(seed) constructor, and extend some tests to Xoshiro Usually, a seed should be equally valid for an RNG constructor or for a call to `seed!`, as `seed!`'s docstring mentions: > After the call to seed!, rng is equivalent to a newly created object > initialized with the same seed. --- stdlib/Random/src/Xoshiro.jl | 3 +++ stdlib/Random/test/runtests.jl | 30 ++++++++++++++++++------------ 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/stdlib/Random/src/Xoshiro.jl b/stdlib/Random/src/Xoshiro.jl index a55c9688f9dda..c1c37b3c1c134 100644 --- a/stdlib/Random/src/Xoshiro.jl +++ b/stdlib/Random/src/Xoshiro.jl @@ -23,6 +23,9 @@ mutable struct Xoshiro <: AbstractRNG s1::UInt64 s2::UInt64 s3::UInt64 + + Xoshiro(s0::Integer, s1::Integer, s2::Integer, s3::Integer) = new(s0, s1, s2, s3) + Xoshiro(seed) = seed!(new(), seed) end Xoshiro(::Nothing) = Xoshiro() diff --git a/stdlib/Random/test/runtests.jl b/stdlib/Random/test/runtests.jl index 1d591ce7493b9..2c1230c0753d3 100644 --- a/stdlib/Random/test/runtests.jl +++ b/stdlib/Random/test/runtests.jl @@ -308,7 +308,7 @@ let a = [rand(RandomDevice(), UInt128) for i=1:10] end # test all rand APIs -for rng in ([], [MersenneTwister(0)], [RandomDevice()]) +for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()]) ftypes = [Float16, Float32, Float64] cftypes = [ComplexF16, ComplexF32, ComplexF64, ftypes...] types = [Bool, Char, BigFloat, Base.BitInteger_types..., ftypes...] @@ -433,7 +433,7 @@ function hist(X, n) end # test uniform distribution of floats -for rng in [MersenneTwister(), RandomDevice()], +for rng in [MersenneTwister(), RandomDevice(), Xoshiro()], T in [Float16, Float32, Float64, BigFloat], prec in (T == BigFloat ? [3, 53, 64, 100, 256, 1000] : [256]) setprecision(BigFloat, prec) do @@ -454,7 +454,7 @@ end # but also for 3 linear combinations of positions (for the array version) lcs = unique!.([rand(1:n, 2), rand(1:n, 3), rand(1:n, 5)]) aslcs = zeros(Int, 3) - for rng = (MersenneTwister(), RandomDevice()) + for rng = (MersenneTwister(), RandomDevice(), Xoshiro()) for scalar = [false, true] fill!(a, 0) fill!(as, 0) @@ -478,8 +478,8 @@ end end end -# test reproducility of methods -let mta = MersenneTwister(42), mtb = MersenneTwister(42) +@testset "reproducility of methods for $RNG" for RNG=(MersenneTwister,Xoshiro) + mta, mtb = RNG(42), RNG(42) @test rand(mta) == rand(mtb) @test rand(mta,10) == rand(mtb,10) @@ -664,7 +664,7 @@ end # this shouldn't crash (#22403) @test_throws ArgumentError rand!(Union{UInt,Int}[1, 2, 3]) -@testset "$RNG() & Random.seed!(rng::$RNG) initializes randomly" for RNG in (MersenneTwister, RandomDevice) +@testset "$RNG() & Random.seed!(rng::$RNG) initializes randomly" for RNG in (MersenneTwister, RandomDevice, Xoshiro) m = RNG() a = rand(m, Int) m = RNG() @@ -685,11 +685,17 @@ end @test rand(m, Int) ∉ (a, b, c, d) end -@testset "MersenneTwister($seed) & Random.seed!(m::MersenneTwister, $seed) produce the same stream" for seed in [0:5; 10000:10005] - m = MersenneTwister(seed) - a = [rand(m) for _=1:100] - Random.seed!(m, seed) - @test a == [rand(m) for _=1:100] +@testset "$RNG(seed) & Random.seed!(m::$RNG, seed) produce the same stream" for RNG=(MersenneTwister,Xoshiro) + seeds = Any[0, 1, 2, 10000, 10001, rand(UInt32, 8), rand(UInt128, 3)...] + if RNG == Xoshiro + push!(seeds, rand(UInt64, rand(1:4)), Tuple(rand(UInt64, 4))) + end + for seed=seeds + m = RNG(seed) + a = [rand(m) for _=1:100] + Random.seed!(m, seed) + @test a == [rand(m) for _=1:100] + end end struct RandomStruct23964 end @@ -698,7 +704,7 @@ struct RandomStruct23964 end @test_throws ArgumentError rand(RandomStruct23964()) end -@testset "rand(::$(typeof(RNG)), ::UnitRange{$T}" for RNG ∈ (MersenneTwister(rand(UInt128)), RandomDevice()), +@testset "rand(::$(typeof(RNG)), ::UnitRange{$T}" for RNG ∈ (MersenneTwister(rand(UInt128)), RandomDevice(), Xoshiro()), T ∈ (Int8, Int16, Int32, UInt32, Int64, Int128, UInt128) for S in (SamplerRangeInt, SamplerRangeFast, SamplerRangeNDL) S == SamplerRangeNDL && sizeof(T) > 8 && continue