This repository has been archived by the owner on May 27, 2021. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement wrappers for WMMA LLVM intrinsics
- Loading branch information
1 parent
93e5a76
commit faae545
Showing
4 changed files
with
383 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
################################################################################ | ||
# CONSTANTS | ||
################################################################################ | ||
|
||
# Maps PTX types to LLVM types | ||
map_ptx_to_llvm = Dict( | ||
"f16" => "<2 x half>", | ||
"f32" => "float" | ||
) | ||
|
||
# Maps PTX types to the LLVM type that llvmcall expects | ||
map_ptx_to_llvmcall = Dict( | ||
"f16" => "<2 x i16>", | ||
"f32" => "float" | ||
) | ||
|
||
# Maps PTX types to Julia types | ||
map_ptx_to_jl = Dict( | ||
"f16" => NTuple{2, VecElement{Float16}}, | ||
"f32" => Float32 | ||
) | ||
|
||
# Maps matrix & PTX types to fragment sizes | ||
map_frag_sizes = Dict( | ||
"a.f16" => 8, | ||
"b.f16" => 8, | ||
"c.f16" => 4, | ||
"c.f32" => 8, | ||
"d.f16" => 4, | ||
"d.f32" => 8 | ||
) | ||
|
||
################################################################################ | ||
# HELPER FUNCTIONS | ||
################################################################################ | ||
|
||
macro gen_ir(template, count, delim="\n") | ||
return quote | ||
join([$(esc(template)) for $(esc(:i)) in 0:$(esc(count))-1], $(esc(delim))) | ||
end | ||
end | ||
|
||
function join_nonempty(args...) | ||
delim = args[end] | ||
arr = [args[1:end-1]...] | ||
|
||
return join(arr[arr .!= ""], delim) | ||
end | ||
|
||
get_llvm_ty(matrix, ptx_el_type) = map_ptx_to_llvm[ptx_el_type] | ||
|
||
get_llvmcall_ty(matrix, ptx_el_type) = map_ptx_to_llvmcall[ptx_el_type] | ||
|
||
get_jl_ty(matrix, ptx_el_type) = map_ptx_to_jl[ptx_el_type] | ||
|
||
get_frag_sz(matrix, ptx_el_type) = map_frag_sizes["$matrix.$ptx_el_type"] | ||
|
||
################################################################################ | ||
# LOW LEVEL API | ||
################################################################################ | ||
|
||
# ----------- | ||
# Matrix load | ||
# ----------- | ||
|
||
for mat in ["a", "b", "c"], | ||
layout in ["col", "row"], | ||
shape in ["m16n16k16"], | ||
addr_space in ["", "shared", "global"], | ||
stride in ["stride"], | ||
elem_type in ["f16", "f32"] | ||
|
||
# TODO: Non-stride versions? | ||
|
||
# Float32 is only supported for C | ||
if (elem_type == "f32") && (mat != "c") | ||
continue | ||
end | ||
|
||
# Name of the Julia wrapper function | ||
func_name = Symbol(join_nonempty("llvm", "wmma", "load", mat, layout, shape, addr_space, stride, elem_type, "_")) | ||
|
||
# Name of the LLVM intrinsic | ||
llvm_intr = join_nonempty("@llvm", "nvvm", "wmma", "load", mat, "sync", layout, shape, addr_space, stride, elem_type, ".") | ||
|
||
# Determine types for this (matrix, elem_type) combination | ||
sz = get_frag_sz(mat, elem_type) | ||
llvm_ty = get_llvm_ty(mat, elem_type) | ||
struct_ty = "{ $(@gen_ir(llvm_ty, sz, ", ")) }" | ||
lc_ty = get_llvmcall_ty(mat, elem_type) | ||
jl_ty = get_jl_ty(mat, elem_type) | ||
|
||
# Generate LLVM IR | ||
ir = ("declare $struct_ty $llvm_intr(i8*, i32)", | ||
" | ||
%src_ptr = inttoptr i64 %0 to i8* | ||
%ret.llvm = call $struct_ty $llvm_intr(i8* %src_ptr, i32 %1) | ||
$(@gen_ir("%ret.llvm.$i = extractvalue $struct_ty %ret.llvm, $i", sz)) | ||
$(@gen_ir("%ret.jl.$i = bitcast $llvm_ty %ret.llvm.$i to $lc_ty", sz)) | ||
$(@gen_ir("%ret.aggr.$i = insertvalue [$sz x $lc_ty] $(i == 0 ? "undef" : "%ret.aggr.$(i-1)"), $lc_ty %ret.jl.$i, $i", sz)) | ||
ret [$sz x $lc_ty] %ret.aggr.$(sz-1) | ||
") | ||
|
||
@eval $func_name(src_addr, stride) = Base.llvmcall($ir, | ||
NTuple{$sz, $jl_ty}, | ||
Tuple{Int64, Int32}, | ||
convert(Int64, src_addr), | ||
convert(Int32, stride)) | ||
|
||
@eval export $func_name | ||
end | ||
|
||
# ------------ | ||
# Matrix store | ||
# ------------ | ||
|
||
for mat in ["d"], | ||
layout in ["col", "row"], | ||
shape in ["m16n16k16"], | ||
addr_space in ["", "shared", "global"], | ||
stride in ["stride"], | ||
elem_type in ["f16", "f32"] | ||
|
||
# TODO: Non-stride versions? | ||
|
||
# Name of the Julia wrapper function | ||
func_name = Symbol(join_nonempty("llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type, "_")) | ||
|
||
# Name of the LLVM intrinsic | ||
llvm_intr = join_nonempty("@llvm", "nvvm", "wmma", "store", mat, "sync", layout, shape, addr_space, stride, elem_type, ".") | ||
|
||
# Determine types for this (matrix, elem_type) combination | ||
sz = get_frag_sz(mat, elem_type) | ||
llvm_ty = get_llvm_ty(mat, elem_type) | ||
lc_ty = get_llvmcall_ty(mat, elem_type) | ||
jl_ty = get_jl_ty(mat, elem_type) | ||
|
||
# Generate LLVM IR | ||
ir = ("declare void $llvm_intr(i8*, $(@gen_ir("$llvm_ty", sz, ", ")), i32)", | ||
" | ||
%dst_ptr = inttoptr i64 %0 to i8* | ||
$(@gen_ir("%data.jl.$i = extractvalue [$sz x $lc_ty] %1, $i", sz)) | ||
$(@gen_ir("%data.llvm.$i = bitcast $lc_ty %data.jl.$i to $llvm_ty", sz)) | ||
call void $llvm_intr(i8* %dst_ptr, $(@gen_ir("$llvm_ty %data.llvm.$i", sz, ", ")) , i32 %2) | ||
ret void | ||
") | ||
|
||
@eval $func_name(dst_addr, data, stride) = Base.llvmcall($ir, | ||
Nothing, | ||
Tuple{Int64, NTuple{$sz, $jl_ty}, Int32}, | ||
convert(Int64, dst_addr), | ||
convert(NTuple{$sz, $jl_ty}, data), | ||
convert(Int32, stride)) | ||
|
||
@eval export $func_name | ||
end | ||
|
||
# -------------------------- | ||
# Matrix multiply accumulate | ||
# -------------------------- | ||
|
||
for a_layout in ["col", "row"], | ||
b_layout in ["col", "row"], | ||
shape in ["m16n16k16"], | ||
d_elem_type in ["f16", "f32"], | ||
c_elem_type in ["f16", "f32"], | ||
b_elem_type in ["f16"], | ||
a_elem_type in ["f16"] | ||
|
||
# Name of the Julia wrapper function | ||
func_name = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type, "_")) | ||
|
||
# Name of the LLVM intrinsic | ||
llvm_intr = join_nonempty("@llvm", "nvvm", "wmma", "mma", "sync", a_layout, b_layout, shape, d_elem_type, c_elem_type, ".") | ||
|
||
# Determine types for the (matrix, elem_type) combinations for matrix A | ||
a_sz = get_frag_sz("a", a_elem_type) | ||
a_llvm_ty = get_llvm_ty("a", a_elem_type) | ||
a_lc_ty = get_llvmcall_ty("a", a_elem_type) | ||
a_jl_ty = get_jl_ty("a", a_elem_type) | ||
|
||
# Determine types for the (matrix, elem_type) combinations for matrix B | ||
b_sz = get_frag_sz("b", b_elem_type) | ||
b_llvm_ty = get_llvm_ty("b", b_elem_type) | ||
b_lc_ty = get_llvmcall_ty("b", b_elem_type) | ||
b_jl_ty = get_jl_ty("b", b_elem_type) | ||
|
||
# Determine types for the (matrix, elem_type) combinations for matrix C | ||
c_sz = get_frag_sz("c", c_elem_type) | ||
c_llvm_ty = get_llvm_ty("c", c_elem_type) | ||
c_lc_ty = get_llvmcall_ty("c", c_elem_type) | ||
c_jl_ty = get_jl_ty("c", c_elem_type) | ||
|
||
# Determine types for the (matrix, elem_type) combinations for matrix D | ||
d_sz = get_frag_sz("d", d_elem_type) | ||
d_llvm_ty = get_llvm_ty("d", d_elem_type) | ||
d_lc_ty = get_llvmcall_ty("d", d_elem_type) | ||
d_jl_ty = get_jl_ty("d", d_elem_type) | ||
d_struct_ty = "{ $(@gen_ir(d_llvm_ty, d_sz, ", ")) }" | ||
|
||
# Create the argument string to the IR call | ||
args = join([ | ||
@gen_ir("$a_llvm_ty %a.llvm.$i", a_sz, ", "), | ||
@gen_ir("$b_llvm_ty %b.llvm.$i", b_sz, ", "), | ||
@gen_ir("$c_llvm_ty %c.llvm.$i", c_sz, ", ")] | ||
, ", ") | ||
|
||
# Generate LLVM IR | ||
ir = ("declare $d_struct_ty $llvm_intr($args)", | ||
" | ||
$(@gen_ir("%a.jl.$i = extractvalue [$a_sz x $a_lc_ty] %0, $i", a_sz)) | ||
$(@gen_ir("%b.jl.$i = extractvalue [$b_sz x $b_lc_ty] %1, $i", b_sz)) | ||
$(@gen_ir("%c.jl.$i = extractvalue [$c_sz x $c_lc_ty] %2, $i", c_sz)) | ||
$(@gen_ir("%a.llvm.$i = bitcast $a_lc_ty %a.jl.$i to $a_llvm_ty", a_sz)) | ||
$(@gen_ir("%b.llvm.$i = bitcast $b_lc_ty %b.jl.$i to $b_llvm_ty", b_sz)) | ||
$(@gen_ir("%c.llvm.$i = bitcast $c_lc_ty %c.jl.$i to $c_llvm_ty", c_sz)) | ||
%d.llvm = call $d_struct_ty $llvm_intr($args) | ||
$(@gen_ir("%d.llvm.$i = extractvalue $d_struct_ty %d.llvm, $i", d_sz)) | ||
$(@gen_ir("%d.jl.$i = bitcast $d_llvm_ty %d.llvm.$i to $d_lc_ty", d_sz)) | ||
$(@gen_ir("%d.aggr.$i = insertvalue [$d_sz x $d_lc_ty] $(i == 0 ? "undef" : "%d.aggr.$(i-1)"), $d_lc_ty %d.jl.$i, $i", d_sz)) | ||
ret [$d_sz x $d_lc_ty] %d.aggr.$(d_sz-1) | ||
") | ||
|
||
@eval $func_name(a, b, c) = Base.llvmcall($ir, | ||
NTuple{$d_sz, $d_jl_ty}, | ||
Tuple{NTuple{$a_sz, $a_jl_ty}, NTuple{$b_sz, $b_jl_ty}, NTuple{$c_sz, $c_jl_ty}}, | ||
convert(NTuple{$a_sz, $a_jl_ty}, a), | ||
convert(NTuple{$b_sz, $b_jl_ty}, b), | ||
convert(NTuple{$c_sz, $c_jl_ty}, c)) | ||
|
||
@eval export $func_name | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
@testset "WMMA" begin | ||
|
||
################################################################################ | ||
|
||
@testset "LLVM intrinsics" begin | ||
|
||
@testset "llvm_wmma_load" begin | ||
@testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["a", "b", "c"], | ||
layout in ["row", "col"], | ||
shape in ["m16n16k16"], | ||
addr_space in [""], | ||
stride in ["stride"], | ||
elem_type in ["f16", "f32"] | ||
|
||
# TODO: Test address space? | ||
|
||
# Float32 is only supported for C | ||
if (elem_type == "f32") && (mat != "c") | ||
continue | ||
end | ||
|
||
# Type-dependent variables | ||
array_ty = elem_type == "f16" ? Float16 : Float32 | ||
expected = elem_type == "f16" ? (VecElement{Float16}(42), VecElement{Float16}(42)) : Float32(42) | ||
|
||
# Get the function name | ||
func = getfield(Main, Symbol("llvm_wmma_load_$(mat)_$(layout)_$(shape)_stride_$(elem_type)")) | ||
|
||
input = 42 * ones(array_ty, (16, 16)) | ||
input_dev = CuArray(input) | ||
result = Array{Bool}(undef, 1) | ||
result_dev = CuArray(result) | ||
|
||
function kernel(input_dev, result_dev) | ||
data = func(pointer(input_dev), 16) | ||
result_dev[1] = all(val -> val == expected, data) | ||
return | ||
end | ||
|
||
@cuda threads=32 kernel(input_dev, result_dev) | ||
@test all(Array(result_dev)) | ||
end | ||
end | ||
|
||
@testset "llvm_wmma_store" begin | ||
@testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["d"], | ||
layout in ["row", "col"], | ||
shape in ["m16n16k16"], | ||
addr_space in [""], | ||
stride in ["stride"], | ||
elem_type in ["f16", "f32"] | ||
|
||
# TODO: Test address space? | ||
|
||
# Type-dependent variables | ||
array_ty = elem_type == "f16" ? Float16 : Float32 | ||
data = elem_type == "f16" ? | ||
( | ||
(VecElement{Float16}(42), VecElement{Float16}(42)), | ||
(VecElement{Float16}(42), VecElement{Float16}(42)), | ||
(VecElement{Float16}(42), VecElement{Float16}(42)), | ||
(VecElement{Float16}(42), VecElement{Float16}(42)) | ||
) : (42, 42, 42, 42, 42, 42, 42, 42) | ||
|
||
# Get the function name | ||
func = getfield(Main, Symbol("llvm_wmma_store_$(mat)_$(layout)_$(shape)_stride_$(elem_type)")) | ||
|
||
output = Array{array_ty}(undef, (16, 16)) | ||
output_dev = CuArray(output) | ||
|
||
function kernel(output_dev) | ||
func(pointer(output_dev), data, 16) | ||
return | ||
end | ||
|
||
@cuda threads=32 kernel(output_dev) | ||
@test all(Array(output_dev) .== 42.0) | ||
end | ||
end | ||
|
||
@testset "llvm_wmma_mma" begin | ||
@testset "$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)" for a_layout in ["row", "col"], | ||
b_layout in ["row", "col"], | ||
shape in ["m16n16k16"], | ||
d_elem_type in ["f16", "f32"], | ||
c_elem_type in ["f16", "f32"] | ||
|
||
# Type-dependent variables | ||
d_ty = d_elem_type == "f16" ? Float16 : Float32 | ||
c_ty = c_elem_type == "f16" ? Float16 : Float32 | ||
|
||
# Get the function names | ||
lda_func = getfield(Main, Symbol("llvm_wmma_load_a_$(a_layout)_m16n16k16_stride_f16")) | ||
ldb_func = getfield(Main, Symbol("llvm_wmma_load_b_$(b_layout)_m16n16k16_stride_f16")) | ||
ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_m16n16k16_stride_$(c_elem_type)")) | ||
mma_func = getfield(Main, Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_m16n16k16_$(d_elem_type)_$(c_elem_type)")) | ||
std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_m16n16k16_stride_$(d_elem_type)")) | ||
|
||
# Generate input matrices | ||
a = rand(Float16, (16, 16)) | ||
a_dev = CuArray(a) | ||
b = rand(Float16, (16, 16)) | ||
b_dev = CuArray(b) | ||
c = rand(c_ty, (16, 16)) | ||
c_dev = CuArray(c) | ||
|
||
# Reserve space for result | ||
d = Array{d_ty}(undef, (16, 16)) | ||
d_dev = CuArray(d) | ||
|
||
# Matrix MAC kernel (D = A * B + C) | ||
function kernel(a_dev, b_dev, c_dev, d_dev) | ||
a_frag = lda_func(pointer(a_dev), 16) | ||
b_frag = ldb_func(pointer(b_dev), 16) | ||
c_frag = ldc_func(pointer(c_dev), 16) | ||
|
||
d_frag = mma_func(a_frag, b_frag, c_frag) | ||
|
||
std_func(pointer(d_dev), d_frag, 16) | ||
return | ||
end | ||
|
||
@cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) | ||
|
||
new_a = (a_layout == "col" ? a : transpose(a)) | ||
new_b = (b_layout == "col" ? b : transpose(b)) | ||
|
||
@test new_a * new_b + c ≈ Array(d_dev) rtol=0.01 | ||
end | ||
end | ||
end | ||
|
||
################################################################################ | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters