Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit

Permalink
Use a custom generator function for control on back-edges.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Feb 27, 2020
1 parent 079cb34 commit c39de57
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 104 deletions.
239 changes: 135 additions & 104 deletions src/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,6 @@ function assign_args!(code, args)
return vars, var_exprs
end

# fast lookup of global world age
world_age() = ccall(:jl_get_tls_world_age, UInt, ())

# slow lookup of local method age
function method_age(f, t)::UInt
for m in Base._methods(f, t, 1, typemax(UInt))
if VERSION >= v"1.2.0-DEV.573"
return m[3].primary_world
else
return m[3].min_world
end
end

tt = Base.to_tuple_type(t)
throw(MethodError(f, tt))
end


## high-level @cuda interface

Expand Down Expand Up @@ -334,8 +317,97 @@ end

## host-side API

const agecache = Dict{UInt, UInt}()
using Core.Compiler: retrieve_code_info, CodeInfo, MethodInstance, SSAValue, SlotNumber
using Base: _methods_by_ftype

# HACK: mechanism to generate calls that are not executed, but ensure method invalidation
const opaque_false = Ref(false)
function fake_call(f)
opaque_false[] || return
f(Ref{Any}()[]...)
end

# actual compilation
function cufunction_slow(f, tt, spec; name=nothing, kwargs...)
start = time_ns()

# generate a fake call to ensure we get recompiled upon method invalidation
fake_call(f)

# compile to PTX
ctx = context()
dev = device(ctx)
cap = supported_capability(dev)
asm, kernel_fn, undefined_fns =
compile(:ptx, cap, f, tt; name=name, strict=true, kwargs...)

# settings to JIT based on Julia's debug setting
jit_options = Dict{CUDAdrv.CUjit_option,Any}()
if Base.JLOptions().debug_level == 1
jit_options[CUDAdrv.JIT_GENERATE_LINE_INFO] = true
elseif Base.JLOptions().debug_level >= 2
jit_options[CUDAdrv.JIT_GENERATE_DEBUG_INFO] = true
end

# link the CUDA device library
image = asm
# linking the device runtime library requires use of the CUDA linker,
# which in turn switches compilation to device relocatable code (-rdc) mode.
#
# even if not doing any actual calls that need -rdc (i.e., calls to the runtime
# library), this significantly hurts performance, so don't do it unconditionally
intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail",
"__nvvm_reflect" #= TODO: should have been optimized away =#]
if !isempty(setdiff(undefined_fns, intrinsic_fns))
@timeit_debug to "device runtime library" begin
linker = CUDAdrv.CuLink(jit_options)
CUDAdrv.add_file!(linker, libcudadevrt[], CUDAdrv.JIT_INPUT_LIBRARY)
CUDAdrv.add_data!(linker, kernel_fn, asm)
image = CUDAdrv.complete(linker)
end
end

# JIT into an executable kernel object
mod = CuModule(image, jit_options)
fun = CuFunction(mod, kernel_fn)
kernel = HostKernel{f,tt}(ctx, mod, fun)

create_exceptions!(mod)

stop = time_ns()
@debug begin
ver = version(kernel)
mem = memory(kernel)
reg = registers(kernel)
fn = something(name, nameof(f))
"""Compiled $fn($(join(tt.parameters, ", "))) to PTX $(ver.ptx) for SM $(ver.binary) in $(round((time_ns() - start) / 1000000; digits=2)) ms.
Kernel uses $reg registers, and $(Base.format_bytes(mem.local)) local, $(Base.format_bytes(mem.shared)) shared, and $(Base.format_bytes(mem.constant)) constant memory."""
end

return kernel
end

# cached compilation
const compilecache = Dict{UInt, HostKernel}()
@inline function cufunction_fast(f, tt, spec; name=nothing, kwargs...)
# generate a key for indexing the compilation cache
ctx = context()
key = hash(spec)
key = hash(pointer_from_objref(ctx), key) # contexts are unique, but handles might alias
# TODO: implement this as a hash function in CUDAdrv
key = hash(name, key)
key = hash(kwargs, key)
for nf in 1:nfields(f)
# mix in the values of any captured variable
key = hash(getfield(f, nf), key)
end

return get!(compilecache, key) do
cufunction_slow(f, tt, spec; name=name, kwargs...)
end::HostKernel{f,tt}
end

specialization_counter = 0

"""
cufunction(f, tt=Tuple{}; kwargs...)
Expand All @@ -356,92 +428,51 @@ The output of this function is automatically cached, i.e. you can simply call `c
in a hot path without degrading performance. New code will be generated automatically, when
when function changes, or when different types or keyword arguments are provided.
"""
@generated function cufunction(f::Core.Function, tt::Type=Tuple{}; name=nothing, kwargs...)
tt = Base.to_tuple_type(tt.parameters[1])
sig = Base.signature_type(f, tt)
t = Tuple(tt.parameters)

precomp_key = hash(sig) # precomputable part of the keys
quote
Base.@_inline_meta

# look-up the method age
key = hash(world_age(), $precomp_key)
if haskey(agecache, key)
age = agecache[key]
else
age = method_age(f, $t)
agecache[key] = age
end

