From 67bede44d933b0c946ab2e023561fc727ac96947 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Fri, 3 Jul 2020 02:17:29 -0400 Subject: [PATCH] Propagate iteration info to optimizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This supersedes #36169. Rather than re-implementing the iteration analysis as done there, this uses the new stmtinfo infrastrcture to propagate all the analysis done during inference all the way to inlining. As a result, it applies not only to splats of singletons, but also to splats of any other short iterable that inference can analyze. E.g.: ``` f(x) = (x...,) @code_typed f(1=>2) @benchmark f(1=>2) ``` Before: ``` julia> @code_typed f(1=>2) CodeInfo( 1 ─ %1 = Core._apply_iterate(Base.iterate, Core.tuple, x)::Tuple{Int64,Int64} └── return %1 ) => Tuple{Int64,Int64} julia> @benchmark f(1=>2) BenchmarkTools.Trial: memory estimate: 96 bytes allocs estimate: 3 -------------- minimum time: 242.659 ns (0.00% GC) median time: 246.904 ns (0.00% GC) mean time: 255.390 ns (1.08% GC) maximum time: 4.415 μs (93.94% GC) -------------- samples: 10000 evals/sample: 405 ``` After: ``` julia> @code_typed f(1=>2) CodeInfo( 1 ─ %1 = Base.getfield(x, 1)::Int64 │ %2 = Base.getfield(x, 2)::Int64 │ %3 = Core.tuple(%1, %2)::Tuple{Int64,Int64} └── return %3 ) => Tuple{Int64,Int64} julia> @benchmark f(1=>2) BenchmarkTools.Trial: memory estimate: 0 bytes allocs estimate: 0 -------------- minimum time: 1.701 ns (0.00% GC) median time: 1.925 ns (0.00% GC) mean time: 1.904 ns (0.00% GC) maximum time: 6.941 ns (0.00% GC) -------------- samples: 10000 evals/sample: 1000 ``` I also implemented the TODO, I had left in #36169 to inline the iterate calls themselves, which gives another 3x improvement over the solution in that PR: ``` julia> @code_typed f(1) CodeInfo( 1 ─ %1 = Core.tuple(x)::Tuple{Int64} └── return %1 ) => Tuple{Int64} julia> @benchmark f(1) BenchmarkTools.Trial: memory estimate: 0 bytes allocs estimate: 0 -------------- minimum time: 1.696 ns (0.00% GC) median time: 1.699 ns (0.00% GC) mean time: 1.702 ns (0.00% GC) maximum time: 5.389 ns (0.00% GC) -------------- samples: 10000 evals/sample: 1000 ``` Fixes #36087 Fixes #29114 --- base/compiler/abstractinterpretation.jl | 80 +++--- base/compiler/ssair/inlining.jl | 311 ++++++++++++++---------- base/compiler/ssair/ir.jl | 18 +- base/compiler/ssair/passes.jl | 4 +- base/compiler/stmtinfo.jl | 24 +- 5 files changed, 259 insertions(+), 178 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 298e3ccd65728c..5adc2f7492f0ad 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -78,7 +78,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), push!(fullmatch, thisfullmatch) end end - info = UnionSplitInfo(splitsigs, infos) + info = UnionSplitInfo(infos) else mt = ccall(:jl_method_table_for, Any, (Any,), atype) if mt === nothing @@ -505,13 +505,13 @@ end # returns an array of types function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft), @nospecialize(typ), vtypes::VarTable, sv::InferenceState) if isa(typ, PartialStruct) && typ.typ.name === Tuple.name - return typ.fields + return typ.fields, nothing end if isa(typ, Const) val = typ.val if isa(val, SimpleVector) || isa(val, Tuple) - return Any[ Const(val[i]) for i in 1:length(val) ] # avoid making a tuple Generator here! + return Any[ Const(val[i]) for i in 1:length(val) ], nothing # avoid making a tuple Generator here! end end @@ -529,27 +529,27 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft) if isa(tti, Union) utis = uniontypes(tti) if _any(t -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis) - return Any[Vararg{Any}] + return Any[Vararg{Any}], nothing end result = Any[rewrap_unionall(p, tti0) for p in utis[1].parameters] for t in utis[2:end] if length(t.parameters) != length(result) - return Any[Vararg{Any}] + return Any[Vararg{Any}], nothing end for j in 1:length(t.parameters) result[j] = tmerge(result[j], rewrap_unionall(t.parameters[j], tti0)) end end - return result + return result, nothing elseif tti0 <: Tuple if isa(tti0, DataType) if isvatuple(tti0) && length(tti0.parameters) == 1 - return Any[Vararg{unwrapva(tti0.parameters[1])}] + return Any[Vararg{unwrapva(tti0.parameters[1])}], nothing else - return Any[ p for p in tti0.parameters ] + return Any[ p for p in tti0.parameters ], nothing end elseif !isa(tti, DataType) - return Any[Vararg{Any}] + return Any[Vararg{Any}], nothing else len = length(tti.parameters) last = tti.parameters[len] @@ -558,12 +558,12 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft) if va elts[len] = Vararg{elts[len]} end - return elts + return elts, nothing end elseif tti0 === SimpleVector || tti0 === Any - return Any[Vararg{Any}] + return Any[Vararg{Any}], nothing elseif tti0 <: Array - return Any[Vararg{eltype(tti0)}] + return Any[Vararg{eltype(tti0)}], nothing else return abstract_iteration(interp, itft, typ, vtypes, sv) end @@ -572,7 +572,7 @@ end # simulate iteration protocol on container type up to fixpoint function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @nospecialize(itertype), vtypes::VarTable, sv::InferenceState) if !isdefined(Main, :Base) || !isdefined(Main.Base, :iterate) || !isconst(Main.Base, :iterate) - return Any[Vararg{Any}] + return Any[Vararg{Any}], nothing end if itft === nothing iteratef = getfield(Main.Base, :iterate) @@ -580,22 +580,27 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n elseif isa(itft, Const) iteratef = itft.val else - return Any[Vararg{Any}] + return Any[Vararg{Any}], nothing end @assert !isvarargtype(itertype) - stateordonet = abstract_call_known(interp, iteratef, nothing, Any[itft, itertype], vtypes, sv).rt + call = abstract_call_known(interp, iteratef, nothing, Any[itft, itertype], vtypes, sv) + stateordonet = call.rt + info = call.info # Return Bottom if this is not an iterator. # WARNING: Changes to the iteration protocol must be reflected here, # this is not just an optimization. - stateordonet === Bottom && return Any[Bottom] + stateordonet === Bottom && return Any[Bottom], AbstractIterationInfo(Any[Bottom], Any[info]) valtype = statetype = Bottom ret = Any[] + states = Any[stateordonet] + infos = Any[info] + # Try to unroll the iteration up to MAX_TUPLE_SPLAT, which covers any finite # length iterators, or interesting prefix while true stateordonet_widened = widenconst(stateordonet) if stateordonet_widened === Nothing - return ret + return ret, AbstractIterationInfo(states, infos) end if Nothing <: stateordonet_widened || length(ret) >= InferenceParams(interp).MAX_TUPLE_SPLAT break @@ -607,12 +612,15 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n # If there's no new information in this statetype, don't bother continuing, # the iterator won't be finite. if nstatetype ⊑ statetype - return Any[Bottom] + return Any[Bottom], nothing end valtype = getfield_tfunc(stateordonet, Const(1)) push!(ret, valtype) statetype = nstatetype - stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv).rt + call = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv) + stateordonet = call.rt + push!(states, stateordonet) + push!(infos, call.info) end # From here on, we start asking for results on the widened types, rather than # the precise (potentially const) state type @@ -629,7 +637,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n if nounion.parameters[1] <: valtype && nounion.parameters[2] <: statetype if typeintersect(stateordonet, Nothing) === Union{} # Reached a fixpoint, but Nothing is not possible => iterator is infinite or failing - return Any[Bottom] + return Any[Bottom], nothing end break end @@ -637,7 +645,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n statetype = tmerge(statetype, nounion.parameters[2]) end push!(ret, Vararg{valtype}) - return ret + return ret, nothing end # do apply(af, fargs...), where af is a function value @@ -656,13 +664,15 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe nargs = length(aargtypes) splitunions = 1 < countunionsplit(aargtypes) <= InferenceParams(interp).MAX_APPLY_UNION_ENUM ctypes = Any[Any[aft]] + infos = [Union{Nothing, AbstractIterationInfo}[]] for i = 1:nargs ctypes´ = [] + infos′ = [] for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]]) if !isvarargtype(ti) - cti = precise_container_type(interp, itft, ti, vtypes, sv) + cti, info = precise_container_type(interp, itft, ti, vtypes, sv) else - cti = precise_container_type(interp, itft, unwrapva(ti), vtypes, sv) + cti, info = precise_container_type(interp, itft, unwrapva(ti), vtypes, sv) # We can't represent a repeating sequence of the same types, # so tmerge everything together to get one type that represents # everything. @@ -678,19 +688,29 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe if _any(t -> t === Bottom, cti) continue end - for ct in ctypes + for j = 1:length(ctypes) + ct = ctypes[j] if isvarargtype(ct[end]) + # This is vararg, we're not gonna be able to do any inling, + # drop the info + info = nothing + tail = tuple_tail_elem(unwrapva(ct[end]), cti) push!(ctypes´, push!(ct[1:(end - 1)], tail)) else push!(ctypes´, append!(ct[:], cti)) end + push!(infos′, push!(copy(infos[j]), info)) end end ctypes = ctypes´ + infos = infos′ end - local info = nothing - for ct in ctypes + retinfos = ApplyCallInfo[] + retinfo = UnionSplitApplyCallInfo(retinfos) + for i = 1:length(ctypes) + ct = ctypes[i] + arginfo = infos[i] lct = length(ct) # truncate argument list at the first Vararg for i = 1:lct-1 @@ -701,15 +721,17 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe end end call = abstract_call(interp, nothing, ct, vtypes, sv, max_methods) - info = call.info + push!(retinfos, ApplyCallInfo(call.info, arginfo)) res = tmerge(res, call.rt) if res === Any + # No point carrying forward the info, we're not gonna inline it anyway + retinfo = nothing break end end # TODO: Add a special info type to capture all the iteration info. # For now, only propagate info if we don't also union-split the iteration - return CallMeta(res, length(ctypes) == 1 ? info : false) + return CallMeta(res, retinfo) end function is_method_pure(method::Method, @nospecialize(sig), sparams::SimpleVector) @@ -779,7 +801,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U end rt = builtin_tfunction(interp, f, argtypes[2:end], sv) if f === getfield && isa(fargs, Vector{Any}) && la == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] ⊑ Tuple - cti = precise_container_type(interp, nothing, argtypes[2], vtypes, sv) + cti, _ = precise_container_type(interp, nothing, argtypes[2], vtypes, sv) idx = argtypes[3].val if 1 <= idx <= length(cti) rt = unwrapva(cti[idx]) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 91458186a1c6be..8ba39480d62327 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -330,7 +330,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector terminator = item.ir[SSAValue(last(inline_cfg.blocks[1].stmts))] #compact[idx] = nothing inline_compact = IncrementalCompact(compact, item.ir, compact.result_idx) - for (idx′, stmt′) in inline_compact + for ((_, idx′), stmt′) in inline_compact # This dance is done to maintain accurate usage counts in the # face of rename_arguments! mutating in place - should figure out # something better eventually. @@ -360,7 +360,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector pn = PhiNode() #compact[idx] = nothing inline_compact = IncrementalCompact(compact, item.ir, compact.result_idx) - for (idx′, stmt′) in inline_compact + for ((_, idx′), stmt′) in inline_compact inline_compact[idx′] = nothing stmt′ = ssa_substitute!(idx′, stmt′, argexprs, item.method.sig, item.sparams, linetable_offset, boundscheck_idx, compact) if isa(stmt′, ReturnNode) @@ -529,8 +529,8 @@ function batch_inline!(todo::Vector{Any}, ir::IRCode, linetable::Vector{LineInfo resize!(compact, nnewnodes) item = popfirst!(todo) inline_idx = item.idx - for (idx, stmt) in compact - if compact.idx - 1 == inline_idx + for ((old_idx, idx), stmt) in compact + if old_idx == inline_idx argexprs = copy(stmt.args) refinish = false if compact.result_idx == first(compact.result_bbs[compact.active_result_bb].stmts) @@ -550,7 +550,7 @@ function batch_inline!(todo::Vector{Any}, ir::IRCode, linetable::Vector{LineInfo end end if isa(item, InliningTodo) - compact.ssa_rename[compact.idx-1] = ir_inline_item!(compact, idx, argexprs, linetable, item, boundscheck, state.todo_bbs) + compact.ssa_rename[old_idx] = ir_inline_item!(compact, idx, argexprs, linetable, item, boundscheck, state.todo_bbs) elseif isa(item, UnionSplit) ir_inline_unionsplit!(compact, idx, argexprs, linetable, item, boundscheck, state.todo_bbs) end @@ -596,49 +596,74 @@ function spec_lambda(@nospecialize(atype), sv::OptimizationState, @nospecialize( end # This assumes the caller has verified that all arguments to the _apply call are Tuples. -function rewrite_apply_exprargs!(ir::IRCode, idx::Int, argexprs::Vector{Any}, atypes::Vector{Any}, arg_start::Int) +function rewrite_apply_exprargs!(ir::IRCode, todo, idx::Int, argexprs::Vector{Any}, atypes::Vector{Any}, arginfos::Vector{Any}, arg_start::Int, sv) new_argexprs = Any[argexprs[arg_start]] new_atypes = Any[atypes[arg_start]] # loop over original arguments and flatten any known iterators for i in (arg_start+1):length(argexprs) def = argexprs[i] def_type = atypes[i] - if def_type isa PartialStruct - # def_type.typ <: Tuple is assumed - def_atypes = def_type.fields - else - def_atypes = Any[] - if isa(def_type, Const) # && isa(def_type.val, Union{Tuple, SimpleVector}) is implied - for p in def_type.val - push!(def_atypes, Const(p)) - end + thisarginfo = arginfos[i-arg_start] + if thisarginfo === nothing + if def_type isa PartialStruct + # def_type.typ <: Tuple is assumed + def_atypes = def_type.fields else - ti = widenconst(def_type) - if ti.name === NamedTuple_typename - ti = ti.parameters[2] - end - for p in ti.parameters - if isa(p, DataType) && isdefined(p, :instance) - # replace singleton types with their equivalent Const object - p = Const(p.instance) - elseif isconstType(p) - p = Const(p.parameters[1]) + def_atypes = Any[] + if isa(def_type, Const) # && isa(def_type.val, Union{Tuple, SimpleVector}) is implied + for p in def_type.val + push!(def_atypes, Const(p)) + end + else + ti = widenconst(def_type) + if ti.name === NamedTuple_typename + ti = ti.parameters[2] + end + for p in ti.parameters + if isa(p, DataType) && isdefined(p, :instance) + # replace singleton types with their equivalent Const object + p = Const(p.instance) + elseif isconstType(p) + p = Const(p.parameters[1]) + end + push!(def_atypes, p) end - push!(def_atypes, p) end end - end - # now push flattened types into new_atypes and getfield exprs into new_argexprs - for j in 1:length(def_atypes) - def_atype = def_atypes[j] - if isa(def_atype, Const) && is_inlineable_constant(def_atype.val) - new_argexpr = quoted(def_atype.val) - else - new_call = Expr(:call, Core.getfield, def, j) - new_argexpr = insert_node!(ir, idx, def_atype, new_call) + # now push flattened types into new_atypes and getfield exprs into new_argexprs + for j in 1:length(def_atypes) + def_atype = def_atypes[j] + if isa(def_atype, Const) && is_inlineable_constant(def_atype.val) + new_argexpr = quoted(def_atype.val) + else + new_call = Expr(:call, Core.getfield, def, j) + new_argexpr = insert_node!(ir, idx, def_atype, new_call) + end + push!(new_argexprs, new_argexpr) + push!(new_atypes, def_atype) + end + else + state = Core.svec() + for i = 1:length(thisarginfo.each) + mthd = thisarginfo.each[i] + T = thisarginfo.it_rt[i] + new_stmt = Expr(:call, argexprs[2], def, state...) + state1 = insert_node!(ir, idx, T, new_stmt) + new_sig = with_atype(call_sig(ir, new_stmt)) + # See if we can inline this call to `iterate` + analyze_single_call!(ir, todo, state1.id, new_stmt, + new_sig, T, Any[mthd], sv) + if i != length(thisarginfo.each) + valT = getfield_tfunc(T, Const(1)) + val_extracted = insert_node!(ir, idx, valT, + Expr(:call, Core.getfield, state1, 1)) + push!(new_argexprs, val_extracted) + push!(new_atypes, valT) + state_extracted = insert_node!(ir, idx, getfield_tfunc(T, Const(2)), + Expr(:call, Core.getfield, state1, 2)) + state = Core.svec(state_extracted) + end end - push!(new_argexprs, new_argexpr) - push!(new_atypes, def_atype) end end return new_argexprs, new_atypes @@ -876,9 +901,22 @@ function call_sig(ir::IRCode, stmt::Expr) Signature(f, ft, atypes) end -function inline_apply!(ir::IRCode, idx::Int, sig::Signature, params::OptimizationParams) +function inline_apply!(ir::IRCode, todo, idx::Int, sig::Signature, params::OptimizationParams, sv) stmt = ir.stmts[idx][:inst] while sig.f === Core._apply || sig.f === Core._apply_iterate + info = ir.stmts[idx][:info] + if isa(info, UnionSplitApplyCallInfo) + if length(info.infos) != 1 + # TODO: Handle union split applies? + new_info = info = nothing + else + info = info.infos[1] + new_info = info.call + end + else + @assert info === nothing || info === false + new_info = info = nothing + end arg_start = sig.f === Core._apply ? 2 : 3 atypes = sig.atypes if arg_start > length(atypes) @@ -906,15 +944,22 @@ function inline_apply!(ir::IRCode, idx::Int, sig::Signature, params::Optimizatio end # Try to figure out the signature of the function being called # and if rewrite_apply_exprargs can deal with this form + infos = Any[] for i = (arg_start + 1):length(atypes) - # TODO: We could basically run the iteration protocol here + thisarginfo = nothing if !is_valid_type_for_apply_rewrite(atypes[i], params) - return nothing + if isa(info, ApplyCallInfo) && info.arginfo[i-arg_start] !== nothing + thisarginfo = info.arginfo[i-arg_start] + else + return nothing + end end + push!(infos, thisarginfo) end # Independent of whether we can inline, the above analysis allows us to rewrite # this apply call to a regular call - stmt.args, atypes = rewrite_apply_exprargs!(ir, idx, stmt.args, atypes, arg_start) + stmt.args, atypes = rewrite_apply_exprargs!(ir, todo, idx, stmt.args, atypes, infos, arg_start, sv) + ir.stmts[idx][:info] = new_info has_free_typevars(ft) && return nothing f = singleton_type(ft) sig = Signature(f, ft, atypes) @@ -945,7 +990,7 @@ end # Handles all analysis and inlining of intrinsics and builtins. In particular, # this method does not access the method table or otherwise process generic # functions. -function process_simple!(ir::IRCode, idx::Int, params::OptimizationParams, world::UInt) +function process_simple!(ir::IRCode, todo, idx::Int, params::OptimizationParams, world::UInt, sv) stmt = ir.stmts[idx][:inst] stmt isa Expr || return nothing if stmt.head === :splatnew @@ -959,7 +1004,7 @@ function process_simple!(ir::IRCode, idx::Int, params::OptimizationParams, world sig === nothing && return nothing # Handle _apply - sig = inline_apply!(ir, idx, sig, params) + sig = inline_apply!(ir, todo, idx, sig, params, sv) sig === nothing && return nothing # Check if we match any of the early inliners @@ -997,7 +1042,7 @@ end # This is not currently called in the regular course, but may be needed # if we ever want to re-run inlining again later in the pass pipeline after # additional type information was discovered. -function recompute_method_matches(atype, sv) +function recompute_method_matches(@nospecialize(atype), sv::OptimizationState) # Regular case: Retrieve matching methods from cache (or compute them) # World age does not need to be taken into account in the cache # because it is forwarded from type inference through `sv.params` @@ -1010,13 +1055,95 @@ function recompute_method_matches(atype, sv) MethodMatchInfo(meth, ambig) end +function analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, sv) + cases = Pair{Any, Any}[] + signature_union = Union{} + only_method = nothing # keep track of whether there is one matching method + too_many = false + local meth + local fully_covered = true + for i in 1:length(infos) + info = infos[i] + meth = info.applicable + if meth === false || info.ambig + # Too many applicable methods + too_many = true + break + elseif length(meth) == 0 + # No applicable methods; try next union split + continue + elseif length(meth) == 1 && only_method !== false + if only_method === nothing + only_method = meth[1][3] + elseif only_method !== meth[1][3] + only_method = false + end + else + only_method = false + end + for match in meth::Vector{Any} + (metharg, methsp, method) = (match[1]::Type, match[2]::SimpleVector, match[3]::Method) + # TODO: This could be better + signature_union = Union{signature_union, metharg} + if !isdispatchtuple(metharg) + fully_covered = false + continue + end + case_sig = Signature(sig.f, sig.ft, sig.atypes, metharg) + case = analyze_method!(idx, case_sig, metharg, methsp, method, + stmt, sv, false, nothing, calltype) + if case === nothing + fully_covered = false + continue + elseif _any(p->p[1] === metharg, cases) + continue + end + push!(cases, Pair{Any,Any}(metharg, case)) + end + end + + too_many && return + + signature_fully_covered = sig.atype <: signature_union + # If we're fully covered and there's only one applicable method, + # we inline, even if the signature is not a dispatch tuple + if signature_fully_covered && length(cases) == 0 && only_method isa Method + if length(infos) > 1 + method = only_method + (metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), + sig.atype, method.sig)::SimpleVector + else + @assert length(meth) == 1 + (metharg, methsp, method) = (meth[1][1]::Type, meth[1][2]::SimpleVector, meth[1][3]::Method) + end + fully_covered = true + case = analyze_method!(idx, sig, metharg, methsp, method, + stmt, sv, false, nothing, calltype) + case === nothing && return + push!(cases, Pair{Any,Any}(metharg, case)) + end + if !signature_fully_covered + fully_covered = false + end + + # If we only have one case and that case is fully covered, we may either + # be able to do the inlining now (for constant cases), or push it directly + # onto the todo list + if fully_covered && length(cases) == 1 + handle_single_case!(ir, stmt, idx, cases[1][2], false, todo) + return + end + length(cases) == 0 && return + push!(todo, UnionSplit(idx, fully_covered, sig.atype, cases)) +end + function assemble_inline_todo!(ir::IRCode, sv::OptimizationState) # todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie) todo = Any[] skip = find_throw_blocks(ir.stmts.inst, RefValue(ir)) for idx in 1:length(ir.stmts) idx in skip && continue - r = process_simple!(ir, idx, sv.params, sv.world) + r = process_simple!(ir, todo, idx, sv.params, sv.world, sv) r === nothing && continue stmt = ir.stmts[idx][:inst] @@ -1039,107 +1166,21 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState) nu = countunionsplit(sig.atypes) if nu == 1 || nu > sv.params.MAX_UNION_SPLITTING if !isa(info, MethodMatchInfo) - info = nothing + info = recompute_method_matches(sig.atype, sv) end infos = Any[info] - splits = Any[sig.atype] else if !isa(info, UnionSplitInfo) - splits = Any[] + infos = MethodMatchInfo[] for union_sig in UnionSplitSignature(sig.atypes) - push!(splits, argtypes_to_type(union_sig)) + push!(infos, recompute_method_matches(union_sig, sv)) end - infos = Any[nothing for i = 1:length(splits)] else - splits = info.sigs infos = info.matches end end - cases = Pair{Any, Any}[] - signature_union = Union{} - only_method = nothing # keep track of whether there is one matching method - too_many = false - local meth - local fully_covered = true - for i in 1:length(splits) - atype = splits[i] - info = infos[i] - if info === nothing - info = recompute_method_matches(atype, sv) - end - meth = info.applicable - if meth === false || info.ambig - # Too many applicable methods - # Or there is a (partial?) ambiguity - too_many = true - break - elseif length(meth) == 0 - # No applicable methods; try next union split - continue - elseif length(meth) == 1 && only_method !== false - if only_method === nothing - only_method = meth[1][3] - elseif only_method !== meth[1][3] - only_method = false - end - else - only_method = false - end - for match in meth::Vector{Any} - (metharg, methsp, method) = (match[1]::Type, match[2]::SimpleVector, match[3]::Method) - # TODO: This could be better - signature_union = Union{signature_union, metharg} - if !isdispatchtuple(metharg) - fully_covered = false - continue - end - case_sig = Signature(sig.f, sig.ft, sig.atypes, metharg) - case = analyze_method!(idx, case_sig, metharg, methsp, method, - stmt, sv, false, nothing, calltype) - if case === nothing - fully_covered = false - continue - elseif _any(p->p[1] === metharg, cases) - continue - end - push!(cases, Pair{Any,Any}(metharg, case)) - end - end - - too_many && continue - - signature_fully_covered = sig.atype <: signature_union - # If we're fully covered and there's only one applicable method, - # we inline, even if the signature is not a dispatch tuple - if signature_fully_covered && length(cases) == 0 && only_method isa Method - if length(splits) > 1 - method = only_method - (metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), - sig.atype, method.sig)::SimpleVector - else - @assert length(meth) == 1 - (metharg, methsp, method) = (meth[1][1]::Type, meth[1][2]::SimpleVector, meth[1][3]::Method) - end - fully_covered = true - case = analyze_method!(idx, sig, metharg, methsp, method, - stmt, sv, false, nothing, calltype) - case === nothing && continue - push!(cases, Pair{Any,Any}(metharg, case)) - end - if !signature_fully_covered - fully_covered = false - end - - # If we only have one case and that case is fully covered, we may either - # be able to do the inlining now (for constant cases), or push it directly - # onto the todo list - if fully_covered && length(cases) == 1 - handle_single_case!(ir, stmt, idx, cases[1][2], false, todo) - continue - end - length(cases) == 0 && continue - push!(todo, UnionSplit(idx, fully_covered, sig.atype, cases)) + analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, sv) end todo end diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index 5e34e20831c82d..621072527a334a 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -275,8 +275,11 @@ function getindex(x::IRCode, s::SSAValue) end function setindex!(x::IRCode, @nospecialize(repl), s::SSAValue) - @assert s.id <= length(x.stmts) - x.stmts[s.id][:inst] = repl + if s.id <= length(x.stmts) + x.stmts[s.id][:inst] = repl + else + x.new_nodes.stmts[s.id - length(x.stmts)][:inst] = repl + end return x end @@ -1074,7 +1077,9 @@ function process_newnode!(compact::IncrementalCompact, new_idx::Int, new_node_en finish_current_bb!(compact, active_bb, old_result_idx) end (old_result_idx == result_idx) && return iterate(compact, (idx, active_bb)) - return Pair{Int, Any}(old_result_idx, compact.result[old_result_idx][:inst]), (idx, active_bb) + return Pair{Pair{Int, Int}, Any}( + Pair{Int,Int}(new_idx,old_result_idx), + compact.result[old_result_idx][:inst]), (idx, active_bb) end struct CompactPeekIterator @@ -1141,9 +1146,9 @@ function iterate(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}= # Move to next block compact.idx += 1 if finish_current_bb!(compact, active_bb, old_result_idx, true) - return iterate(compact, (compact.idx, active_bb + 1)) + return iterate(compact, (compact.idx-1, active_bb + 1)) else - return Pair{Int, Any}(old_result_idx, compact.result[old_result_idx][:inst]), (compact.idx, active_bb + 1) + return Pair{Pair{Int, Int}, Any}(Pair{Int,Int}(compact.idx-1, old_result_idx), compact.result[old_result_idx][:inst]), (compact.idx, active_bb + 1) end end if compact.new_nodes_idx <= length(compact.perm) && @@ -1180,7 +1185,8 @@ function iterate(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}= @goto restart end @assert isassigned(compact.result.inst, old_result_idx) - return Pair{Int, Any}(old_result_idx, compact.result[old_result_idx][:inst]), (compact.idx, active_bb) + return Pair{Pair{Int,Int}, Any}(Pair{Int,Int}(compact.idx-1, old_result_idx), + compact.result[old_result_idx][:inst]), (compact.idx, active_bb) end function maybe_erase_unused!(extra_worklist, compact, idx, callback = x->nothing) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 4b444aa5047150..a2dc6d6e75f605 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -531,7 +531,7 @@ function getfield_elim_pass!(ir::IRCode, domtree::DomTree) lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}() revisit_worklist = Int[] #ndone, nmax = 0, 200 - for (idx, stmt) in compact + for ((_, idx), stmt) in compact isa(stmt, Expr) || continue #ndone >= nmax && continue #ndone += 1 @@ -872,7 +872,7 @@ function adce_pass!(ir::IRCode) phi_uses = fill(0, length(ir.stmts) + length(ir.new_nodes)) all_phis = Int[] compact = IncrementalCompact(ir) - for (idx, stmt) in compact + for ((_, idx), stmt) in compact if isa(stmt, PhiNode) push!(all_phis, idx) end diff --git a/base/compiler/stmtinfo.jl b/base/compiler/stmtinfo.jl index dda51817f76bef..cc39903d08bbab 100644 --- a/base/compiler/stmtinfo.jl +++ b/base/compiler/stmtinfo.jl @@ -4,15 +4,27 @@ struct MethodMatchInfo end struct UnionSplitInfo - # TODO: In principle we shouldn't have to store this, but could just - # recompute it using `switchtuple` union. However, it is not the case - # that if T == S, then switchtupleunion(T) == switchtupleunion(S), e.g. for - # T = Tuple{Tuple{Union{Float64, Int64},String}} - # S = Tuple{Union{Tuple{Float64, String}, Tuple{Int64, String}}} - sigs::Vector{Any} matches::Vector{MethodMatchInfo} end +struct AbstractIterationInfo + # The rt for each iterate call + it_rt::Vector{Any} + # The call info, for each implied call to `iterate` (in order) + each::Vector{Any} +end + +struct ApplyCallInfo + # The info for the call itself + call::Any + # AbstractIterationInfo for each argument, if applicable + arginfo::Vector{Union{Nothing, AbstractIterationInfo}} +end + +struct UnionSplitApplyCallInfo + infos::Vector{ApplyCallInfo} +end + struct CallMeta rt::Any info::Any