-
-
Notifications
You must be signed in to change notification settings - Fork 55
Implement wrappers for WMMA LLVM intrinsics #494
Conversation
7561da6
to
faae545
Compare
Thanks! This is a really nice first PR.
I would say definitely for the shared AS. I will leave the detailed review up to Tim, but from my perspective it would be nice to shorten the Julia functions a bit, Secondly (and more a matter of style) I prefer using CUDAnative.jl/src/device/pointer.jl Line 219 in 9d2737f
|
Looks good! Appreciate the documentation. Did you look into julia> foo(x) = ccall("llvm.donothing", llvmcall, Nothing, (NTuple{8, NTuple{2, VecElement{Float16}}},), x)
foo (generic function with 2 methods)
julia> code_llvm(foo, Tuple{NTuple{8, NTuple{2, VecElement{Float16}}}}; optimize=false)
; @ REPL[19]:1 within `foo'
define void @julia_foo_16071([8 x <2 x i16>] addrspace(11)* nocapture nonnull readonly dereferenceable(32)) {
top:
%1 = call %jl_value_t*** @julia.ptls_states()
%2 = bitcast %jl_value_t*** %1 to %jl_value_t addrspace(10)**
%3 = getelementptr inbounds %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %2, i64 4
%4 = bitcast %jl_value_t addrspace(10)** %3 to i64**
%5 = load i64*, i64** %4
%6 = load [8 x <2 x i16>], [8 x <2 x i16>] addrspace(11)* %0, align 4
call void @llvm.donothing([8 x <2 x i16>] %6) [ "jl_roots"([8 x <2 x i16>] addrspace(11)* %0) ]
ret void
} The @vchuravy has there been any movement on JuliaLang/julia#26381? That still seems wanted, because IEEEFloat16.jl won't integrate with above |
No movement on the f16 front we can't use f16 if the LLVM backend doesn't support it and so we can't enable it universally and need to special case it for the target backend, which is messy. |
It doesn't seem like passing $ JULIA_LLVM_ARGS="--version" jl
LLVM (http://llvm.org/):
LLVM version 6.0.1
Optimized build with assertions.
Default target: x86_64--linux-gnu
Host CPU: ivybridge The following snippet: using CuArrays
using CUDAnative
d = rand(Float32, (16, 16))
d_dev = CuArray(d)
function kernel(d_dev)
ccall("extern llvm.nvvm.wmma.store.d.sync.col.m16n16k16.stride.f32", llvmcall, Nothing, (Int64, Float32, Float32, Float32, Float32, Float32, Float32, Float32, Float32, Int32), pointer(d_dev), 1, 2, 3, 4, 5, 6, 7, 8, 16)
return
end
@cuda threads=32 kernel(d_dev)
Array(d_dev) gives an assertion failure: Assertion failure (piped through c++filt)
Everything works fine when disabling LLVM assertions. Changing the type from |
Ha, so this isn't just broken for |
Would it make sense to implement this in https://github.com/JuliaLang/julia (so it works outside of the context of GPUs), or just using the compiler hooks in CUDAnative? Judging from the discussion at JuliaLang/julia#23367, anonymising pointers and other types was fully intended, even though it broke |
Definitely, I meant this to be a fix in the Julia compiler. Anonimization was intended indeed, we should just fix our ABI when interfacing with LLVM intrinsics. |
@maleadt While adding tests for the shared address space, I stumbled on two issues:
|
e595a61
to
28f44b5
Compare
bors try |
tryBuild succeeded |
I addressed your comments. The most major changes are:
Feel free to comment on anything you'd still like to change. bors try |
tryBuild succeeded |
Finally, it is important to note that the resultant ``D`` matrix can be used as a ``C`` matrix for a subsequent multiply-accumulate. | ||
This is useful if one needs to calculate a sum of the form ``\sum_{i=0}^{n} A_i B_i``, where ``A_i`` and ``B_i`` are matrices of the correct dimension. | ||
|
||
## LLVM Intrinsics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From a user perspective I would want to first read about C-like
/highlevel API and then if I am interested I care about the intrinsics.
Looks very nice! |
Really nice! Let's go ahead and merge this. |
494: Implement wrappers for WMMA LLVM intrinsics r=maleadt a=thomasfaingnaert This PR adds low-level wrappers around the LLVM WMMA intrinsics. There is a one-to-one mapping between Julia functions and the LLVM intrinsics, which means that the function names can be very long. The return types are the Julia types that correspond closest to the return type of the LLVM intrinsic (e.g. `[8 x <2 x half>]` becomes `NTuple{8, NTuple{2, VecElement{Float16}}}`). In essence, these wrappers return the SSA nodes returned by the LLVM intrinsic. Once this PR is finalised, I will start on a higher level API, similar to how WMMA is used in CUDA C++. I added all intrinsics available in LLVM 6, PTX 6.0, SM 70, with the following exceptions: - The load/store intrinsics have a version without a stride parameter. In that case, the stride is derived from the datatype of the arguments and the WMMA shape. The same behaviour can be achieved by explicitly specifying that stride, so I decided to leave the strideless version out. - The MMA intrinsic can use saturation arithmetic. However, this is deprecated for floating point operations starting from PTX 6.4, so I decided not to add it. Example usage: <details> <summary>Julia code</summary> ```julia using CUDAnative using CuArrays using Test # Generate input matrices a = rand(Float16, (16, 16)) a_dev = CuArray(a) b = rand(Float16, (16, 16)) b_dev = CuArray(b) c = rand(Float32, (16, 16)) c_dev = CuArray(c) # Allocate space for result d_dev = similar(c_dev) # Matrix multiply-accumulate kernel (D = A * B + C) function kernel(a_dev, b_dev, c_dev, d_dev) a_frag = llvm_wmma_load_a_col_m16n16k16_stride_f16(pointer(a_dev), 16) b_frag = llvm_wmma_load_b_col_m16n16k16_stride_f16(pointer(b_dev), 16) c_frag = llvm_wmma_load_c_col_m16n16k16_stride_f32(pointer(c_dev), 16) d_frag = llvm_wmma_mma_col_col_m16n16k16_f32_f32(a_frag, b_frag, c_frag) llvm_wmma_store_d_col_m16n16k16_stride_f32(pointer(d_dev), d_frag, 16) return end @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) @test a * b + c ≈ Array(d_dev) rtol=0.01 ``` </details> This will be compiled to the following LLVM IR: <details> <summary>LLVM IR</summary> ```llvm %src_ptr.i.i = inttoptr i64 %.fca.1.extract15 to i8* %ret.llvm.i.i = call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.load.a.sync.col.m16n16k16.stride.f16(i8* %src_ptr.i.i, i32 16) %ret.llvm.0.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 0 %ret.llvm.1.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 1 %ret.llvm.2.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 2 %ret.llvm.3.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 3 %ret.llvm.4.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 4 %ret.llvm.5.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 5 %ret.llvm.6.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 6 %ret.llvm.7.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 7 %src_ptr.i5.i = inttoptr i64 %.fca.1.extract9 to i8* %ret.llvm.i6.i = call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.load.b.sync.col.m16n16k16.stride.f16(i8* %src_ptr.i5.i, i32 16) %ret.llvm.0.i7.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 0 %ret.llvm.1.i8.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 1 %ret.llvm.2.i9.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 2 %ret.llvm.3.i10.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 3 %ret.llvm.4.i11.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 4 %ret.llvm.5.i12.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 5 %ret.llvm.6.i13.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 6 %ret.llvm.7.i14.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 7 %src_ptr.i31.i = inttoptr i64 %.fca.1.extract3 to i8* %ret.llvm.i32.i = call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.load.c.sync.col.m16n16k16.stride.f32(i8* %src_ptr.i31.i, i32 16) %ret.llvm.0.i33.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 0 %ret.llvm.1.i34.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 1 %ret.llvm.2.i35.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 2 %ret.llvm.3.i36.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 3 %ret.llvm.4.i37.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 4 %ret.llvm.5.i38.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 5 %ret.llvm.6.i39.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 6 %ret.llvm.7.i40.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 7 %d.llvm.i.i = call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.mma.sync.col.col.m16n16k16.f32.f32(<2 x half> %ret.llvm.0.i.i, <2 x half> %ret.llvm.1.i.i, <2 x half> %ret.llvm.2.i.i, <2 x half> %ret.llvm.3.i.i, <2 x half> %ret.llvm.4.i.i, <2 x half> %ret.llvm.5.i.i, <2 x half> %ret.llvm.6.i.i, <2 x half> %ret.llvm.7.i.i, <2 x half> %ret.llvm.0.i7.i, <2 x half> %ret.llvm.1.i8.i, <2 x half> %ret.llvm.2.i9.i, <2 x half> %ret.llvm.3.i10.i, <2 x half> %ret.llvm.4.i11.i, <2 x half> %ret.llvm.5.i12.i, <2 x half> %ret.llvm.6.i13.i, <2 x half> %ret.llvm.7.i14.i, float %ret.llvm.0.i33.i, float %ret.llvm.1.i34.i, float %ret.llvm.2.i35.i, float %ret.llvm.3.i36.i, float %ret.llvm.4.i37.i, float %ret.llvm.5.i38.i, float %ret.llvm.6.i39.i, float %ret.llvm.7.i40.i) %d.llvm.0.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 0 %d.llvm.1.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 1 %d.llvm.2.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 2 %d.llvm.3.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 3 %d.llvm.4.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 4 %d.llvm.5.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 5 %d.llvm.6.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 6 %d.llvm.7.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 7 %dst_ptr.i.i = inttoptr i64 %.fca.1.extract to i8* call void @llvm.nvvm.wmma.store.d.sync.col.m16n16k16.stride.f32(i8* %dst_ptr.i.i, float %d.llvm.0.i.i, float %d.llvm.1.i.i, float %d.llvm.2.i.i, float %d.llvm.3.i.i, float %d.llvm.4.i.i, float %d.llvm.5.i.i, float %d.llvm.6.i.i, float %d.llvm.7.i.i, i32 16) ret void ``` </details> Note that all the `bitcast`, `extractvalue` and `insertvalue` instructions (necessary to generate correct LLVM IR for use with `llvmcall`) are optimised away. The remaining `extractvalue` is necessary to convert the struct return type to separate arguments. Finally, the NVPTX backend generated the following PTX code: <details> <summary>PTX code</summary> ``` LBB1_8: // %julia_kernel_3.exit mov.u32 %r1, 16; wmma.load.a.sync.col.m16n16k16.f16 {%hh1, %hh2, %hh3, %hh4, %hh5, %hh6, %hh7, %hh8}, [%rd1], %r1; wmma.load.b.sync.col.m16n16k16.f16 {%hh9, %hh10, %hh11, %hh12, %hh13, %hh14, %hh15, %hh16}, [%rd2], %r1; wmma.load.c.sync.col.m16n16k16.f32 {%f1, %f2, %f3, %f4, %f5, %f6, %f7, %f8}, [%rd3], %r1; wmma.mma.sync.col.col.m16n16k16.f32.f32 {%f9, %f10, %f11, %f12, %f13, %f14, %f15, %f16}, {%hh1, %hh2, %hh3, %hh4, %hh5, %hh6, %hh7, %hh8}, {%hh9, %hh10, %hh11, %hh12, %hh13, %hh14, %hh15, %hh16}, {%f1, %f2, %f3, %f4, %f5, %f6, %f7, %f8}; wmma.store.d.sync.col.m16n16k16.f32 [%rd4], {%f9, %f10, %f11, %f12, %f13, %f14, %f15, %f16}, %r1; ret; ``` </details> **TODO/Questions:** - ~~I should probably add documentation for this, or should I leave this for the higher-level API?~~ - Would you prefer to have the non-stride versions anyway, or can I leave these out? - ~~The loads and stores are tested with the default address space (generic), using global arrays. Should I add tests for the intrinsic versions with global and shared address spaces as well?~~ Co-authored-by: Thomas Faingnaert <[email protected]>
Build succeeded |
@maleadt Just a heads up, julia:nightly is failing on master, which leads to this PR failing too |
Yeah, no worries. JuliaLang/julia#34611 broke a bunch of packages on Julia#master. |
Forgot to squash merge so I went ahead and force-pushed master (I know, I know, but otherwise bisecting is broken). |
This PR adds low-level wrappers around the LLVM WMMA intrinsics.
There is a one-to-one mapping between Julia functions and the LLVM intrinsics, which means that the function names can be very long.
The return types are the Julia types that correspond closest to the return type of the LLVM intrinsic (e.g.
[8 x <2 x half>]
becomesNTuple{8, NTuple{2, VecElement{Float16}}}
).In essence, these wrappers return the SSA nodes returned by the LLVM intrinsic.
Once this PR is finalised, I will start on a higher level API, similar to how WMMA is used in CUDA C++.
I added all intrinsics available in LLVM 6, PTX 6.0, SM 70, with the following exceptions:
Example usage:
Julia code
This will be compiled to the following LLVM IR:
LLVM IR
Note that all the
bitcast
,extractvalue
andinsertvalue
instructions (necessary to generate correct LLVM IR for use withllvmcall
) are optimised away. The remainingextractvalue
is necessary to convert the struct return type to separate arguments.Finally, the NVPTX backend generated the following PTX code:
PTX code
TODO/Questions:
I should probably add documentation for this, or should I leave this for the higher-level API?The loads and stores are tested with the default address space (generic), using global arrays. Should I add tests for the intrinsic versions with global and shared address spaces as well?