Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

Fix method invalidation #581

Merged
merged 3 commits into from
Mar 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 132 additions & 110 deletions src/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,6 @@ function assign_args!(code, args)
return vars, var_exprs
end

# fast lookup of global world age
world_age() = ccall(:jl_get_tls_world_age, UInt, ())

# slow lookup of local method age
function method_age(f, t)::UInt
for m in Base._methods(f, t, 1, typemax(UInt))
if VERSION >= v"1.2.0-DEV.573"
return m[3].primary_world
else
return m[3].min_world
end
end

tt = Base.to_tuple_type(t)
throw(MethodError(f, tt))
end


## high-level @cuda interface

Expand Down Expand Up @@ -120,16 +103,11 @@ A device-side launch, aka. dynamic parallelism, is similar but more restricted:
"""
macro cuda(ex...)
# destructure the `@cuda` expression
if length(ex) > 0 && ex[1].head == :tuple
error("The tuple argument to @cuda has been replaced by keywords: `@cuda threads=... fun(args...)`")
end
call = ex[end]
kwargs = ex[1:end-1]

# destructure the kernel call
if call.head != :call
throw(ArgumentError("second argument to @cuda should be a function call"))
end
Meta.isexpr(call, :call) || throw(ArgumentError("second argument to @cuda should be a function call"))
f = call.args[1]
args = call.args[2:end]

Expand Down Expand Up @@ -334,8 +312,87 @@ end

## host-side API

const agecache = Dict{UInt, UInt}()
using Core.Compiler: retrieve_code_info, CodeInfo, MethodInstance, SSAValue, SlotNumber
using Base: _methods_by_ftype

# actual compilation
function cufunction_slow(f, tt, spec; name=nothing, kwargs...)
start = time_ns()

# compile to PTX
ctx = context()
dev = device(ctx)
cap = supported_capability(dev)
asm, kernel_fn, undefined_fns =
compile(:ptx, cap, f, tt; name=name, strict=true, kwargs...)

# settings to JIT based on Julia's debug setting
jit_options = Dict{CUDAdrv.CUjit_option,Any}()
if Base.JLOptions().debug_level == 1
jit_options[CUDAdrv.JIT_GENERATE_LINE_INFO] = true
elseif Base.JLOptions().debug_level >= 2
jit_options[CUDAdrv.JIT_GENERATE_DEBUG_INFO] = true
end

# link the CUDA device library
image = asm
# linking the device runtime library requires use of the CUDA linker,
# which in turn switches compilation to device relocatable code (-rdc) mode.
#
# even if not doing any actual calls that need -rdc (i.e., calls to the runtime
# library), this significantly hurts performance, so don't do it unconditionally
intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail",
"__nvvm_reflect" #= TODO: should have been optimized away =#]
if !isempty(setdiff(undefined_fns, intrinsic_fns))
@timeit_debug to "device runtime library" begin
linker = CUDAdrv.CuLink(jit_options)
CUDAdrv.add_file!(linker, libcudadevrt[], CUDAdrv.JIT_INPUT_LIBRARY)
CUDAdrv.add_data!(linker, kernel_fn, asm)
image = CUDAdrv.complete(linker)
end
end

# JIT into an executable kernel object
mod = CuModule(image, jit_options)
fun = CuFunction(mod, kernel_fn)
kernel = HostKernel{f,tt}(ctx, mod, fun)

create_exceptions!(mod)

stop = time_ns()
@debug begin
ver = version(kernel)
mem = memory(kernel)
reg = registers(kernel)
fn = something(name, nameof(f))
"""Compiled $fn($(join(tt.parameters, ", "))) to PTX $(ver.ptx) for SM $(ver.binary) in $(round((time_ns() - start) / 1000000; digits=2)) ms.
Kernel uses $reg registers, and $(Base.format_bytes(mem.local)) local, $(Base.format_bytes(mem.shared)) shared, and $(Base.format_bytes(mem.constant)) constant memory."""
end

return kernel
end

# cached compilation
const compilecache = Dict{UInt, HostKernel}()
@inline function cufunction_fast(f, tt, spec; name=nothing, kwargs...)
# generate a key for indexing the compilation cache
ctx = context()
key = hash(spec)
key = hash(pointer_from_objref(ctx), key) # contexts are unique, but handles might alias
# TODO: implement this as a hash function in CUDAdrv
key = hash(name, key)
key = hash(kwargs, key)
for nf in 1:nfields(f)
# mix in the values of any captured variable
key = hash(getfield(f, nf), key)
end

return get!(compilecache, key) do
cufunction_slow(f, tt, spec; name=name, kwargs...)
end::HostKernel{f,tt}
end

specialization_counter = 0

"""
cufunction(f, tt=Tuple{}; kwargs...)
Expand All @@ -356,92 +413,57 @@ The output of this function is automatically cached, i.e. you can simply call `c
in a hot path without degrading performance. New code will be generated automatically, when
when function changes, or when different types or keyword arguments are provided.
"""
@generated function cufunction(f::Core.Function, tt::Type=Tuple{}; name=nothing, kwargs...)
tt = Base.to_tuple_type(tt.parameters[1])
sig = Base.signature_type(f, tt)
t = Tuple(tt.parameters)

precomp_key = hash(sig) # precomputable part of the keys
quote
Base.@_inline_meta

# look-up the method age
key = hash(world_age(), $precomp_key)
if haskey(agecache, key)
age = agecache[key]
else
age = method_age(f, $t)
agecache[key] = age
end

# generate a key for indexing the compilation cache
ctx = context()
key = hash(age, $precomp_key)
key = hash(pointer_from_objref(ctx), key) # contexts are unique, but handles might alias
key = hash(name, key)
key = hash(kwargs, key)
for nf in 1:nfields(f)
# mix in the values of any captured variable
key = hash(getfield(f, nf), key)
end

# compile the function
if !haskey(compilecache, key)
start = time_ns()

# compile to PTX
dev = device(ctx)
cap = supported_capability(dev)
asm, kernel_fn, undefined_fns =
compile(:ptx, cap, f, tt; name=name, strict=true, kwargs...)

# settings to JIT based on Julia's debug setting
jit_options = Dict{CUDAdrv.CUjit_option,Any}()
if Base.JLOptions().debug_level == 1
jit_options[CUDAdrv.JIT_GENERATE_LINE_INFO] = true
elseif Base.JLOptions().debug_level >= 2
jit_options[CUDAdrv.JIT_GENERATE_DEBUG_INFO] = true
end

# link the CUDA device library
image = asm
# linking the device runtime library requires use of the CUDA linker,
# which in turn switches compilation to device relocatable code (-rdc) mode.
#
# even if not doing any actual calls that need -rdc (i.e., calls to the runtime
# library), this significantly hurts performance, so don't do it unconditionally
intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail",
"__nvvm_reflect" #= TODO: should have been optimized away =#]
if !isempty(setdiff(undefined_fns, intrinsic_fns))
@timeit_debug to "device runtime library" begin
linker = CUDAdrv.CuLink(jit_options)
CUDAdrv.add_file!(linker, libcudadevrt[], CUDAdrv.JIT_INPUT_LIBRARY)
CUDAdrv.add_data!(linker, kernel_fn, asm)
image = CUDAdrv.complete(linker)
end
end

# JIT into an executable kernel object
mod = CuModule(image, jit_options)
fun = CuFunction(mod, kernel_fn)
kernel = HostKernel{f,tt}(ctx, mod, fun)

create_exceptions!(mod)

compilecache[key] = kernel
stop = time_ns()
@debug begin
ver = version(kernel)
mem = memory(kernel)
reg = registers(kernel)
fn = something(name, nameof(f))
"""Compiled $fn($(join(tt.parameters, ", "))) to PTX $(ver.ptx) for SM $(ver.binary) in $(round((time_ns() - start) / 1000000; digits=2)) ms.
Kernel uses $reg registers, and $(Base.format_bytes(mem.local)) local, $(Base.format_bytes(mem.shared)) shared, and $(Base.format_bytes(mem.constant)) constant memory."""
end
end

return compilecache[key]::HostKernel{f,tt}
end
@generated function cufunction(f::Core.Function, tt::Type=Tuple{}; kwargs...)
# generated function that crafts a custom code info to call the actual cufunction impl.
# this gives us the flexibility to insert manual back edges for automatic recompilation.
tt = tt.parameters[1]

# get a hold of the method and code info of the kernel function
sig = Tuple{f, tt.parameters...}
maleadt marked this conversation as resolved.
Show resolved Hide resolved
mthds = _methods_by_ftype(sig, -1, typemax(UInt))
Base.isdispatchtuple(tt) || return(:(error("$tt is not a dispatch tuple")))
length(mthds) == 1 || return (:(throw(MethodError(f,tt))))
maleadt marked this conversation as resolved.
Show resolved Hide resolved
mtypes, msp, m = mthds[1]
mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp)
mi.def.isva && return :(error("varargs kernel methods are not supported"))
ci = retrieve_code_info(mi)
@assert isa(ci, CodeInfo)

# generate a unique id to represent this specialization
global specialization_counter
maleadt marked this conversation as resolved.
Show resolved Hide resolved
id = (specialization_counter += 1)
# TODO: save the mi/ci here (or embed it in the AST to pass to cufunction)
maleadt marked this conversation as resolved.
Show resolved Hide resolved
# and use that to drive compilation

# prepare a new code info
new_ci = copy(ci)
empty!(new_ci.code)
empty!(new_ci.codelocs)
resize!(new_ci.linetable, 1) # codegen assumes at least one entry
empty!(new_ci.ssaflags)
new_ci.ssavaluetypes = 0
new_ci.edges = MethodInstance[mi]
# XXX: setting this edge does not give us proper method invalidation, see
# JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
# invoking `code_llvm` also does the necessary codegen, as does calling the
# underlying C methods -- which CUDAnative does, so everything Just Works.

# prepare the slots
new_ci.slotnames = Symbol[:kwfunc, :kwargs, Symbol("#self#"), :f, :tt]
new_ci.slotflags = UInt8[0x00 for i = 1:5]
kwargs = SlotNumber(2)
f = SlotNumber(4)
tt = SlotNumber(5)

# call the compiler
append!(new_ci.code, [Expr(:call, Core.kwfunc, cufunction_fast),
Expr(:call, merge, NamedTuple(), kwargs),
Expr(:call, SSAValue(1), SSAValue(2), cufunction_fast, f, tt, id),
Expr(:return, SSAValue(3))])
append!(new_ci.codelocs, [0, 0, 0, 0])
new_ci.ssavaluetypes += 4

return new_ci
end

# https://github.com/JuliaLang/julia/issues/14919
Expand Down
20 changes: 20 additions & 0 deletions test/device/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,26 @@ end
end


@testset "automatic recompilation (bis)" begin
arr = CuArray(zeros(Int))

@eval doit(ptr) = unsafe_store!(ptr, 1)

function kernel(ptr)
doit(ptr)
return
end

@cuda kernel(pointer(arr))
@test Array(arr)[] == 1

@eval doit(ptr) = unsafe_store!(ptr, 2)

@cuda kernel(pointer(arr))
@test Array(arr)[] == 2
end


@testset "non-isbits arguments" begin
function kernel1(T, i)
sink(i)
Expand Down