Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Xoshiro(seed) constructor, and extend some tests to Xoshiro #41105

Merged
merged 1 commit into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions stdlib/Random/src/Xoshiro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ mutable struct Xoshiro <: AbstractRNG
s1::UInt64
s2::UInt64
s3::UInt64

Xoshiro(s0::Integer, s1::Integer, s2::Integer, s3::Integer) = new(s0, s1, s2, s3)
Xoshiro(seed) = seed!(new(), seed)
end

Xoshiro(::Nothing) = Xoshiro()
Expand Down
30 changes: 18 additions & 12 deletions stdlib/Random/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ let a = [rand(RandomDevice(), UInt128) for i=1:10]
end

# test all rand APIs
for rng in ([], [MersenneTwister(0)], [RandomDevice()])
for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()])
ftypes = [Float16, Float32, Float64]
cftypes = [ComplexF16, ComplexF32, ComplexF64, ftypes...]
types = [Bool, Char, BigFloat, Base.BitInteger_types..., ftypes...]
Expand Down Expand Up @@ -433,7 +433,7 @@ function hist(X, n)
end

# test uniform distribution of floats
for rng in [MersenneTwister(), RandomDevice()],
for rng in [MersenneTwister(), RandomDevice(), Xoshiro()],
T in [Float16, Float32, Float64, BigFloat],
prec in (T == BigFloat ? [3, 53, 64, 100, 256, 1000] : [256])
setprecision(BigFloat, prec) do
Expand All @@ -454,7 +454,7 @@ end
# but also for 3 linear combinations of positions (for the array version)
lcs = unique!.([rand(1:n, 2), rand(1:n, 3), rand(1:n, 5)])
aslcs = zeros(Int, 3)
for rng = (MersenneTwister(), RandomDevice())
for rng = (MersenneTwister(), RandomDevice(), Xoshiro())
for scalar = [false, true]
fill!(a, 0)
fill!(as, 0)
Expand All @@ -478,8 +478,8 @@ end
end
end

# test reproducility of methods
let mta = MersenneTwister(42), mtb = MersenneTwister(42)
@testset "reproducility of methods for $RNG" for RNG=(MersenneTwister,Xoshiro)
mta, mtb = RNG(42), RNG(42)

@test rand(mta) == rand(mtb)
@test rand(mta,10) == rand(mtb,10)
Expand Down Expand Up @@ -664,7 +664,7 @@ end
# this shouldn't crash (#22403)
@test_throws ArgumentError rand!(Union{UInt,Int}[1, 2, 3])

@testset "$RNG() & Random.seed!(rng::$RNG) initializes randomly" for RNG in (MersenneTwister, RandomDevice)
@testset "$RNG() & Random.seed!(rng::$RNG) initializes randomly" for RNG in (MersenneTwister, RandomDevice, Xoshiro)
m = RNG()
a = rand(m, Int)
m = RNG()
Expand All @@ -685,11 +685,17 @@ end
@test rand(m, Int) ∉ (a, b, c, d)
end

@testset "MersenneTwister($seed) & Random.seed!(m::MersenneTwister, $seed) produce the same stream" for seed in [0:5; 10000:10005]
m = MersenneTwister(seed)
a = [rand(m) for _=1:100]
Random.seed!(m, seed)
@test a == [rand(m) for _=1:100]
@testset "$RNG(seed) & Random.seed!(m::$RNG, seed) produce the same stream" for RNG=(MersenneTwister,Xoshiro)
seeds = Any[0, 1, 2, 10000, 10001, rand(UInt32, 8), rand(UInt128, 3)...]
if RNG == Xoshiro
push!(seeds, rand(UInt64, rand(1:4)), Tuple(rand(UInt64, 4)))
end
for seed=seeds
m = RNG(seed)
a = [rand(m) for _=1:100]
Random.seed!(m, seed)
@test a == [rand(m) for _=1:100]
end
end

struct RandomStruct23964 end
Expand All @@ -698,7 +704,7 @@ struct RandomStruct23964 end
@test_throws ArgumentError rand(RandomStruct23964())
end

@testset "rand(::$(typeof(RNG)), ::UnitRange{$T}" for RNG ∈ (MersenneTwister(rand(UInt128)), RandomDevice()),
@testset "rand(::$(typeof(RNG)), ::UnitRange{$T}" for RNG ∈ (MersenneTwister(rand(UInt128)), RandomDevice(), Xoshiro()),
T ∈ (Int8, Int16, Int32, UInt32, Int64, Int128, UInt128)
for S in (SamplerRangeInt, SamplerRangeFast, SamplerRangeNDL)
S == SamplerRangeNDL && sizeof(T) > 8 && continue
Expand Down