Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

Commit

Permalink
Implement wrappers for WMMA LLVM intrinsics
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasfaingnaert committed Nov 10, 2019
1 parent 93e5a76 commit faae545
Show file tree
Hide file tree
Showing 4 changed files with 383 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/device/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ include("cuda/assertion.jl")
include("cuda/memory_dynamic.jl")
include("cuda/atomics.jl")
include("cuda/misc.jl")
include("cuda/wmma.jl")

# functionality from libdevice
#
Expand Down
246 changes: 246 additions & 0 deletions src/device/cuda/wmma.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
################################################################################
# CONSTANTS
################################################################################

# Maps PTX types to LLVM types
map_ptx_to_llvm = Dict(
"f16" => "<2 x half>",
"f32" => "float"
)

# Maps PTX types to the LLVM type that llvmcall expects
map_ptx_to_llvmcall = Dict(
"f16" => "<2 x i16>",
"f32" => "float"
)

# Maps PTX types to Julia types
map_ptx_to_jl = Dict(
"f16" => NTuple{2, VecElement{Float16}},
"f32" => Float32
)

# Maps matrix & PTX types to fragment sizes
map_frag_sizes = Dict(
"a.f16" => 8,
"b.f16" => 8,
"c.f16" => 4,
"c.f32" => 8,
"d.f16" => 4,
"d.f32" => 8
)

################################################################################
# HELPER FUNCTIONS
################################################################################

macro gen_ir(template, count, delim="\n")
return quote
join([$(esc(template)) for $(esc(:i)) in 0:$(esc(count))-1], $(esc(delim)))
end
end

function join_nonempty(args...)
delim = args[end]
arr = [args[1:end-1]...]

return join(arr[arr .!= ""], delim)
end

get_llvm_ty(matrix, ptx_el_type) = map_ptx_to_llvm[ptx_el_type]

get_llvmcall_ty(matrix, ptx_el_type) = map_ptx_to_llvmcall[ptx_el_type]

get_jl_ty(matrix, ptx_el_type) = map_ptx_to_jl[ptx_el_type]

get_frag_sz(matrix, ptx_el_type) = map_frag_sizes["$matrix.$ptx_el_type"]

################################################################################
# LOW LEVEL API
################################################################################

# -----------
# Matrix load
# -----------

for mat in ["a", "b", "c"],
layout in ["col", "row"],
shape in ["m16n16k16"],
addr_space in ["", "shared", "global"],
stride in ["stride"],
elem_type in ["f16", "f32"]

# TODO: Non-stride versions?

# Float32 is only supported for C
if (elem_type == "f32") && (mat != "c")
continue
end

# Name of the Julia wrapper function
func_name = Symbol(join_nonempty("llvm", "wmma", "load", mat, layout, shape, addr_space, stride, elem_type, "_"))

# Name of the LLVM intrinsic
llvm_intr = join_nonempty("@llvm", "nvvm", "wmma", "load", mat, "sync", layout, shape, addr_space, stride, elem_type, ".")

# Determine types for this (matrix, elem_type) combination
sz = get_frag_sz(mat, elem_type)
llvm_ty = get_llvm_ty(mat, elem_type)
struct_ty = "{ $(@gen_ir(llvm_ty, sz, ", ")) }"
lc_ty = get_llvmcall_ty(mat, elem_type)
jl_ty = get_jl_ty(mat, elem_type)

# Generate LLVM IR
ir = ("declare $struct_ty $llvm_intr(i8*, i32)",
"
%src_ptr = inttoptr i64 %0 to i8*
%ret.llvm = call $struct_ty $llvm_intr(i8* %src_ptr, i32 %1)
$(@gen_ir("%ret.llvm.$i = extractvalue $struct_ty %ret.llvm, $i", sz))
$(@gen_ir("%ret.jl.$i = bitcast $llvm_ty %ret.llvm.$i to $lc_ty", sz))
$(@gen_ir("%ret.aggr.$i = insertvalue [$sz x $lc_ty] $(i == 0 ? "undef" : "%ret.aggr.$(i-1)"), $lc_ty %ret.jl.$i, $i", sz))
ret [$sz x $lc_ty] %ret.aggr.$(sz-1)
")

@eval $func_name(src_addr, stride) = Base.llvmcall($ir,
NTuple{$sz, $jl_ty},
Tuple{Int64, Int32},
convert(Int64, src_addr),
convert(Int32, stride))

@eval export $func_name
end

# ------------
# Matrix store
# ------------

for mat in ["d"],
layout in ["col", "row"],
shape in ["m16n16k16"],
addr_space in ["", "shared", "global"],
stride in ["stride"],
elem_type in ["f16", "f32"]

# TODO: Non-stride versions?

# Name of the Julia wrapper function
func_name = Symbol(join_nonempty("llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type, "_"))

# Name of the LLVM intrinsic
llvm_intr = join_nonempty("@llvm", "nvvm", "wmma", "store", mat, "sync", layout, shape, addr_space, stride, elem_type, ".")

# Determine types for this (matrix, elem_type) combination
sz = get_frag_sz(mat, elem_type)
llvm_ty = get_llvm_ty(mat, elem_type)
lc_ty = get_llvmcall_ty(mat, elem_type)
jl_ty = get_jl_ty(mat, elem_type)

# Generate LLVM IR
ir = ("declare void $llvm_intr(i8*, $(@gen_ir("$llvm_ty", sz, ", ")), i32)",
"
%dst_ptr = inttoptr i64 %0 to i8*
$(@gen_ir("%data.jl.$i = extractvalue [$sz x $lc_ty] %1, $i", sz))
$(@gen_ir("%data.llvm.$i = bitcast $lc_ty %data.jl.$i to $llvm_ty", sz))
call void $llvm_intr(i8* %dst_ptr, $(@gen_ir("$llvm_ty %data.llvm.$i", sz, ", ")) , i32 %2)
ret void
")

@eval $func_name(dst_addr, data, stride) = Base.llvmcall($ir,
Nothing,
Tuple{Int64, NTuple{$sz, $jl_ty}, Int32},
convert(Int64, dst_addr),
convert(NTuple{$sz, $jl_ty}, data),
convert(Int32, stride))

@eval export $func_name
end

# --------------------------
# Matrix multiply accumulate
# --------------------------

for a_layout in ["col", "row"],
b_layout in ["col", "row"],
shape in ["m16n16k16"],
d_elem_type in ["f16", "f32"],
c_elem_type in ["f16", "f32"],
b_elem_type in ["f16"],
a_elem_type in ["f16"]

# Name of the Julia wrapper function
func_name = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type, "_"))

# Name of the LLVM intrinsic
llvm_intr = join_nonempty("@llvm", "nvvm", "wmma", "mma", "sync", a_layout, b_layout, shape, d_elem_type, c_elem_type, ".")

# Determine types for the (matrix, elem_type) combinations for matrix A
a_sz = get_frag_sz("a", a_elem_type)
a_llvm_ty = get_llvm_ty("a", a_elem_type)
a_lc_ty = get_llvmcall_ty("a", a_elem_type)
a_jl_ty = get_jl_ty("a", a_elem_type)

# Determine types for the (matrix, elem_type) combinations for matrix B
b_sz = get_frag_sz("b", b_elem_type)
b_llvm_ty = get_llvm_ty("b", b_elem_type)
b_lc_ty = get_llvmcall_ty("b", b_elem_type)
b_jl_ty = get_jl_ty("b", b_elem_type)

# Determine types for the (matrix, elem_type) combinations for matrix C
c_sz = get_frag_sz("c", c_elem_type)
c_llvm_ty = get_llvm_ty("c", c_elem_type)
c_lc_ty = get_llvmcall_ty("c", c_elem_type)
c_jl_ty = get_jl_ty("c", c_elem_type)

# Determine types for the (matrix, elem_type) combinations for matrix D
d_sz = get_frag_sz("d", d_elem_type)
d_llvm_ty = get_llvm_ty("d", d_elem_type)
d_lc_ty = get_llvmcall_ty("d", d_elem_type)
d_jl_ty = get_jl_ty("d", d_elem_type)
d_struct_ty = "{ $(@gen_ir(d_llvm_ty, d_sz, ", ")) }"

# Create the argument string to the IR call
args = join([
@gen_ir("$a_llvm_ty %a.llvm.$i", a_sz, ", "),
@gen_ir("$b_llvm_ty %b.llvm.$i", b_sz, ", "),
@gen_ir("$c_llvm_ty %c.llvm.$i", c_sz, ", ")]
, ", ")

# Generate LLVM IR
ir = ("declare $d_struct_ty $llvm_intr($args)",
"
$(@gen_ir("%a.jl.$i = extractvalue [$a_sz x $a_lc_ty] %0, $i", a_sz))
$(@gen_ir("%b.jl.$i = extractvalue [$b_sz x $b_lc_ty] %1, $i", b_sz))
$(@gen_ir("%c.jl.$i = extractvalue [$c_sz x $c_lc_ty] %2, $i", c_sz))
$(@gen_ir("%a.llvm.$i = bitcast $a_lc_ty %a.jl.$i to $a_llvm_ty", a_sz))
$(@gen_ir("%b.llvm.$i = bitcast $b_lc_ty %b.jl.$i to $b_llvm_ty", b_sz))
$(@gen_ir("%c.llvm.$i = bitcast $c_lc_ty %c.jl.$i to $c_llvm_ty", c_sz))
%d.llvm = call $d_struct_ty $llvm_intr($args)
$(@gen_ir("%d.llvm.$i = extractvalue $d_struct_ty %d.llvm, $i", d_sz))
$(@gen_ir("%d.jl.$i = bitcast $d_llvm_ty %d.llvm.$i to $d_lc_ty", d_sz))
$(@gen_ir("%d.aggr.$i = insertvalue [$d_sz x $d_lc_ty] $(i == 0 ? "undef" : "%d.aggr.$(i-1)"), $d_lc_ty %d.jl.$i, $i", d_sz))
ret [$d_sz x $d_lc_ty] %d.aggr.$(d_sz-1)
")

@eval $func_name(a, b, c) = Base.llvmcall($ir,
NTuple{$d_sz, $d_jl_ty},
Tuple{NTuple{$a_sz, $a_jl_ty}, NTuple{$b_sz, $b_jl_ty}, NTuple{$c_sz, $c_jl_ty}},
convert(NTuple{$a_sz, $a_jl_ty}, a),
convert(NTuple{$b_sz, $b_jl_ty}, b),
convert(NTuple{$c_sz, $c_jl_ty}, c))

@eval export $func_name
end
135 changes: 135 additions & 0 deletions test/device/wmma.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
@testset "WMMA" begin

################################################################################

@testset "LLVM intrinsics" begin

@testset "llvm_wmma_load" begin
@testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["a", "b", "c"],
layout in ["row", "col"],
shape in ["m16n16k16"],
addr_space in [""],
stride in ["stride"],
elem_type in ["f16", "f32"]

# TODO: Test address space?

# Float32 is only supported for C
if (elem_type == "f32") && (mat != "c")
continue
end

# Type-dependent variables
array_ty = elem_type == "f16" ? Float16 : Float32
expected = elem_type == "f16" ? (VecElement{Float16}(42), VecElement{Float16}(42)) : Float32(42)

# Get the function name
func = getfield(Main, Symbol("llvm_wmma_load_$(mat)_$(layout)_$(shape)_stride_$(elem_type)"))

input = 42 * ones(array_ty, (16, 16))
input_dev = CuArray(input)
result = Array{Bool}(undef, 1)
result_dev = CuArray(result)

function kernel(input_dev, result_dev)
data = func(pointer(input_dev), 16)
result_dev[1] = all(val -> val == expected, data)
return
end

@cuda threads=32 kernel(input_dev, result_dev)
@test all(Array(result_dev))
end
end

@testset "llvm_wmma_store" begin
@testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["d"],
layout in ["row", "col"],
shape in ["m16n16k16"],
addr_space in [""],
stride in ["stride"],
elem_type in ["f16", "f32"]

# TODO: Test address space?

# Type-dependent variables
array_ty = elem_type == "f16" ? Float16 : Float32
data = elem_type == "f16" ?
(
(VecElement{Float16}(42), VecElement{Float16}(42)),
(VecElement{Float16}(42), VecElement{Float16}(42)),
(VecElement{Float16}(42), VecElement{Float16}(42)),
(VecElement{Float16}(42), VecElement{Float16}(42))
) : (42, 42, 42, 42, 42, 42, 42, 42)

# Get the function name
func = getfield(Main, Symbol("llvm_wmma_store_$(mat)_$(layout)_$(shape)_stride_$(elem_type)"))

output = Array{array_ty}(undef, (16, 16))
output_dev = CuArray(output)

function kernel(output_dev)
func(pointer(output_dev), data, 16)
return
end

@cuda threads=32 kernel(output_dev)
@test all(Array(output_dev) .== 42.0)
end
end

@testset "llvm_wmma_mma" begin
@testset "$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)" for a_layout in ["row", "col"],
b_layout in ["row", "col"],
shape in ["m16n16k16"],
d_elem_type in ["f16", "f32"],
c_elem_type in ["f16", "f32"]

# Type-dependent variables
d_ty = d_elem_type == "f16" ? Float16 : Float32
c_ty = c_elem_type == "f16" ? Float16 : Float32

# Get the function names
lda_func = getfield(Main, Symbol("llvm_wmma_load_a_$(a_layout)_m16n16k16_stride_f16"))
ldb_func = getfield(Main, Symbol("llvm_wmma_load_b_$(b_layout)_m16n16k16_stride_f16"))
ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_m16n16k16_stride_$(c_elem_type)"))
mma_func = getfield(Main, Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_m16n16k16_$(d_elem_type)_$(c_elem_type)"))
std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_m16n16k16_stride_$(d_elem_type)"))

# Generate input matrices
a = rand(Float16, (16, 16))
a_dev = CuArray(a)
b = rand(Float16, (16, 16))
b_dev = CuArray(b)
c = rand(c_ty, (16, 16))
c_dev = CuArray(c)

# Reserve space for result
d = Array{d_ty}(undef, (16, 16))
d_dev = CuArray(d)

# Matrix MAC kernel (D = A * B + C)
function kernel(a_dev, b_dev, c_dev, d_dev)
a_frag = lda_func(pointer(a_dev), 16)
b_frag = ldb_func(pointer(b_dev), 16)
c_frag = ldc_func(pointer(c_dev), 16)

d_frag = mma_func(a_frag, b_frag, c_frag)

std_func(pointer(d_dev), d_frag, 16)
return
end

@cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev)

new_a = (a_layout == "col" ? a : transpose(a))
new_b = (b_layout == "col" ? b : transpose(b))

@test new_a * new_b + c Array(d_dev) rtol=0.01
end
end
end

################################################################################

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ else
include("device/pointer.jl")
include("device/array.jl")
include("device/cuda.jl")
include("device/wmma.jl")

include("examples.jl")
end
Expand Down

0 comments on commit faae545

Please sign in to comment.