# generate a key for indexing the compilation cache
ctx = context()
key = hash(age, $precomp_key)
key = hash(pointer_from_objref(ctx), key) # contexts are unique, but handles might alias
key = hash(name, key)
key = hash(kwargs, key)
for nf in 1:nfields(f)
# mix in the values of any captured variable
key = hash(getfield(f, nf), key)
end

# compile the function
if !haskey(compilecache, key)
start = time_ns()

# compile to PTX
dev = device(ctx)
cap = supported_capability(dev)
asm, kernel_fn, undefined_fns =
compile(:ptx, cap, f, tt; name=name, strict=true, kwargs...)

# settings to JIT based on Julia's debug setting
jit_options = Dict{CUDAdrv.CUjit_option,Any}()
if Base.JLOptions().debug_level == 1
jit_options[CUDAdrv.JIT_GENERATE_LINE_INFO] = true
elseif Base.JLOptions().debug_level >= 2
jit_options[CUDAdrv.JIT_GENERATE_DEBUG_INFO] = true
end

# link the CUDA device library
image = asm
# linking the device runtime library requires use of the CUDA linker,
# which in turn switches compilation to device relocatable code (-rdc) mode.
#
# even if not doing any actual calls that need -rdc (i.e., calls to the runtime
# library), this significantly hurts performance, so don't do it unconditionally
intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail",
"__nvvm_reflect" #= TODO: should have been optimized away =#]
if !isempty(setdiff(undefined_fns, intrinsic_fns))
@timeit_debug to "device runtime library" begin
linker = CUDAdrv.CuLink(jit_options)
CUDAdrv.add_file!(linker, libcudadevrt[], CUDAdrv.JIT_INPUT_LIBRARY)
CUDAdrv.add_data!(linker, kernel_fn, asm)
image = CUDAdrv.complete(linker)
end
end

# JIT into an executable kernel object
mod = CuModule(image, jit_options)
fun = CuFunction(mod, kernel_fn)
kernel = HostKernel{f,tt}(ctx, mod, fun)

create_exceptions!(mod)

compilecache[key] = kernel
stop = time_ns()
@debug begin
ver = version(kernel)
mem = memory(kernel)
reg = registers(kernel)
fn = something(name, nameof(f))
"""Compiled $fn($(join(tt.parameters, ", "))) to PTX $(ver.ptx) for SM $(ver.binary) in $(round((time_ns() - start) / 1000000; digits=2)) ms.
Kernel uses $reg registers, and $(Base.format_bytes(mem.local)) local, $(Base.format_bytes(mem.shared)) shared, and $(Base.format_bytes(mem.constant)) constant memory."""
end
end

return compilecache[key]::HostKernel{f,tt}
end
@generated function cufunction(f::Core.Function, tt::Type=Tuple{}; kwargs...)
# generated function that crafts a custom code info to call the actual cufunction impl.
# this gives us the flexibility to insert manual back edges for automatic recompilation.
tt = tt.parameters[1]

# get a hold of the method and code info of the kernel function
sig = Tuple{f, tt.parameters...}
mthds = _methods_by_ftype(sig, -1, typemax(UInt))
length(mthds) == 1 || return (:(throw(MethodError(f,tt))))
mtypes, msp, m = mthds[1]
mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp)
mi.def.isva && return :(error("varargs kernel methods are not supported"))
ci = retrieve_code_info(mi)
@assert isa(ci, CodeInfo)

# generate a unique id to represent this specialization
global specialization_counter
id = (specialization_counter += 1)
# TODO: save the mi/ci here (or embed it in the AST to pass to cufunction)
# and use that to drive compilation

# prepare a new code info
new_ci = copy(ci)
empty!(new_ci.code)
empty!(new_ci.codelocs)
empty!(new_ci.linetable)
empty!(new_ci.ssaflags)
new_ci.edges = MethodInstance[mi]

# prepare the slots
new_ci.slotnames = Symbol[:kwfunc, :kwargs, Symbol("#self#"), :f, :tt]
new_ci.slotflags = UInt8[0x00 for i = 1:5]
kwargs = SlotNumber(2)
f = SlotNumber(4)
tt = SlotNumber(5)

# call the compiler
append!(new_ci.code, [Expr(:call, Core.kwfunc, cufunction_fast),
Expr(:call, merge, NamedTuple(), kwargs),
Expr(:call, SSAValue(1), SSAValue(2), cufunction_fast, f, tt, id),
Expr(:return, SSAValue(3))])
append!(new_ci.codelocs, [0, 0, 0, 0, 0])
new_ci.ssavaluetypes += 5

return new_ci
end

# https://github.com/JuliaLang/julia/issues/14919
Expand Down
20 changes: 20 additions & 0 deletions test/device/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,26 @@ end
end


@testset "automatic recompilation (bis)" begin
arr = CuArray(zeros(Int))

@eval doit(ptr) = unsafe_store!(ptr, 1)

function kernel(ptr)
doit(ptr)
return
end

@cuda kernel(pointer(arr))
@test Array(arr)[] == 1

@eval doit(ptr) = unsafe_store!(ptr, 2)

@cuda kernel(pointer(arr))
@test Array(arr)[] == 2
end


@testset "non-isbits arguments" begin
function kernel1(T, i)
sink(i)
Expand Down

0 comments on commit c39de57

Please sign in to comment.