Skip to content

Commit

Permalink
Add device/random.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsh530 authored and jpsamaroo committed Mar 29, 2023
1 parent 77a4c60 commit 3cbaab3
Show file tree
Hide file tree
Showing 5 changed files with 391 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ROCmDeviceLibs_jll = "873c0968-716b-5aa7-bb8d-d1e2e2aeff2d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Expand Down
1 change: 1 addition & 0 deletions src/AMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ module Device
include(joinpath("device", "gcn.jl"))
include(joinpath("device", "runtime.jl"))
include(joinpath("device", "quirks.jl"))
include(joinpath("device", "random.jl"))
end
import .Device: malloc, signal_exception, report_exception, report_oom, report_exception_frame
import .Device: ROCDeviceArray, AS, HostCall, hostcall!
Expand Down
260 changes: 260 additions & 0 deletions src/device/random.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
# Copied from CUDA.jl/src/device/random.jl


## random number generation

using Random
import RandomNumbers


# global state

# shared memory with the actual seed, per warp, loaded lazily or overridden by calling `seed!`
@eval @inline function global_random_keys()
ptr = Base.llvmcall(
$("""@__zeroinit_global_random_keys = external addrspace($(AS.Local)) global [32 x i32], align 32
define i8 addrspace($(AS.Local))* @entry() #0 {
%ptr = getelementptr inbounds [32 x i32], [32 x i32] addrspace($(AS.Local))* @__zeroinit_global_random_keys, i64 0, i64 0
%untyped_ptr = bitcast i32 addrspace($(AS.Local))* %ptr to i8 addrspace($(AS.Local))*
ret i8 addrspace($(AS.Local))* %untyped_ptr
}
attributes #0 = { alwaysinline }
""", "entry"), LLVMPtr{UInt32, AS.Local}, Tuple{})
ROCDeviceArray{UInt32,1,AS.Local}((32,), ptr)
end

# shared memory with per-warp counters, incremented when generating numbers
@eval @inline function global_random_counters()
ptr = Base.llvmcall(
$("""@__zeroinit_global_random_counters = external addrspace($(AS.Local)) global [32 x i32], align 32
define i8 addrspace($(AS.Local))* @entry() #0 {
%ptr = getelementptr inbounds [32 x i32], [32 x i32] addrspace($(AS.Local))* @__zeroinit_global_random_counters, i64 0, i64 0
%untyped_ptr = bitcast i32 addrspace($(AS.Local))* %ptr to i8 addrspace($(AS.Local))*
ret i8 addrspace($(AS.Local))* %untyped_ptr
}
attributes #0 = { alwaysinline }
""", "entry"), LLVMPtr{UInt32, AS.Local}, Tuple{})
ROCDeviceArray{UInt32,1,AS.Local}((32,), ptr)
end

@device_override Random.make_seed() = Base.unsafe_trunc(UInt32, memrealtime())


# generators

using Random123: philox2x_round, philox2x_bumpkey

# GPU-compatible/optimized version of the generator from Random123.jl
struct Philox2x32{R} <: RandomNumbers.AbstractRNG{UInt64}
@inline function Philox2x32{R}() where R
rng = new{R}()
if rng.key == 0
# initialize the key. this happens when first accessing the (0-initialized)
# shared memory key from each block. if we ever want to make the device seed
# controlable from the host, this would be the place to read a global seed.
#
# note however that it is undefined how shared memory persists across e.g.
# launches, so we may not be able to rely on the zero initalization then.
rng.key = Random.make_seed()
end
return rng
end
end

# default to 7 rounds; enough to pass SmallCrush
@inline Philox2x32() = Philox2x32{7}()

@inline function Base.getproperty(rng::Philox2x32, field::Symbol)
threadId = workitemIdx().x + (workitemIdx().y - Int32(1)) * workgroupDim().x +
(workitemIdx().z - Int32(1)) * workgroupDim().x * workgroupDim().y
warpId = (threadId - Int32(1)) >> 0x5 + Int32(1) # fld1

if field === :seed
@inbounds global_random_seed()[1]
elseif field === :key
@inbounds global_random_keys()[warpId]
elseif field === :ctr1
@inbounds global_random_counters()[warpId]
elseif field === :ctr2
blockId = workgroupIdx().x + (workgroupIdx().y - Int32(1)) * gridGroupDim().x +
(workgroupIdx().z - Int32(1)) * gridGroupDim().x * gridGroupDim().y
globalId = threadId + (blockId - Int32(1)) * (workgroupDim().x * workgroupDim().y * workgroupDim().z)
globalId%UInt32
end::UInt32
end

@inline function Base.setproperty!(rng::Philox2x32, field::Symbol, x)
threadId = workitemIdx().x + (workitemIdx().y - Int32(1)) * workgroupDim().x +
(workitemIdx().z - Int32(1)) * workgroupDim().x * workgroupDim().y
warpId = (threadId - Int32(1)) >> 0x5 + Int32(1) # fld1

if field === :key
@inbounds global_random_keys()[warpId] = x
elseif field === :ctr1
@inbounds global_random_counters()[warpId] = x
end
end

@device_override @inline Random.default_rng() = Philox2x32()

"""
Random.seed!(rng::Philox2x32, seed::Integer, [counter::Integer=0])
Seed the on-device Philox2x32 generator with an UInt32 number.
Should be called by at least one thread per warp.
"""
function Random.seed!(rng::Philox2x32, seed::Integer, counter::Integer=0)
rng.key = seed % UInt32
rng.ctr1 = counter
return
end

