-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
StaticArrays-related regression in 1.9 #366
Comments
Actually no, the above only reproduced in my running session, with a clean session it requires using the CUDA.jl compiler target: function main()
source = FunctionSpec(typeof(f), Tuple{Int, Int})
target = CUDA.CUDACompilerTarget(device())
params = CUDA.CUDACompilerParams()
job = CompilerJob(target, source, params)
JuliaContext() do ctx
ir, ir_meta = GPUCompiler.compile(:llvm, job; ctx, strip=true)
println(ir)
end
end So yeah, that kind of bug... |
Apparently not a cache bug -- it must have been primed with a bad result in my active session, as CUDA.jl uses the same CI cache as the PTXCompilerTarget. Rather, the issue seems to be with an
Disabling these fixes the issue. I'm not sure it's very important we use these, but it does seem to signify a deeper issue... |
Without CUDA.jl: using GPUCompiler
using StaticArrays
module TestRuntime
# dummy methods
signal_exception() = return
malloc(sz) = C_NULL
report_oom(sz) = return
report_exception(ex) = return
report_exception_name(ex) = return
report_exception_frame(idx, func, file, line) = return
end
struct TestCompilerParams <: AbstractCompilerParams end
GPUCompiler.runtime_module(::CompilerJob{<:Any,TestCompilerParams}) = TestRuntime
function f(width, height)
xy = SVector{2, Float32}(0.5f0, 0.5f0)
res = SVector{2, UInt32}(width, height)
floor.(UInt32, max.(0f0, xy) .* res)
return
end
# this breaks stuff
using Base.Experimental: @overlay
Base.Experimental.@MethodTable(method_table)
GPUCompiler.method_table(@nospecialize(job::CompilerJob)) = method_table
@overlay method_table Base.isnan(x::Float32) = (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0
GPUCompiler.method_table(@nospecialize(job::CompilerJob)) = method_table
function main()
source = FunctionSpec(typeof(f), Tuple{Int, Int})
target = NativeCompilerTarget()
params = TestCompilerParams()
job = CompilerJob(target, source, params)
JuliaContext() do ctx
ir, ir_meta = GPUCompiler.compile(:llvm, job; ctx, strip=true)
println(ir)
end
end
isinteractive() || main() |
Without GPUCompiler.jl: using LLVM
using Core.Compiler: CodeInstance, MethodInstance, InferenceParams, OptimizationParams, AbstractInterpreter, InferenceResult, InferenceState, OverlayMethodTable, WorldView
struct CodeCache
dict::Dict{MethodInstance,Vector{CodeInstance}}
CodeCache() = new(Dict{MethodInstance,Vector{CodeInstance}}())
end
function Core.Compiler.setindex!(cache::CodeCache, ci::CodeInstance, mi::MethodInstance)
cis = get!(cache.dict, mi, CodeInstance[])
push!(cis, ci)
end
struct CustomInterpreter <: AbstractInterpreter
global_cache::CodeCache
method_table::Union{Nothing,Core.MethodTable}
# Cache of inference results for this particular interpreter
local_cache::Vector{InferenceResult}
# The world age we're working inside of
world::UInt
# Parameters for inference and optimization
inf_params::InferenceParams
opt_params::OptimizationParams
function CustomInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt, ip::InferenceParams, op::OptimizationParams)
@assert world <= Base.get_world_counter()
return new(
cache,
mt,
# Initially empty cache
Vector{InferenceResult}(),
# world age counter
world,
# parameters for inference and optimization
ip,
op
)
end
end
Core.Compiler.InferenceParams(interp::CustomInterpreter) = interp.inf_params
Core.Compiler.OptimizationParams(interp::CustomInterpreter) = interp.opt_params
Core.Compiler.get_world_counter(interp::CustomInterpreter) = interp.world
Core.Compiler.get_inference_cache(interp::CustomInterpreter) = interp.local_cache
Core.Compiler.code_cache(interp::CustomInterpreter) = WorldView(interp.global_cache, interp.world)
# No need to do any locking since we're not putting our results into the runtime cache
Core.Compiler.lock_mi_inference(interp::CustomInterpreter, mi::MethodInstance) = nothing
Core.Compiler.unlock_mi_inference(interp::CustomInterpreter, mi::MethodInstance) = nothing
Core.Compiler.may_optimize(interp::CustomInterpreter) = true
Core.Compiler.may_compress(interp::CustomInterpreter) = true
Core.Compiler.may_discard_trees(interp::CustomInterpreter) = true
Core.Compiler.verbose_stmt_info(interp::CustomInterpreter) = false
if v"1.8-beta2" <= VERSION < v"1.9-" || VERSION >= v"1.9.0-DEV.120"
Core.Compiler.method_table(interp::CustomInterpreter) =
OverlayMethodTable(interp.world, interp.method_table)
else
Core.Compiler.method_table(interp::CustomInterpreter, sv::InferenceState) =
OverlayMethodTable(interp.world, interp.method_table)
end
## world view of the cache
function Core.Compiler.haskey(wvc::WorldView{CodeCache}, mi::MethodInstance)
Core.Compiler.get(wvc, mi, nothing) !== nothing
end
function Core.Compiler.get(wvc::WorldView{CodeCache}, mi::MethodInstance, default)
# check the cache
for ci in get!(wvc.cache.dict, mi, CodeInstance[])
if ci.min_world <= wvc.worlds.min_world && wvc.worlds.max_world <= ci.max_world
# TODO: if (code && (code == jl_nothing || jl_ir_flag_inferred((jl_array_t*)code)))
src = if ci.inferred isa Vector{UInt8}
ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
mi.def, C_NULL, ci.inferred)
else
ci.inferred
end
return ci
end
end
return default
end
function Core.Compiler.getindex(wvc::WorldView{CodeCache}, mi::MethodInstance)
r = Core.Compiler.get(wvc, mi, nothing)
r === nothing && throw(KeyError(mi))
return r::CodeInstance
end
function Core.Compiler.setindex!(wvc::WorldView{CodeCache}, ci::CodeInstance, mi::MethodInstance)
src = if ci.inferred isa Vector{UInt8}
ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
mi.def, C_NULL, ci.inferred)
else
ci.inferred
end
Core.Compiler.setindex!(wvc.cache, ci, mi)
end
## codegen/inference integration
function ci_cache_populate(interp, cache, mt, mi, min_world, max_world)
src = Core.Compiler.typeinf_ext_toplevel(interp, mi)
# inference populates the cache, so we don't need to jl_get_method_inferred
wvc = WorldView(cache, min_world, max_world)
@assert Core.Compiler.haskey(wvc, mi)
# if src is rettyp_const, the codeinfo won't cache ci.inferred
# (because it is normally not supposed to be used ever again).
# to avoid the need to re-infer, set that field here.
ci = Core.Compiler.getindex(wvc, mi)
if ci !== nothing && ci.inferred === nothing
@static if VERSION >= v"1.9.0-DEV.1115"
@atomic ci.inferred = src
else
ci.inferred = src
end
end
return ci
end
function ci_cache_lookup(cache, mi, min_world, max_world)
wvc = WorldView(cache, min_world, max_world)
ci = Core.Compiler.get(wvc, mi, nothing)
if ci !== nothing && ci.inferred === nothing
# if for some reason we did end up with a codeinfo without inferred source, e.g.,
# because of calling `Base.return_types` which only sets rettyp, pretend we didn't
# run inference so that we re-infer now and not during codegen (which is disallowed)
return nothing
end
return ci
end
# for platforms without @cfunction-with-closure support
const _method_instances = Ref{Any}()
const _cache = Ref{Any}()
function _lookup_fun(mi, min_world, max_world)
push!(_method_instances[], mi)
ci_cache_lookup(_cache[], mi, min_world, max_world)
end
if VERSION >= v"1.9.0-DEV.516"
const JuliaContextType = ThreadSafeContext
else
const JuliaContextType = Context
end
function JuliaContext()
if VERSION >= v"1.9.0-DEV.115"
# Julia 1.9 knows how to deal with arbitrary contexts
JuliaContextType()
else
# earlier versions of Julia claim so, but actually use a global context
isboxed_ref = Ref{Bool}()
typ = LLVMType(ccall(:jl_type_to_llvm, LLVM.API.LLVMTypeRef,
(Any, Ptr{Bool}), Any, isboxed_ref))
context(typ)
end
end
function JuliaContext(f)
if VERSION >= v"1.9.0-DEV.115"
JuliaContextType(f)
else
f(JuliaContext())
# we cannot dispose of the global unique context
end
end
if VERSION >= v"1.9.0-DEV.516"
unwrap_context(ctx::ThreadSafeContext) = context(ctx)
end
unwrap_context(ctx::Context) = ctx
####################
# this breaks stuff
using Base.Experimental: @overlay, @MethodTable
@MethodTable(method_table)
@overlay method_table Base.isnan(x::Float32) = (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0
#########
using StaticArrays
function kernel(width, height)
xy = SVector{2, Float32}(0.5f0, 0.5f0)
res = SVector{2, UInt32}(width, height)
floor.(UInt32, max.(0f0, xy) .* res)
return
end
function main()
f = typeof(kernel)
tt = Tuple{Int, Int}
world = Base.get_world_counter()
JuliaContext() do ctx
# get the method instance
u = Base.unwrap_unionall(tt)
sig = Base.rewrap_unionall(Tuple{f, u.parameters...}, tt)
meth = which(sig)
(ti, env) = ccall(:jl_type_intersection_with_env, Any,
(Any, Any), sig, meth.sig)::Core.SimpleVector
meth = Base.func_for_method_checked(meth, ti, env)
method_instance = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
(Any, Any, Any, UInt), meth, ti, env, world)
# populate the cache
cache = CodeCache()
mt = method_table
interp = CustomInterpreter(cache, method_table, Base.get_world_counter(), InferenceParams(;unoptimize_throw_blocks=false), OptimizationParams())
if ci_cache_lookup(cache, method_instance, world, typemax(Cint)) === nothing
ci_cache_populate(interp, cache, mt, method_instance, world, typemax(Cint))
end
# create a callback to look-up function in our cache,
# and keep track of the method instances we needed.
method_instances = []
if Sys.ARCH == :x86 || Sys.ARCH == :x86_64
function lookup_fun(mi, min_world, max_world)
push!(method_instances, mi)
ci_cache_lookup(cache, mi, min_world, max_world)
end
lookup_cb = @cfunction($lookup_fun, Any, (Any, UInt, UInt))
else
_cache[] = cache
_method_instances[] = method_instances
lookup_cb = @cfunction(_lookup_fun, Any, (Any, UInt, UInt))
end
# set-up the compiler interface
params = Base.CodegenParams(;
track_allocations = false,
code_coverage = false,
prefer_specsig = true,
gnu_pubnames = false,
lookup = Base.unsafe_convert(Ptr{Nothing}, lookup_cb))
# generate IR
GC.@preserve lookup_cb begin
native_code = if VERSION >= v"1.9.0-DEV.516"
mod = LLVM.Module("start"; ctx=unwrap_context(ctx))
# configure the module
flags(mod)["Dwarf Version", LLVM.API.LLVMModuleFlagBehaviorWarning] =
Metadata(ConstantInt(4; ctx=unwrap_context(ctx)))
flags(mod)["Debug Info Version", LLVM.API.LLVMModuleFlagBehaviorWarning] =
Metadata(ConstantInt(DEBUG_METADATA_VERSION(); ctx=unwrap_context(ctx)))
ts_mod = ThreadSafeModule(mod; ctx)
ccall(:jl_create_native, Ptr{Cvoid},
(Vector{MethodInstance}, LLVM.API.LLVMOrcThreadSafeModuleRef, Ptr{Base.CodegenParams}, Cint),
[method_instance], ts_mod, Ref(params), #=extern policy=# 1)
elseif VERSION >= v"1.9.0-DEV.115"
ccall(:jl_create_native, Ptr{Cvoid},
(Vector{MethodInstance}, LLVM.API.LLVMContextRef, Ptr{Base.CodegenParams}, Cint),
[method_instance], ctx, Ref(params), #=extern policy=# 1)
elseif VERSION >= v"1.8.0-DEV.661"
@assert ctx == JuliaContext()
ccall(:jl_create_native, Ptr{Cvoid},
(Vector{MethodInstance}, Ptr{Base.CodegenParams}, Cint),
[method_instance], Ref(params), #=extern policy=# 1)
else
@assert ctx == JuliaContext()
ccall(:jl_create_native, Ptr{Cvoid},
(Vector{MethodInstance}, Base.CodegenParams, Cint),
[method_instance], params, #=extern policy=# 1)
end
@assert native_code != C_NULL
llvm_mod_ref = if VERSION >= v"1.9.0-DEV.516"
ccall(:jl_get_llvm_module, LLVM.API.LLVMOrcThreadSafeModuleRef,
(Ptr{Cvoid},), native_code)
else
ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef,
(Ptr{Cvoid},), native_code)
end
@assert llvm_mod_ref != C_NULL
if VERSION >= v"1.9.0-DEV.516"
llvm_ts_mod = LLVM.ThreadSafeModule(llvm_mod_ref)
llvm_mod = nothing
llvm_ts_mod() do mod
llvm_mod = mod
end
else
llvm_mod = LLVM.Module(llvm_mod_ref)
end
println(llvm_mod)
end
end
end
isinteractive() || main() |
Bisected to:
I'll reduce the MWE further and try to figure out what's up / file this upstream. |
"Minimal" reproducer: using LLVM
using Core.Compiler: CodeInstance, MethodInstance, InferenceParams, OptimizationParams,
AbstractInterpreter, InferenceResult, InferenceState,
OverlayMethodTable, WorldView
## cache
struct CodeCache
dict::Dict{MethodInstance,Vector{CodeInstance}}
CodeCache() = new(Dict{MethodInstance,Vector{CodeInstance}}())
end
function Core.Compiler.setindex!(cache::CodeCache, ci::CodeInstance, mi::MethodInstance)
cis = get!(cache.dict, mi, CodeInstance[])
push!(cis, ci)
end
## interpreter
struct CustomInterpreter <: AbstractInterpreter
global_cache::CodeCache
local_cache::Vector{InferenceResult}
method_table::Union{Nothing,Core.MethodTable}
world::UInt
CustomInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt) =
new(cache, Vector{InferenceResult}(), mt, world)
end
Core.Compiler.InferenceParams(interp::CustomInterpreter) = InferenceParams()
Core.Compiler.OptimizationParams(interp::CustomInterpreter) = OptimizationParams()
Core.Compiler.get_world_counter(interp::CustomInterpreter) = interp.world
Core.Compiler.get_inference_cache(interp::CustomInterpreter) = interp.local_cache
Core.Compiler.code_cache(interp::CustomInterpreter) = WorldView(interp.global_cache, interp.world)
Core.Compiler.lock_mi_inference(interp::CustomInterpreter, mi::MethodInstance) = nothing
Core.Compiler.unlock_mi_inference(interp::CustomInterpreter, mi::MethodInstance) = nothing
Core.Compiler.may_optimize(interp::CustomInterpreter) = true
Core.Compiler.may_compress(interp::CustomInterpreter) = true
Core.Compiler.may_discard_trees(interp::CustomInterpreter) = true
Core.Compiler.verbose_stmt_info(interp::CustomInterpreter) = false
if v"1.8-beta2" <= VERSION < v"1.9-" || VERSION >= v"1.9.0-DEV.120"
Core.Compiler.method_table(interp::CustomInterpreter) =
OverlayMethodTable(interp.world, interp.method_table)
else
Core.Compiler.method_table(interp::CustomInterpreter, sv::InferenceState) =
OverlayMethodTable(interp.world, interp.method_table)
end
## world view of the cache
function Core.Compiler.haskey(wvc::WorldView{CodeCache}, mi::MethodInstance)
Core.Compiler.get(wvc, mi, nothing) !== nothing
end
function Core.Compiler.get(wvc::WorldView{CodeCache}, mi::MethodInstance, default)
# check the cache
for ci in get!(wvc.cache.dict, mi, CodeInstance[])
if ci.min_world <= wvc.worlds.min_world && wvc.worlds.max_world <= ci.max_world
# TODO: if (code && (code == jl_nothing || jl_ir_flag_inferred((jl_array_t*)code)))
src = if ci.inferred isa Vector{UInt8}
ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
mi.def, C_NULL, ci.inferred)
else
ci.inferred
end
return ci
end
end
return default
end
function Core.Compiler.getindex(wvc::WorldView{CodeCache}, mi::MethodInstance)
r = Core.Compiler.get(wvc, mi, nothing)
r === nothing && throw(KeyError(mi))
return r::CodeInstance
end
function Core.Compiler.setindex!(wvc::WorldView{CodeCache}, ci::CodeInstance, mi::MethodInstance)
src = if ci.inferred isa Vector{UInt8}
ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
mi.def, C_NULL, ci.inferred)
else
ci.inferred
end
Core.Compiler.setindex!(wvc.cache, ci, mi)
end
## codegen/inference integration
function ci_cache_populate(interp, cache, mt, mi, min_world, max_world)
src = Core.Compiler.typeinf_ext_toplevel(interp, mi)
# inference populates the cache, so we don't need to jl_get_method_inferred
wvc = WorldView(cache, min_world, max_world)
# if src is rettyp_const, the codeinfo won't cache ci.inferred
# (because it is normally not supposed to be used ever again).
# to avoid the need to re-infer, set that field here.
ci = Core.Compiler.getindex(wvc, mi)
if ci !== nothing && ci.inferred === nothing
@static if VERSION >= v"1.9.0-DEV.1115"
@atomic ci.inferred = src
else
ci.inferred = src
end
end
return ci
end
function ci_cache_lookup(cache, mi, min_world, max_world)
wvc = WorldView(cache, min_world, max_world)
ci = Core.Compiler.get(wvc, mi, nothing)
if ci !== nothing && ci.inferred === nothing
# if for some reason we did end up with a codeinfo without inferred source, e.g.,
# because of calling `Base.return_types` which only sets rettyp, pretend we didn't
# run inference so that we re-infer now and not during codegen (which is disallowed)
return nothing
end
return ci
end
## LLVM context handling
if VERSION >= v"1.9.0-DEV.516"
const JuliaContextType = ThreadSafeContext
else
const JuliaContextType = Context
end
function JuliaContext()
if VERSION >= v"1.9.0-DEV.115"
# Julia 1.9 knows how to deal with arbitrary contexts
JuliaContextType()
else
# earlier versions of Julia claim so, but actually use a global context
isboxed_ref = Ref{Bool}()
typ = LLVMType(ccall(:jl_type_to_llvm, LLVM.API.LLVMTypeRef,
(Any, Ptr{Bool}), Any, isboxed_ref))
context(typ)
end
end
function JuliaContext(f)
if VERSION >= v"1.9.0-DEV.115"
JuliaContextType(f)
else
f(JuliaContext())
# we cannot dispose of the global unique context
end
end
if VERSION >= v"1.9.0-DEV.516"
unwrap_context(ctx::ThreadSafeContext) = context(ctx)
end
unwrap_context(ctx::Context) = ctx
## main
function irgen(f, tt, world)
JuliaContext() do ctx
# get the method instance
u = Base.unwrap_unionall(tt)
sig = Base.rewrap_unionall(Tuple{f, u.parameters...}, tt)
meth = which(sig)
(ti, env) = ccall(:jl_type_intersection_with_env, Any,(Any, Any), sig, meth.sig)
meth = Base.func_for_method_checked(meth, ti, env)
mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
(Any, Any, Any, UInt), meth, ti, env, world)
# populate the cache
cache = CodeCache()
interp = CustomInterpreter(cache, method_table, Base.get_world_counter())
ci_cache_populate(interp, cache, method_table, mi, world, typemax(Cint))
# set-up the compiler interface
function lookup_fun(mi, min_world, max_world)
ci_cache_lookup(cache, mi, min_world, max_world)
end
lookup_cb = @cfunction($lookup_fun, Any, (Any, UInt, UInt))
params = Base.CodegenParams(;lookup = Base.unsafe_convert(Ptr{Nothing}, lookup_cb))
GC.@preserve lookup_cb begin
# generate inferred
native_code = if VERSION >= v"1.9.0-DEV.516"
mod = LLVM.Module("start"; ctx=unwrap_context(ctx))
ts_mod = ThreadSafeModule(mod; ctx)
ccall(:jl_create_native, Ptr{Cvoid},
(Vector{MethodInstance}, LLVM.API.LLVMOrcThreadSafeModuleRef, Ptr{Base.CodegenParams}, Cint),
[mi], ts_mod, Ref(params), #=extern policy=# 1)
elseif VERSION >= v"1.9.0-DEV.115"
ccall(:jl_create_native, Ptr{Cvoid},
(Vector{MethodInstance}, LLVM.API.LLVMContextRef, Ptr{Base.CodegenParams}, Cint),
[mi], ctx, Ref(params), #=extern policy=# 1)
elseif VERSION >= v"1.8.0-DEV.661"
ccall(:jl_create_native, Ptr{Cvoid},
(Vector{MethodInstance}, Ptr{Base.CodegenParams}, Cint),
[mi], Ref(params), #=extern policy=# 1)
else
ccall(:jl_create_native, Ptr{Cvoid},
(Vector{MethodInstance}, Base.CodegenParams, Cint),
[mi], params, #=extern policy=# 1)
end
# get the module
llvm_mod_ref = if VERSION >= v"1.9.0-DEV.516"
ccall(:jl_get_llvm_module, LLVM.API.LLVMOrcThreadSafeModuleRef,
(Ptr{Cvoid},), native_code)
else
ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef,
(Ptr{Cvoid},), native_code)
end
# display its IR
if VERSION >= v"1.9.0-DEV.516"
llvm_ts_mod = LLVM.ThreadSafeModule(llvm_mod_ref)
llvm_mod = nothing
llvm_ts_mod() do mod
llvm_mod = mod
end
else
llvm_mod = LLVM.Module(llvm_mod_ref)
end
println(llvm_mod)
end
end
end
# this breaks stuff
using Base.Experimental: @overlay, @MethodTable
@MethodTable(method_table)
@overlay method_table Base.isnan(x::Float32) =
ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x) != 0
# simple kernel
using StaticArrays
function kernel(width, height)
xy = SVector{2, Float32}(0.5f0, 0.5f0)
res = SVector{2, UInt32}(width, height)
floor.(UInt32, max.(0f0, xy) .* res)
return
end
f = typeof(kernel)
tt = Tuple{Int, Int}
world = Base.get_world_counter()
@time irgen(f, tt, world) |
Filed upstream as JuliaLang/julia#47349 |
Seems like some StaticArrays/broadcast-related code now does a dynamic call on 1.9, but only using the the PTX target:
The text was updated successfully, but these errors were encountered: