Skip to content

Commit

Permalink
Allow inlining methods with unmatched type parameters (#45062)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Atol authored Aug 25, 2022
1 parent 36aab14 commit 19f44b6
Show file tree
Hide file tree
Showing 17 changed files with 355 additions and 82 deletions.
143 changes: 117 additions & 26 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,28 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
boundscheck = :off
end
end
if !validate_sparams(sparam_vals)
if def.isva
nonva_args = argexprs[1:end-1]
va_arg = argexprs[end]
tuple_call = Expr(:call, TOP_TUPLE, def, nonva_args...)
tuple_type = tuple_tfunc(Any[argextype(arg, compact) for arg in nonva_args])
tupl = insert_node_here!(compact, NewInstruction(tuple_call, tuple_type, topline))
apply_iter_expr = Expr(:call, Core._apply_iterate, iterate, Core._compute_sparams, tupl, va_arg)
sparam_vals = insert_node_here!(compact,
effect_free(NewInstruction(apply_iter_expr, SimpleVector, topline)))
else
sparam_vals = insert_node_here!(compact,
effect_free(NewInstruction(Expr(:call, Core._compute_sparams, def, argexprs...), SimpleVector, topline)))
end
end
# If the iterator already moved on to the next basic block,
# temporarily re-open in again.
local return_value
sig = def.sig
# Special case inlining that maintains the current basic block if there's only one BB in the target
new_new_offset = length(compact.new_new_nodes)
late_fixup_offset = length(compact.late_fixup)
if spec.linear_inline_eligible
#compact[idx] = nothing
inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx)
Expand All @@ -389,7 +406,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
# face of rename_arguments! mutating in place - should figure out
# something better eventually.
inline_compact[idx′] = nothing
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact)
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact)
if isa(stmt′, ReturnNode)
val = stmt′.val
return_value = SSAValue(idx′)
Expand All @@ -402,7 +419,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
end
inline_compact[idx′] = stmt′
end
just_fixup!(inline_compact)
just_fixup!(inline_compact, new_new_offset, late_fixup_offset)
compact.result_idx = inline_compact.result_idx
else
bb_offset, post_bb_id = popfirst!(todo_bbs)
Expand All @@ -416,7 +433,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx)
for ((_, idx′), stmt′) in inline_compact
inline_compact[idx′] = nothing
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact)
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact)
if isa(stmt′, ReturnNode)
if isdefined(stmt′, :val)
val = stmt′.val
Expand All @@ -436,7 +453,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
end
inline_compact[idx′] = stmt′
end
just_fixup!(inline_compact)
just_fixup!(inline_compact, new_new_offset, late_fixup_offset)
compact.result_idx = inline_compact.result_idx
compact.active_result_bb = inline_compact.active_result_bb
if length(pn.edges) == 1
Expand All @@ -460,7 +477,8 @@ function fix_va_argexprs!(compact::IncrementalCompact,
push!(tuple_typs, argextype(arg, compact))
end
tuple_typ = tuple_tfunc(tuple_typs)
push!(newargexprs, insert_node_here!(compact, NewInstruction(tuple_call, tuple_typ, line_idx)))
tuple_inst = NewInstruction(tuple_call, tuple_typ, line_idx)
push!(newargexprs, insert_node_here!(compact, tuple_inst))
return newargexprs
end

Expand Down Expand Up @@ -875,8 +893,26 @@ function validate_sparams(sparams::SimpleVector)
return true
end

function may_have_fcalls(m::Method)
may_have_fcall = true
if isdefined(m, :source)
src = m.source
isa(src, Vector{UInt8}) && (src = uncompressed_ir(m))
if isa(src, CodeInfo)
may_have_fcall = src.has_fcall
end
end
return may_have_fcall
end

function can_inline_typevars(m::MethodMatch, argtypes::Vector{Any})
may_have_fcalls(m.method) && return false
any(@nospecialize(x) -> x isa UnionAll, argtypes[2:end]) && return false
return true
end

function analyze_method!(match::MethodMatch, argtypes::Vector{Any}, invokesig,
flag::UInt8, state::InliningState)
flag::UInt8, state::InliningState, allow_typevars::Bool = false)
method = match.method
spec_types = match.spec_types

Expand All @@ -898,8 +934,9 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any}, invokesig,
end
end

# Bail out if any static parameters are left as TypeVar
validate_sparams(match.sparams) || return nothing
if !validate_sparams(match.sparams)
(allow_typevars && can_inline_typevars(match, argtypes)) || return nothing
end

et = state.et

Expand Down Expand Up @@ -1231,6 +1268,9 @@ function compute_inlining_cases(infos::Vector{MethodMatchInfo},
flag::UInt8, sig::Signature, state::InliningState)
argtypes = sig.argtypes
cases = InliningCase[]
local only_method = nothing
local meth::MethodLookupResult
local revisit_idx = nothing
local any_fully_covered = false
local handled_all_cases = true
for i in 1:length(infos)
Expand All @@ -1243,14 +1283,58 @@ function compute_inlining_cases(infos::Vector{MethodMatchInfo},
# No applicable methods; try next union split
handled_all_cases = false
continue
else
if length(meth) == 1 && only_method !== false
if only_method === nothing
only_method = meth[1].method
elseif only_method !== meth[1].method
only_method = false
end
else
only_method = false
end
end
for match in meth
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true)
for (j, match) in enumerate(meth)
any_fully_covered |= match.fully_covers
if !validate_sparams(match.sparams)
if !match.fully_covers
handled_all_cases = false
continue
end
if revisit_idx === nothing
revisit_idx = (i, j)
else
handled_all_cases = false
revisit_idx = nothing
end
else
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#false)
end
end
end

if !handled_all_cases
if handled_all_cases && revisit_idx !== nothing
# we handled everything except one match with unmatched sparams,
# so try to handle it by bypassing validate_sparams
(i, j) = revisit_idx
match = infos[i].results[j]
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#true)
elseif length(cases) == 0 && only_method isa Method
# if the signature is fully covered and there is only one applicable method,
# we can try to inline it even in the prescence of unmatched sparams
# -- But don't try it if we already tried to handle the match in the revisit_idx
# case, because that'll (necessarily) be the same method.
if length(infos) > 1
atype = argtypes_to_type(argtypes)
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), atype, only_method.sig)::SimpleVector
match = MethodMatch(metharg, methsp::SimpleVector, only_method, true)
else
@assert length(meth) == 1
match = meth[1]
end
handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#true) || return nothing
any_fully_covered = handled_all_cases = match.fully_covers
elseif !handled_all_cases
# if we've not seen all candidates, union split is valid only for dispatch tuples
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
end
Expand Down Expand Up @@ -1286,10 +1370,10 @@ function compute_inlining_cases(info::ConstCallInfo,
case = concrete_result_item(result, state)
push!(cases, InliningCase(result.mi.specTypes, case))
elseif isa(result, ConstPropResult)
handled_all_cases &= handle_const_prop_result!(result, argtypes, flag, state, cases, true)
handled_all_cases &= handle_const_prop_result!(result, argtypes, flag, state, cases, #=allow_abstract=#true)
else
@assert result === nothing
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true)
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#false)
end
end
end
Expand Down Expand Up @@ -1324,22 +1408,22 @@ end

function handle_match!(
match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
cases::Vector{InliningCase}, allow_abstract::Bool = false)
cases::Vector{InliningCase}, allow_abstract::Bool, allow_typevars::Bool)
spec_types = match.spec_types
allow_abstract || isdispatchtuple(spec_types) || return false
# we may see duplicated dispatch signatures here when a signature gets widened
# We may see duplicated dispatch signatures here when a signature gets widened
# during abstract interpretation: for the purpose of inlining, we can just skip
# processing this dispatch candidate
_any(case->case.sig === spec_types, cases) && return true
item = analyze_method!(match, argtypes, nothing, flag, state)
# processing this dispatch candidate (unless unmatched type parameters are present)
!allow_typevars && _any(case->case.sig === spec_types, cases) && return true
item = analyze_method!(match, argtypes, nothing, flag, state, allow_typevars)
item === nothing && return false
push!(cases, InliningCase(spec_types, item))
return true
end

function handle_const_prop_result!(
result::ConstPropResult, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
cases::Vector{InliningCase}, allow_abstract::Bool = false)
cases::Vector{InliningCase}, allow_abstract::Bool)
(; mi) = item = InliningTodo(result.result, argtypes)
spec_types = mi.specTypes
allow_abstract || isdispatchtuple(spec_types) || return false
Expand Down Expand Up @@ -1624,30 +1708,37 @@ function late_inline_special_case!(
end

function ssa_substitute!(idx::Int, @nospecialize(val), arg_replacements::Vector{Any},
@nospecialize(spsig), spvals::SimpleVector,
@nospecialize(spsig), spvals::Union{SimpleVector, SSAValue},
linetable_offset::Int32, boundscheck::Symbol, compact::IncrementalCompact)
compact.result[idx][:flag] &= ~IR_FLAG_INBOUNDS
compact.result[idx][:line] += linetable_offset
return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck)
return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck, compact, idx)
end

function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
@nospecialize(spsig), spvals::SimpleVector, boundscheck::Symbol)
@nospecialize(spsig), spvals::Union{SimpleVector, SSAValue},
boundscheck::Symbol, compact::IncrementalCompact, idx::Int)
if isa(val, Argument)
return arg_replacements[val.n]
end
if isa(val, Expr)
e = val::Expr
head = e.head
if head === :static_parameter
return quoted(spvals[e.args[1]::Int])
elseif head === :cfunction
if isa(spvals, SimpleVector)
return quoted(spvals[e.args[1]::Int])
else
ret = insert_node!(compact, SSAValue(idx),
effect_free(NewInstruction(Expr(:call, Core._svec_ref, false, spvals, e.args[1]), Any)))
return ret
end
elseif head === :cfunction && isa(spvals, SimpleVector)
@assert !isa(spsig, UnionAll) || !isempty(spvals)
e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals)
e.args[4] = svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
for argt in e.args[4]::SimpleVector ]...)
elseif head === :foreigncall
elseif head === :foreigncall && isa(spvals, SimpleVector)
@assert !isa(spsig, UnionAll) || !isempty(spvals)
for i = 1:length(e.args)
if i == 2
Expand All @@ -1671,7 +1762,7 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
isa(val, Union{SSAValue, NewSSAValue}) && return val # avoid infinite loop
urs = userefs(val)
for op in urs
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck)
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck, compact, idx)
end
return urs[]
end
Loading

0 comments on commit 19f44b6

Please sign in to comment.