diff --git a/Project.toml b/Project.toml index bba4aaafe..12a154a97 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,8 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ROCmDeviceLibs_jll = "873c0968-716b-5aa7-bb8d-d1e2e2aeff2d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Random123 = "74087812-796a-5b5d-8853-05524746bad3" +RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/src/device/Device.jl b/src/device/Device.jl index 1ddad193d..463c9177b 100644 --- a/src/device/Device.jl +++ b/src/device/Device.jl @@ -22,5 +22,6 @@ include("exceptions.jl") include("gcn.jl") include("runtime.jl") include("quirks.jl") +include("random.jl") end diff --git a/src/device/random.jl b/src/device/random.jl new file mode 100644 index 000000000..7dd8cbde7 --- /dev/null +++ b/src/device/random.jl @@ -0,0 +1,250 @@ +# Copied from CUDA.jl/src/device/random.jl + + +## random number generation + +using Random +import RandomNumbers + + +# global state + +# shared memory with the actual seed, per warp, loaded lazily or overridden by calling `seed!` +@eval @inline function global_random_keys() + ptr = Base.llvmcall( + $("""@__zeroinit_global_random_keys = external addrspace($(AS.Local)) global [32 x i32], align 32 + define i8 addrspace($(AS.Local))* @entry() #0 { + %ptr = getelementptr inbounds [32 x i32], [32 x i32] addrspace($(AS.Local))* @__zeroinit_global_random_keys, i64 0, i64 0 + %untyped_ptr = bitcast i32 addrspace($(AS.Local))* %ptr to i8 addrspace($(AS.Local))* + ret i8 addrspace($(AS.Local))* %untyped_ptr + } + attributes #0 = { alwaysinline } + """, "entry"), LLVMPtr{UInt32, AS.Local}, Tuple{}) + ROCDeviceArray{UInt32,1,AS.Local}((32,), ptr) +end + +# shared memory with per-warp counters, incremented when generating numbers +@eval @inline function global_random_counters() + ptr = Base.llvmcall( + $("""@__zeroinit_global_random_counters = external addrspace($(AS.Local)) global [32 x i32], align 32 + define i8 addrspace($(AS.Local))* @entry() #0 { + %ptr = getelementptr inbounds [32 x i32], [32 x i32] addrspace($(AS.Local))* @__zeroinit_global_random_counters, i64 0, i64 0 + %untyped_ptr = bitcast i32 addrspace($(AS.Local))* %ptr to i8 addrspace($(AS.Local))* + ret i8 addrspace($(AS.Local))* %untyped_ptr + } + attributes #0 = { alwaysinline } + """, "entry"), LLVMPtr{UInt32, AS.Local}, Tuple{}) + ROCDeviceArray{UInt32,1,AS.Local}((32,), ptr) +end + +@device_override Random.make_seed() = Base.unsafe_trunc(UInt32, memrealtime()) + + +# generators + +using Random123: philox2x_round, philox2x_bumpkey + +# GPU-compatible/optimized version of the generator from Random123.jl +struct Philox2x32{R} <: RandomNumbers.AbstractRNG{UInt64} + @inline function Philox2x32{R}() where R + rng = new{R}() + if rng.key == 0 + # initialize the key. this happens when first accessing the (0-initialized) + # shared memory key from each block. if we ever want to make the device seed + # controlable from the host, this would be the place to read a global seed. + # + # note however that it is undefined how shared memory persists across e.g. + # launches, so we may not be able to rely on the zero initalization then. + rng.key = Random.make_seed() + end + return rng + end +end + +# default to 7 rounds; enough to pass SmallCrush +@inline Philox2x32() = Philox2x32{7}() + +@inline function Base.getproperty(rng::Philox2x32, field::Symbol) + threadId = workitemIdx().x + (workitemIdx().y - Int32(1)) * workgroupDim().x + + (workitemIdx().z - Int32(1)) * workgroupDim().x * workgroupDim().y + warpId = (threadId - Int32(1)) >> 0x6 + Int32(1) # fld1 by 64 + + if field === :seed + @inbounds global_random_seed()[1] + elseif field === :key + @inbounds global_random_keys()[warpId] + elseif field === :ctr1 + @inbounds global_random_counters()[warpId] + elseif field === :ctr2 + blockId = workgroupIdx().x + (workgroupIdx().y - Int32(1)) * gridGroupDim().x + + (workgroupIdx().z - Int32(1)) * gridGroupDim().x * gridGroupDim().y + globalId = threadId + (blockId - Int32(1)) * (workgroupDim().x * workgroupDim().y * workgroupDim().z) + globalId%UInt32 + end::UInt32 +end + +@inline function Base.setproperty!(rng::Philox2x32, field::Symbol, x) + threadId = workitemIdx().x + (workitemIdx().y - Int32(1)) * workgroupDim().x + + (workitemIdx().z - Int32(1)) * workgroupDim().x * workgroupDim().y + warpId = (threadId - Int32(1)) >> 0x6 + Int32(1) # fld1 by 64 + + if field === :key + @inbounds global_random_keys()[warpId] = x + elseif field === :ctr1 + @inbounds global_random_counters()[warpId] = x + end +end + +@device_override @inline Random.default_rng() = Philox2x32() + +""" + Random.seed!(rng::Philox2x32, seed::Integer, [counter::Integer=0]) + +Seed the on-device Philox2x32 generator with an UInt32 number. +Should be called by at least one thread per warp. +""" +function Random.seed!(rng::Philox2x32, seed::Integer, counter::Integer=0) + rng.key = seed % UInt32 + rng.ctr1 = counter + return +end + +if VERSION >= v"1.7-" +@device_override Random.seed!(::Random._GLOBAL_RNG, seed) = + Random.seed!(Random.default_rng(), seed) +end + +""" + Random.rand(rng::Philox2x32, UInt32) + +Generate a byte of random data using the on-device Tausworthe generator. +""" +function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R} + ctr1, ctr2, key = rng.ctr1, rng.ctr2, rng.key + + if R > 0 ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 1 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 2 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 3 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 4 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 5 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 6 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 7 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 8 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 9 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 10 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 11 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 12 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 13 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 14 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 15 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + + # update the warp counter + # NOTE: this performs the same update on every thread in the warp, but each warp writes + # to a unique location so the duplicate writes are innocuous + # XXX: what if this overflows? we can't increment ctr2. bump the key? + rng.ctr1 += Int32(1) + + # NOTE: it's too expensive to keep both numbers around in case the user only wanted one, + # so just make our 2x32 generator return 64-bit numbers by default. + return (ctr1 % UInt64) << 32 | (ctr2 % UInt64) +end + + + +# normally distributed random numbers using Ziggurat algorithm +# +# copied from Base because we don't support its global tables + +# a hacky method of exposing constant tables as constant GPU memory +function emit_constant_array(name::Symbol, data::AbstractArray{T}) where {T} + @dispose ctx=Context() begin + T_val = convert(LLVMType, T) + T_ptr = convert(LLVMType, LLVMPtr{T,AS.Constant}) + + # define function and get LLVM module + llvm_f, _ = create_function(T_ptr) + mod = LLVM.parent(llvm_f) + + # create a global memory global variable + # TODO: global_var alignment? + T_global = LLVM.ArrayType(T_val, length(data)) + # XXX: why can't we use a single name like emit_shmem + gv = GlobalVariable(mod, T_global, "gpu_$(name)_data", AS.Constant) + alignment!(gv, 16) + linkage!(gv, LLVM.API.LLVMInternalLinkage) + initializer!(gv, ConstantArray(data)) + + # generate IR + @dispose builder=IRBuilder() begin + entry = BasicBlock(llvm_f, "entry") + position!(builder, entry) + + ptr = gep!(builder, T_global, gv, [ConstantInt(0), ConstantInt(0)]) + + untyped_ptr = bitcast!(builder, ptr, T_ptr) + + ret!(builder, untyped_ptr) + end + + call_function(llvm_f, LLVMPtr{T,AS.Constant}) + end +end + +for var in [:ki, :wi, :fi, :ke, :we, :fe] + val = getfield(Random, var) + gpu_var = Symbol("gpu_$var") + arr_typ = :(ROCDeviceArray{$(eltype(val)),$(ndims(val)),AS.Constant}) + @eval @inline @generated function $gpu_var() + ptr = emit_constant_array($(QuoteNode(var)), $val) + Expr(:call, $arr_typ, $(size(val)), ptr) + end +end + +## randn + +@device_override @inline function Random.randn(rng::AbstractRNG) + @label retry + r = Random.rand(rng, Random.UInt52Raw()) + @inbounds begin + r &= 0x000fffffffffffff + rabs = Int64(r>>1) # One bit for the sign + idx = rabs & 0xFF + x = ifelse(r % Bool, -rabs, rabs)*gpu_wi()[idx+1] + rabs < gpu_ki()[idx+1] && return x # 99.3% of the time we return here 1st try + # TODO: This code could be outlined once LLVM supports LDS access in recursively-called functions + @inbounds if idx == 0 + while true + xx = -Random.ziggurat_nor_inv_r*log(Random.rand(rng)) + yy = -log(Random.rand(rng)) + yy+yy > xx*xx && + return (rabs >> 8) % Bool ? -Random.ziggurat_nor_r-xx : Random.ziggurat_nor_r+xx + end + elseif (gpu_fi()[idx] - gpu_fi()[idx+1])*Random.rand(rng) + gpu_fi()[idx+1] < exp(-0.5*x*x) + return x # return from the triangular area + else + @goto retry + end + end +end + +## randexp + +@device_override @inline function Random.randexp(rng::AbstractRNG) + @label retry + ri = Random.rand(rng, Random.UInt52Raw()) + @inbounds begin + ri &= 0x000fffffffffffff + idx = ri & 0xFF + x = ri*gpu_we()[idx+1] + ri < gpu_ke()[idx+1] && return x # 98.9% of the time we return here 1st try + # TODO: This code could be outlined once LLVM supports LDS access in recursively-called functions + @inbounds if idx == 0 + return Random.ziggurat_exp_r - log(Random.rand(rng)) + elseif (gpu_fe()[idx] - gpu_fe()[idx+1])*Random.rand(rng) + gpu_fe()[idx+1] < exp(-x) + return x # return from the triangular area + else + @goto retry + end + end +end diff --git a/test/device/random.jl b/test/device/random.jl new file mode 100644 index 000000000..b75e0320c --- /dev/null +++ b/test/device/random.jl @@ -0,0 +1,130 @@ +using Random + +n = 256 + +function apply_seed(seed) + if seed === missing + # should result in different numbers across launches + Random.seed!() + # XXX: this currently doesn't work, because of the definition in Base, + # `seed!(r::MersenneTwister=default_rng())`, which breaks overriding + # `default_rng` with a non-MersenneTwister RNG. + elseif seed !== nothing + # should result in the same numbers + Random.seed!(seed) + elseif seed === nothing + # should result in different numbers across launches, + # as determined by the seed set during module loading. + end +end + +@testset "rand($T), seed $seed" for T in (Int32, UInt32, Int64, UInt64, Int128, UInt128, + Float16, Float32, Float64), + seed in (nothing, #=missing,=# 1234) + # different kernel invocations should get different numbers + @testset "across launches" begin + function kernel(A::AbstractArray{T}, seed) where {T} + apply_seed(seed) + tid = workitemIdx().x + A[tid] = rand(T) + return nothing + end + + a = AMDGPU.zeros(T, n) + b = AMDGPU.zeros(T, n) + + @roc groupsize=n kernel(a, seed) + @roc groupsize=n kernel(b, seed) + + if seed === nothing || seed === missing + @test Array(a) != Array(b) + else + @test Array(a) == Array(b) + end + end + + # multiple calls to rand should get different numbers + @testset "across calls" begin + function kernel(A::AbstractArray{T}, B::AbstractArray{T}, seed) where {T} + apply_seed(seed) + tid = workitemIdx().x + A[tid] = rand(T) + B[tid] = rand(T) + return nothing + end + + a = AMDGPU.zeros(T, n) + b = AMDGPU.zeros(T, n) + + @roc groupsize=n kernel(a, b, seed) + + @test Array(a) != Array(b) + end + + # different threads should get different numbers + @testset "across threads" for active_dim in 1:6 + function kernel(A::AbstractArray{T}, seed) where {T} + apply_seed(seed) + id = workitemIdx().x*workitemIdx().y*workitemIdx().z*workgroupIdx().x*workgroupIdx().y*workgroupIdx().z + if 1 <= id <= length(A) + A[id] = rand(T) + end + return nothing + end + + tx, ty, tz, bx, by, bz = [dim == active_dim ? 3 : 1 for dim in 1:6] + gx, gy, gz = tx*bx, ty*by, tz*bz + a = AMDGPU.zeros(T, 3) + + @roc groupsize=(tx, ty, tz) gridsize=(gx, gy, gz) kernel(a, seed) + + # NOTE: we don't just generate two numbers and compare them, instead generating a + # couple more and checking they're not all the same, in order to avoid + # occasional collisions with lower-precision types (i.e., Float16). + @test length(unique(Array(a))) > 1 + end +end + +@testset "basic randn($T), seed $seed" for T in (Float16, Float32, Float64), + seed in (nothing, #=missing,=# 1234) + function kernel(A::AbstractArray{T}, seed) where {T} + apply_seed(seed) + tid = workitemIdx().x + A[tid] = randn(T) + return + end + + a = AMDGPU.zeros(T, n) + b = AMDGPU.zeros(T, n) + + @roc groupsize=n kernel(a, seed) + @roc groupsize=n kernel(b, seed) + + if seed === nothing || seed === missing + @test Array(a) != Array(b) + else + @test Array(a) == Array(b) + end +end + +@testset "basic randexp($T), seed $seed" for T in (Float16, Float32, Float64), + seed in (nothing, #=missing,=# 1234) + function kernel(A::AbstractArray{T}, seed) where {T} + apply_seed(seed) + tid = workitemIdx().x + A[tid] = randexp(T) + return + end + + a = AMDGPU.zeros(T, n) + b = AMDGPU.zeros(T, n) + + @roc groupsize=n kernel(a, seed) + @roc groupsize=n kernel(b, seed) + + if seed === nothing || seed === missing + @test Array(a) != Array(b) + else + @test Array(a) == Array(b) + end +end diff --git a/test/device_tests.jl b/test/device_tests.jl index 11bb5b2d9..6e779fc8e 100644 --- a/test/device_tests.jl +++ b/test/device_tests.jl @@ -15,6 +15,7 @@ include("device/wavefront.jl") include("device/synchronization.jl") include("device/execution_control.jl") include("device/exceptions.jl") +include("device/random.jl") # TODO https://github.com/JuliaGPU/AMDGPU.jl/issues/546 include("device/math.jl")