Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimizer: fix up the inlining algorithm to use correct nargs/isva #55976

Merged
merged 1 commit into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,7 @@ function semi_concrete_eval_call(interp::AbstractInterpreter,
effects = Effects(effects; noub=ALWAYS_TRUE)
end
exct = refine_exception_type(result.exct, effects)
return ConstCallResults(rt, exct, SemiConcreteResult(mi, ir, effects), effects, mi)
return ConstCallResults(rt, exct, SemiConcreteResult(mi, ir, effects, spec_info(irsv)), effects, mi)
end
end
end
Expand Down
25 changes: 12 additions & 13 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ mutable struct InferenceState
slottypes::Vector{Any}
src::CodeInfo
cfg::CFG
method_info::MethodInfo
spec_info::SpecInfo

#= intermediate states for local abstract interpretation =#
currbb::Int
Expand Down Expand Up @@ -294,7 +294,7 @@ mutable struct InferenceState
sptypes = sptypes_from_meth_instance(mi)
code = src.code::Vector{Any}
cfg = compute_basic_blocks(code)
method_info = MethodInfo(src)
spec_info = SpecInfo(src)

currbb = currpc = 1
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
Expand Down Expand Up @@ -351,7 +351,7 @@ mutable struct InferenceState
restrict_abstract_call_sites = isa(def, Module)

this = new(
mi, world, mod, sptypes, slottypes, src, cfg, method_info,
mi, world, mod, sptypes, slottypes, src, cfg, spec_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
tasks, pclimitations, limitations, cycle_backedges, callstack, 0, 0, 0,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
Expand Down Expand Up @@ -791,7 +791,7 @@ end

# TODO add `result::InferenceResult` and put the irinterp result into the inference cache?
mutable struct IRInterpretationState
const method_info::MethodInfo
const spec_info::SpecInfo
const ir::IRCode
const mi::MethodInstance
const world::UInt
Expand All @@ -809,7 +809,7 @@ mutable struct IRInterpretationState
parentid::Int

function IRInterpretationState(interp::AbstractInterpreter,
method_info::MethodInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
spec_info::SpecInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
world::UInt, min_world::UInt, max_world::UInt)
curridx = 1
given_argtypes = Vector{Any}(undef, length(argtypes))
Expand All @@ -831,7 +831,7 @@ mutable struct IRInterpretationState
tasks = WorkThunk[]
edges = Any[]
callstack = AbsIntState[]
return new(method_info, ir, mi, world, curridx, argtypes_refined, ir.sptypes, tpdum,
return new(spec_info, ir, mi, world, curridx, argtypes_refined, ir.sptypes, tpdum,
ssa_refined, lazyreachability, valid_worlds, tasks, edges, callstack, 0, 0)
end
end
Expand All @@ -845,14 +845,13 @@ function IRInterpretationState(interp::AbstractInterpreter,
else
isa(src, CodeInfo) || return nothing
end
method_info = MethodInfo(src)
spec_info = SpecInfo(src)
ir = inflate_ir(src, mi)
argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, src.nargs, src.isva)
return IRInterpretationState(interp, method_info, ir, mi, argtypes, world,
return IRInterpretationState(interp, spec_info, ir, mi, argtypes, world,
codeinst.min_world, codeinst.max_world)
end


# AbsIntState
# ===========

Expand Down Expand Up @@ -927,11 +926,11 @@ is_constproped(::IRInterpretationState) = true
is_cached(sv::InferenceState) = !iszero(sv.cache_mode & CACHE_MODE_GLOBAL)
is_cached(::IRInterpretationState) = false

method_info(sv::InferenceState) = sv.method_info
method_info(sv::IRInterpretationState) = sv.method_info
spec_info(sv::InferenceState) = sv.spec_info
spec_info(sv::IRInterpretationState) = sv.spec_info

propagate_inbounds(sv::AbsIntState) = method_info(sv).propagate_inbounds
method_for_inference_limit_heuristics(sv::AbsIntState) = method_info(sv).method_for_inference_limit_heuristics
propagate_inbounds(sv::AbsIntState) = spec_info(sv).propagate_inbounds
method_for_inference_limit_heuristics(sv::AbsIntState) = spec_info(sv).method_for_inference_limit_heuristics

frame_world(sv::InferenceState) = sv.world
frame_world(sv::IRInterpretationState) = sv.world
Expand Down
14 changes: 7 additions & 7 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -644,10 +644,10 @@ function ((; code_cache)::GetNativeEscapeCache)(mi::MethodInstance)
return false
end

function refine_effects!(interp::AbstractInterpreter, sv::PostOptAnalysisState)
function refine_effects!(interp::AbstractInterpreter, opt::OptimizationState, sv::PostOptAnalysisState)
if !is_effect_free(sv.result.ipo_effects) && sv.all_effect_free && !isempty(sv.ea_analysis_pending)
ir = sv.ir
nargs = let def = sv.result.linfo.def; isa(def, Method) ? Int(def.nargs) : 0; end
nargs = Int(opt.src.nargs)
estate = EscapeAnalysis.analyze_escapes(ir, nargs, optimizer_lattice(interp), GetNativeEscapeCache(interp))
argescapes = EscapeAnalysis.ArgEscapeCache(estate)
stack_analysis_result!(sv.result, argescapes)
Expand Down Expand Up @@ -939,7 +939,8 @@ function check_inconsistentcy!(sv::PostOptAnalysisState, scanner::BBScanner)
end
end

function ipo_dataflow_analysis!(interp::AbstractInterpreter, ir::IRCode, result::InferenceResult)
function ipo_dataflow_analysis!(interp::AbstractInterpreter, opt::OptimizationState,
ir::IRCode, result::InferenceResult)
if !is_ipo_dataflow_analysis_profitable(result.ipo_effects)
return false
end
Expand Down Expand Up @@ -967,13 +968,13 @@ function ipo_dataflow_analysis!(interp::AbstractInterpreter, ir::IRCode, result:
end
end

return refine_effects!(interp, sv)
return refine_effects!(interp, opt, sv)
end

# run the optimization work
function optimize(interp::AbstractInterpreter, opt::OptimizationState, caller::InferenceResult)
@timeit "optimizer" ir = run_passes_ipo_safe(opt.src, opt, caller)
ipo_dataflow_analysis!(interp, ir, caller)
@timeit "optimizer" ir = run_passes_ipo_safe(opt.src, opt)
ipo_dataflow_analysis!(interp, opt, ir, caller)
return finish(interp, opt, ir, caller)
end

Expand All @@ -995,7 +996,6 @@ matchpass(::Nothing, _, _) = false
function run_passes_ipo_safe(
ci::CodeInfo,
sv::OptimizationState,
caller::InferenceResult,
optimize_until = nothing, # run all passes by default
)
__stage__ = 0 # used by @pass
Expand Down
68 changes: 37 additions & 31 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct InliningTodo
mi::MethodInstance
# The IR of the inlinee
ir::IRCode
# The SpecInfo for the inlinee
spec_info::SpecInfo
# The DebugInfo table for the inlinee
di::DebugInfo
# If the function being inlined is a single basic block we can use a
Expand All @@ -20,8 +22,8 @@ struct InliningTodo
# Effects of the call statement
effects::Effects
end
function InliningTodo(mi::MethodInstance, (ir, di)::Tuple{IRCode, DebugInfo}, effects::Effects)
return InliningTodo(mi, ir, di, linear_inline_eligible(ir), effects)
function InliningTodo(mi::MethodInstance, ir::IRCode, spec_info::SpecInfo, di::DebugInfo, effects::Effects)
return InliningTodo(mi, ir, spec_info, di, linear_inline_eligible(ir), effects)
end

struct ConstantCase
Expand Down Expand Up @@ -321,7 +323,8 @@ function ir_inline_linetable!(debuginfo::DebugInfoStream, inlinee_debuginfo::Deb
end

function ir_prepare_inlining!(insert_node!::Inserter, inline_target::Union{IRCode, IncrementalCompact},
ir::IRCode, di::DebugInfo, mi::MethodInstance, inlined_at::NTuple{3,Int32}, argexprs::Vector{Any})
ir::IRCode, spec_info::SpecInfo, di::DebugInfo, mi::MethodInstance,
inlined_at::NTuple{3,Int32}, argexprs::Vector{Any})
def = mi.def::Method
debuginfo = inline_target isa IRCode ? inline_target.debuginfo : inline_target.ir.debuginfo
topline = new_inlined_at = ir_inline_linetable!(debuginfo, di, inlined_at)
Expand All @@ -334,8 +337,8 @@ function ir_prepare_inlining!(insert_node!::Inserter, inline_target::Union{IRCod
spvals_ssa = insert_node!(
removable_if_unused(NewInstruction(Expr(:call, Core._compute_sparams, def, argexprs...), SimpleVector, topline)))
end
if def.isva
nargs_def = Int(def.nargs::Int32)
if spec_info.isva
nargs_def = spec_info.nargs
if nargs_def > 0
argexprs = fix_va_argexprs!(insert_node!, inline_target, argexprs, nargs_def, topline)
end
Expand All @@ -362,7 +365,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
item::InliningTodo, boundscheck::Symbol, todo_bbs::Vector{Tuple{Int, Int}})
# Ok, do the inlining here
inlined_at = compact.result[idx][:line]
ssa_substitute = ir_prepare_inlining!(InsertHere(compact), compact, item.ir, item.di, item.mi, inlined_at, argexprs)
ssa_substitute = ir_prepare_inlining!(InsertHere(compact), compact, item.ir, item.spec_info, item.di, item.mi, inlined_at, argexprs)
boundscheck = has_flag(compact.result[idx], IR_FLAG_INBOUNDS) ? :off : boundscheck

# If the iterator already moved on to the next basic block,
Expand Down Expand Up @@ -860,15 +863,14 @@ function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult,
if inferred_result isa ConstantCase
add_inlining_backedge!(et, mi)
return inferred_result
end
if inferred_result isa InferredResult
elseif inferred_result isa InferredResult
(; src, effects) = inferred_result
elseif inferred_result isa CodeInstance
src = @atomic :monotonic inferred_result.inferred
effects = decode_effects(inferred_result.ipo_purity_bits)
else
src = nothing
effects = Effects()
else # there is no cached source available, bail out
return compileable_specialization(mi, Effects(), et, info;
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)
end

# the duplicated check might have been done already within `analyze_method!`, but still
Expand All @@ -883,9 +885,12 @@ function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult,
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)

add_inlining_backedge!(et, mi)
ir = inferred_result isa CodeInstance ? retrieve_ir_for_inlining(inferred_result, src) :
retrieve_ir_for_inlining(mi, src, preserve_local_sources)
return InliningTodo(mi, ir, effects)
if inferred_result isa CodeInstance
ir, spec_info, debuginfo = retrieve_ir_for_inlining(inferred_result, src)
else
ir, spec_info, debuginfo = retrieve_ir_for_inlining(mi, src, preserve_local_sources)
end
return InliningTodo(mi, ir, spec_info, debuginfo, effects)
end

# the special resolver for :invoke-d call
Expand All @@ -901,23 +906,17 @@ function resolve_todo(mi::MethodInstance, @nospecialize(info::CallInfo), flag::U
if cached_result isa ConstantCase
add_inlining_backedge!(et, mi)
return cached_result
end
if cached_result isa InferredResult
(; src, effects) = cached_result
elseif cached_result isa CodeInstance
src = @atomic :monotonic cached_result.inferred
effects = decode_effects(cached_result.ipo_purity_bits)
else
src = nothing
effects = Effects()
else # there is no cached source available, bail out
return nothing
end

preserve_local_sources = true
src_inlining_policy(state.interp, src, info, flag) || return nothing
ir = cached_result isa CodeInstance ? retrieve_ir_for_inlining(cached_result, src) :
retrieve_ir_for_inlining(mi, src, preserve_local_sources)
ir, spec_info, debuginfo = retrieve_ir_for_inlining(cached_result, src)
add_inlining_backedge!(et, mi)
return InliningTodo(mi, ir, effects)
return InliningTodo(mi, ir, spec_info, debuginfo, effects)
end

function validate_sparams(sparams::SimpleVector)
Expand Down Expand Up @@ -971,22 +970,29 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
return resolve_todo(mi, volatile_inf_result, info, flag, state; invokesig)
end

function retrieve_ir_for_inlining(cached_result::CodeInstance, src::MaybeCompressed)
src = _uncompressed_ir(cached_result, src)::CodeInfo
return inflate_ir!(src, cached_result.def), src.debuginfo
function retrieve_ir_for_inlining(cached_result::CodeInstance, src::String)
src = _uncompressed_ir(cached_result, src)
return inflate_ir!(src, cached_result.def), SpecInfo(src), src.debuginfo
end
function retrieve_ir_for_inlining(cached_result::CodeInstance, src::CodeInfo)
return inflate_ir!(copy(src), cached_result.def), SpecInfo(src), src.debuginfo
end
function retrieve_ir_for_inlining(mi::MethodInstance, src::CodeInfo, preserve_local_sources::Bool)
if preserve_local_sources
src = copy(src)
end
return inflate_ir!(src, mi), src.debuginfo
return inflate_ir!(src, mi), SpecInfo(src), src.debuginfo
end
function retrieve_ir_for_inlining(mi::MethodInstance, ir::IRCode, preserve_local_sources::Bool)
if preserve_local_sources
ir = copy(ir)
end
# COMBAK this is not correct, we should make `InferenceResult` propagate `SpecInfo`
spec_info = let m = mi.def::Method
SpecInfo(Int(m.nargs), m.isva, false, nothing)
end
ir.debuginfo.def = mi
return ir, DebugInfo(ir.debuginfo, length(ir.stmts))
return ir, spec_info, DebugInfo(ir.debuginfo, length(ir.stmts))
end

function handle_single_case!(todo::Vector{Pair{Int,Any}},
Expand Down Expand Up @@ -1466,8 +1472,8 @@ function semiconcrete_result_item(result::SemiConcreteResult,

add_inlining_backedge!(et, mi)
preserve_local_sources = OptimizationParams(state.interp).preserve_local_sources
ir = retrieve_ir_for_inlining(mi, result.ir, preserve_local_sources)
return InliningTodo(mi, ir, result.effects)
ir, _, debuginfo = retrieve_ir_for_inlining(mi, result.ir, preserve_local_sources)
return InliningTodo(mi, ir, result.spec_info, debuginfo, result.effects)
end

function handle_semi_concrete_result!(cases::Vector{InliningCase}, result::SemiConcreteResult,
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ function try_inline_finalizer!(ir::IRCode, argexprs::Vector{Any}, idx::Int,
end

src_inlining_policy(inlining.interp, src, info, IR_FLAG_NULL) || return false
src, di = retrieve_ir_for_inlining(code, src)
src, spec_info, di = retrieve_ir_for_inlining(code, src)

# For now: Require finalizer to only have one basic block
length(src.cfg.blocks) == 1 || return false
Expand All @@ -1542,7 +1542,7 @@ function try_inline_finalizer!(ir::IRCode, argexprs::Vector{Any}, idx::Int,

# TODO: Should there be a special line number node for inlined finalizers?
inline_at = ir[SSAValue(idx)][:line]
ssa_substitute = ir_prepare_inlining!(InsertBefore(ir, SSAValue(idx)), ir, src, di, mi, inline_at, argexprs)
ssa_substitute = ir_prepare_inlining!(InsertBefore(ir, SSAValue(idx)), ir, src, spec_info, di, mi, inline_at, argexprs)

# TODO: Use the actual inliner here rather than open coding this special purpose inliner.
ssa_rename = Vector{Any}(undef, length(src.stmts))
Expand Down
1 change: 1 addition & 0 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ struct SemiConcreteResult <: ConstResult
mi::MethodInstance
ir::IRCode
effects::Effects
spec_info::SpecInfo
end

# XXX Technically this does not represent a result of constant inference, but rather that of
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ function typeinf_ircode(interp::AbstractInterpreter, mi::MethodInstance,
end
(; result) = frame
opt = OptimizationState(frame, interp)
ir = run_passes_ipo_safe(opt.src, opt, result, optimize_until)
ir = run_passes_ipo_safe(opt.src, opt, optimize_until)
rt = widenconst(ignorelimited(result.result))
return ir, rt
end
Expand Down
7 changes: 5 additions & 2 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ struct StmtInfo
used::Bool
end

struct MethodInfo
struct SpecInfo
nargs::Int
isva::Bool
propagate_inbounds::Bool
method_for_inference_limit_heuristics::Union{Nothing,Method}
end
MethodInfo(src::CodeInfo) = MethodInfo(
SpecInfo(src::CodeInfo) = SpecInfo(
Int(src.nargs), src.isva,
src.propagate_inbounds,
src.method_for_inference_limit_heuristics::Union{Nothing,Method})

Expand Down
11 changes: 7 additions & 4 deletions test/compiler/EscapeAnalysis/EAUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,14 @@ CC.get_inference_world(interp::EscapeAnalyzer) = interp.world
CC.get_inference_cache(interp::EscapeAnalyzer) = interp.inf_cache
CC.cache_owner(::EscapeAnalyzer) = EAToken()

function CC.ipo_dataflow_analysis!(interp::EscapeAnalyzer, ir::IRCode, caller::InferenceResult)
function CC.ipo_dataflow_analysis!(interp::EscapeAnalyzer, opt::OptimizationState,
ir::IRCode, caller::InferenceResult)
# run EA on all frames that have been optimized
nargs = let def = caller.linfo.def; isa(def, Method) ? Int(def.nargs) : 0; end
nargs = Int(opt.src.nargs)
𝕃ₒ = CC.optimizer_lattice(interp)
get_escape_cache = GetEscapeCache(interp)
estate = try
analyze_escapes(ir, nargs, CC.optimizer_lattice(interp), get_escape_cache)
analyze_escapes(ir, nargs, 𝕃ₒ, get_escape_cache)
catch err
@error "error happened within EA, inspect `Main.failed_escapeanalysis`"
Main.failed_escapeanalysis = FailedAnalysis(ir, nargs, get_escape_cache)
Expand All @@ -133,7 +135,8 @@ function CC.ipo_dataflow_analysis!(interp::EscapeAnalyzer, ir::IRCode, caller::I
end
record_escapes!(interp, caller, estate, ir)

@invoke CC.ipo_dataflow_analysis!(interp::AbstractInterpreter, ir::IRCode, caller::InferenceResult)
@invoke CC.ipo_dataflow_analysis!(interp::AbstractInterpreter, opt::OptimizationState,
ir::IRCode, caller::InferenceResult)
end

function record_escapes!(interp::EscapeAnalyzer,
Expand Down
Loading