if VERSION >= v"1.7-"
@device_override Random.seed!(::Random._GLOBAL_RNG, seed) =
Random.seed!(Random.default_rng(), seed)
end

"""
Random.rand(rng::Philox2x32, UInt32)
Generate a byte of random data using the on-device Tausworthe generator.
"""
function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R}
ctr1, ctr2, key = rng.ctr1, rng.ctr2, rng.key

if R > 0 ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
if R > 1 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
if R > 2 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
if R > 3 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
if R > 4 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
if R > 5 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
if R > 6 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
if R > 7 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
if R > 8 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
if R > 9 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
if R > 10 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
if R > 11 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
if R > 12 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
if R > 13 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
if R > 14 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
if R > 15 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end

# update the warp counter
# NOTE: this performs the same update on every thread in the warp, but each warp writes
# to a unique location so the duplicate writes are innocuous
# XXX: what if this overflows? we can't increment ctr2. bump the key?
rng.ctr1 += Int32(1)

# NOTE: it's too expensive to keep both numbers around in case the user only wanted one,
# so just make our 2x32 generator return 64-bit numbers by default.
return (ctr1 % UInt64) << 32 | (ctr2 % UInt64)
end



# normally distributed random numbers using Ziggurat algorithm
#
# copied from Base because we don't support its global tables

# a hacky method of exposing constant tables as constant GPU memory
function emit_constant_array(name::Symbol, data::AbstractArray{T}) where {T}
@dispose ctx=Context() begin
T_val = convert(LLVMType, T; ctx)
T_ptr = convert(LLVMType, LLVMPtr{T,AS.Constant}; ctx)

# define function and get LLVM module
llvm_f, _ = create_function(T_ptr)
mod = LLVM.parent(llvm_f)

# create a global memory global variable
# TODO: global_var alignment?
T_global = LLVM.ArrayType(T_val, length(data))
# XXX: why can't we use a single name like emit_shmem
gv = GlobalVariable(mod, T_global, "gpu_$(name)_data", AS.Constant)
linkage!(gv, LLVM.API.LLVMInternalLinkage)
initializer!(gv, ConstantArray(data; ctx))

# generate IR
@dispose builder=Builder(ctx) begin
entry = BasicBlock(llvm_f, "entry"; ctx)
position!(builder, entry)

ptr = gep!(builder, gv, [ConstantInt(0; ctx), ConstantInt(0; ctx)])

untyped_ptr = bitcast!(builder, ptr, T_ptr)

ret!(builder, untyped_ptr)
end

call_function(llvm_f, LLVMPtr{T,AS.Constant})
end
end

for var in [:ki, :wi, :fi, :ke, :we, :fe]
val = getfield(Random, var)
gpu_var = Symbol("gpu_$var")
arr_typ = :(ROCDeviceArray{$(eltype(val)),$(ndims(val)),AS.Constant})
@eval @inline @generated function $gpu_var()
ptr = emit_constant_array($(QuoteNode(var)), $val)
Expr(:call, $arr_typ, $(size(val)), ptr)
end
end

## randn

@device_override @inline Random.randn(rng::AbstractRNG) =
_randn(rng, Random.rand(rng, Random.UInt52Raw()))

@inline function _randn(rng::AbstractRNG, r::UInt64)
@inbounds begin
r &= 0x000fffffffffffff
rabs = Int64(r>>1) # One bit for the sign
idx = rabs & 0xFF
x = ifelse(r % Bool, -rabs, rabs)*gpu_wi()[idx+1]
rabs < gpu_ki()[idx+1] && return x # 99.3% of the time we return here 1st try
return randn_unlikely(rng, idx, rabs, x)
end
end

# this unlikely branch is put in a separate function for better efficiency
# FIXME: we can't do @noinline because this accesses LDS
@inline function randn_unlikely(rng, idx, rabs, x)
@inbounds if idx == 0
while true
xx = -Random.ziggurat_nor_inv_r*log(Random.rand(rng))
yy = -log(Random.rand(rng))
yy+yy > xx*xx &&
return (rabs >> 8) % Bool ? -Random.ziggurat_nor_r-xx : Random.ziggurat_nor_r+xx
end
elseif (gpu_fi()[idx] - gpu_fi()[idx+1])*Random.rand(rng) + gpu_fi()[idx+1] < exp(-0.5*x*x)
return x # return from the triangular area
else
return Random.randn(rng)
end
end

## randexp

@device_override Random.randexp(rng::AbstractRNG) =
_randexp(rng, Random.rand(rng, Random.UInt52Raw()))

function _randexp(rng::AbstractRNG, ri::UInt64)
@inbounds begin
ri &= 0x000fffffffffffff
idx = ri & 0xFF
x = ri*gpu_we()[idx+1]
ri < gpu_ke()[idx+1] && return x # 98.9% of the time we return here 1st try
return randexp_unlikely(rng, idx, x)
end
end

# FIXME: we can't do @noinline because this accesses LDS
@inline function randexp_unlikely(rng, idx, x)
@inbounds if idx == 0
return Random.ziggurat_exp_r - log(Random.rand(rng))
elseif (gpu_fe()[idx] - gpu_fe()[idx+1])*Random.rand(rng) + gpu_fe()[idx+1] < exp(-x)
return x # return from the triangular area
else
return Random.randexp(rng)
end
end
Loading

0 comments on commit 3cbaab3

Please sign in to comment.