Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference regression because of semi-concrete interpretation #47349

Closed
maleadt opened this issue Oct 27, 2022 · 7 comments · Fixed by #47490
Closed

Inference regression because of semi-concrete interpretation #47349

maleadt opened this issue Oct 27, 2022 · 7 comments · Fixed by #47490
Assignees
Labels
compiler:inference Type inference gpu Affects running Julia on a GPU regression Regression in behavior compared to a previous version

Comments

@maleadt
Copy link
Member

maleadt commented Oct 27, 2022

I have a fairly involved MWE, reduced from a CUDA.jl bug report, where a StaticArrays-based kernel (doing broadcast with floor) stop generating static code when combined with a method overlay definition of Base.isnan. Essentially:

# simple kernel
using StaticArrays
function kernel(width, height)
    xy = SVector{2, Float32}(0.5f0, 0.5f0)
    res = SVector{2, UInt32}(width, height)
    floor.(UInt32, max.(0f0, xy) .* res)
    return
end

# this breaks static irgen
using Base.Experimental: @overlay, @MethodTable
@MethodTable(method_table)
@overlay method_table Base.isnan(x::Float32) =
    ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x) != 0

f = typeof(kernel)
tt = Tuple{Int, Int}
world = Base.get_world_counter()
irgen(f, tt, world)

On 1.8, the generated IR is fully static and contains calls to the floor intrinsic:

  %38 = call float @llvm.floor.f32(float %37), !dbg !91
  %69 = call float @llvm.floor.f32(float %68), !dbg !91

On 1.9 however, the IR is full of dynamic dispatches and no floor to be seen, even though the code_warntype looks similar. I've bisected this to #44803, but have yet to look into this any closer.

The full MWE, reduced from GPUCompiler.jl:

using LLVM

using Core.Compiler: CodeInstance, MethodInstance, InferenceParams, OptimizationParams,
                     AbstractInterpreter, InferenceResult, InferenceState,
                     OverlayMethodTable, WorldView


## cache

struct CodeCache
    dict::Dict{MethodInstance,Vector{CodeInstance}}
    CodeCache() = new(Dict{MethodInstance,Vector{CodeInstance}}())
end

function Core.Compiler.setindex!(cache::CodeCache, ci::CodeInstance, mi::MethodInstance)
    cis = get!(cache.dict, mi, CodeInstance[])
    push!(cis, ci)
end


## interpreter

struct CustomInterpreter <: AbstractInterpreter
    global_cache::CodeCache
    local_cache::Vector{InferenceResult}
    method_table::Union{Nothing,Core.MethodTable}
    world::UInt

    CustomInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt) =
        new(cache, Vector{InferenceResult}(), mt, world)
end

Core.Compiler.InferenceParams(interp::CustomInterpreter) = InferenceParams()
Core.Compiler.OptimizationParams(interp::CustomInterpreter) = OptimizationParams()
Core.Compiler.get_world_counter(interp::CustomInterpreter) = interp.world
Core.Compiler.get_inference_cache(interp::CustomInterpreter) = interp.local_cache
Core.Compiler.code_cache(interp::CustomInterpreter) = WorldView(interp.global_cache, interp.world)

Core.Compiler.lock_mi_inference(interp::CustomInterpreter, mi::MethodInstance) = nothing
Core.Compiler.unlock_mi_inference(interp::CustomInterpreter, mi::MethodInstance) = nothing

Core.Compiler.may_optimize(interp::CustomInterpreter) = true
Core.Compiler.may_compress(interp::CustomInterpreter) = true
Core.Compiler.may_discard_trees(interp::CustomInterpreter) = true
Core.Compiler.verbose_stmt_info(interp::CustomInterpreter) = false

if v"1.8-beta2" <= VERSION < v"1.9-" || VERSION >= v"1.9.0-DEV.120"
Core.Compiler.method_table(interp::CustomInterpreter) =
    OverlayMethodTable(interp.world, interp.method_table)
