Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inlining/Finalizer-elimination cleanup refactor #46700

Merged
merged 1 commit into from
Sep 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 74 additions & 82 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,11 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
if extra_coverage_line != 0
insert_node_here!(compact, NewInstruction(Expr(:code_coverage_effect), Nothing, extra_coverage_line))
end
if !validate_sparams(sparam_vals)
# N.B. This works on the caller-side argexprs, (i.e. before the va fixup below)
sparam_vals = insert_node_here!(compact,
effect_free(NewInstruction(Expr(:call, Core._compute_sparams, def, argexprs...), SimpleVector, topline)))
end
if def.isva
nargs_def = Int(def.nargs::Int32)
if nargs_def > 0
Expand All @@ -378,21 +383,6 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
boundscheck = :off
end
end
if !validate_sparams(sparam_vals)
if def.isva
nonva_args = argexprs[1:end-1]
va_arg = argexprs[end]
tuple_call = Expr(:call, TOP_TUPLE, def, nonva_args...)
tuple_type = tuple_tfunc(OptimizerLattice(), Any[argextype(arg, compact) for arg in nonva_args])
tupl = insert_node_here!(compact, NewInstruction(tuple_call, tuple_type, topline))
apply_iter_expr = Expr(:call, Core._apply_iterate, iterate, Core._compute_sparams, tupl, va_arg)
sparam_vals = insert_node_here!(compact,
effect_free(NewInstruction(apply_iter_expr, SimpleVector, topline)))
else
sparam_vals = insert_node_here!(compact,
effect_free(NewInstruction(Expr(:call, Core._compute_sparams, def, argexprs...), SimpleVector, topline)))
end
end
# If the iterator already moved on to the next basic block,
# temporarily re-open in again.
local return_value
Expand Down Expand Up @@ -911,11 +901,12 @@ function may_have_fcalls(m::Method)
return may_have_fcall
end

function can_inline_typevars(m::MethodMatch, argtypes::Vector{Any})
may_have_fcalls(m.method) && return false
function can_inline_typevars(method::Method, argtypes::Vector{Any})
may_have_fcalls(method) && return false
any(@nospecialize(x) -> x isa UnionAll, argtypes[2:end]) && return false
return true
end
can_inline_typevars(m::MethodMatch, argtypes::Vector{Any}) = can_inline_typevars(m.method, argtypes)

function analyze_method!(match::MethodMatch, argtypes::Vector{Any}, invokesig,
flag::UInt8, state::InliningState, allow_typevars::Bool = false)
Expand Down Expand Up @@ -969,6 +960,19 @@ end
retrieve_ir_for_inlining(mi::MethodInstance, src::CodeInfo) = inflate_ir(src, mi)
retrieve_ir_for_inlining(mi::MethodInstance, ir::IRCode) = copy(ir)

function flags_for_effects(effects::Effects)
flags::UInt8 = 0
if is_consistent(effects)
flags |= IR_FLAG_CONSISTENT
end
if is_removable_if_unused(effects)
flags |= IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW
elseif is_nothrow(effects)
flags |= IR_FLAG_NOTHROW
end
return flags
end

