Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StaticArrays-related regression in 1.9 #366

Closed
maleadt opened this issue Oct 27, 2022 · 7 comments · Fixed by #369
Closed

StaticArrays-related regression in 1.9 #366

maleadt opened this issue Oct 27, 2022 · 7 comments · Fixed by #369
Labels
bug Something isn't working

Comments

@maleadt
Copy link
Member

maleadt commented Oct 27, 2022

Seems like some StaticArrays/broadcast-related code now does a dynamic call on 1.9, but only using the the PTX target:

using GPUCompiler
using StaticArrays

module TestRuntime
    # dummy methods
    signal_exception() = return
    malloc(sz) = C_NULL
    report_oom(sz) = return
    report_exception(ex) = return
    report_exception_name(ex) = return
    report_exception_frame(idx, func, file, line) = return
end

struct TestCompilerParams <: AbstractCompilerParams end
GPUCompiler.runtime_module(::CompilerJob{<:Any,TestCompilerParams}) = TestRuntime

function f(width, height)
    xy = SVector{2, Float32}(0.5f0, 0.5f0)
    res = SVector{2, UInt32}(width, height)
    floor.(UInt32, max.(0f0, xy) .* res)
    return
end

function main()
    source = FunctionSpec(typeof(f), Tuple{Int, Int})
    target = PTXCompilerTarget(; cap=v"7.5")    # doesn't work
    #target = NativeCompilerTarget()            # works
    params = TestCompilerParams()
    job = CompilerJob(target, source, params)

    JuliaContext() do ctx
      ir, ir_meta = GPUCompiler.compile(:llvm, job; ctx, strip=true)
      println(ir)
    end
end

isinteractive() || main()
@maleadt maleadt added the bug Something isn't working label Oct 27, 2022
@maleadt
Copy link
Member Author

maleadt commented Oct 27, 2022

Actually no, the above only reproduced in my running session, with a clean session it requires using the CUDA.jl compiler target:

function main()
    source = FunctionSpec(typeof(f), Tuple{Int, Int})
    target = CUDA.CUDACompilerTarget(device())
    params = CUDA.CUDACompilerParams()
    job = CompilerJob(target, source, params)

    JuliaContext() do ctx
      ir, ir_meta = GPUCompiler.compile(:llvm, job; ctx, strip=true)
      println(ir)
    end
end

So yeah, that kind of bug...

@maleadt
Copy link
Member Author

maleadt commented Oct 27, 2022

Apparently not a cache bug -- it must have been primed with a bad result in my active session, as CUDA.jl uses the same CI cache as the PTXCompilerTarget. Rather, the issue seems to be with an isnan override:

