From 41b41abb3ce87dc22fca70475d8a46e19feba4de Mon Sep 17 00:00:00 2001 From: Nathan Zimmerberg <39104088+nhz2@users.noreply.github.com> Date: Fri, 15 Sep 2023 18:28:26 -0400 Subject: [PATCH] [Random] Add s4 field to Xoshiro type (#51332) This PR adds an optional field to the existing `Xoshiro` struct to be able to faithfully copy the task-local RNG state. Fixes #51255 Redo of #51271 Background context: #49110 added an additional state to the task-local RNG. However, before this PR `copy(default_rng())` did not include this extra state, causing subtle errors in `Test` where `copy(default_rng())` is assumed to contain the full task-local RNG state. --- stdlib/Random/src/Xoshiro.jl | 45 ++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/stdlib/Random/src/Xoshiro.jl b/stdlib/Random/src/Xoshiro.jl index 3be276ad23754..a25e2c1077e04 100644 --- a/stdlib/Random/src/Xoshiro.jl +++ b/stdlib/Random/src/Xoshiro.jl @@ -48,28 +48,37 @@ mutable struct Xoshiro <: AbstractRNG s1::UInt64 s2::UInt64 s3::UInt64 + s4::UInt64 # internal splitmix state - Xoshiro(s0::Integer, s1::Integer, s2::Integer, s3::Integer) = new(s0, s1, s2, s3) + Xoshiro(s0::Integer, s1::Integer, s2::Integer, s3::Integer, s4::Integer) = new(s0, s1, s2, s3, s4) + Xoshiro(s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64) = new(s0, s1, s2, s3, 1s0 + 3s1 + 5s2 + 7s3) Xoshiro(seed=nothing) = seed!(new(), seed) end -function setstate!(x::Xoshiro, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64) +Xoshiro(s0::Integer, s1::Integer, s2::Integer, s3::Integer) = Xoshiro(UInt64(s0), UInt64(s1), UInt64(s2), UInt64(s3)) + +function setstate!( + x::Xoshiro, + s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64, # xoshiro256 state + s4::UInt64, # internal splitmix state +) x.s0 = s0 x.s1 = s1 x.s2 = s2 x.s3 = s3 + x.s4 = s4 x end -copy(rng::Xoshiro) = Xoshiro(rng.s0, rng.s1, rng.s2, rng.s3) +copy(rng::Xoshiro) = Xoshiro(rng.s0, rng.s1, rng.s2, rng.s3, rng.s4) function copy!(dst::Xoshiro, src::Xoshiro) - dst.s0, dst.s1, dst.s2, dst.s3 = src.s0, src.s1, src.s2, src.s3 + dst.s0, dst.s1, dst.s2, dst.s3, dst.s4 = src.s0, src.s1, src.s2, src.s3, src.s4 dst end function ==(a::Xoshiro, b::Xoshiro) - a.s0 == b.s0 && a.s1 == b.s1 && a.s2 == b.s2 && a.s3 == b.s3 + a.s0 == b.s0 && a.s1 == b.s1 && a.s2 == b.s2 && a.s3 == b.s3 && a.s4 == b.s4 end rng_native_52(::Xoshiro) = UInt64 @@ -116,7 +125,7 @@ rng_native_52(::TaskLocalRNG) = UInt64 function setstate!( x::TaskLocalRNG, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64, # xoshiro256 state - s4::UInt64 = 1s0 + 3s1 + 5s2 + 7s3, # internal splitmix state + s4::UInt64, # internal splitmix state ) t = current_task() t.rngState0 = s0 @@ -148,14 +157,20 @@ end function seed!(rng::Union{TaskLocalRNG,Xoshiro}) # as we get good randomness from RandomDevice, we can skip hashing rd = RandomDevice() - setstate!(rng, rand(rd, UInt64), rand(rd, UInt64), rand(rd, UInt64), rand(rd, UInt64)) + s0 = rand(rd, UInt64) + s1 = rand(rd, UInt64) + s2 = rand(rd, UInt64) + s3 = rand(rd, UInt64) + s4 = 1s0 + 3s1 + 5s2 + 7s3 + setstate!(rng, s0, s1, s2, s3, s4) end function seed!(rng::Union{TaskLocalRNG,Xoshiro}, seed::Union{Vector{UInt32}, Vector{UInt64}}) c = SHA.SHA2_256_CTX() SHA.update!(c, reinterpret(UInt8, seed)) s0, s1, s2, s3 = reinterpret(UInt64, SHA.digest!(c)) - setstate!(rng, s0, s1, s2, s3) + s4 = 1s0 + 3s1 + 5s2 + 7s3 + setstate!(rng, s0, s1, s2, s3, s4) end seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Integer) = seed!(rng, make_seed(seed)) @@ -178,24 +193,30 @@ end function copy(rng::TaskLocalRNG) t = current_task() - Xoshiro(t.rngState0, t.rngState1, t.rngState2, t.rngState3) + Xoshiro(t.rngState0, t.rngState1, t.rngState2, t.rngState3, t.rngState4) end function copy!(dst::TaskLocalRNG, src::Xoshiro) t = current_task() - setstate!(dst, src.s0, src.s1, src.s2, src.s3) + setstate!(dst, src.s0, src.s1, src.s2, src.s3, src.s4) return dst end function copy!(dst::Xoshiro, src::TaskLocalRNG) t = current_task() - setstate!(dst, t.rngState0, t.rngState1, t.rngState2, t.rngState3) + setstate!(dst, t.rngState0, t.rngState1, t.rngState2, t.rngState3, t.rngState4) return dst end function ==(a::Xoshiro, b::TaskLocalRNG) t = current_task() - a.s0 == t.rngState0 && a.s1 == t.rngState1 && a.s2 == t.rngState2 && a.s3 == t.rngState3 + ( + a.s0 == t.rngState0 && + a.s1 == t.rngState1 && + a.s2 == t.rngState2 && + a.s3 == t.rngState3 && + a.s4 == t.rngState4 + ) end ==(a::TaskLocalRNG, b::Xoshiro) = b == a