else
Core.Compiler.method_table(interp::CustomInterpreter, sv::InferenceState) =
    OverlayMethodTable(interp.world, interp.method_table)
end


## world view of the cache

function Core.Compiler.haskey(wvc::WorldView{CodeCache}, mi::MethodInstance)
    Core.Compiler.get(wvc, mi, nothing) !== nothing
end

function Core.Compiler.get(wvc::WorldView{CodeCache}, mi::MethodInstance, default)
    # check the cache
    for ci in get!(wvc.cache.dict, mi, CodeInstance[])
        if ci.min_world <= wvc.worlds.min_world && wvc.worlds.max_world <= ci.max_world
            # TODO: if (code && (code == jl_nothing || jl_ir_flag_inferred((jl_array_t*)code)))
            src = if ci.inferred isa Vector{UInt8}
                ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
                       mi.def, C_NULL, ci.inferred)
            else
                ci.inferred
            end
            return ci
        end
    end

    return default
end

function Core.Compiler.getindex(wvc::WorldView{CodeCache}, mi::MethodInstance)
    r = Core.Compiler.get(wvc, mi, nothing)
    r === nothing && throw(KeyError(mi))
    return r::CodeInstance
end

function Core.Compiler.setindex!(wvc::WorldView{CodeCache}, ci::CodeInstance, mi::MethodInstance)
    src = if ci.inferred isa Vector{UInt8}
        ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
                mi.def, C_NULL, ci.inferred)
    else
        ci.inferred
    end
    Core.Compiler.setindex!(wvc.cache, ci, mi)
end


## codegen/inference integration

function ci_cache_populate(interp, cache, mt, mi, min_world, max_world)
    src = Core.Compiler.typeinf_ext_toplevel(interp, mi)

    # inference populates the cache, so we don't need to jl_get_method_inferred
    wvc = WorldView(cache, min_world, max_world)

    # if src is rettyp_const, the codeinfo won't cache ci.inferred
    # (because it is normally not supposed to be used ever again).
    # to avoid the need to re-infer, set that field here.
    ci = Core.Compiler.getindex(wvc, mi)
    if ci !== nothing && ci.inferred === nothing
        @static if VERSION >= v"1.9.0-DEV.1115"
            @atomic ci.inferred = src
        else
            ci.inferred = src
        end
    end

    return ci
end

function ci_cache_lookup(cache, mi, min_world, max_world)
    wvc = WorldView(cache, min_world, max_world)
    ci = Core.Compiler.get(wvc, mi, nothing)
    if ci !== nothing && ci.inferred === nothing
        # if for some reason we did end up with a codeinfo without inferred source, e.g.,
        # because of calling `Base.return_types` which only sets rettyp, pretend we didn't
        # run inference so that we re-infer now and not during codegen (which is disallowed)
        return nothing
    end
    return ci
end


## LLVM context handling

if VERSION >= v"1.9.0-DEV.516"
    const JuliaContextType = ThreadSafeContext
else
    const JuliaContextType = Context
end

function JuliaContext()
    if VERSION >= v"1.9.0-DEV.115"
        # Julia 1.9 knows how to deal with arbitrary contexts
        JuliaContextType()
    else
        # earlier versions of Julia claim so, but actually use a global context
        isboxed_ref = Ref{Bool}()
        typ = LLVMType(ccall(:jl_type_to_llvm, LLVM.API.LLVMTypeRef,
                       (Any, Ptr{Bool}), Any, isboxed_ref))
        context(typ)
    end
end
function JuliaContext(f)
    if VERSION >= v"1.9.0-DEV.115"
        JuliaContextType(f)
    else
        f(JuliaContext())
        # we cannot dispose of the global unique context
    end
end

if VERSION >= v"1.9.0-DEV.516"
unwrap_context(ctx::ThreadSafeContext) = context(ctx)
end
unwrap_context(ctx::Context) = ctx


## main

