Skip to content

Commit

Permalink
Improve ROCm library handle integrations
Browse files Browse the repository at this point in the history
Switch library handles to be allocated per-task
Copy HandleCache mechanism from CUDA.jl
Always include library wrapper code
  • Loading branch information
jpsamaroo committed Feb 23, 2023
1 parent bc3cc37 commit 290abc8
Show file tree
Hide file tree
Showing 10 changed files with 273 additions and 75 deletions.
65 changes: 31 additions & 34 deletions src/AMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,19 @@ end

# Load HSA Runtime
const libhsaruntime = "libhsa-runtime64.so.1"
include(joinpath(@__DIR__, "hsa", "HSA.jl"))
include(joinpath("hsa", "HSA.jl"))
import .HSA: Agent, Queue, Executable, Status, Signal

# Load binary dependencies
include(joinpath(dirname(@__DIR__), "deps", "bindeps.jl"))

# Utilities
include("utils.jl")
include("cache.jl")

# Load HIP
const libhip = "libamdhip64.so"
include(joinpath(@__DIR__, "hip", "HIP.jl"))
include(joinpath("hip", "HIP.jl"))
import .HIP: HIPContext, HIPDevice, HIPStream

module Runtime
Expand All @@ -70,24 +71,24 @@ module Runtime
const RT_LOCK = Threads.ReentrantLock()
const RT_EXITING = Ref{Bool}(false)

include("runtime/logging.jl")
include("runtime/error.jl")
include("runtime/thread-utils.jl")
include("runtime/device.jl")
include("runtime/queue.jl")
include("runtime/signal.jl")
include("runtime/dims.jl")
include(joinpath("runtime", "logging.jl"))
include(joinpath("runtime", "error.jl"))
include(joinpath("runtime", "thread-utils.jl"))
include(joinpath("runtime", "device.jl"))
include(joinpath("runtime", "queue.jl"))
include(joinpath("runtime", "signal.jl"))
include(joinpath("runtime", "dims.jl"))
module Mem
include("runtime/memory.jl")
include(joinpath("runtime", "memory.jl"))
end
include("runtime/executable.jl")
include("runtime/hashing.jl")
include("runtime/kernel.jl")
include("runtime/kernel-signal.jl")
include("runtime/launch.jl")
include("runtime/execution.jl")
include("runtime/sync.jl")
include("runtime/fault.jl")
include(joinpath("runtime", "executable.jl"))
include(joinpath("runtime", "hashing.jl"))
include(joinpath("runtime", "kernel.jl"))
include(joinpath("runtime", "kernel-signal.jl"))
include(joinpath("runtime", "launch.jl"))
include(joinpath("runtime", "execution.jl"))
include(joinpath("runtime", "sync.jl"))
include(joinpath("runtime", "fault.jl"))
end # module Runtime
import .Runtime: Mem
import .Runtime: ROCDevice, ROCQueue
Expand Down Expand Up @@ -145,11 +146,11 @@ module Compiler
import .Runtime: Adaptor
import .Runtime: Mem

include("compiler/device-libs.jl")
include("compiler/utils.jl")
include("compiler/global-hooks.jl")
include("compiler/codegen.jl")
include("compiler/occupancy.jl")
include(joinpath("compiler", "device-libs.jl"))
include(joinpath("compiler", "utils.jl"))
include(joinpath("compiler", "global-hooks.jl"))
include(joinpath("compiler", "codegen.jl"))
include(joinpath("compiler", "occupancy.jl"))
end # module Compiler

include("tls.jl")
Expand Down Expand Up @@ -185,17 +186,13 @@ function hsaunref!()
end

# Load ROCm external libraries
if functional(:hip)
functional(:rocblas) && include(joinpath(@__DIR__, "blas", "rocBLAS.jl"))
#functional(:rocsparse) && include("sparse/rocSPARSE.jl")
#functional(:rocsolver) && include("solver/rocSOLVER.jl")
#functional(:rocalution) && include("solver/rocALUTION.jl")
if functional(:rocrand)
include(joinpath(@__DIR__, "rand", "rocRAND.jl"))
end
functional(:rocfft) && include(joinpath(@__DIR__, "fft", "rocFFT.jl"))
functional(:MIOpen) && include("dnn/MIOpen.jl")
end
include(joinpath("blas", "rocBLAS.jl"))
#include(joinpath("sparse", "rocSPARSE.jl")
#include(joinpath("solver", "rocSOLVER.jl")
#include(joinpath("solver", "rocALUTION.jl")
include(joinpath("rand", "rocRAND.jl"))
include(joinpath("fft", "rocFFT.jl"))
include(joinpath("dnn", "MIOpen.jl"))

include("random.jl")

Expand Down
70 changes: 59 additions & 11 deletions src/blas/rocBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,13 @@ module rocBLAS

using ..AMDGPU
import AMDGPU: wait!, mark!, librocblas, AnyROCArray
import AMDGPU: HandleCache
import AMDGPU.HIP: HIPContext, HIPStream

using LinearAlgebra

include("librocblas_types.jl")
include("error.jl")

const _handle = Ref{rocblas_handle}(C_NULL)
function handle()
if _handle[] == C_NULL
handle = rocblas_create_handle()
atexit(()->rocblas_destroy_handle(handle))
_handle[] = handle
end
return _handle[]
end

include("librocblas.jl")
include("wrappers.jl")
include("highlevel.jl")
Expand All @@ -26,4 +17,61 @@ function version()
VersionNumber(join(split(rocblas_get_version_string(), '.')[1:3], '.'))
end

# Copied from CUDA.jl/lib/cublas/CUBLAS.jl

# cache for created, but unused handles
const idle_handles = HandleCache{HIPContext, rocblas_handle}()

function handle()
rocblas_check_functional()

tls = AMDGPU.task_local_state()

# every task maintains library state per device
LibraryState = @NamedTuple{handle::rocblas_handle, stream::HIPStream}
states = get!(task_local_storage(), :rocBLAS) do
Dict{HIPContext,LibraryState}()
end::Dict{HIPContext,LibraryState}

# get library state
@noinline function new_state(tls)
new_handle = pop!(idle_handles, tls.context) do
rocblas_create_handle()
end

finalizer(current_task()) do task
push!(idle_handles, tls.context, new_handle) do
context!(tls.context) do
rocblas_destroy_handle(new_handle)
end
end
end

rocblas_set_stream(new_handle, tls.stream)

(; handle=new_handle, tls.stream)
end
state = get!(states, tls.context) do
new_state(tls)
end

# update stream
@noinline function update_stream(tls, state)
rocblas_set_stream(state.handle, tls.stream)
(; state.handle, stream=tls.stream)
end
if state.stream != tls.stream
states[tls.context] = state = update_stream(tls, state)
end

return state.handle
end

if AMDGPU.functional(:rocblas)
@eval rocblas_check_functional() = nothing
else
@eval rocblas_check_functional() =
throw(ArgumentError("rocBLAS is not functional"))
end

end
86 changes: 86 additions & 0 deletions src/cache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# a cache for library handles
# Copied from CUDA.jl/lib/utils/cache.jl

# TODO:
# - keep track of the (estimated?) size of cache contents
# - clean the caches when memory is needed. this will require registering the destructor
# upfront, so that it can set the environment (e.g. switch to the appropriate context).
# alternatively, register the `unsafe_free!`` methods with the pool instead of the cache.

export HandleCache

struct HandleCache{K,V}
active_handles::Set{Pair{K,V}} # for debugging, and to prevent handle finalization
idle_handles::Dict{K,Vector{V}}
lock::ReentrantLock

max_entries::Int

function HandleCache{K,V}(max_entries::Int=32) where {K,V}
return new{K,V}(Set{Pair{K,V}}(), Dict{K,Vector{V}}(), ReentrantLock(), max_entries)
end
end

# remove a handle from the cache, or create a new one
function Base.pop!(f::Function, cache::HandleCache{K,V}, key) where {K,V}
function check_cache(f::Function=()->nothing)
lock(cache.lock) do
handle = if !haskey(cache.idle_handles, key) || isempty(cache.idle_handles[key])
f()
else
pop!(cache.idle_handles[key])
end

if handle !== nothing
push!(cache.active_handles, key=>handle)
end

return handle
end
end

handle = check_cache()

if handle === nothing
# if we didn't find anything, perform a quick GC collection to free up old handles.
GC.gc(false)

handle = check_cache(f)
end

return handle::V
end

# put a handle in the cache, or destroy it if it doesn't fit
function Base.push!(f::Function, cache::HandleCache{K,V}, key::K, handle::V) where {K,V}
lock(cache.lock) do
delete!(cache.active_handles, key=>handle)

if haskey(cache.idle_handles, key)
if length(cache.idle_handles[key]) > cache.max_entries
f()
else
push!(cache.idle_handles[key], handle)
end
else
cache.idle_handles[key] = [handle]
end
end
end

# shorthand version to put a handle back without having to remember the key
function Base.push!(f::Function, cache::HandleCache{K,V}, handle::V) where {K,V}
lock(cache.lock) do
key = nothing
for entry in cache.active_handles
if entry[2] == handle
key = entry[1]
break
end
end
if key === nothing
error("Attempt to cache handle $handle that was not created by the handle cache")
end
push!(f, cache, key, handle)
end
end
3 changes: 1 addition & 2 deletions src/dnn/MIOpen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@ import AMDGPU: ROCArray, ROCDevice, LockedObject
using CEnum
using GPUArrays

@static if AMDGPU.use_artifacts
if AMDGPU.use_artifacts && AMDGPU.functional(:MIOpen)
using MIOpen_jll
const libMIOpen_path = MIOpen_jll.libMIOpen_path
else
const libMIOpen_path = AMDGPU.libMIOpen
end


include("low_level.jl")

const HANDLE = LockedObject(Ref{miopenHandle_t}(C_NULL))
Expand Down
Loading

0 comments on commit 290abc8

Please sign in to comment.