Skip to content

Commit

Permalink
factor out common code for Xoshiro and TaskLocalRNG (#51347)
Browse files Browse the repository at this point in the history
This makes more use of `setstate!`, and adds `getstate(rng)` which
returns the 5-tuple `(s0, s1, s2, s3, s4)`.

This is essentially "NFC", but it enables the useless
`copy!(TaskLocalRNG(), TaskLocalRNG())`.
  • Loading branch information
rfourquet authored Sep 18, 2023
1 parent b189bed commit 14119e0
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 119 deletions.
135 changes: 43 additions & 92 deletions stdlib/Random/src/Xoshiro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,53 +51,25 @@ mutable struct Xoshiro <: AbstractRNG
s4::UInt64 # internal splitmix state

Xoshiro(s0::Integer, s1::Integer, s2::Integer, s3::Integer, s4::Integer) = new(s0, s1, s2, s3, s4)
Xoshiro(s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64) = new(s0, s1, s2, s3, 1s0 + 3s1 + 5s2 + 7s3)
Xoshiro(s0::Integer, s1::Integer, s2::Integer, s3::Integer) = initstate!(new(), map(UInt64, (s0, s1, s2, s3)))
Xoshiro(seed=nothing) = seed!(new(), seed)
end

Xoshiro(s0::Integer, s1::Integer, s2::Integer, s3::Integer) = Xoshiro(UInt64(s0), UInt64(s1), UInt64(s2), UInt64(s3))

function setstate!(
x::Xoshiro,
s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64, # xoshiro256 state
s4::UInt64, # internal splitmix state
)
@inline function setstate!(x::Xoshiro, (s0, s1, s2, s3, s4))
x.s0 = s0
x.s1 = s1
x.s2 = s2
x.s3 = s3
x.s4 = s4
if s4 !== nothing
x.s4 = s4
end
x
end

copy(rng::Xoshiro) = Xoshiro(rng.s0, rng.s1, rng.s2, rng.s3, rng.s4)

function copy!(dst::Xoshiro, src::Xoshiro)
dst.s0, dst.s1, dst.s2, dst.s3, dst.s4 = src.s0, src.s1, src.s2, src.s3, src.s4
dst
end

function ==(a::Xoshiro, b::Xoshiro)
a.s0 == b.s0 && a.s1 == b.s1 && a.s2 == b.s2 && a.s3 == b.s3 && a.s4 == b.s4
end
@inline getstate(x::Xoshiro) = (x.s0, x.s1, x.s2, x.s3, x.s4)

rng_native_52(::Xoshiro) = UInt64

@inline function rand(rng::Xoshiro, ::SamplerType{UInt64})
s0, s1, s2, s3 = rng.s0, rng.s1, rng.s2, rng.s3
tmp = s0 + s3
res = ((tmp << 23) | (tmp >> 41)) + s0
t = s1 << 17
s2 = xor(s2, s0)
s3 = xor(s3, s1)
s1 = xor(s1, s2)
s0 = xor(s0, s3)
s2 = xor(s2, t)
s3 = s3 << 45 | s3 >> 19
rng.s0, rng.s1, rng.s2, rng.s3 = s0, s1, s2, s3
res
end


## Task local RNG

Expand All @@ -120,61 +92,70 @@ is undefined behavior: it will work most of the time, and may sometimes fail sil
"""
struct TaskLocalRNG <: AbstractRNG end
TaskLocalRNG(::Nothing) = TaskLocalRNG()
rng_native_52(::TaskLocalRNG) = UInt64

function setstate!(
x::TaskLocalRNG,
s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64, # xoshiro256 state
s4::UInt64, # internal splitmix state
)
@inline function setstate!(x::TaskLocalRNG, (s0, s1, s2, s3, s4))
t = current_task()
t.rngState0 = s0
t.rngState1 = s1
t.rngState2 = s2
t.rngState3 = s3
t.rngState4 = s4
if s4 !== nothing
t.rngState4 = s4
end
x
end

@inline function rand(::TaskLocalRNG, ::SamplerType{UInt64})
task = current_task()
s0, s1, s2, s3 = task.rngState0, task.rngState1, task.rngState2, task.rngState3
tmp = s0 + s3
res = ((tmp << 23) | (tmp >> 41)) + s0
t = s1 << 17
s2 ⊻= s0
s3 ⊻= s1
s1 ⊻= s2
s0 ⊻= s3
s2 ⊻= t
s3 = s3 << 45 | s3 >> 19
task.rngState0, task.rngState1, task.rngState2, task.rngState3 = s0, s1, s2, s3
res
@inline function getstate(::TaskLocalRNG)
t = current_task()
(t.rngState0, t.rngState1, t.rngState2, t.rngState3, t.rngState4)
end

# Shared implementation between Xoshiro and TaskLocalRNG -- seeding
rng_native_52(::TaskLocalRNG) = UInt64


## Shared implementation between Xoshiro and TaskLocalRNG

function seed!(rng::Union{TaskLocalRNG,Xoshiro})
# this variant of setstate! initializes the internal splitmix state, a.k.a. `s4`
@inline initstate!(x::Union{TaskLocalRNG, Xoshiro}, (s0, s1, s2, s3)::NTuple{4, UInt64}) =
setstate!(x, (s0, s1, s2, s3, 1s0 + 3s1 + 5s2 + 7s3))

copy(rng::Union{TaskLocalRNG, Xoshiro}) = Xoshiro(getstate(rng)...)
copy!(dst::Union{TaskLocalRNG, Xoshiro}, src::Union{TaskLocalRNG, Xoshiro}) = setstate!(dst, getstate(src))
==(x::Union{TaskLocalRNG, Xoshiro}, y::Union{TaskLocalRNG, Xoshiro}) = getstate(x) == getstate(y)

function seed!(rng::Union{TaskLocalRNG, Xoshiro})
# as we get good randomness from RandomDevice, we can skip hashing
rd = RandomDevice()
s0 = rand(rd, UInt64)
s1 = rand(rd, UInt64)
s2 = rand(rd, UInt64)
s3 = rand(rd, UInt64)
s4 = 1s0 + 3s1 + 5s2 + 7s3
setstate!(rng, s0, s1, s2, s3, s4)
initstate!(rng, (s0, s1, s2, s3))
end

function seed!(rng::Union{TaskLocalRNG,Xoshiro}, seed::Union{Vector{UInt32}, Vector{UInt64}})
function seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Union{Vector{UInt32}, Vector{UInt64}})
c = SHA.SHA2_256_CTX()
SHA.update!(c, reinterpret(UInt8, seed))
s0, s1, s2, s3 = reinterpret(UInt64, SHA.digest!(c))
s4 = 1s0 + 3s1 + 5s2 + 7s3
setstate!(rng, s0, s1, s2, s3, s4)
initstate!(rng, (s0, s1, s2, s3))
end

seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Integer) = seed!(rng, make_seed(seed))

@inline function rand(x::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{UInt64})
s0, s1, s2, s3 = getstate(x)
tmp = s0 + s3
res = ((tmp << 23) | (tmp >> 41)) + s0
t = s1 << 17
s2 ⊻= s0
s3 ⊻= s1
s1 ⊻= s2
s0 ⊻= s3
s2 ⊻= t
s3 = s3 << 45 | s3 >> 19
setstate!(x, (s0, s1, s2, s3, nothing))
res
end

@inline function rand(rng::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{UInt128})
first = rand(rng, UInt64)
Expand All @@ -191,36 +172,6 @@ end
(rand(rng, UInt64) >>> (64 - 8*sizeof(S))) % S
end

function copy(rng::TaskLocalRNG)
t = current_task()
Xoshiro(t.rngState0, t.rngState1, t.rngState2, t.rngState3, t.rngState4)
end

function copy!(dst::TaskLocalRNG, src::Xoshiro)
t = current_task()
setstate!(dst, src.s0, src.s1, src.s2, src.s3, src.s4)
return dst
end

function copy!(dst::Xoshiro, src::TaskLocalRNG)
t = current_task()
setstate!(dst, t.rngState0, t.rngState1, t.rngState2, t.rngState3, t.rngState4)
return dst
end

function ==(a::Xoshiro, b::TaskLocalRNG)
t = current_task()
(
a.s0 == t.rngState0 &&
a.s1 == t.rngState1 &&
a.s2 == t.rngState2 &&
a.s3 == t.rngState3 &&
a.s4 == t.rngState4
)
end

==(a::TaskLocalRNG, b::Xoshiro) = b == a

# for partial words, use upper bits from Xoshiro

rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{UInt52Raw{UInt64}}) = rand(r, UInt64) >>> 12
Expand Down
35 changes: 8 additions & 27 deletions stdlib/Random/src/XoshiroSimd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

module XoshiroSimd
# Getting the xoroshiro RNG to reliably vectorize is somewhat of a hassle without Simd.jl.
import ..Random: TaskLocalRNG, rand, rand!, Xoshiro, CloseOpen01, UnsafeView,
SamplerType, SamplerTrivial
import ..Random: rand!
using ..Random: TaskLocalRNG, rand, Xoshiro, CloseOpen01, UnsafeView, SamplerType, SamplerTrivial, getstate, setstate!
using Base: BitInteger_types
using Base.Libc: memcpy
using Core.Intrinsics: llvmcall
Expand Down Expand Up @@ -149,14 +149,9 @@ _id(x, T) = x
nothing
end

@noinline function xoshiro_bulk_nosimd(rng::Union{TaskLocalRNG, Xoshiro}, dst::Ptr{UInt8}, len::Int, ::Type{T}, f::F) where {T, F}
if rng isa TaskLocalRNG
task = current_task()
s0, s1, s2, s3 = task.rngState0, task.rngState1, task.rngState2, task.rngState3
else
(; s0, s1, s2, s3) = rng::Xoshiro
end

@noinline function xoshiro_bulk_nosimd(rng::Union{TaskLocalRNG, Xoshiro}, dst::Ptr{UInt8}, len::Int, ::Type{T}, f::F
) where {T, F}
s0, s1, s2, s3 = getstate(rng)
i = 0
while i+8 <= len
res = _plus(_rotl23(_plus(s0,s3)),s0)
Expand All @@ -183,22 +178,12 @@ end
# TODO: This may make the random-stream dependent on system endianness
GC.@preserve ref memcpy(dst+i, Base.unsafe_convert(Ptr{Cvoid}, ref), len-i)
end
if rng isa TaskLocalRNG
task.rngState0, task.rngState1, task.rngState2, task.rngState3 = s0, s1, s2, s3
else
rng.s0, rng.s1, rng.s2, rng.s3 = s0, s1, s2, s3
end
setstate!(rng, (s0, s1, s2, s3, nothing))
nothing
end

@noinline function xoshiro_bulk_nosimd(rng::Union{TaskLocalRNG, Xoshiro}, dst::Ptr{UInt8}, len::Int, ::Type{Bool}, f)
if rng isa TaskLocalRNG
task = current_task()
s0, s1, s2, s3 = task.rngState0, task.rngState1, task.rngState2, task.rngState3
else
(; s0, s1, s2, s3) = rng::Xoshiro
end

s0, s1, s2, s3 = getstate(rng)
i = 0
while i+8 <= len
res = _plus(_rotl23(_plus(s0,s3)),s0)
Expand Down Expand Up @@ -232,11 +217,7 @@ end
s2 = _xor(s2, t)
s3 = _rotl45(s3)
end
if rng isa TaskLocalRNG
task.rngState0, task.rngState1, task.rngState2, task.rngState3 = s0, s1, s2, s3
else
rng.s0, rng.s1, rng.s2, rng.s3 = s0, s1, s2, s3
end
setstate!(rng, (s0, s1, s2, s3, nothing))
nothing
end

Expand Down

0 comments on commit 14119e0

Please sign in to comment.