Skip to content

Commit

Permalink
Device RNG: Use LLVM to build random LDS arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
jpsamaroo committed Dec 1, 2023
1 parent d7e3012 commit 4172591
Showing 1 changed file with 39 additions and 22 deletions.
61 changes: 39 additions & 22 deletions src/device/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,49 @@ import RandomNumbers

# global state

@inline @generated function emit_global_random_values(::Val{name}) where name
@dispose ctx=Context() begin
T_val = convert(LLVMType, UInt32)
T_ptr = convert(LLVMType, LLVMPtr{UInt32,AS.Local})

# define function and get LLVM module
llvm_f, _ = create_function(T_ptr)
mod = LLVM.parent(llvm_f)

# create a global memory global variable
T_global = LLVM.ArrayType(T_val, 32)
gv = GlobalVariable(mod, T_global, "__zeroinit_global_random_$(name)", AS.Local)
linkage!(gv, LLVM.API.LLVMExternalLinkage)

# TODO: we need alwaysinline to ensure we don't access LDS in a non-kernel function
push!(function_attributes(llvm_f), EnumAttribute("alwaysinline"))

# generate IR
@dispose builder=IRBuilder() begin
entry = BasicBlock(llvm_f, "entry")
position!(builder, entry)

ptr = gep!(builder, T_global, gv, [ConstantInt(0), ConstantInt(0)])

untyped_ptr = bitcast!(builder, ptr, T_ptr)

ret!(builder, untyped_ptr)
end

call_function(llvm_f, LLVMPtr{UInt32,AS.Local})
end
end

# shared memory with the actual seed, per warp, loaded lazily or overridden by calling `seed!`
@eval @inline function global_random_keys()
ptr = Base.llvmcall(
$("""@__zeroinit_global_random_keys = external addrspace($(AS.Local)) global [32 x i32], align 32
define i8 addrspace($(AS.Local))* @entry() #0 {
%ptr = getelementptr inbounds [32 x i32], [32 x i32] addrspace($(AS.Local))* @__zeroinit_global_random_keys, i64 0, i64 0
%untyped_ptr = bitcast i32 addrspace($(AS.Local))* %ptr to i8 addrspace($(AS.Local))*
ret i8 addrspace($(AS.Local))* %untyped_ptr
}
attributes #0 = { alwaysinline }
""", "entry"), LLVMPtr{UInt32, AS.Local}, Tuple{})
ROCDeviceArray{UInt32,1,AS.Local}((32,), ptr)
@inline function global_random_keys()
ptr = emit_global_random_values(Val{:keys}())
return ROCDeviceArray{UInt32,1,AS.Local}((32,), ptr)
end

# shared memory with per-warp counters, incremented when generating numbers
@eval @inline function global_random_counters()
ptr = Base.llvmcall(
$("""@__zeroinit_global_random_counters = external addrspace($(AS.Local)) global [32 x i32], align 32
define i8 addrspace($(AS.Local))* @entry() #0 {
%ptr = getelementptr inbounds [32 x i32], [32 x i32] addrspace($(AS.Local))* @__zeroinit_global_random_counters, i64 0, i64 0
%untyped_ptr = bitcast i32 addrspace($(AS.Local))* %ptr to i8 addrspace($(AS.Local))*
ret i8 addrspace($(AS.Local))* %untyped_ptr
}
attributes #0 = { alwaysinline }
""", "entry"), LLVMPtr{UInt32, AS.Local}, Tuple{})
ROCDeviceArray{UInt32,1,AS.Local}((32,), ptr)
@inline function global_random_counters()
ptr = emit_global_random_values(Val{:counters}())
return ROCDeviceArray{UInt32,1,AS.Local}((32,), ptr)
end

@device_override Random.make_seed() = Base.unsafe_trunc(UInt32, memrealtime())
Expand Down

0 comments on commit 4172591

Please sign in to comment.