function irgen(f, tt, world)
    JuliaContext() do ctx
        # get the method instance
        u = Base.unwrap_unionall(tt)
        sig = Base.rewrap_unionall(Tuple{f, u.parameters...}, tt)
        meth = which(sig)

        (ti, env) = ccall(:jl_type_intersection_with_env, Any,(Any, Any), sig, meth.sig)

        meth = Base.func_for_method_checked(meth, ti, env)

        mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
                    (Any, Any, Any, UInt), meth, ti, env, world)

        # populate the cache
        cache = CodeCache()
        interp = CustomInterpreter(cache, method_table, Base.get_world_counter())
        ci_cache_populate(interp, cache, method_table, mi, world, typemax(Cint))

        # set-up the compiler interface
        function lookup_fun(mi, min_world, max_world)
            ci_cache_lookup(cache, mi, min_world, max_world)
        end
        lookup_cb = @cfunction($lookup_fun, Any, (Any, UInt, UInt))
        params = Base.CodegenParams(;lookup = Base.unsafe_convert(Ptr{Nothing}, lookup_cb))

        GC.@preserve lookup_cb begin
            # generate inferred
            native_code = if VERSION >= v"1.9.0-DEV.516"
                mod = LLVM.Module("start"; ctx=unwrap_context(ctx))
                ts_mod = ThreadSafeModule(mod; ctx)
                ccall(:jl_create_native, Ptr{Cvoid},
                        (Vector{MethodInstance}, LLVM.API.LLVMOrcThreadSafeModuleRef, Ptr{Base.CodegenParams}, Cint),
                        [mi], ts_mod, Ref(params), #=extern policy=# 1)
            elseif VERSION >= v"1.9.0-DEV.115"
                ccall(:jl_create_native, Ptr{Cvoid},
                        (Vector{MethodInstance}, LLVM.API.LLVMContextRef, Ptr{Base.CodegenParams}, Cint),
                        [mi], ctx, Ref(params), #=extern policy=# 1)
            elseif VERSION >= v"1.8.0-DEV.661"
                ccall(:jl_create_native, Ptr{Cvoid},
                        (Vector{MethodInstance}, Ptr{Base.CodegenParams}, Cint),
                        [mi], Ref(params), #=extern policy=# 1)
            else
                ccall(:jl_create_native, Ptr{Cvoid},
                        (Vector{MethodInstance}, Base.CodegenParams, Cint),
                        [mi], params, #=extern policy=# 1)
            end

            # get the module
            llvm_mod_ref = if VERSION >= v"1.9.0-DEV.516"
                ccall(:jl_get_llvm_module, LLVM.API.LLVMOrcThreadSafeModuleRef,
                        (Ptr{Cvoid},), native_code)
            else
                ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef,
                        (Ptr{Cvoid},), native_code)
            end

            # display its IR
            if VERSION >= v"1.9.0-DEV.516"
                llvm_ts_mod = LLVM.ThreadSafeModule(llvm_mod_ref)
                llvm_mod = nothing
                llvm_ts_mod() do mod
                    llvm_mod = mod
                end
            else
                llvm_mod = LLVM.Module(llvm_mod_ref)
            end
            println(llvm_mod)
        end
    end
end

# this breaks stuff
using Base.Experimental: @overlay, @MethodTable
@MethodTable(method_table)
@overlay method_table Base.isnan(x::Float32) =
    ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x) != 0

# simple kernel
using StaticArrays
function kernel(width, height)
    xy = SVector{2, Float32}(0.5f0, 0.5f0)
    res = SVector{2, UInt32}(width, height)
    floor.(UInt32, max.(0f0, xy) .* res)
    return
end

f = typeof(kernel)
tt = Tuple{Int, Int}
world = Base.get_world_counter()
@time irgen(f, tt, world)

This also serves as a "minimal" example how to generate LLVM IR using a custom interpreter and cache, which is significantly slower on 1.9, taking ~7 seconds where it only used to take about 3, so I guess it also demonstrates #47296.

@maleadt maleadt added regression Regression in behavior compared to a previous version compiler:codegen Generation of LLVM IR and native code gpu Affects running Julia on a GPU labels Oct 27, 2022
@maleadt
Copy link
Member Author

maleadt commented Oct 27, 2022

Cthulhu doesn't help here, on both sides of the semi-concrete merge the Julia and LLVM IR it reports is identical, and always contain static calls to floor. Guess I'll have to debug this at a lower level. @Keno any thoughts what may have caused this?

@gbaraldi
Copy link
Member

@brenhinkeller This might be what we are seeing in StaticCompiler.jl with the overlays on.

@brenhinkeller
Copy link
Contributor

Oh! Interesting..

@aviatesk aviatesk self-assigned this Nov 7, 2022
@aviatesk
Copy link
Member

aviatesk commented Nov 7, 2022

It looks like this issue is because of semi-concrete interpretation on overlayed methods (although I'm still debugging why that happens).
For the meanwhile, we can turn-off semi-concrete interpretation as like:

function Core.Compiler.concrete_eval_eligible(interp::CustomInterpreter, 
    @nospecialize(f), result::Core.Compiler.MethodCallResult, arginfo::Core.Compiler.ArgInfo)
    ret = @invoke Core.Compiler.concrete_eval_eligible(interp::AbstractInterpreter, 
        f::Any, result::Core.Compiler.MethodCallResult, arginfo::Core.Compiler.ArgInfo)
    ret === false && return nothing # XXX JuliaLang/julia#47349
    return ret
end

and get an (seemingly) optimal LLVM IR.

@aviatesk
Copy link
Member

aviatesk commented Nov 8, 2022

Update: the root problem here seems to be a general precision issue of semi-concrete interpretation: More specifically semi-concrete interpretation is enabled for StaticArrays.StableFlatten.prepare_args but it turns out that const-prop' based abstract interpretation returns more accurate result than semi-concrete interpretation (and it is not the root problem that isnan happens to be overlayed).

With this overload, we can see there are some precision issue within semi-concrete eval:

@eval Core.Compiler function abstract_call_method_with_const_args(interp::AbstractInterpreter,
    result::MethodCallResult, @nospecialize(f), arginfo::ArgInfo, si::StmtInfo, match::MethodMatch,
    sv::InferenceState, invokecall::Union{Nothing,InvokeCall}=nothing)
    if !const_prop_enabled(interp, sv, match)
        return nothing
    end
    if is_removable_if_unused(result.effects)
        if isa(result.rt, Const) || call_result_unused(si)
            add_remark!(interp, sv, "[constprop] No more information to be gained")
            return nothing
        end
    end
    res = concrete_eval_call(interp, f, result, arginfo, si, sv, invokecall)
    isa(res, ConstCallResults) && return res
    mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, si, match, sv)
    mi === nothing && return nothing
    # try semi-concrete evaluation
    local semiresult = nothing
    if res::Bool && !has_conditional(arginfo)
        mi_cache = WorldView(code_cache(interp), sv.world)
        code = get(mi_cache, mi, nothing)
        if code !== nothing
            ir = codeinst_to_ir(interp, code)
            if isa(ir, IRCode)
                irsv = IRInterpretationState(interp, ir, mi, sv.world, arginfo.argtypes)
                rt = ir_abstract_constant_propagation(interp, irsv)
                if !isa(rt, Type) || typeintersect(rt, Bool) === Union{}
                    semiresult = (rt, mi, ir)
                end
            end
        end
    end
    # try constant prop'
    inf_cache = get_inference_cache(interp)
    inf_result = cache_lookup(typeinf_lattice(interp), mi, arginfo.argtypes, inf_cache)
    if inf_result === nothing
        # if there might be a cycle, check to make sure we don't end up
        # calling ourselves here.
        if result.edgecycle && (result.edgelimited ?
            is_constprop_method_recursed(match.method, sv) :
            # if the type complexity limiting didn't decide to limit the call signature (`result.edgelimited = false`)
            # we can relax the cycle detection by comparing `MethodInstance`s and allow inference to
            # propagate different constant elements if the recursion is finite over the lattice
            is_constprop_edge_recursed(mi, sv))
            add_remark!(interp, sv, "[constprop] Edge cycle encountered")
            return nothing
        end
        inf_result = InferenceResult(mi, (arginfo, sv))
        if !any(inf_result.overridden_by_const)
            add_remark!(interp, sv, "[constprop] Could not handle constant info in matching_cache_argtypes")
            return nothing
        end
        frame = InferenceState(inf_result, #=cache=#:local, interp)
        if frame === nothing
            add_remark!(interp, sv, "[constprop] Could not retrieve the source")
            return nothing # this is probably a bad generated function (unsound), but just ignore it
        end
        frame.parent = sv
        if !typeinf(interp, frame)
            add_remark!(interp, sv, "[constprop] Fresh constant inference hit a cycle")
            return nothing
        end
        @assert !isa(inf_result.result, InferenceState)
    else
        if isa(inf_result.result, InferenceState)
            add_remark!(interp, sv, "[constprop] Found cached constant inference in a cycle")
            return nothing
        end
    end
    if semiresult !== nothing
        (rt, mi, ir) = semiresult
        if inf_result.result  rt
            println("semi concrete result is worse than constprop': ", sv.linfo, mi, sv.linfo.def)
            println(rt, " vs. ", inf_result.result)
        end
    end
    return ConstCallResults(inf_result.result, ConstPropResult(inf_result), inf_result.ipo_effects, mi)
end
semi concrete result is worse than constprop': prepare_args(Tuple{StaticArrays.StableFlatten.var"#8#10"{1}, StaticArrays.StableFlatten.var"#makeargs1#11"{typeof(Base.:(*)), Tuple{StaticArrays.StableFlatten.var"#makeargs1#11"{typeof(Base.max), Tuple{StaticArrays.StableFlatten.var"#8#10"{2}, StaticArrays.StableFlatten.var"#8#10"{3}}}, StaticArrays.StableFlatten.var"#8#10"{4}}}}, Tuple{DataType, Float32, Float32, UInt32}) from prepare_args(Tuple, Tuple)(::StaticArrays.StableFlatten.var"#8#10"{1})(Tuple{DataType, Float32, Float32, UInt32}) from (::StaticArrays.StableFlatten.var"#8#10"{N})(Tuple) where {N}prepare_args(Tuple, Tuple)
DataType vs. Core.Const(val=UInt32)
semi concrete result is worse than constprop': prepare_args(Tuple{StaticArrays.StableFlatten.var"#8#10"{1}, StaticArrays.StableFlatten.var"#makeargs1#11"{typeof(Base.:(*)), Tuple{StaticArrays.StableFlatten.var"#makeargs1#11"{typeof(Base.max), Tuple{StaticArrays.StableFlatten.var"#8#10"{2}, StaticArrays.StableFlatten.var"#8#10"{3}}}, StaticArrays.StableFlatten.var"#8#10"{4}}}}, Tuple{DataType, Float32, Float32, UInt32}) from prepare_args(Tuple, Tuple)(::StaticArrays.StableFlatten.var"#8#10"{1})(Tuple{DataType, Float32, Float32, UInt32}) from (::StaticArrays.StableFlatten.var"#8#10"{N})(Tuple) where {N}prepare_args(Tuple, Tuple)
DataType vs. Core.Const(val=UInt32)
semi concrete result is worse than constprop': prepare_args(Tuple{StaticArrays.StableFlatten.var"#8#10"{2}, StaticArrays.StableFlatten.var"#8#10"{3}}, Tuple{DataType, Float32, Float32, UInt32}) from prepare_args(Tuple, Tuple)(::StaticArrays.StableFlatten.var"#8#10"{2})(Tuple{DataType, Float32, Float32, UInt32}) from (::StaticArrays.StableFlatten.var"#8#10"{N})(Tuple) where {N}prepare_args(Tuple, Tuple)
Float32 vs. Core.Const(val=0f)
semi concrete result is worse than constprop': prepare_args(Tuple{StaticArrays.StableFlatten.var"#8#10"{3}}, Tuple{DataType, Float32, Float32, UInt32}) from prepare_args(Tuple, Tuple)(::StaticArrays.StableFlatten.var"#8#10"{3})(Tuple{DataType, Float32, Float32, UInt32}) from (::StaticArrays.StableFlatten.var"#8#10"{N})(Tuple) where {N}prepare_args(Tuple, Tuple)
Float32 vs. Core.Const(val=0.5f)

I will try to come up with a minimum target and fix the precision issue, but for the meanwhile external AbstractInterpreters can just disable semi-concrete eval like above so that it can keep the accuracy of type inference.

@aviatesk aviatesk changed the title Regression in code generation with overlay methods Inference regression because of semi-concrete eval Nov 8, 2022
@aviatesk aviatesk changed the title Inference regression because of semi-concrete eval Inference regression because of semi-concrete interpretatoin Nov 8, 2022
@aviatesk aviatesk changed the title Inference regression because of semi-concrete interpretatoin Inference regression because of semi-concrete interpretation Nov 8, 2022
@aviatesk aviatesk added compiler:inference Type inference and removed compiler:codegen Generation of LLVM IR and native code labels Nov 8, 2022
aviatesk added a commit that referenced this issue Nov 8, 2022
#45459 moved `:static_parameter` always to statement
position as our optimizer assumes `:static_parameter` in value position
effect-free.
But it turns out that it can cause precision issue for semi-concrete
interpretation as discovered at #47349, since the type of
`:static_parameter` in statement position is widened when converted
to compressed IR for cache.

This commit follows up #45459 so that we inline
effect-free `:static_parameter` during IR conversion and get a more
reasonable semi-concrete interpretation.
aviatesk added a commit that referenced this issue Nov 9, 2022
#45459 moved `:static_parameter` always to statement
position as our optimizer assumes `:static_parameter` in value position
effect-free.
But it turns out that it can cause precision issue for semi-concrete
interpretation as discovered at #47349, since the type of
`:static_parameter` in statement position is widened when converted
to compressed IR for cache.

This commit follows up #45459 so that we inline
effect-free `:static_parameter` during IR conversion and get a more
reasonable semi-concrete interpretation.
@aviatesk
Copy link
Member

aviatesk commented Nov 9, 2022

This particular issue got fixed, but I think there are other inference precision issue within semi-concrete interpretation at this moment. So I suggest you keep JuliaGPU/GPUCompiler.jl#369 for the meanwhile. I will ping you once I get satisfied with incoming fixups (@maleadt).

aviatesk added a commit that referenced this issue Nov 28, 2022
This should let us generate a smaller IR in general, as well as reducing
chances that the semi-concrete interpretation causes the precision issue
that is caused by the current limitation that a compressed IR can't keep
propagate a constant type information in statement position to the
semi-concrete interpretation (xref: #47349).
aviatesk added a commit that referenced this issue Nov 28, 2022
This should let us generate a smaller IR in general, as well as reducing
chances that the semi-concrete interpretation causes the precision issue
that is caused by the current limitation that a compressed IR can't keep
propagate a constant type information in statement position to the
semi-concrete interpretation (xref: #47349).
@aviatesk
Copy link
Member

aviatesk commented Jan 4, 2023

#47994 fixed most cases where semi-concrete interpretation ended up with a wider result than constant propagation. It may be okay to turn on semi-concrete interpretation in GPUCompiler after 1.10.0-DEV.203. However, there are still a few known remaining cases, so you may want to keep the overload unless you are willing to optimize compilation performance at the risk of inference regression.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compiler:inference Type inference gpu Affects running Julia on a GPU regression Regression in behavior compared to a previous version
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants