diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 7bb583ffd8bbc6..a3faba0e76f5d9 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -75,7 +75,7 @@ function ssa_inlining_pass!(ir::IRCode, linetable::Vector{LineInfoNode}, state:: @timeit "analysis" todo = assemble_inline_todo!(ir, state) isempty(todo) && return ir # Do the actual inlining for every call we identified - @timeit "execution" ir = batch_inline!(todo, ir, linetable, propagate_inbounds) + @timeit "execution" ir = batch_inline!(todo, ir, linetable, propagate_inbounds, state.params) return ir end @@ -210,7 +210,8 @@ end function cfg_inline_unionsplit!(ir::IRCode, idx::Int, (; fully_covered, #=atype,=# cases, bbs)::UnionSplit, - state::CFGInliningState) + state::CFGInliningState, + params::OptimizationParams) inline_into_block!(state, block_for_inst(ir, idx)) from_bbs = Int[] delete!(state.split_targets, length(state.new_cfg_blocks)) @@ -233,7 +234,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int, push!(from_bbs, length(state.new_cfg_blocks)) # TODO: Right now we unconditionally generate a fallback block # in case of subtyping errors - This is probably unnecessary. - if true # i != length(cases) || !fully_covered + if i != length(cases) || (!fully_covered || !params.trust_inference) # This block will have the next condition or the final else case push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx))) push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks)) @@ -457,12 +458,13 @@ const FATAL_TYPE_BOUND_ERROR = ErrorException("fatal error in type inference (ty function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::Vector{Any}, linetable::Vector{LineInfoNode}, (; fully_covered, atype, cases, bbs)::UnionSplit, - boundscheck::Symbol, todo_bbs::Vector{Tuple{Int, Int}}) + boundscheck::Symbol, todo_bbs::Vector{Tuple{Int, Int}}, + params::OptimizationParams) stmt, typ, line = compact.result[idx][:inst], compact.result[idx][:type], compact.result[idx][:line] join_bb = bbs[end] pn = PhiNode() local bb = compact.active_result_bb - @assert length(bbs) > length(cases) + @assert length(bbs) >= length(cases) for i in 1:length(cases) ithcase = cases[i] metharg = ithcase.sig @@ -472,21 +474,23 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, cond = true aparams, mparams = atype.parameters::SimpleVector, metharg.parameters::SimpleVector @assert length(aparams) == length(mparams) - for i in 1:length(aparams) - a, m = aparams[i], mparams[i] - # If this is always true, we don't need to check for it - a <: m && continue - # Generate isa check - isa_expr = Expr(:call, isa, argexprs[i], m) - ssa = insert_node_here!(compact, NewInstruction(isa_expr, Bool, line)) - if cond === true - cond = ssa - else - and_expr = Expr(:call, and_int, cond, ssa) - cond = insert_node_here!(compact, NewInstruction(and_expr, Bool, line)) + if i != length(cases) || !fully_covered || !params.trust_inference + for i in 1:length(aparams) + a, m = aparams[i], mparams[i] + # If this is always true, we don't need to check for it + a <: m && continue + # Generate isa check + isa_expr = Expr(:call, isa, argexprs[i], m) + ssa = insert_node_here!(compact, NewInstruction(isa_expr, Bool, line)) + if cond === true + cond = ssa + else + and_expr = Expr(:call, and_int, cond, ssa) + cond = insert_node_here!(compact, NewInstruction(and_expr, Bool, line)) + end end + insert_node_here!(compact, NewInstruction(GotoIfNot(cond, next_cond_bb), Union{}, line)) end - insert_node_here!(compact, NewInstruction(GotoIfNot(cond, next_cond_bb), Union{}, line)) bb = next_cond_bb - 1 finish_current_bb!(compact, 0) argexprs′ = argexprs @@ -525,10 +529,12 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, bb += 1 # We're now in the fall through block, decide what to do if fully_covered - e = Expr(:call, GlobalRef(Core, :throw), FATAL_TYPE_BOUND_ERROR) - insert_node_here!(compact, NewInstruction(e, Union{}, line)) - insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line)) - finish_current_bb!(compact, 0) + if !params.trust_inference + e = Expr(:call, GlobalRef(Core, :throw), FATAL_TYPE_BOUND_ERROR) + insert_node_here!(compact, NewInstruction(e, Union{}, line)) + insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line)) + finish_current_bb!(compact, 0) + end else ssa = insert_node_here!(compact, NewInstruction(stmt, typ, line)) push!(pn.edges, bb) @@ -541,12 +547,12 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, return insert_node_here!(compact, NewInstruction(pn, typ, line)) end -function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vector{LineInfoNode}, propagate_inbounds::Bool) +function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vector{LineInfoNode}, propagate_inbounds::Bool, params::OptimizationParams) # Compute the new CFG first (modulo statement ranges, which will be computed below) state = CFGInliningState(ir) for (idx, item) in todo if isa(item, UnionSplit) - cfg_inline_unionsplit!(ir, idx, item::UnionSplit, state) + cfg_inline_unionsplit!(ir, idx, item::UnionSplit, state, params) else item = item::InliningTodo spec = item.spec::ResolvedInliningSpec @@ -599,7 +605,7 @@ function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vect if isa(item, InliningTodo) compact.ssa_rename[old_idx] = ir_inline_item!(compact, idx, argexprs, linetable, item, boundscheck, state.todo_bbs) elseif isa(item, UnionSplit) - compact.ssa_rename[old_idx] = ir_inline_unionsplit!(compact, idx, argexprs, linetable, item, boundscheck, state.todo_bbs) + compact.ssa_rename[old_idx] = ir_inline_unionsplit!(compact, idx, argexprs, linetable, item, boundscheck, state.todo_bbs, params) end compact[idx] = nothing refinish && finish_current_bb!(compact, 0) @@ -845,7 +851,7 @@ end function handle_single_case!( ir::IRCode, idx::Int, stmt::Expr, - @nospecialize(case), todo::Vector{Pair{Int, Any}}, isinvoke::Bool = false) + @nospecialize(case), todo::Vector{Pair{Int, Any}}, params::OptimizationParams, isinvoke::Bool = false) if isa(case, ConstantCase) ir[SSAValue(idx)][:inst] = case.val elseif isa(case, MethodInstance) @@ -1017,12 +1023,12 @@ function inline_invoke!( validate_sparams(mi.sparam_vals) || return nothing if argtypes_to_type(argtypes) <: mi.def.sig state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) - handle_single_case!(ir, idx, stmt, item, todo, true) + handle_single_case!(ir, idx, stmt, item, todo, state.params, true) return nothing end end item = analyze_method!(match, argtypes, flag, state) - handle_single_case!(ir, idx, stmt, item, todo, true) + handle_single_case!(ir, idx, stmt, item, todo, state.params, true) return nothing end @@ -1190,7 +1196,7 @@ function analyze_single_call!( fully_covered &= atype <: signature_union end - handle_cases!(ir, idx, stmt, atype, cases, fully_covered, todo) + handle_cases!(ir, idx, stmt, atype, cases, fully_covered, todo, state.params) end # similar to `analyze_single_call!`, but with constant results @@ -1241,7 +1247,7 @@ function handle_const_call!( fully_covered &= atype <: signature_union end - handle_cases!(ir, idx, stmt, atype, cases, fully_covered, todo) + handle_cases!(ir, idx, stmt, atype, cases, fully_covered, todo, state.params) end function handle_match!( @@ -1270,12 +1276,13 @@ function handle_const_result!( end function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, @nospecialize(atype), - cases::Vector{InliningCase}, fully_covered::Bool, todo::Vector{Pair{Int, Any}}) + cases::Vector{InliningCase}, fully_covered::Bool, todo::Vector{Pair{Int, Any}}, + params::OptimizationParams) # If we only have one case and that case is fully covered, we may either # be able to do the inlining now (for constant cases), or push it directly # onto the todo list if fully_covered && length(cases) == 1 - handle_single_case!(ir, idx, stmt, cases[1].item, todo) + handle_single_case!(ir, idx, stmt, cases[1].item, todo, params) elseif length(cases) > 0 push!(todo, idx=>UnionSplit(fully_covered, atype, cases)) end @@ -1289,7 +1296,7 @@ function handle_const_opaque_closure_call!( isdispatchtuple(item.mi.specTypes) || return validate_sparams(item.mi.sparam_vals) || return state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) - handle_single_case!(ir, idx, stmt, item, todo) + handle_single_case!(ir, idx, stmt, item, todo, state.params) return nothing end @@ -1334,7 +1341,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) sig, state, todo) else item = analyze_method!(info.match, sig.argtypes, flag, state) - handle_single_case!(ir, idx, stmt, item, todo) + handle_single_case!(ir, idx, stmt, item, todo, state.params) end continue end diff --git a/base/compiler/types.jl b/base/compiler/types.jl index c72896b61b0e53..e5894ab3d3f899 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -54,6 +54,8 @@ struct OptimizationParams inline_tupleret_bonus::Int # extra inlining willingness for non-concrete tuple return types (in hopes of splitting it up) inline_error_path_cost::Int # cost of (un-optimized) calls in blocks that throw + trust_inference::Bool + # Duplicating for now because optimizer inlining requires it. # Keno assures me this will be removed in the near future MAX_METHODS::Int @@ -69,6 +71,7 @@ struct OptimizationParams max_methods::Int = 3, tuple_splat::Int = 32, union_splitting::Int = 4, + trust_inference::Bool = false ) return new( inlining, @@ -76,9 +79,10 @@ struct OptimizationParams inline_nonleaf_penalty, inline_tupleret_bonus, inline_error_path_cost, + trust_inference, max_methods, tuple_splat, - union_splitting, + union_splitting ) end end