function handle_single_case!(
ir::IRCode, idx::Int, stmt::Expr,
@nospecialize(case), todo::Vector{Pair{Int, Any}}, params::OptimizationParams, isinvoke::Bool = false)
Expand All @@ -979,11 +983,7 @@ function handle_single_case!(
isinvoke && rewrite_invoke_exprargs!(stmt)
stmt.head = :invoke
pushfirst!(stmt.args, case.invoke)
if is_removable_if_unused(case.effects)
ir[SSAValue(idx)][:flag] |= IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW
elseif is_nothrow(case.effects)
ir[SSAValue(idx)][:flag] |= IR_FLAG_NOTHROW
end
ir[SSAValue(idx)][:flag] |= flags_for_effects(case.effects)
elseif case === nothing
# Do, well, nothing
else
Expand Down Expand Up @@ -1274,16 +1274,39 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto
return stmt, sig
end

# TODO inline non-`isdispatchtuple`, union-split callsites?
function compute_inlining_cases(infos::Vector{MethodMatchInfo},
function handle_any_const_result!(cases::Vector{InliningCase}, @nospecialize(result), match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState, allow_typevars::Bool=false)
if isa(result, ConcreteResult)
case = concrete_result_item(result, state)
push!(cases, InliningCase(result.mi.specTypes, case))
return true
elseif isa(result, ConstPropResult)
return handle_const_prop_result!(result, argtypes, flag, state, cases, #=allow_abstract=#true, allow_typevars)
elseif isa(result, SemiConcreteResult)
return handle_semi_concrete_result!(result, cases, #=allow_abstract=#true)
else
@assert result === nothing
return handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, allow_typevars)
end
end

function compute_inlining_cases(info::Union{ConstCallInfo, Vector{MethodMatchInfo}},
flag::UInt8, sig::Signature, state::InliningState)
argtypes = sig.argtypes
if isa(info, ConstCallInfo)
(; call, results) = info
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
else
results = nothing
infos = info
end
cases = InliningCase[]
local any_fully_covered = false
local handled_all_cases::Bool = true
local revisit_idx = nothing
local only_method = nothing
local meth::MethodLookupResult
local revisit_idx = nothing
local any_fully_covered = false
local handled_all_cases = true
local all_result_count = 0

for i in 1:length(infos)
meth = infos[i].results
if meth.ambig
Expand All @@ -1306,30 +1329,33 @@ function compute_inlining_cases(infos::Vector{MethodMatchInfo},
end
end
for (j, match) in enumerate(meth)
all_result_count += 1
result = results === nothing ? nothing : results[all_result_count]
any_fully_covered |= match.fully_covers
if !validate_sparams(match.sparams)
if !match.fully_covers
handled_all_cases = false
continue
end
if revisit_idx === nothing
revisit_idx = (i, j)
revisit_idx = (i, j, all_result_count)
else
handled_all_cases = false
revisit_idx = nothing
end
else
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#false)
handled_all_cases &= handle_any_const_result!(cases, result, match, argtypes, flag, state, false)
end
end
end

if handled_all_cases && revisit_idx !== nothing
# we handled everything except one match with unmatched sparams,
# so try to handle it by bypassing validate_sparams
(i, j) = revisit_idx
(i, j, k) = revisit_idx
match = infos[i].results[j]
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#true)
result = results === nothing ? nothing : results[k]
handled_all_cases &= handle_any_const_result!(cases, result, match, argtypes, flag, state, true)
elseif length(cases) == 0 && only_method isa Method
# if the signature is fully covered and there is only one applicable method,
# we can try to inline it even in the presence of unmatched sparams
Expand All @@ -1339,11 +1365,13 @@ function compute_inlining_cases(infos::Vector{MethodMatchInfo},
atype = argtypes_to_type(argtypes)
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), atype, only_method.sig)::SimpleVector
match = MethodMatch(metharg, methsp::SimpleVector, only_method, true)
result = nothing
else
@assert length(meth) == 1
match = meth[1]
result = results === nothing ? nothing : results[1]
end
handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#true) || return nothing
handle_any_const_result!(cases, result, match, argtypes, flag, state, true)
any_fully_covered = handled_all_cases = match.fully_covers
elseif !handled_all_cases
# if we've not seen all candidates, union split is valid only for dispatch tuples
Expand All @@ -1353,52 +1381,6 @@ function compute_inlining_cases(infos::Vector{MethodMatchInfo},
return cases, handled_all_cases & any_fully_covered
end

function compute_inlining_cases(info::ConstCallInfo,
flag::UInt8, sig::Signature, state::InliningState)
argtypes = sig.argtypes
(; call, results) = info
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
cases = InliningCase[]
local any_fully_covered = false
local handled_all_cases = true
local j = 0
for i in 1:length(infos)
meth = infos[i].results
if meth.ambig
# Too many applicable methods
# Or there is a (partial?) ambiguity
return nothing
elseif length(meth) == 0
# No applicable methods; try next union split
handled_all_cases = false
continue
end
for match in meth
j += 1
result = results[j]
any_fully_covered |= match.fully_covers
if isa(result, ConcreteResult)
case = concrete_result_item(result, state)
push!(cases, InliningCase(result.mi.specTypes, case))
elseif isa(result, ConstPropResult)
handled_all_cases &= handle_const_prop_result!(result, argtypes, flag, state, cases, #=allow_abstract=#true)
elseif isa(result, SemiConcreteResult)
handled_all_cases &= handle_semi_concrete_result!(result, cases, #=allow_abstract=#true)
else
@assert result === nothing
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#false)
end
end
end

if !handled_all_cases
# if we've not seen all candidates, union split is valid only for dispatch tuples
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
end

return cases, handled_all_cases & any_fully_covered
end

function handle_call!(
ir::IRCode, idx::Int, stmt::Expr, infos::Vector{MethodMatchInfo}, flag::UInt8,
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
Expand Down Expand Up @@ -1436,11 +1418,13 @@ end

function handle_const_prop_result!(
result::ConstPropResult, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
cases::Vector{InliningCase}, allow_abstract::Bool)
cases::Vector{InliningCase}, allow_abstract::Bool, allow_typevars::Bool = false)
(; mi) = item = InliningTodo(result.result, argtypes)
spec_types = mi.specTypes
allow_abstract || isdispatchtuple(spec_types) || return false
validate_sparams(mi.sparam_vals) || return false
if !validate_sparams(mi.sparam_vals)
(allow_typevars && can_inline_typevars(mi.def, argtypes)) || return false
end
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
item === nothing && return false
push!(cases, InliningCase(spec_types, item))
Expand Down Expand Up @@ -1496,7 +1480,15 @@ function handle_const_opaque_closure_call!(
end

function handle_finalizer_call!(
ir::IRCode, stmt::Expr, info::FinalizerInfo, state::InliningState)
ir::IRCode, idx::Int, stmt::Expr, info::FinalizerInfo, state::InliningState)

# Finalizers don't return values, so if their execution is not observable,
# we can just not register them
if is_removable_if_unused(info.effects)
ir[SSAValue(idx)] = nothing
return nothing
end

# Only inline finalizers that are known nothrow and notls.
# This avoids having to set up state for finalizer isolation
is_finalizer_inlineable(info.effects) || return nothing
Expand Down Expand Up @@ -1601,7 +1593,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)

# Handle finalizer
if isa(info, FinalizerInfo)
handle_finalizer_call!(ir, stmt, info, state)
handle_finalizer_call!(ir, idx, stmt, info, state)
continue
end

Expand Down
3 changes: 3 additions & 0 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@ effect_free(inst::NewInstruction) =
NewInstruction(inst.stmt, inst.type, inst.info, inst.line, inst.flag | IR_FLAG_EFFECT_FREE, true)
non_effect_free(inst::NewInstruction) =
NewInstruction(inst.stmt, inst.type, inst.info, inst.line, inst.flag & ~IR_FLAG_EFFECT_FREE, true)
with_flags(inst::NewInstruction, flags::UInt8) =
NewInstruction(inst.stmt, inst.type, inst.info, inst.line, inst.flag | flags, true)


struct IRCode
stmts::InstructionStream
Expand Down
45 changes: 28 additions & 17 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -873,10 +873,15 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing, InliningState} = nothin
end
elseif is_known_call(stmt, Core.finalizer, compact)
3 <= length(stmt.args) <= 5 || continue
# Inlining performs legality checks on the finalizer to determine
# whether or not we may inline it. If so, it appends extra arguments
# at the end of the intrinsic. Detect that here.
length(stmt.args) == 5 || continue
info = compact[SSAValue(idx)][:info]
if isa(info, FinalizerInfo)
is_finalizer_inlineable(info.effects) || continue
else
# Inlining performs legality checks on the finalizer to determine
# whether or not we may inline it. If so, it appends extra arguments
# at the end of the intrinsic. Detect that here.
length(stmt.args) == 5 || continue
end
is_finalizer = true
elseif isexpr(stmt, :foreigncall)
nccallargs = length(stmt.args[3]::SimpleVector)
Expand Down Expand Up @@ -1100,7 +1105,7 @@ end

is_nothrow(ir::IRCode, pc::Int) = ir.stmts[pc][:flag] & (IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW) ≠ 0

function try_resolve_finalizer!(ir::IRCode, idx::Int, finalizer_idx::Int, defuse::SSADefUse, inlining::InliningState)
function try_resolve_finalizer!(ir::IRCode, idx::Int, finalizer_idx::Int, defuse::SSADefUse, inlining::InliningState, info::Union{FinalizerInfo, Nothing})
# For now: Require that all uses and defs are in the same basic block,
# so that live range calculations are easy.
bb = ir.cfg.blocks[block_for_inst(ir.cfg, first(defuse.uses).idx)]
Expand Down Expand Up @@ -1128,23 +1133,28 @@ function try_resolve_finalizer!(ir::IRCode, idx::Int, finalizer_idx::Int, defuse
all(check_in_range, defuse.defs) || return nothing

# For now: Require all statements in the basic block range to be nothrow.
all(minval:maxval) do idx::Int
return is_nothrow(ir, idx) || idx == finalizer_idx
all(minval:maxval) do sidx::Int
return is_nothrow(ir, idx) || sidx == finalizer_idx || sidx == idx
end || return nothing

# Ok, `finalizer` rewrite is legal.
finalizer_stmt = ir[SSAValue(finalizer_idx)][:inst]
argexprs = Any[finalizer_stmt.args[2], finalizer_stmt.args[3]]
inline = finalizer_stmt.args[4]
if inline === nothing
# No code in the function - Nothing to do
else
mi = finalizer_stmt.args[5]::MethodInstance
if inline::Bool && try_inline_finalizer!(ir, argexprs, maxval, mi, inlining)
# the finalizer body has been inlined
flags = info === nothing ? UInt8(0) : flags_for_effects(info.effects)
if length(finalizer_stmt.args) >= 4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be length(finalizer_stmt.args) >= 5?

inline = finalizer_stmt.args[4]
if inline === nothing
# No code in the function - Nothing to do
else
insert_node!(ir, maxval, NewInstruction(Expr(:invoke, mi, argexprs...), Nothing), true)
mi = finalizer_stmt.args[5]::MethodInstance
if inline::Bool && try_inline_finalizer!(ir, argexprs, maxval, mi, inlining)
# the finalizer body has been inlined
else
insert_node!(ir, maxval, with_flags(NewInstruction(Expr(:invoke, mi, argexprs...), Nothing), flags), true)
end
end
else
insert_node!(ir, maxval, with_flags(NewInstruction(Expr(:call, argexprs...), Nothing), flags), true)
Copy link
Member

@aviatesk aviatesk Sep 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is my understanding right that this branch is for generating Core.finalizer call expression that may be handled by the second run of this optimization pass?

EDIT: or maybe for external consumers who may run the inlining pass and SROA pass in a different order?

end
# Erase the call to `finalizer`
ir[SSAValue(finalizer_idx)][:inst] = nothing
Expand Down Expand Up @@ -1184,7 +1194,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
end
end
if finalizer_idx !== nothing && inlining !== nothing
try_resolve_finalizer!(ir, idx, finalizer_idx, defuse, inlining)
try_resolve_finalizer!(ir, idx, finalizer_idx, defuse, inlining, ir[SSAValue(finalizer_idx)][:info])
continue
end
# Partition defuses by field
Expand Down Expand Up @@ -1409,7 +1419,8 @@ end

function is_union_phi(compact::IncrementalCompact, idx::Int)
inst = compact.result[idx]
return isa(inst[:inst], PhiNode) && is_some_union(inst[:type])
isa(inst[:inst], PhiNode) || return false
return is_some_union(inst[:type])
end

"""
Expand Down
Loading