-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
77a4c60
commit 3cbaab3
Showing
5 changed files
with
391 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,260 @@ | ||
# Copied from CUDA.jl/src/device/random.jl | ||
|
||
|
||
## random number generation | ||
|
||
using Random | ||
import RandomNumbers | ||
|
||
|
||
# global state | ||
|
||
# shared memory with the actual seed, per warp, loaded lazily or overridden by calling `seed!` | ||
@eval @inline function global_random_keys() | ||
ptr = Base.llvmcall( | ||
$("""@__zeroinit_global_random_keys = external addrspace($(AS.Local)) global [32 x i32], align 32 | ||
define i8 addrspace($(AS.Local))* @entry() #0 { | ||
%ptr = getelementptr inbounds [32 x i32], [32 x i32] addrspace($(AS.Local))* @__zeroinit_global_random_keys, i64 0, i64 0 | ||
%untyped_ptr = bitcast i32 addrspace($(AS.Local))* %ptr to i8 addrspace($(AS.Local))* | ||
ret i8 addrspace($(AS.Local))* %untyped_ptr | ||
} | ||
attributes #0 = { alwaysinline } | ||
""", "entry"), LLVMPtr{UInt32, AS.Local}, Tuple{}) | ||
ROCDeviceArray{UInt32,1,AS.Local}((32,), ptr) | ||
end | ||
|
||
# shared memory with per-warp counters, incremented when generating numbers | ||
@eval @inline function global_random_counters() | ||
ptr = Base.llvmcall( | ||
$("""@__zeroinit_global_random_counters = external addrspace($(AS.Local)) global [32 x i32], align 32 | ||
define i8 addrspace($(AS.Local))* @entry() #0 { | ||
%ptr = getelementptr inbounds [32 x i32], [32 x i32] addrspace($(AS.Local))* @__zeroinit_global_random_counters, i64 0, i64 0 | ||
%untyped_ptr = bitcast i32 addrspace($(AS.Local))* %ptr to i8 addrspace($(AS.Local))* | ||
ret i8 addrspace($(AS.Local))* %untyped_ptr | ||
} | ||
attributes #0 = { alwaysinline } | ||
""", "entry"), LLVMPtr{UInt32, AS.Local}, Tuple{}) | ||
ROCDeviceArray{UInt32,1,AS.Local}((32,), ptr) | ||
end | ||
|
||
@device_override Random.make_seed() = Base.unsafe_trunc(UInt32, memrealtime()) | ||
|
||
|
||
# generators | ||
|
||
using Random123: philox2x_round, philox2x_bumpkey | ||
|
||
# GPU-compatible/optimized version of the generator from Random123.jl | ||
struct Philox2x32{R} <: RandomNumbers.AbstractRNG{UInt64} | ||
@inline function Philox2x32{R}() where R | ||
rng = new{R}() | ||
if rng.key == 0 | ||
# initialize the key. this happens when first accessing the (0-initialized) | ||
# shared memory key from each block. if we ever want to make the device seed | ||
# controlable from the host, this would be the place to read a global seed. | ||
# | ||
# note however that it is undefined how shared memory persists across e.g. | ||
# launches, so we may not be able to rely on the zero initalization then. | ||
rng.key = Random.make_seed() | ||
end | ||
return rng | ||
end | ||
end | ||
|
||
# default to 7 rounds; enough to pass SmallCrush | ||
@inline Philox2x32() = Philox2x32{7}() | ||
|
||
@inline function Base.getproperty(rng::Philox2x32, field::Symbol) | ||
threadId = workitemIdx().x + (workitemIdx().y - Int32(1)) * workgroupDim().x + | ||
(workitemIdx().z - Int32(1)) * workgroupDim().x * workgroupDim().y | ||
warpId = (threadId - Int32(1)) >> 0x5 + Int32(1) # fld1 | ||
|
||
if field === :seed | ||
@inbounds global_random_seed()[1] | ||
elseif field === :key | ||
@inbounds global_random_keys()[warpId] | ||
elseif field === :ctr1 | ||
@inbounds global_random_counters()[warpId] | ||
elseif field === :ctr2 | ||
blockId = workgroupIdx().x + (workgroupIdx().y - Int32(1)) * gridGroupDim().x + | ||
(workgroupIdx().z - Int32(1)) * gridGroupDim().x * gridGroupDim().y | ||
globalId = threadId + (blockId - Int32(1)) * (workgroupDim().x * workgroupDim().y * workgroupDim().z) | ||
globalId%UInt32 | ||
end::UInt32 | ||
end | ||
|
||
@inline function Base.setproperty!(rng::Philox2x32, field::Symbol, x) | ||
threadId = workitemIdx().x + (workitemIdx().y - Int32(1)) * workgroupDim().x + | ||
(workitemIdx().z - Int32(1)) * workgroupDim().x * workgroupDim().y | ||
warpId = (threadId - Int32(1)) >> 0x5 + Int32(1) # fld1 | ||
|
||
if field === :key | ||
@inbounds global_random_keys()[warpId] = x | ||
elseif field === :ctr1 | ||
@inbounds global_random_counters()[warpId] = x | ||
end | ||
end | ||
|
||
@device_override @inline Random.default_rng() = Philox2x32() | ||
|
||
""" | ||
Random.seed!(rng::Philox2x32, seed::Integer, [counter::Integer=0]) | ||
Seed the on-device Philox2x32 generator with an UInt32 number. | ||
Should be called by at least one thread per warp. | ||
""" | ||
function Random.seed!(rng::Philox2x32, seed::Integer, counter::Integer=0) | ||
rng.key = seed % UInt32 | ||
rng.ctr1 = counter | ||
return | ||
end | ||
|
||
if VERSION >= v"1.7-" | ||
@device_override Random.seed!(::Random._GLOBAL_RNG, seed) = | ||
Random.seed!(Random.default_rng(), seed) | ||
end | ||
|
||
""" | ||
Random.rand(rng::Philox2x32, UInt32) | ||
Generate a byte of random data using the on-device Tausworthe generator. | ||
""" | ||
function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R} | ||
ctr1, ctr2, key = rng.ctr1, rng.ctr2, rng.key | ||
|
||
if R > 0 ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end | ||
if R > 1 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end | ||
if R > 2 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end | ||
if R > 3 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end | ||
if R > 4 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end | ||
if R > 5 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end | ||
if R > 6 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end | ||
if R > 7 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end | ||
if R > 8 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end | ||
if R > 9 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end | ||
if R > 10 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end | ||
if R > 11 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end | ||
if R > 12 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end | ||
if R > 13 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end | ||
if R > 14 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end | ||
if R > 15 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end | ||
|
||
# update the warp counter | ||
# NOTE: this performs the same update on every thread in the warp, but each warp writes | ||
# to a unique location so the duplicate writes are innocuous | ||
# XXX: what if this overflows? we can't increment ctr2. bump the key? | ||
rng.ctr1 += Int32(1) | ||
|
||
# NOTE: it's too expensive to keep both numbers around in case the user only wanted one, | ||
# so just make our 2x32 generator return 64-bit numbers by default. | ||
return (ctr1 % UInt64) << 32 | (ctr2 % UInt64) | ||
end | ||
|
||
|
||
|
||
# normally distributed random numbers using Ziggurat algorithm | ||
# | ||
# copied from Base because we don't support its global tables | ||
|
||
# a hacky method of exposing constant tables as constant GPU memory | ||
function emit_constant_array(name::Symbol, data::AbstractArray{T}) where {T} | ||
@dispose ctx=Context() begin | ||
T_val = convert(LLVMType, T; ctx) | ||
T_ptr = convert(LLVMType, LLVMPtr{T,AS.Constant}; ctx) | ||
|
||
# define function and get LLVM module | ||
llvm_f, _ = create_function(T_ptr) | ||
mod = LLVM.parent(llvm_f) | ||
|
||
# create a global memory global variable | ||
# TODO: global_var alignment? | ||
T_global = LLVM.ArrayType(T_val, length(data)) | ||
# XXX: why can't we use a single name like emit_shmem | ||
gv = GlobalVariable(mod, T_global, "gpu_$(name)_data", AS.Constant) | ||
linkage!(gv, LLVM.API.LLVMInternalLinkage) | ||
initializer!(gv, ConstantArray(data; ctx)) | ||
|
||
# generate IR | ||
@dispose builder=Builder(ctx) begin | ||
entry = BasicBlock(llvm_f, "entry"; ctx) | ||
position!(builder, entry) | ||
|
||
ptr = gep!(builder, gv, [ConstantInt(0; ctx), ConstantInt(0; ctx)]) | ||
|
||
untyped_ptr = bitcast!(builder, ptr, T_ptr) | ||
|
||
ret!(builder, untyped_ptr) | ||
end | ||
|
||
call_function(llvm_f, LLVMPtr{T,AS.Constant}) | ||
end | ||
end | ||
|
||
for var in [:ki, :wi, :fi, :ke, :we, :fe] | ||
val = getfield(Random, var) | ||
gpu_var = Symbol("gpu_$var") | ||
arr_typ = :(ROCDeviceArray{$(eltype(val)),$(ndims(val)),AS.Constant}) | ||
@eval @inline @generated function $gpu_var() | ||
ptr = emit_constant_array($(QuoteNode(var)), $val) | ||
Expr(:call, $arr_typ, $(size(val)), ptr) | ||
end | ||
end | ||
|
||
## randn | ||
|
||
@device_override @inline Random.randn(rng::AbstractRNG) = | ||
_randn(rng, Random.rand(rng, Random.UInt52Raw())) | ||
|
||
@inline function _randn(rng::AbstractRNG, r::UInt64) | ||
@inbounds begin | ||
r &= 0x000fffffffffffff | ||
rabs = Int64(r>>1) # One bit for the sign | ||
idx = rabs & 0xFF | ||
x = ifelse(r % Bool, -rabs, rabs)*gpu_wi()[idx+1] | ||
rabs < gpu_ki()[idx+1] && return x # 99.3% of the time we return here 1st try | ||
return randn_unlikely(rng, idx, rabs, x) | ||
end | ||
end | ||
|
||
# this unlikely branch is put in a separate function for better efficiency | ||
# FIXME: we can't do @noinline because this accesses LDS | ||
@inline function randn_unlikely(rng, idx, rabs, x) | ||
@inbounds if idx == 0 | ||
while true | ||
xx = -Random.ziggurat_nor_inv_r*log(Random.rand(rng)) | ||
yy = -log(Random.rand(rng)) | ||
yy+yy > xx*xx && | ||
return (rabs >> 8) % Bool ? -Random.ziggurat_nor_r-xx : Random.ziggurat_nor_r+xx | ||
end | ||
elseif (gpu_fi()[idx] - gpu_fi()[idx+1])*Random.rand(rng) + gpu_fi()[idx+1] < exp(-0.5*x*x) | ||
return x # return from the triangular area | ||
else | ||
return Random.randn(rng) | ||
end | ||
end | ||
|
||
## randexp | ||
|
||
@device_override Random.randexp(rng::AbstractRNG) = | ||
_randexp(rng, Random.rand(rng, Random.UInt52Raw())) | ||
|
||
function _randexp(rng::AbstractRNG, ri::UInt64) | ||
@inbounds begin | ||
ri &= 0x000fffffffffffff | ||
idx = ri & 0xFF | ||
x = ri*gpu_we()[idx+1] | ||
ri < gpu_ke()[idx+1] && return x # 98.9% of the time we return here 1st try | ||
return randexp_unlikely(rng, idx, x) | ||
end | ||
end | ||
|
||
# FIXME: we can't do @noinline because this accesses LDS | ||
@inline function randexp_unlikely(rng, idx, x) | ||
@inbounds if idx == 0 | ||
return Random.ziggurat_exp_r - log(Random.rand(rng)) | ||
elseif (gpu_fe()[idx] - gpu_fe()[idx+1])*Random.rand(rng) + gpu_fe()[idx+1] < exp(-x) | ||
return x # return from the triangular area | ||
else | ||
return Random.randexp(rng) | ||
end | ||
end |
Oops, something went wrong.