From 2e388e3731fcdd8d1db4c1aed5c6a39df3ef7153 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 25 Oct 2021 01:30:09 +0900 Subject: [PATCH 1/2] optimizer: eliminate excessive specialization in inlining code This commit includes several code quality improvements in inlining code: - eliminate excessive specializations around: * `item::Pair{Any, Any}` constructions * iterations on `Vector{Pair{Any, Any}}` - replace `Pair{Any, Any}` with new, more explicit data type `InliningCase` - remove dead code --- base/compiler/ssair/inlining.jl | 202 ++++++++++++++++---------------- 1 file changed, 102 insertions(+), 100 deletions(-) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 77bc6b43604f7..8cece4cf21657 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -35,7 +35,6 @@ end struct DelayedInliningSpec match::Union{MethodMatch, InferenceResult} atypes::Vector{Any} - stmttype::Any end struct InliningTodo @@ -44,23 +43,32 @@ struct InliningTodo spec::Union{ResolvedInliningSpec, DelayedInliningSpec} end -InliningTodo(mi::MethodInstance, match::MethodMatch, - atypes::Vector{Any}, @nospecialize(stmttype)) = InliningTodo(mi, DelayedInliningSpec(match, atypes, stmttype)) +InliningTodo(mi::MethodInstance, match::MethodMatch, atypes::Vector{Any}) = + InliningTodo(mi, DelayedInliningSpec(match, atypes)) -InliningTodo(result::InferenceResult, atypes::Vector{Any}, @nospecialize(stmttype)) = - InliningTodo(result.linfo, DelayedInliningSpec(result, atypes, stmttype)) +InliningTodo(result::InferenceResult, atypes::Vector{Any}) = + InliningTodo(result.linfo, DelayedInliningSpec(result, atypes)) struct ConstantCase val::Any ConstantCase(val) = new(val) end +struct InliningCase + sig # ::Type + item # Union{InliningTodo, MethodInstance, ConstantCase} + function InliningCase(@nospecialize(sig), @nospecialize(item)) + @assert isa(item, Union{InliningTodo, MethodInstance, ConstantCase}) "invalid inlining item" + return new(sig, item) + end +end + struct UnionSplit fully_covered::Bool atype # ::Type - cases::Vector{Pair{Any, Any}} + cases::Vector{InliningCase} bbs::Vector{Int} - UnionSplit(fully_covered::Bool, atype, cases::Vector{Pair{Any, Any}}) = + UnionSplit(fully_covered::Bool, atype, cases::Vector{InliningCase}) = new(fully_covered, atype, cases, Int[]) end @@ -137,14 +145,13 @@ function cfg_inline_item!(ir::IRCode, idx::Int, spec::ResolvedInliningSpec, stat need_split = true #!(idx == last_block_idx) end - if !need_split - delete!(state.merged_orig_blocks, last(new_range)) - end + need_split || delete!(state.merged_orig_blocks, last(new_range)) push!(state.todo_bbs, (length(state.new_cfg_blocks) - 1 + (need_split_before ? 1 : 0), post_bb_id)) from_unionsplit || delete!(state.split_targets, length(state.new_cfg_blocks)) - orig_succs = copy(state.new_cfg_blocks[end].succs) + local orig_succs + need_split && (orig_succs = copy(state.new_cfg_blocks[end].succs)) empty!(state.new_cfg_blocks[end].succs) if need_split_before l = length(state.new_cfg_blocks) @@ -204,53 +211,51 @@ function cfg_inline_item!(ir::IRCode, idx::Int, spec::ResolvedInliningSpec, stat end end end + any_edges || push!(state.dead_blocks, post_bb_id) - if !any_edges - push!(state.dead_blocks, post_bb_id) - end + return nothing end -function cfg_inline_unionsplit!(ir::IRCode, idx::Int, item::UnionSplit, state::CFGInliningState) - block = block_for_inst(ir, idx) - inline_into_block!(state, block) +function cfg_inline_unionsplit!(ir::IRCode, idx::Int, + (; fully_covered, #=atype,=# cases, bbs)::UnionSplit, + state::CFGInliningState) + inline_into_block!(state, block_for_inst(ir, idx)) from_bbs = Int[] delete!(state.split_targets, length(state.new_cfg_blocks)) orig_succs = copy(state.new_cfg_blocks[end].succs) empty!(state.new_cfg_blocks[end].succs) - for (i, (_, case)) in enumerate(item.cases) + for i in 1:length(cases) # The condition gets sunk into the previous block # Add a block for the union-split body push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx))) cond_bb = length(state.new_cfg_blocks)-1 push!(state.new_cfg_blocks[end].preds, cond_bb) push!(state.new_cfg_blocks[cond_bb].succs, cond_bb+1) + case = cases[i].item if isa(case, InliningTodo) spec = case.spec::ResolvedInliningSpec if !spec.linear_inline_eligible cfg_inline_item!(ir, idx, spec, state, true) end end - bb = length(state.new_cfg_blocks) - push!(from_bbs, bb) + push!(from_bbs, length(state.new_cfg_blocks)) # TODO: Right now we unconditionally generate a fallback block # in case of subtyping errors - This is probably unnecessary. - if true # i != length(item.cases) || !item.fully_covered + if true # i != length(cases) || !fully_covered # This block will have the next condition or the final else case push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx))) push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks)) push!(state.new_cfg_blocks[end].preds, cond_bb) - push!(item.bbs, length(state.new_cfg_blocks)) + push!(bbs, length(state.new_cfg_blocks)) end end # The edge from the fallback block. - if !item.fully_covered - push!(from_bbs, length(state.new_cfg_blocks)) - end + fully_covered || push!(from_bbs, length(state.new_cfg_blocks)) # This block will be the block everyone returns to push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx), from_bbs, orig_succs)) join_bb = length(state.new_cfg_blocks) push!(state.split_targets, join_bb) - push!(item.bbs, join_bb) + push!(bbs, join_bb) for bb in from_bbs push!(state.new_cfg_blocks[bb].succs, join_bb) end @@ -258,8 +263,10 @@ end function finish_cfg_inline!(state::CFGInliningState) new_range = (state.first_bb + 1):length(state.cfg.blocks) - l = length(state.new_cfg_blocks) - state.bb_rename[new_range] = (l+1:l+length(new_range)) + state.bb_rename[new_range] = let + l = length(state.new_cfg_blocks) + l+1:l+length(new_range) + end append!(state.new_cfg_blocks, state.cfg.blocks[new_range]) # Rename edges original bbs @@ -307,7 +314,6 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector sparam_vals = item.mi.sparam_vals def = item.mi.def::Method inline_cfg = spec.ir.cfg - stmt = compact.result[idx][:inst] linetable_offset::Int32 = length(linetable) # Append the linetable of the inlined function to our line table inlined_at = Int(compact.result[idx][:line]) @@ -339,8 +345,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector vararg = mk_tuplecall!(compact, argexprs[nargs_def:end], topline) argexprs = Any[argexprs[1:(nargs_def - 1)]..., vararg] end - is_opaque = def.is_for_opaque_closure - if is_opaque + if def.is_for_opaque_closure # Replace the first argument by a load of the capture environment argexprs[1] = insert_node_here!(compact, NewInstruction(Expr(:call, GlobalRef(Core, :getfield), argexprs[1], QuoteNode(:captures)), @@ -358,7 +363,6 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector local return_value # Special case inlining that maintains the current basic block if there's only one BB in the target if spec.linear_inline_eligible - terminator = spec.ir[SSAValue(last(inline_cfg.blocks[1].stmts))] #compact[idx] = nothing inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx) for ((_, idx′), stmt′) in inline_compact @@ -450,16 +454,18 @@ const FATAL_TYPE_BOUND_ERROR = ErrorException("fatal error in type inference (ty function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::Vector{Any}, linetable::Vector{LineInfoNode}, - item::UnionSplit, boundscheck::Symbol, todo_bbs::Vector{Tuple{Int, Int}}) + (; fully_covered, atype, cases, bbs)::UnionSplit, + boundscheck::Symbol, todo_bbs::Vector{Tuple{Int, Int}}) stmt, typ, line = compact.result[idx][:inst], compact.result[idx][:type], compact.result[idx][:line] - atype = item.atype - generic_bb = item.bbs[end-1] - join_bb = item.bbs[end] - bb = compact.active_result_bb + join_bb = bbs[end] pn = PhiNode() - has_generic = false - @assert length(item.bbs) > length(item.cases) - for ((metharg, case), next_cond_bb) in zip(item.cases, item.bbs) + local bb = compact.active_result_bb + @assert length(bbs) > length(cases) + for i in 1:length(cases) + ithcase = cases[i] + metharg = ithcase.sig + case = ithcase.item + next_cond_bb = bbs[i] @assert !isa(metharg, UnionAll) cond = true aparams, mparams = atype.parameters::SimpleVector, metharg.parameters::SimpleVector @@ -515,7 +521,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, end bb += 1 # We're now in the fall through block, decide what to do - if item.fully_covered + if fully_covered e = Expr(:call, GlobalRef(Core, :throw), FATAL_TYPE_BOUND_ERROR) insert_node_here!(compact, NewInstruction(e, Union{}, line)) insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line)) @@ -677,7 +683,8 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx:: handled = false if isa(info, ConstCallInfo) if !is_stmt_noinline(flag) && maybe_handle_const_call!( - ir, state1.id, new_stmt, info, new_sig,call.rt, istate, flag, false, todo) + ir, state1.id, new_stmt, info, new_sig, + istate, flag, false, todo) handled = true else info = info.call @@ -687,8 +694,9 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx:: info = isa(info, MethodMatchInfo) ? MethodMatchInfo[info] : info.matches # See if we can inline this call to `iterate` - analyze_single_call!(ir, todo, state1.id, new_stmt, - new_sig, call.rt, info, istate, flag) + analyze_single_call!( + ir, todo, state1.id, new_stmt, + new_sig, info, istate, flag) end if i != length(thisarginfo.each) valT = getfield_tfunc(call.rt, Const(1)) @@ -708,11 +716,13 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx:: return new_argexprs, new_atypes end -function rewrite_invoke_exprargs!(argexprs::Vector{Any}) +function rewrite_invoke_exprargs!(expr::Expr) + argexprs = expr.args argexpr0 = argexprs[2] - argexprs = argexprs[4:end] - pushfirst!(argexprs, argexpr0) - return argexprs + argexprs = argexprs[3:end] + argexprs[1] = argexpr0 + expr.args = argexprs + return expr end function compileable_specialization(et::Union{EdgeTracker, Nothing}, match::MethodMatch) @@ -778,9 +788,15 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8) return InliningTodo(mi, src) end -function resolve_todo(todo::UnionSplit, state::InliningState, flag::UInt8) - UnionSplit(todo.fully_covered, todo.atype, - Pair{Any,Any}[sig=>resolve_todo(item, state, flag) for (sig, item) in todo.cases]) +function resolve_todo((; fully_covered, atype, cases, #=bbs=#)::UnionSplit, state::InliningState, flag::UInt8) + ncases = length(cases) + newcases = Vector{InliningCase}(undef, ncases) + for i in 1:ncases + (; sig, item) = cases[i] + newitem = resolve_todo(item, state, flag) + push!(newcases, InliningCase(sig, newitem)) + end + return UnionSplit(fully_covered, atype, newcases) end function validate_sparams(sparams::SimpleVector) @@ -791,7 +807,7 @@ function validate_sparams(sparams::SimpleVector) end function analyze_method!(match::MethodMatch, atypes::Vector{Any}, - state::InliningState, @nospecialize(stmttyp), flag::UInt8) + state::InliningState, flag::UInt8) method = match.method methsig = method.sig @@ -821,7 +837,7 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any}, return compileable_specialization(et, match) end - todo = InliningTodo(mi, match, atypes, stmttyp) + todo = InliningTodo(mi, match, atypes) # If we don't have caches here, delay resolving this MethodInstance # until the batch inlining step (or an external post-processing pass) state.mi_cache === nothing && return todo @@ -846,17 +862,13 @@ function handle_single_case!(ir::IRCode, stmt::Expr, idx::Int, @nospecialize(cas if isa(case, ConstantCase) ir[SSAValue(idx)] = case.val elseif isa(case, MethodInstance) - if isinvoke - stmt.args = rewrite_invoke_exprargs!(stmt.args) - end + isinvoke && rewrite_invoke_exprargs!(stmt) stmt.head = :invoke pushfirst!(stmt.args, case) elseif case === nothing # Do, well, nothing else - if isinvoke - stmt.args = rewrite_invoke_exprargs!(stmt.args) - end + isinvoke && rewrite_invoke_exprargs!(stmt) push!(todo, idx=>(case::InliningTodo)) end nothing @@ -1005,7 +1017,6 @@ is_builtin(s::Signature) = function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result)::InvokeCallInfo, state::InliningState, todo::Vector{Pair{Int, Any}}, flag::UInt8) stmt = ir.stmts[idx][:inst] - calltype = ir.stmts[idx][:type] if !match.fully_covers # TODO: We could union split out the signature check and continue on @@ -1018,7 +1029,7 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result): pushfirst!(atypes, atype0) if isa(result, InferenceResult) && !is_stmt_noinline(flag) - (; mi) = item = InliningTodo(result, atypes, calltype) + (; mi) = item = InliningTodo(result, atypes) validate_sparams(mi.sparam_vals) || return nothing if argtypes_to_type(atypes) <: mi.def.sig state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) @@ -1027,7 +1038,7 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result): end end - result = analyze_method!(match, atypes, state, calltype, flag) + result = analyze_method!(match, atypes, state, flag) handle_single_case!(ir, stmt, idx, result, true, todo) return nothing end @@ -1136,13 +1147,12 @@ function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sta return sig end -function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt), - sig::Signature, @nospecialize(calltype), infos::Vector{MethodMatchInfo}, - state::InliningState, flag::UInt8) - cases = Pair{Any, Any}[] - signature_union = Union{} - only_method = nothing # keep track of whether there is one matching method - too_many = false +function analyze_single_call!( + ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt), + sig::Signature, infos::Vector{MethodMatchInfo}, state::InliningState, flag::UInt8) + cases = InliningCase[] + local signature_union = Bottom + local only_method = nothing # keep track of whether there is one matching method local meth local fully_covered = true for i in 1:length(infos) @@ -1151,8 +1161,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int if meth.ambig # Too many applicable methods # Or there is a (partial?) ambiguity - too_many = true - break + return elseif length(meth) == 0 # No applicable methods; try next union split continue @@ -1172,19 +1181,17 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int fully_covered = false continue end - case = analyze_method!(match, sig.atypes, state, calltype, flag) - if case === nothing + item = analyze_method!(match, sig.atypes, state, flag) + if item === nothing fully_covered = false continue - elseif _any(p->p[1] === spec_types, cases) + elseif _any(case->case.sig === spec_types, cases) continue end - push!(cases, Pair{Any,Any}(spec_types, case)) + push!(cases, InliningCase(spec_types, item)) end end - too_many && return - signature_fully_covered = sig.atype <: signature_union # If we're fully covered and there's only one applicable method, # we inline, even if the signature is not a dispatch tuple @@ -1199,9 +1206,9 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int match = meth[1] end fully_covered = true - case = analyze_method!(match, sig.atypes, state, calltype, flag) - case === nothing && return - push!(cases, Pair{Any,Any}(match.spec_types, case)) + item = analyze_method!(match, sig.atypes, state, flag) + item === nothing && return + push!(cases, InliningCase(match.spec_types, item)) end if !signature_fully_covered fully_covered = false @@ -1211,7 +1218,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int # be able to do the inlining now (for constant cases), or push it directly # onto the todo list if fully_covered && length(cases) == 1 - handle_single_case!(ir, stmt, idx, cases[1][2], false, todo) + handle_single_case!(ir, stmt, idx, cases[1].item, false, todo) return end length(cases) == 0 && return @@ -1219,31 +1226,26 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int return nothing end -function maybe_handle_const_call!(ir::IRCode, idx::Int, stmt::Expr, - info::ConstCallInfo, sig::Signature, @nospecialize(calltype), - state::InliningState, flag::UInt8, - isinvoke::Bool, todo::Vector{Pair{Int, Any}}) +function maybe_handle_const_call!( + ir::IRCode, idx::Int, stmt::Expr, info::ConstCallInfo, sig::Signature, + state::InliningState, flag::UInt8, isinvoke::Bool, todo::Vector{Pair{Int, Any}}) # when multiple matches are found, bail out and later inliner will union-split this signature # TODO effectively use multiple constant analysis results here length(info.results) == 1 || return false result = info.results[1] isa(result, InferenceResult) || return false - (; mi) = item = InliningTodo(result, sig.atypes, calltype) + (; mi) = item = InliningTodo(result, sig.atypes) validate_sparams(mi.sparam_vals) || return true - mthd_sig = mi.def.sig - mistypes = mi.specTypes state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) - if sig.atype <: mthd_sig + if sig.atype <: mi.def.sig handle_single_case!(ir, stmt, idx, item, isinvoke, todo) return true else item === nothing && return true # Union split out the error case - item = UnionSplit(false, sig.atype, Pair{Any, Any}[mistypes => item]) - if isinvoke - stmt.args = rewrite_invoke_exprargs!(stmt.args) - end + item = UnionSplit(false, sig.atype, InliningCase[InliningCase(mi.specTypes, item)]) + isinvoke && rewrite_invoke_exprargs!(stmt) push!(todo, idx=>item) return true end @@ -1258,11 +1260,11 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) sig === nothing && continue stmt = ir.stmts[idx][:inst] - calltype = ir.stmts[idx][:type] info = ir.stmts[idx][:info] # Check whether this call was @pure and evaluates to a constant if info isa MethodResultPure + calltype = ir.stmts[idx][:type] if calltype isa Const && is_inlineable_constant(calltype.val) ir.stmts[idx][:inst] = quoted(calltype.val) continue @@ -1278,12 +1280,12 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) continue end - # If inference arrived at this result by using constant propagation, - # it'll have performed a specialized analysis for just this case. Use its - # result. + # if inference arrived here with constant-prop'ed result(s), + # we can perform a specialized analysis for just this case if isa(info, ConstCallInfo) if !is_stmt_noinline(flag) && maybe_handle_const_call!( - ir, idx, stmt, info, sig, calltype, state, flag, sig.f === Core.invoke, todo) + ir, idx, stmt, info, sig, + state, flag, sig.f === Core.invoke, todo) continue else info = info.call @@ -1291,7 +1293,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) end if isa(info, OpaqueClosureCallInfo) - result = analyze_method!(info.match, sig.atypes, state, calltype, flag) + result = analyze_method!(info.match, sig.atypes, state, flag) handle_single_case!(ir, stmt, idx, result, false, todo) continue end @@ -1313,7 +1315,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) continue end - analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, state, flag) + analyze_single_call!(ir, todo, idx, stmt, sig, infos, state, flag) end todo end From 1510eaa93e60c9c8d7e92fd1d296b78ded49d52e Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 25 Oct 2021 01:35:12 +0900 Subject: [PATCH 2/2] optimizer: fix #42754, inline union-split const-prop'ed sources MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit complements #39754 and #39305: implements a logic to use constant-prop'ed results for inlining at union-split callsite. Currently it works only for cases when constant-prop' succeeded for all (union-split) signatures. > example ```julia julia> mutable struct X # NOTE in order to confuse `fieldtype_tfunc`, we need to have at least two fields with different types a::Union{Nothing, Int} b::Symbol end; julia> code_typed((X, Union{Nothing,Int})) do x, a # this `setproperty` call would be union-split and constant-prop will happen for # each signature: inlining would fail if we don't use constant-prop'ed source # since the approximated inlining cost of `convert(fieldtype(X, sym), a)` would # end up very high if we don't propagate `sym::Const(:a)` x.a = a x end |> only |> first ``` > before this commit ```julia CodeInfo( 1 ─ %1 = Base.setproperty!::typeof(setproperty!) │ %2 = (isa)(a, Nothing)::Bool └── goto #3 if not %2 2 ─ %4 = π (a, Nothing) │ invoke %1(_2::X, :a::Symbol, %4::Nothing)::Any └── goto #6 3 ─ %7 = (isa)(a, Int64)::Bool └── goto #5 if not %7 4 ─ %9 = π (a, Int64) │ invoke %1(_2::X, :a::Symbol, %9::Int64)::Any └── goto #6 5 ─ Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{} └── unreachable 6 ┄ return x ) ``` > after this commit ```julia CodeInfo( 1 ─ %1 = (isa)(a, Nothing)::Bool └── goto #3 if not %1 2 ─ Base.setfield!(x, :a, nothing)::Nothing └── goto #6 3 ─ %5 = (isa)(a, Int64)::Bool └── goto #5 if not %5 4 ─ %7 = π (a, Int64) │ Base.setfield!(x, :a, %7)::Int64 └── goto #6 5 ─ Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{} └── unreachable 6 ┄ return x ) ``` --- base/compiler/ssair/inlining.jl | 147 ++++++++++++++++++++++---------- test/compiler/inline.jl | 77 +++++++++++++++++ 2 files changed, 177 insertions(+), 47 deletions(-) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 8cece4cf21657..7c622e50482e5 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -1147,9 +1147,10 @@ function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sta return sig end +# TODO inline non-`isdispatchtuple`, union-split callsites function analyze_single_call!( ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt), - sig::Signature, infos::Vector{MethodMatchInfo}, state::InliningState, flag::UInt8) + (; atypes, atype)::Signature, infos::Vector{MethodMatchInfo}, state::InliningState, flag::UInt8) cases = InliningCase[] local signature_union = Bottom local only_method = nothing # keep track of whether there is one matching method @@ -1181,7 +1182,7 @@ function analyze_single_call!( fully_covered = false continue end - item = analyze_method!(match, sig.atypes, state, flag) + item = analyze_method!(match, atypes, state, flag) if item === nothing fully_covered = false continue @@ -1192,25 +1193,25 @@ function analyze_single_call!( end end - signature_fully_covered = sig.atype <: signature_union - # If we're fully covered and there's only one applicable method, - # we inline, even if the signature is not a dispatch tuple - if signature_fully_covered && length(cases) == 0 && only_method isa Method - if length(infos) > 1 - (metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), - sig.atype, only_method.sig)::SimpleVector - match = MethodMatch(metharg, methsp, only_method, true) - else - meth = meth::MethodLookupResult - @assert length(meth) == 1 - match = meth[1] + # if the signature is fully covered and there is only one applicable method, + # we can try to inline it even if the signature is not a dispatch tuple + if atype <: signature_union + if length(cases) == 0 && only_method isa Method + if length(infos) > 1 + (metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), + atype, only_method.sig)::SimpleVector + match = MethodMatch(metharg, methsp, only_method, true) + else + meth = meth::MethodLookupResult + @assert length(meth) == 1 + match = meth[1] + end + item = analyze_method!(match, atypes, state, flag) + item === nothing && return + push!(cases, InliningCase(match.spec_types, item)) + fully_covered = true end - fully_covered = true - item = analyze_method!(match, sig.atypes, state, flag) - item === nothing && return - push!(cases, InliningCase(match.spec_types, item)) - end - if !signature_fully_covered + else fully_covered = false end @@ -1219,36 +1220,81 @@ function analyze_single_call!( # onto the todo list if fully_covered && length(cases) == 1 handle_single_case!(ir, stmt, idx, cases[1].item, false, todo) - return + elseif length(cases) > 0 + push!(todo, idx=>UnionSplit(fully_covered, atype, cases)) end - length(cases) == 0 && return - push!(todo, idx=>UnionSplit(fully_covered, sig.atype, cases)) return nothing end +# try to create `InliningCase`s using constant-prop'ed results +# currently it works only when constant-prop' succeeded for all (union-split) signatures +# TODO use any of constant-prop'ed results, and leave the other unhandled cases to later +# TODO this function contains a lot of duplications with `analyze_single_call!`, factor them out function maybe_handle_const_call!( - ir::IRCode, idx::Int, stmt::Expr, info::ConstCallInfo, sig::Signature, + ir::IRCode, idx::Int, stmt::Expr, (; results)::ConstCallInfo, (; atypes, atype)::Signature, state::InliningState, flag::UInt8, isinvoke::Bool, todo::Vector{Pair{Int, Any}}) - # when multiple matches are found, bail out and later inliner will union-split this signature - # TODO effectively use multiple constant analysis results here - length(info.results) == 1 || return false - result = info.results[1] - isa(result, InferenceResult) || return false - - (; mi) = item = InliningTodo(result, sig.atypes) - validate_sparams(mi.sparam_vals) || return true - state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) - if sig.atype <: mi.def.sig - handle_single_case!(ir, stmt, idx, item, isinvoke, todo) - return true + cases = InliningCase[] # TODO avoid this allocation for single cases ? + local fully_covered = true + local signature_union = Bottom + for result in results + isa(result, InferenceResult) || return false + (; mi) = item = InliningTodo(result, atypes) + spec_types = mi.specTypes + signature_union = Union{signature_union, spec_types} + if !isdispatchtuple(spec_types) + fully_covered = false + continue + end + if !validate_sparams(mi.sparam_vals) + fully_covered = false + continue + end + state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) + if item === nothing + fully_covered = false + continue + end + push!(cases, InliningCase(spec_types, item)) + end + + # if the signature is fully covered and there is only one applicable method, + # we can try to inline it even if the signature is not a dispatch tuple + if atype <: signature_union + if length(cases) == 0 && length(results) == 1 + (; mi) = item = InliningTodo(results[1]::InferenceResult, atypes) + state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) + validate_sparams(mi.sparam_vals) || return true + item === nothing && return true + push!(cases, InliningCase(mi.specTypes, item)) + fully_covered = true + end else - item === nothing && return true - # Union split out the error case - item = UnionSplit(false, sig.atype, InliningCase[InliningCase(mi.specTypes, item)]) + fully_covered = false + end + + # If we only have one case and that case is fully covered, we may either + # be able to do the inlining now (for constant cases), or push it directly + # onto the todo list + if fully_covered && length(cases) == 1 + handle_single_case!(ir, stmt, idx, cases[1].item, isinvoke, todo) + elseif length(cases) > 0 isinvoke && rewrite_invoke_exprargs!(stmt) - push!(todo, idx=>item) - return true + push!(todo, idx=>UnionSplit(fully_covered, atype, cases)) end + return true +end + +function handle_const_opaque_closure_call!( + ir::IRCode, idx::Int, stmt::Expr, (; results)::ConstCallInfo, + (; atypes)::Signature, state::InliningState, flag::UInt8, todo::Vector{Pair{Int, Any}}) + @assert length(results) == 1 + result = results[1]::InferenceResult + item = InliningTodo(result, atypes) + isdispatchtuple(item.mi.specTypes) || return + validate_sparams(item.mi.sparam_vals) || return + state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) + handle_single_case!(ir, stmt, idx, item, false, todo) + return nothing end function assemble_inline_todo!(ir::IRCode, state::InliningState) @@ -1283,18 +1329,25 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) # if inference arrived here with constant-prop'ed result(s), # we can perform a specialized analysis for just this case if isa(info, ConstCallInfo) - if !is_stmt_noinline(flag) && maybe_handle_const_call!( - ir, idx, stmt, info, sig, - state, flag, sig.f === Core.invoke, todo) - continue + if !is_stmt_noinline(flag) + if isa(info.call, OpaqueClosureCallInfo) + handle_const_opaque_closure_call!( + ir, idx, stmt, info, + sig, state, flag, todo) + continue + else + maybe_handle_const_call!( + ir, idx, stmt, info, sig, + state, flag, sig.f === Core.invoke, todo) && continue + end else info = info.call end end if isa(info, OpaqueClosureCallInfo) - result = analyze_method!(info.match, sig.atypes, state, flag) - handle_single_case!(ir, stmt, idx, result, false, todo) + item = analyze_method!(info.match, sig.atypes, state, flag) + handle_single_case!(ir, stmt, idx, item, false, todo) continue end diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index 6bdb71bf8f292..a891937c72942 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -680,3 +680,80 @@ let f(x) = (x...,) # the the original apply call is not union-split, but the inserted `iterate` call is. @test code_typed(f, Tuple{Union{Int64, CartesianIndex{1}, CartesianIndex{3}}})[1][2] == Tuple{Int64} end + +# https://github.com/JuliaLang/julia/issues/42754 +# inline union-split constant-prop'ed sources +mutable struct X42754 + # NOTE in order to confuse `fieldtype_tfunc`, we need to have at least two fields with different types + a::Union{Nothing, Int} + b::Symbol +end +let code = code_typed1((X42754, Union{Nothing,Int})) do x, a + # this `setproperty` call would be union-split and constant-prop will happen for + # each signature: inlining would fail if we don't use constant-prop'ed source + # since the approximate inlining cost of `convert(fieldtype(X, sym), a)` would + # end up very high if we don't propagate `sym::Const(:a)` + x.a = a + x + end + @test all(code) do @nospecialize(x) + isinvoke(x, :setproperty!) && return false + if Meta.isexpr(x, :call) + f = x.args[1] + isa(f, GlobalRef) && f.name === :setproperty! && return false + end + return true + end +end + +import Base: @constprop + +# test single, non-dispatchtuple callsite inlining + +@constprop :none @inline test_single_nondispatchtuple(@nospecialize(t)) = + isa(t, DataType) && t.name === Type.body.name +let + code = code_typed1((Any,)) do x + test_single_nondispatchtuple(x) + end + @test all(code) do @nospecialize(x) + isinvoke(x, :test_single_nondispatchtuple) && return false + if Meta.isexpr(x, :call) + f = x.args[1] + isa(f, GlobalRef) && f.name === :test_single_nondispatchtuple && return false + end + return true + end +end + +@constprop :aggressive @inline test_single_nondispatchtuple(c, @nospecialize(t)) = + c && isa(t, DataType) && t.name === Type.body.name +let + code = code_typed1((Any,)) do x + test_single_nondispatchtuple(true, x) + end + @test all(code) do @nospecialize(x) + isinvoke(x, :test_single_nondispatchtuple) && return false + if Meta.isexpr(x, :call) + f = x.args[1] + isa(f, GlobalRef) && f.name === :test_single_nondispatchtuple && return false + end + return true + end +end + +# validate inlining processing + +@constprop :none @inline validate_unionsplit_inlining(@nospecialize(t)) = throw("invalid inlining processing detected") +@constprop :none @noinline validate_unionsplit_inlining(i::Integer) = (println(IOBuffer(), "prevent inlining"); false) +let + invoke(xs) = validate_unionsplit_inlining(xs[1]) + @test invoke(Any[10]) === false +end + +@constprop :aggressive @inline validate_unionsplit_inlining(c, @nospecialize(t)) = c && throw("invalid inlining processing detected") +@constprop :aggressive @noinline validate_unionsplit_inlining(c, i::Integer) = c && (println(IOBuffer(), "prevent inlining"); false) +let + invoke(xs) = validate_unionsplit_inlining(true, xs[1]) + @test invoke(Any[10]) === false +end