Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

Commit

Permalink
Merge #557
Browse files Browse the repository at this point in the history
557: Use AddrSpacePtr to call WMMA intrinsics r=vchuravy a=thomasfaingnaert

Closes #548 
Depends on JuliaLang/julia#34760

Co-authored-by: Thomas Faingnaert <[email protected]>
  • Loading branch information
bors[bot] and thomasfaingnaert authored Feb 23, 2020
2 parents 0c83640 + e176c58 commit 1c8d85a
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/src/device/wmma.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ VERSION >= v"1.4.0-DEV.666"
```
then make sure you are running Julia v1.4.0-DEV.666 or later!

!!! note

For optimal performance, you should use Julia `v1.5.0-DEV.324` or later.

## Introduction of Terminology

The WMMA operations perform a matrix multiply-accumulate.
Expand Down
31 changes: 31 additions & 0 deletions src/device/cuda/wmma.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
export WMMA
module WMMA

import Base.convert
using CUDAnative: AS, DevicePtr

################################################################################
# CONSTANTS
################################################################################

# Determines whether or not to Core.AddrSpacePtr is available
const addrspaceptr_available = (VERSION >= v"1.5.0-DEV.324")

# Maps PTX types to Julia array types
const map_ptx_to_jl_array = Dict(
"f16" => Float16,
Expand Down Expand Up @@ -49,6 +53,25 @@ get_frag_info(matrix, ptx_el_type) = (

get_addrspace_info(addr_space) = convert(Int, map_ptx_as_to_as_ty[addr_space])

if addrspaceptr_available
@generated function convert(::Type{Core.AddrSpacePtr{T, as}}, x::DevicePtr{T, AS}) where {T, as, AS}
# Addrspacecast from i8* to i8* is invalid in LLVM
if as == 0
return quote
return Base.bitcast(Core.AddrSpacePtr{T, as}, x)
end
else
ir = "%p = inttoptr i64 %0 to i8*
%ptr = addrspacecast i8* %p to i8 addrspace($as)*
ret i8 addrspace($as)* %ptr"

return quote
return Base.llvmcall($ir, Core.AddrSpacePtr{T, as}, Tuple{Int64}, Base.bitcast(Int64, x))
end
end
end
end

################################################################################
# LOW LEVEL API
################################################################################
Expand Down Expand Up @@ -103,7 +126,11 @@ for mat in ["a", "b", "c"],

ccall_name = "extern $llvm_intr"

if addrspaceptr_available
@eval $func_name(src_addr, stride) = ccall($ccall_name, llvmcall, NTuple{$sz, $frag_ty}, (Core.AddrSpacePtr{$arr_ty, $addr_space_int}, Int32), convert(Core.AddrSpacePtr{$arr_ty, $addr_space_int}, src_addr), stride)
else
@eval $func_name(src_addr, stride) = ccall($ccall_name, llvmcall, NTuple{$sz, $frag_ty}, (Ref{$arr_ty}, Int32), src_addr, stride)
end
@eval export $func_name
@eval @doc (@doc llvm_wmma_load) $func_name
end
Expand Down Expand Up @@ -155,7 +182,11 @@ for mat in ["d"],
frag_types = ntuple(i -> frag_ty, sz)
frag_vars = ntuple(i -> :(data[$i]), sz)

if addrspaceptr_available
@eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, (Core.AddrSpacePtr{$arr_ty, $addr_space_int}, $(frag_types...), Int32), convert(Core.AddrSpacePtr{$arr_ty, $addr_space_int}, dst_addr), $(frag_vars...), stride)
else
@eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, (Ref{$arr_ty}, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride)
end
@eval export $func_name
@eval @doc (@doc llvm_wmma_store) $func_name
end
Expand Down
42 changes: 42 additions & 0 deletions test/device/wmma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,5 +248,47 @@ using CUDAnative.WMMA
end

################################################################################

# Need https://github.com/JuliaLang/julia/pull/34760
# See https://github.com/JuliaGPU/CUDAnative.jl/issues/548
if VERSION >= v"1.5.0-DEV.324"
@testset "Codegen addressing" begin
@testset "Global" begin
function kernel(d)
conf = WMMA.Config{16, 16, 16, Float32}

d_frag = WMMA.fill_c(Float32(0), conf)
WMMA.store_d(pointer(d), d_frag, 16, WMMA.ColMajor, conf)

return
end

ptx = sprint(io -> CUDAnative.code_ptx(io, kernel, (CuDeviceArray{Float32,1,CUDAnative.AS.Global},)))

@test !occursin("wmma.store.d.sync.aligned.col.m16n16k16.f32", ptx)
@test occursin("wmma.store.d.sync.aligned.col.m16n16k16.global.f32", ptx)
end

@testset "Shared" begin
function kernel()
shmem = @cuStaticSharedMem(Float32, (16, 16))
conf = WMMA.Config{16, 16, 16, Float32}

d_frag = WMMA.fill_c(Float32(0), conf)
WMMA.store_d(pointer(shmem), d_frag, 16, WMMA.ColMajor, conf)

return
end

ptx = sprint(io -> CUDAnative.code_ptx(io, kernel, ()))

@test !occursin("wmma.store.d.sync.aligned.col.m16n16k16.f32", ptx)
@test occursin("wmma.store.d.sync.aligned.col.m16n16k16.shared.f32", ptx)
end
end
end

################################################################################

end
end

0 comments on commit 1c8d85a

Please sign in to comment.