Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

Commit

Permalink
Merge #557
Browse files Browse the repository at this point in the history
557: Use AddrSpacePtr to call WMMA intrinsics r=thomasfaingnaert a=thomasfaingnaert

Closes #548 
Depends on JuliaLang/julia#34760

Co-authored-by: Thomas Faingnaert <[email protected]>
  • Loading branch information
bors[bot] and thomasfaingnaert authored Feb 24, 2020
2 parents 250aaa7 + 5ae1049 commit 2f60116
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 2 deletions.
4 changes: 4 additions & 0 deletions docs/src/device/wmma.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ VERSION >= v"1.4.0-DEV.666"
```
then make sure you are running Julia v1.4.0-DEV.666 or later!

!!! note

For optimal performance, you should use Julia `v1.5.0-DEV.324` or later.

## Introduction of Terminology

The WMMA operations perform a matrix multiply-accumulate.
Expand Down
30 changes: 28 additions & 2 deletions src/device/cuda/wmma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ using CUDAnative: AS, DevicePtr
# CONSTANTS
################################################################################

# Determines whether or not to Core.AddrSpacePtr is available
const addrspaceptr_available = (VERSION >= v"1.5.0-DEV.324")

# Maps PTX types to Julia array types
const map_ptx_to_jl_array = Dict(
"f16" => Float16,
Expand Down Expand Up @@ -49,6 +52,25 @@ get_frag_info(matrix, ptx_el_type) = (

get_addrspace_info(addr_space) = convert(Int, map_ptx_as_to_as_ty[addr_space])

if addrspaceptr_available
@generated function Base.cconvert(::Type{Core.AddrSpacePtr{T, as}}, x::DevicePtr{T, AS}) where {T, as, AS}
# Addrspacecast from i8* to i8* is invalid in LLVM
if as == 0
return quote
return Base.bitcast(Core.AddrSpacePtr{T, as}, x)
end
else
ir = "%p = inttoptr i64 %0 to i8*
%ptr = addrspacecast i8* %p to i8 addrspace($as)*
ret i8 addrspace($as)* %ptr"

return quote
return Base.llvmcall($ir, Core.AddrSpacePtr{T, as}, Tuple{Int64}, Base.bitcast(Int64, x))
end
end
end
end

################################################################################
# LOW LEVEL API
################################################################################
Expand Down Expand Up @@ -103,7 +125,9 @@ for mat in ["a", "b", "c"],

ccall_name = "extern $llvm_intr"

@eval $func_name(src_addr, stride) = ccall($ccall_name, llvmcall, NTuple{$sz, $frag_ty}, (Ref{$arr_ty}, Int32), src_addr, stride)
ptr_ty = addrspaceptr_available ? Core.AddrSpacePtr{arr_ty, addr_space_int} : Ref{arr_ty}

@eval $func_name(src_addr, stride) = ccall($ccall_name, llvmcall, NTuple{$sz, $frag_ty}, ($ptr_ty, Int32), src_addr, stride)
@eval export $func_name
@eval @doc (@doc llvm_wmma_load) $func_name
end
Expand Down Expand Up @@ -155,7 +179,9 @@ for mat in ["d"],
frag_types = ntuple(i -> frag_ty, sz)
frag_vars = ntuple(i -> :(data[$i]), sz)

@eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, (Ref{$arr_ty}, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride)
ptr_ty = addrspaceptr_available ? Core.AddrSpacePtr{arr_ty, addr_space_int} : Ref{arr_ty}

@eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride)
@eval export $func_name
@eval @doc (@doc llvm_wmma_store) $func_name
end
Expand Down
42 changes: 42 additions & 0 deletions test/device/wmma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,5 +248,47 @@ using CUDAnative.WMMA
end

################################################################################

# Need https://github.com/JuliaLang/julia/pull/34760
# See https://github.com/JuliaGPU/CUDAnative.jl/issues/548
if VERSION >= v"1.5.0-DEV.324"
@testset "Codegen addressing" begin
@testset "Global" begin
function kernel(d)
conf = WMMA.Config{16, 16, 16, Float32}

d_frag = WMMA.fill_c(Float32(0), conf)
WMMA.store_d(pointer(d), d_frag, 16, WMMA.ColMajor, conf)

return
end

ptx = sprint(io -> CUDAnative.code_ptx(io, kernel, (CuDeviceArray{Float32,1,CUDAnative.AS.Global},)))

@test !occursin("wmma.store.d.sync.aligned.col.m16n16k16.f32", ptx)
@test occursin("wmma.store.d.sync.aligned.col.m16n16k16.global.f32", ptx)
end

@testset "Shared" begin
function kernel()
shmem = @cuStaticSharedMem(Float32, (16, 16))
conf = WMMA.Config{16, 16, 16, Float32}

d_frag = WMMA.fill_c(Float32(0), conf)
WMMA.store_d(pointer(shmem), d_frag, 16, WMMA.ColMajor, conf)

return
end

ptx = sprint(io -> CUDAnative.code_ptx(io, kernel, ()))

@test !occursin("wmma.store.d.sync.aligned.col.m16n16k16.f32", ptx)
@test occursin("wmma.store.d.sync.aligned.col.m16n16k16.shared.f32", ptx)
end
end
end

################################################################################

end
end

0 comments on commit 2f60116

Please sign in to comment.