Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inlining: bail out unless match.spec_types <: match.method.sig #53720

Merged
merged 1 commit into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,11 +542,11 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::
@assert nparams == fieldcount(mtype)
if !(i == ncases && fully_covered)
for i = 1:nparams
a, m = fieldtype(atype, i), fieldtype(mtype, i)
aft, mft = fieldtype(atype, i), fieldtype(mtype, i)
# If this is always true, we don't need to check for it
a <: m && continue
aft <: mft && continue
# Generate isa check
isa_expr = Expr(:call, isa, argexprs[i], m)
isa_expr = Expr(:call, isa, argexprs[i], mft)
ssa = insert_node_here!(compact, NewInstruction(isa_expr, Bool, line))
if cond === true
cond = ssa
Expand All @@ -565,10 +565,10 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::
for i = 1:nparams
argex = argexprs[i]
(isa(argex, SSAValue) || isa(argex, Argument)) || continue
a, m = fieldtype(atype, i), fieldtype(mtype, i)
if !(a <: m)
aft, mft = fieldtype(atype, i), fieldtype(mtype, i)
if !(aft <: mft)
argexprs′[i] = insert_node_here!(compact,
NewInstruction(PiNode(argex, m), m, line))
NewInstruction(PiNode(argex, mft), mft, line))
end
end
end
Expand Down Expand Up @@ -944,7 +944,8 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
if !match.fully_covers
# type-intersection was not able to give us a simple list of types, so
# ir_inline_unionsplit won't be able to deal with inlining this
if !(spec_types isa DataType && length(spec_types.parameters) == length(argtypes) && !isvarargtype(spec_types.parameters[end]))
if !(spec_types isa DataType && length(spec_types.parameters) == npassedargs &&
!isvarargtype(spec_types.parameters[end]))
return nothing
end
end
Expand Down Expand Up @@ -1355,16 +1356,18 @@ function compute_inlining_cases(@nospecialize(info::CallInfo), flag::UInt32, sig
joint_effects = merge_effects(joint_effects, info_effects(result, match, state))
split_fully_covered |= match.fully_covers
if !validate_sparams(match.sparams)
if !match.fully_covers
handled_all_cases = false
continue
end
if revisit_idx === nothing
revisit_idx = (i, j, all_result_count)
if match.fully_covers
if revisit_idx === nothing
revisit_idx = (i, j, all_result_count)
else
handled_all_cases = false
revisit_idx = nothing
end
else
handled_all_cases = false
revisit_idx = nothing
end
elseif !(match.spec_types <: match.method.sig) # the requirement for correct union-split
handled_all_cases = false
else
handled_all_cases &= handle_any_const_result!(cases,
result, match, argtypes, info, flag, state; allow_abstract=true, allow_typevars=false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
result, match, argtypes, info, flag, state; allow_abstract=true, allow_typevars=false)
result, match, argtypes, info, flag, state; allow_abstract=handled_all_cases, allow_typevars=false)

This can only permit later abstract calls if we proved that all previous calls were disjoint from this. One conservative way to express that requirement is to only allow isdispatchtuple here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this filtering enough for this?

elseif !isempty(cases)
# if we've not seen all candidates, union split is valid only for dispatch tuples
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, if nunion != 1, then the list of matches cannot be assumed to be sorted, so we can only allow the abstract calls that we can prove are disjoint from the other union elements, since we do not attempt to merge-sort them to combine them (the merge sort might fail) but we simply concat the lists.

Suggested change
result, match, argtypes, info, flag, state; allow_abstract=true, allow_typevars=false)
result, match, argtypes, info, flag, state; allow_abstract=handled_all_cases && nunion == 1, allow_typevars=false)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, okay, I wasn't expecting we would over-estimate this then filter it later. That late filter is actually a bit more strict than necessary then, since it doesn't consider the ordering. But that does at least seem valid.

Expand Down Expand Up @@ -1399,14 +1402,15 @@ function handle_call!(todo::Vector{Pair{Int,Any}},
cases = compute_inlining_cases(info, flag, sig, state)
cases === nothing && return nothing
cases, all_covered, joint_effects = cases
handle_cases!(todo, ir, idx, stmt, argtypes_to_type(sig.argtypes), cases,
all_covered, joint_effects)
atype = argtypes_to_type(sig.argtypes)
handle_cases!(todo, ir, idx, stmt, atype, cases, all_covered, joint_effects)
end

function handle_match!(cases::Vector{InliningCase},
match::MethodMatch, argtypes::Vector{Any}, @nospecialize(info::CallInfo), flag::UInt32,
state::InliningState;
allow_abstract::Bool, allow_typevars::Bool, volatile_inf_result::Union{Nothing,VolatileInferenceResult})
allow_abstract::Bool, allow_typevars::Bool,
volatile_inf_result::Union{Nothing,VolatileInferenceResult})
spec_types = match.spec_types
allow_abstract || isdispatchtuple(spec_types) || return false
# We may see duplicated dispatch signatures here when a signature gets widened
Expand Down Expand Up @@ -1493,19 +1497,19 @@ function concrete_result_item(result::ConcreteResult, @nospecialize(info::CallIn
end

function handle_cases!(todo::Vector{Pair{Int,Any}}, ir::IRCode, idx::Int, stmt::Expr,
@nospecialize(atype), cases::Vector{InliningCase}, fully_covered::Bool,
@nospecialize(atype), cases::Vector{InliningCase}, all_covered::Bool,
joint_effects::Effects)
# If we only have one case and that case is fully covered, we may either
# be able to do the inlining now (for constant cases), or push it directly
# onto the todo list
if fully_covered && length(cases) == 1
if all_covered && length(cases) == 1
handle_single_case!(todo, ir, idx, stmt, cases[1].item)
elseif length(cases) > 0
isa(atype, DataType) || return nothing
for case in cases
isa(case.sig, DataType) || return nothing
end
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
push!(todo, idx=>UnionSplit(all_covered, atype, cases))
else
add_flag!(ir[SSAValue(idx)], flags_for_effects(joint_effects))
end
Expand Down
21 changes: 21 additions & 0 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2159,3 +2159,24 @@ end
@test !Core.Compiler.is_nothrow(Base.infer_effects(issue53062, (Bool,)))
@test issue53062(false) == -1
@test_throws MethodError issue53062(true)

struct Issue52644
tuple::Type{<:Tuple}
end
issue52644(::DataType) = :DataType
issue52644(::UnionAll) = :UnionAll
let ir = Base.code_ircode((Issue52644,); optimize_until="Inlining") do t
issue52644(t.tuple)
end |> only |> first
irfunc = Core.OpaqueClosure(ir)
@test irfunc(Issue52644(Tuple{})) === :DataType
@test irfunc(Issue52644(Tuple{<:Integer})) === :UnionAll
end
issue52644_single(x::DataType) = :DataType
let ir = Base.code_ircode((Issue52644,); optimize_until="Inlining") do t
issue52644_single(t.tuple)
end |> only |> first
irfunc = Core.OpaqueClosure(ir)
@test irfunc(Issue52644(Tuple{})) === :DataType
@test_throws MethodError irfunc(Issue52644(Tuple{<:Integer}))
end