@device_override Base.isnan(x::Float64) = (ccall("extern __nv_isnand", llvmcall, Int32, (Cdouble,), x)) != 0
@device_override Base.isnan(x::Float32) = (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0

Disabling these fixes the issue. I'm not sure it's very important we use these, but it does seem to signify a deeper issue...

@maleadt
Copy link
Member Author

maleadt commented Oct 27, 2022

Without CUDA.jl:

using GPUCompiler
using StaticArrays

module TestRuntime
    # dummy methods
    signal_exception() = return
    malloc(sz) = C_NULL
    report_oom(sz) = return
    report_exception(ex) = return
    report_exception_name(ex) = return
    report_exception_frame(idx, func, file, line) = return
end

struct TestCompilerParams <: AbstractCompilerParams end
GPUCompiler.runtime_module(::CompilerJob{<:Any,TestCompilerParams}) = TestRuntime

function f(width, height)
    xy = SVector{2, Float32}(0.5f0, 0.5f0)
    res = SVector{2, UInt32}(width, height)
    floor.(UInt32, max.(0f0, xy) .* res)
    return
end

# this breaks stuff
using Base.Experimental: @overlay
Base.Experimental.@MethodTable(method_table)
GPUCompiler.method_table(@nospecialize(job::CompilerJob)) = method_table
@overlay method_table Base.isnan(x::Float32) = (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0
GPUCompiler.method_table(@nospecialize(job::CompilerJob)) = method_table

function main()
    source = FunctionSpec(typeof(f), Tuple{Int, Int})
    target = NativeCompilerTarget()
    params = TestCompilerParams()
    job = CompilerJob(target, source, params)

    JuliaContext() do ctx
      ir, ir_meta = GPUCompiler.compile(:llvm, job; ctx, strip=true)
      println(ir)
    end
end

isinteractive() || main()

@maleadt
Copy link
Member Author

maleadt commented Oct 27, 2022

Without GPUCompiler.jl:

using LLVM

using Core.Compiler: CodeInstance, MethodInstance, InferenceParams, OptimizationParams, AbstractInterpreter, InferenceResult, InferenceState, OverlayMethodTable, WorldView

struct CodeCache
    dict::Dict{MethodInstance,Vector{CodeInstance}}

    CodeCache() = new(Dict{MethodInstance,Vector{CodeInstance}}())
end

function Core.Compiler.setindex!(cache::CodeCache, ci::CodeInstance, mi::MethodInstance)
    cis = get!(cache.dict, mi, CodeInstance[])
    push!(cis, ci)
end

struct CustomInterpreter <: AbstractInterpreter
    global_cache::CodeCache
    method_table::Union{Nothing,Core.MethodTable}

    # Cache of inference results for this particular interpreter
    local_cache::Vector{InferenceResult}
    # The world age we're working inside of
    world::UInt

    # Parameters for inference and optimization
    inf_params::InferenceParams
    opt_params::OptimizationParams


    function CustomInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt, ip::InferenceParams, op::OptimizationParams)
        @assert world <= Base.get_world_counter()

        return new(
            cache,
            mt,

            # Initially empty cache
            Vector{InferenceResult}(),

            # world age counter
            world,

            # parameters for inference and optimization
            ip,
            op
        )
    end
end

Core.Compiler.InferenceParams(interp::CustomInterpreter) = interp.inf_params
Core.Compiler.OptimizationParams(interp::CustomInterpreter) = interp.opt_params
Core.Compiler.get_world_counter(interp::CustomInterpreter) = interp.world
Core.Compiler.get_inference_cache(interp::CustomInterpreter) = interp.local_cache
Core.Compiler.code_cache(interp::CustomInterpreter) = WorldView(interp.global_cache, interp.world)

# No need to do any locking since we're not putting our results into the runtime cache
Core.Compiler.lock_mi_inference(interp::CustomInterpreter, mi::MethodInstance) = nothing
Core.Compiler.unlock_mi_inference(interp::CustomInterpreter, mi::MethodInstance) = nothing

Core.Compiler.may_optimize(interp::CustomInterpreter) = true
Core.Compiler.may_compress(interp::CustomInterpreter) = true
Core.Compiler.may_discard_trees(interp::CustomInterpreter) = true
Core.Compiler.verbose_stmt_info(interp::CustomInterpreter) = false

if v"1.8-beta2" <= VERSION < v"1.9-" || VERSION >= v"1.9.0-DEV.120"
Core.Compiler.method_table(interp::CustomInterpreter) =
    OverlayMethodTable(interp.world, interp.method_table)
else
Core.Compiler.method_table(interp::CustomInterpreter, sv::InferenceState) =
    OverlayMethodTable(interp.world, interp.method_table)
end


## world view of the cache

function Core.Compiler.haskey(wvc::WorldView{CodeCache}, mi::MethodInstance)
    Core.Compiler.get(wvc, mi, nothing) !== nothing
end

function Core.Compiler.get(wvc::WorldView{CodeCache}, mi::MethodInstance, default)
    # check the cache
    for ci in get!(wvc.cache.dict, mi, CodeInstance[])
        if ci.min_world <= wvc.worlds.min_world && wvc.worlds.max_world <= ci.max_world
            # TODO: if (code && (code == jl_nothing || jl_ir_flag_inferred((jl_array_t*)code)))
            src = if ci.inferred isa Vector{UInt8}
                ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
                       mi.def, C_NULL, ci.inferred)
            else
                ci.inferred
            end
            return ci
        end
    end

    return default
end

function Core.Compiler.getindex(wvc::WorldView{CodeCache}, mi::MethodInstance)
    r = Core.Compiler.get(wvc, mi, nothing)
    r === nothing && throw(KeyError(mi))
    return r::CodeInstance
end

function Core.Compiler.setindex!(wvc::WorldView{CodeCache}, ci::CodeInstance, mi::MethodInstance)
    src = if ci.inferred isa Vector{UInt8}
        ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
                mi.def, C_NULL, ci.inferred)
    else
        ci.inferred
    end
    Core.Compiler.setindex!(wvc.cache, ci, mi)
end


## codegen/inference integration

function ci_cache_populate(interp, cache, mt, mi, min_world, max_world)
    src = Core.Compiler.typeinf_ext_toplevel(interp, mi)

    # inference populates the cache, so we don't need to jl_get_method_inferred
    wvc = WorldView(cache, min_world, max_world)
    @assert Core.Compiler.haskey(wvc, mi)

    # if src is rettyp_const, the codeinfo won't cache ci.inferred
    # (because it is normally not supposed to be used ever again).
    # to avoid the need to re-infer, set that field here.
    ci = Core.Compiler.getindex(wvc, mi)
    if ci !== nothing && ci.inferred === nothing
        @static if VERSION >= v"1.9.0-DEV.1115"
            @atomic ci.inferred = src
        else
            ci.inferred = src
        end
    end

    return ci
end

function ci_cache_lookup(cache, mi, min_world, max_world)
    wvc = WorldView(cache, min_world, max_world)
    ci = Core.Compiler.get(wvc, mi, nothing)
    if ci !== nothing && ci.inferred === nothing
        # if for some reason we did end up with a codeinfo without inferred source, e.g.,
        # because of calling `Base.return_types` which only sets rettyp, pretend we didn't
        # run inference so that we re-infer now and not during codegen (which is disallowed)
        return nothing
    end
    return ci
end

# for platforms without @cfunction-with-closure support
const _method_instances = Ref{Any}()
const _cache = Ref{Any}()
function _lookup_fun(mi, min_world, max_world)
    push!(_method_instances[], mi)
    ci_cache_lookup(_cache[], mi, min_world, max_world)
end

if VERSION >= v"1.9.0-DEV.516"
    const JuliaContextType = ThreadSafeContext
else
    const JuliaContextType = Context
end

function JuliaContext()
    if VERSION >= v"1.9.0-DEV.115"
        # Julia 1.9 knows how to deal with arbitrary contexts
        JuliaContextType()
    else
        # earlier versions of Julia claim so, but actually use a global context
        isboxed_ref = Ref{Bool}()
        typ = LLVMType(ccall(:jl_type_to_llvm, LLVM.API.LLVMTypeRef,
                       (Any, Ptr{Bool}), Any, isboxed_ref))
        context(typ)
    end
end
function JuliaContext(f)
    if VERSION >= v"1.9.0-DEV.115"
        JuliaContextType(f)
    else
        f(JuliaContext())
        # we cannot dispose of the global unique context
    end
end

if VERSION >= v"1.9.0-DEV.516"
unwrap_context(ctx::ThreadSafeContext) = context(ctx)
end
unwrap_context(ctx::Context) = ctx

####################

# this breaks stuff
using Base.Experimental: @overlay, @MethodTable
@MethodTable(method_table)
@overlay method_table Base.isnan(x::Float32) = (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0

#########

using StaticArrays
function kernel(width, height)
    xy = SVector{2, Float32}(0.5f0, 0.5f0)
    res = SVector{2, UInt32}(width, height)
    floor.(UInt32, max.(0f0, xy) .* res)
    return
end

function main()
    f = typeof(kernel)
    tt = Tuple{Int, Int}
    world = Base.get_world_counter()

    JuliaContext() do ctx
        # get the method instance
        u = Base.unwrap_unionall(tt)
        sig = Base.rewrap_unionall(Tuple{f, u.parameters...}, tt)
        meth = which(sig)

        (ti, env) = ccall(:jl_type_intersection_with_env, Any,
                          (Any, Any), sig, meth.sig)::Core.SimpleVector

        meth = Base.func_for_method_checked(meth, ti, env)

        method_instance = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
                      (Any, Any, Any, UInt), meth, ti, env, world)


        # populate the cache
        cache = CodeCache()
        mt = method_table
        interp = CustomInterpreter(cache, method_table, Base.get_world_counter(), InferenceParams(;unoptimize_throw_blocks=false), OptimizationParams())
        if ci_cache_lookup(cache, method_instance, world, typemax(Cint)) === nothing
            ci_cache_populate(interp, cache, mt, method_instance, world, typemax(Cint))
        end

        # create a callback to look-up function in our cache,
        # and keep track of the method instances we needed.
        method_instances = []
        if Sys.ARCH == :x86 || Sys.ARCH == :x86_64
            function lookup_fun(mi, min_world, max_world)
                push!(method_instances, mi)
                ci_cache_lookup(cache, mi, min_world, max_world)
            end
            lookup_cb = @cfunction($lookup_fun, Any, (Any, UInt, UInt))
        else
            _cache[] = cache
            _method_instances[] = method_instances
            lookup_cb = @cfunction(_lookup_fun, Any, (Any, UInt, UInt))
        end

        # set-up the compiler interface
        params = Base.CodegenParams(;
            track_allocations  = false,
            code_coverage      = false,
            prefer_specsig     = true,
            gnu_pubnames       = false,
            lookup             = Base.unsafe_convert(Ptr{Nothing}, lookup_cb))

        # generate IR
        GC.@preserve lookup_cb begin
            native_code = if VERSION >= v"1.9.0-DEV.516"
                mod = LLVM.Module("start"; ctx=unwrap_context(ctx))

                # configure the module
                flags(mod)["Dwarf Version", LLVM.API.LLVMModuleFlagBehaviorWarning] =
                    Metadata(ConstantInt(4; ctx=unwrap_context(ctx)))
                flags(mod)["Debug Info Version", LLVM.API.LLVMModuleFlagBehaviorWarning] =
                    Metadata(ConstantInt(DEBUG_METADATA_VERSION(); ctx=unwrap_context(ctx)))

                ts_mod = ThreadSafeModule(mod; ctx)
                ccall(:jl_create_native, Ptr{Cvoid},
                      (Vector{MethodInstance}, LLVM.API.LLVMOrcThreadSafeModuleRef, Ptr{Base.CodegenParams}, Cint),
                      [method_instance], ts_mod, Ref(params), #=extern policy=# 1)
            elseif VERSION >= v"1.9.0-DEV.115"
                ccall(:jl_create_native, Ptr{Cvoid},
                      (Vector{MethodInstance}, LLVM.API.LLVMContextRef, Ptr{Base.CodegenParams}, Cint),
                      [method_instance], ctx, Ref(params), #=extern policy=# 1)
            elseif VERSION >= v"1.8.0-DEV.661"
                @assert ctx == JuliaContext()
                ccall(:jl_create_native, Ptr{Cvoid},
                      (Vector{MethodInstance}, Ptr{Base.CodegenParams}, Cint),
                      [method_instance], Ref(params), #=extern policy=# 1)
            else
                @assert ctx == JuliaContext()
                ccall(:jl_create_native, Ptr{Cvoid},
                      (Vector{MethodInstance}, Base.CodegenParams, Cint),
                      [method_instance], params, #=extern policy=# 1)
            end
            @assert native_code != C_NULL
            llvm_mod_ref = if VERSION >= v"1.9.0-DEV.516"
                ccall(:jl_get_llvm_module, LLVM.API.LLVMOrcThreadSafeModuleRef,
                      (Ptr{Cvoid},), native_code)
            else
                ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef,
                      (Ptr{Cvoid},), native_code)
            end
            @assert llvm_mod_ref != C_NULL
            if VERSION >= v"1.9.0-DEV.516"
                llvm_ts_mod = LLVM.ThreadSafeModule(llvm_mod_ref)
                llvm_mod = nothing
                llvm_ts_mod() do mod
                    llvm_mod = mod
                end
            else
                llvm_mod = LLVM.Module(llvm_mod_ref)
            end
          println(llvm_mod)
        end
    end
end

isinteractive() || main()

@maleadt
Copy link
Member Author

maleadt commented Oct 27, 2022

Bisected to:

aa20b321a50d4b5a9f9dc5948c4733ccf4a6781f is the first bad commit
commit aa20b321a50d4b5a9f9dc5948c4733ccf4a6781f
Author: Ian Atol <[email protected]>
Date:   Thu Sep 1 14:54:39 2022 -0700

    Semi-concrete IR interpreter (#44803)

    Co-authored-by: Keno Fischer <[email protected]>

I'll reduce the MWE further and try to figure out what's up / file this upstream.

@maleadt
Copy link
Member Author

maleadt commented Oct 27, 2022

"Minimal" reproducer:

using LLVM

using Core.Compiler: CodeInstance, MethodInstance, InferenceParams, OptimizationParams,
                     AbstractInterpreter, InferenceResult, InferenceState,
                     OverlayMethodTable, WorldView


## cache

struct CodeCache
    dict::Dict{MethodInstance,Vector{CodeInstance}}
    CodeCache() = new(Dict{MethodInstance,Vector{CodeInstance}}())
end

function Core.Compiler.setindex!(cache::CodeCache, ci::CodeInstance, mi::MethodInstance)
    cis = get!(cache.dict, mi, CodeInstance[])
    push!(cis, ci)
end


## interpreter

struct CustomInterpreter <: AbstractInterpreter
    global_cache::CodeCache
    local_cache::Vector{InferenceResult}
    method_table::Union{Nothing,Core.MethodTable}
    world::UInt

    CustomInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt) =
        new(cache, Vector{InferenceResult}(), mt, world)
end

Core.Compiler.InferenceParams(interp::CustomInterpreter) = InferenceParams()
Core.Compiler.OptimizationParams(interp::CustomInterpreter) = OptimizationParams()
Core.Compiler.get_world_counter(interp::CustomInterpreter) = interp.world
Core.Compiler.get_inference_cache(interp::CustomInterpreter) = interp.local_cache
Core.Compiler.code_cache(interp::CustomInterpreter) = WorldView(interp.global_cache, interp.world)

Core.Compiler.lock_mi_inference(interp::CustomInterpreter, mi::MethodInstance) = nothing
Core.Compiler.unlock_mi_inference(interp::CustomInterpreter, mi::MethodInstance) = nothing

Core.Compiler.may_optimize(interp::CustomInterpreter) = true
Core.Compiler.may_compress(interp::CustomInterpreter) = true
Core.Compiler.may_discard_trees(interp::CustomInterpreter) = true
Core.Compiler.verbose_stmt_info(interp::CustomInterpreter) = false

if v"1.8-beta2" <= VERSION < v"1.9-" || VERSION >= v"1.9.0-DEV.120"
Core.Compiler.method_table(interp::CustomInterpreter) =
    OverlayMethodTable(interp.world, interp.method_table)
else
Core.Compiler.method_table(interp::CustomInterpreter, sv::InferenceState) =
    OverlayMethodTable(interp.world, interp.method_table)
end


## world view of the cache

function Core.Compiler.haskey(wvc::WorldView{CodeCache}, mi::MethodInstance)
    Core.Compiler.get(wvc, mi, nothing) !== nothing
end

function Core.Compiler.get(wvc::WorldView{CodeCache}, mi::MethodInstance, default)
    # check the cache
    for ci in get!(wvc.cache.dict, mi, CodeInstance[])
        if ci.min_world <= wvc.worlds.min_world && wvc.worlds.max_world <= ci.max_world
            # TODO: if (code && (code == jl_nothing || jl_ir_flag_inferred((jl_array_t*)code)))
            src = if ci.inferred isa Vector{UInt8}
                ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
                       mi.def, C_NULL, ci.inferred)
            else
                ci.inferred
            end
            return ci
        end
    end

    return default
end

function Core.Compiler.getindex(wvc::WorldView{CodeCache}, mi::MethodInstance)
    r = Core.Compiler.get(wvc, mi, nothing)
    r === nothing && throw(KeyError(mi))
    return r::CodeInstance
end

function Core.Compiler.setindex!(wvc::WorldView{CodeCache}, ci::CodeInstance, mi::MethodInstance)
    src = if ci.inferred isa Vector{UInt8}
        ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
                mi.def, C_NULL, ci.inferred)
    else
        ci.inferred
    end
    Core.Compiler.setindex!(wvc.cache, ci, mi)
end


## codegen/inference integration

function ci_cache_populate(interp, cache, mt, mi, min_world, max_world)
    src = Core.Compiler.typeinf_ext_toplevel(interp, mi)

    # inference populates the cache, so we don't need to jl_get_method_inferred
    wvc = WorldView(cache, min_world, max_world)

    # if src is rettyp_const, the codeinfo won't cache ci.inferred
    # (because it is normally not supposed to be used ever again).
    # to avoid the need to re-infer, set that field here.
    ci = Core.Compiler.getindex(wvc, mi)
    if ci !== nothing && ci.inferred === nothing
        @static if VERSION >= v"1.9.0-DEV.1115"
            @atomic ci.inferred = src
        else
            ci.inferred = src
        end
    end

    return ci
end

function ci_cache_lookup(cache, mi, min_world, max_world)
    wvc = WorldView(cache, min_world, max_world)
    ci = Core.Compiler.get(wvc, mi, nothing)
    if ci !== nothing && ci.inferred === nothing
        # if for some reason we did end up with a codeinfo without inferred source, e.g.,
        # because of calling `Base.return_types` which only sets rettyp, pretend we didn't
        # run inference so that we re-infer now and not during codegen (which is disallowed)
        return nothing
    end
    return ci
end


## LLVM context handling

if VERSION >= v"1.9.0-DEV.516"
    const JuliaContextType = ThreadSafeContext
else
    const JuliaContextType = Context
end

function JuliaContext()
    if VERSION >= v"1.9.0-DEV.115"
        # Julia 1.9 knows how to deal with arbitrary contexts
        JuliaContextType()
    else
        # earlier versions of Julia claim so, but actually use a global context
        isboxed_ref = Ref{Bool}()
        typ = LLVMType(ccall(:jl_type_to_llvm, LLVM.API.LLVMTypeRef,
                       (Any, Ptr{Bool}), Any, isboxed_ref))
        context(typ)
    end
end
function JuliaContext(f)
    if VERSION >= v"1.9.0-DEV.115"
        JuliaContextType(f)
    else
        f(JuliaContext())
        # we cannot dispose of the global unique context
    end
end

if VERSION >= v"1.9.0-DEV.516"
unwrap_context(ctx::ThreadSafeContext) = context(ctx)
end
unwrap_context(ctx::Context) = ctx


## main

function irgen(f, tt, world)
    JuliaContext() do ctx
        # get the method instance
        u = Base.unwrap_unionall(tt)
        sig = Base.rewrap_unionall(Tuple{f, u.parameters...}, tt)
        meth = which(sig)

        (ti, env) = ccall(:jl_type_intersection_with_env, Any,(Any, Any), sig, meth.sig)

        meth = Base.func_for_method_checked(meth, ti, env)

        mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
                    (Any, Any, Any, UInt), meth, ti, env, world)

        # populate the cache
        cache = CodeCache()
        interp = CustomInterpreter(cache, method_table, Base.get_world_counter())
        ci_cache_populate(interp, cache, method_table, mi, world, typemax(Cint))

        # set-up the compiler interface
        function lookup_fun(mi, min_world, max_world)
            ci_cache_lookup(cache, mi, min_world, max_world)
        end
        lookup_cb = @cfunction($lookup_fun, Any, (Any, UInt, UInt))
        params = Base.CodegenParams(;lookup = Base.unsafe_convert(Ptr{Nothing}, lookup_cb))

        GC.@preserve lookup_cb begin
            # generate inferred
            native_code = if VERSION >= v"1.9.0-DEV.516"
                mod = LLVM.Module("start"; ctx=unwrap_context(ctx))
                ts_mod = ThreadSafeModule(mod; ctx)
                ccall(:jl_create_native, Ptr{Cvoid},
                        (Vector{MethodInstance}, LLVM.API.LLVMOrcThreadSafeModuleRef, Ptr{Base.CodegenParams}, Cint),
                        [mi], ts_mod, Ref(params), #=extern policy=# 1)
            elseif VERSION >= v"1.9.0-DEV.115"
                ccall(:jl_create_native, Ptr{Cvoid},
                        (Vector{MethodInstance}, LLVM.API.LLVMContextRef, Ptr{Base.CodegenParams}, Cint),
                        [mi], ctx, Ref(params), #=extern policy=# 1)
            elseif VERSION >= v"1.8.0-DEV.661"
                ccall(:jl_create_native, Ptr{Cvoid},
                        (Vector{MethodInstance}, Ptr{Base.CodegenParams}, Cint),
                        [mi], Ref(params), #=extern policy=# 1)
            else
                ccall(:jl_create_native, Ptr{Cvoid},
                        (Vector{MethodInstance}, Base.CodegenParams, Cint),
                        [mi], params, #=extern policy=# 1)
            end

            # get the module
            llvm_mod_ref = if VERSION >= v"1.9.0-DEV.516"
                ccall(:jl_get_llvm_module, LLVM.API.LLVMOrcThreadSafeModuleRef,
                        (Ptr{Cvoid},), native_code)
            else
                ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef,
                        (Ptr{Cvoid},), native_code)
            end

            # display its IR
            if VERSION >= v"1.9.0-DEV.516"
                llvm_ts_mod = LLVM.ThreadSafeModule(llvm_mod_ref)
                llvm_mod = nothing
                llvm_ts_mod() do mod
                    llvm_mod = mod
                end
            else
                llvm_mod = LLVM.Module(llvm_mod_ref)
            end
            println(llvm_mod)
        end
    end
end

# this breaks stuff
using Base.Experimental: @overlay, @MethodTable
@MethodTable(method_table)
@overlay method_table Base.isnan(x::Float32) =
    ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x) != 0

# simple kernel
using StaticArrays
function kernel(width, height)
    xy = SVector{2, Float32}(0.5f0, 0.5f0)
    res = SVector{2, UInt32}(width, height)
    floor.(UInt32, max.(0f0, xy) .* res)
    return
end

f = typeof(kernel)
tt = Tuple{Int, Int}
world = Base.get_world_counter()
@time irgen(f, tt, world)

@maleadt
Copy link
Member Author

maleadt commented Oct 27, 2022

Filed upstream as JuliaLang/julia#47349

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant