Skip to content

Commit

Permalink
Allow external lattice elements to properly union split (#49030)
Browse files Browse the repository at this point in the history
Currently `MustAlias` is the only lattice element that is allowed
to widen to union types. However, there are others in external
packages. Expand the support we have for this in order to allow
union splitting of lattice elements.

Co-authored-by: Shuhei Kadowaki <[email protected]>
  • Loading branch information
Keno and aviatesk authored Mar 18, 2023
1 parent 4486bc4 commit 0a9abc1
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 27 deletions.
13 changes: 7 additions & 6 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# as we may want to concrete-evaluate this frame in cases when there are
# no overlayed calls, try an additional effort now to check if this call
# isn't overlayed rather than just handling it conservatively
matches = find_matching_methods(arginfo.argtypes, atype, method_table(interp),
matches = find_matching_methods(typeinf_lattice(interp), arginfo.argtypes, atype, method_table(interp),
InferenceParams(interp).max_union_splitting, max_methods)
if !isa(matches, FailedMethodMatch)
nonoverlayed = matches.nonoverlayed
Expand All @@ -75,7 +75,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end

argtypes = arginfo.argtypes
matches = find_matching_methods(argtypes, atype, method_table(interp),
matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp),
InferenceParams(interp).max_union_splitting, max_methods)
if isa(matches, FailedMethodMatch)
add_remark!(interp, sv, matches.reason)
Expand Down Expand Up @@ -273,11 +273,12 @@ struct UnionSplitMethodMatches
end
any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches)

function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView,
function find_matching_methods(𝕃::AbstractLattice,
argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView,
max_union_splitting::Int, max_methods::Int)
# NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type
if 1 < unionsplitcost(argtypes) <= max_union_splitting
split_argtypes = switchtupleunion(argtypes)
if 1 < unionsplitcost(𝕃, argtypes) <= max_union_splitting
split_argtypes = switchtupleunion(𝕃, argtypes)
infos = MethodMatchInfo[]
applicable = Any[]
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
Expand Down Expand Up @@ -1496,7 +1497,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
end
res = Union{}
nargs = length(aargtypes)
splitunions = 1 < unionsplitcost(aargtypes) <= InferenceParams(interp).max_apply_union_enum
splitunions = 1 < unionsplitcost(typeinf_lattice(interp), aargtypes) <= InferenceParams(interp).max_apply_union_enum
ctypes = [Any[aft]]
infos = Vector{MaybeAbstractIterationInfo}[MaybeAbstractIterationInfo[]]
effects = EFFECTS_TOTAL
Expand Down
4 changes: 4 additions & 0 deletions base/compiler/abstractlattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ has_mustalias(𝕃::AbstractLattice) = has_mustalias(widenlattice(𝕃))
has_mustalias(::AnyMustAliasesLattice) = true
has_mustalias(::JLTypeLattice) = false

has_extended_unionsplit(𝕃::AbstractLattice) = has_extended_unionsplit(widenlattice(𝕃))
has_extended_unionsplit(::AnyMustAliasesLattice) = true
has_extended_unionsplit(::JLTypeLattice) = false

# Curried versions
(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (lattice, a, b)
(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (lattice, a, b)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2542,7 +2542,7 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
isvarargtype(argtypes[2]) && return CallMeta(Bool, EFFECTS_UNKNOWN, NoCallInfo())
argtypes = argtypes[2:end]
atype = argtypes_to_type(argtypes)
matches = find_matching_methods(argtypes, atype, method_table(interp),
matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp),
InferenceParams(interp).max_union_splitting, max_methods)
if isa(matches, FailedMethodMatch)
rt = Bool # too many matches to analyze
Expand Down
2 changes: 2 additions & 0 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ end
MustAlias(var::SlotNumber, @nospecialize(vartyp), fldidx::Int, @nospecialize(fldtyp)) =
MustAlias(slot_id(var), vartyp, fldidx, fldtyp)

_uniontypes(x::MustAlias, ts) = _uniontypes(widenconst(x), ts)

"""
alias::InterMustAlias
Expand Down
28 changes: 15 additions & 13 deletions base/compiler/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ function typesubtract(@nospecialize(a), @nospecialize(b), max_union_splitting::I
if ub isa DataType
if a.name === ub.name === Tuple.name &&
length(a.parameters) == length(ub.parameters)
if 1 < unionsplitcost(a.parameters) <= max_union_splitting
if 1 < unionsplitcost(JLTypeLattice(), a.parameters) <= max_union_splitting
ta = switchtupleunion(a)
return typesubtract(Union{ta...}, b, 0)
elseif b isa DataType
Expand Down Expand Up @@ -227,12 +227,11 @@ end
# or outside of the Tuple/Union nesting, though somewhat more expensive to be
# outside than inside because the representation is larger (because and it
# informs the callee whether any splitting is possible).
function unionsplitcost(argtypes::Union{SimpleVector,Vector{Any}})
function unionsplitcost(𝕃::AbstractLattice, argtypes::Union{SimpleVector,Vector{Any}})
nu = 1
max = 2
for ti in argtypes
# TODO remove this to implement callsite refinement of MustAlias
if isa(ti, MustAlias) && isa(widenconst(ti), Union)
if has_extended_unionsplit(𝕃) && !isvarargtype(ti)
ti = widenconst(ti)
end
if isa(ti, Union)
Expand All @@ -252,12 +251,12 @@ end
# and `Union{return...} == ty`
function switchtupleunion(@nospecialize(ty))
tparams = (unwrap_unionall(ty)::DataType).parameters
return _switchtupleunion(Any[tparams...], length(tparams), [], ty)
return _switchtupleunion(JLTypeLattice(), Any[tparams...], length(tparams), [], ty)
end

switchtupleunion(argtypes::Vector{Any}) = _switchtupleunion(argtypes, length(argtypes), [], nothing)
switchtupleunion(𝕃::AbstractLattice, argtypes::Vector{Any}) = _switchtupleunion(𝕃, argtypes, length(argtypes), [], nothing)

function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt))
function _switchtupleunion(𝕃::AbstractLattice, t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt))
if i == 0
if origt === nothing
push!(tunion, copy(t))
Expand All @@ -268,17 +267,20 @@ function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospeci
else
origti = ti = t[i]
# TODO remove this to implement callsite refinement of MustAlias
if isa(ti, MustAlias) && isa(widenconst(ti), Union)
ti = widenconst(ti)
end
if isa(ti, Union)
for ty in uniontypes(ti::Union)
for ty in uniontypes(ti)
t[i] = ty
_switchtupleunion(𝕃, t, i - 1, tunion, origt)
end
t[i] = origti
elseif has_extended_unionsplit(𝕃) && !isa(ti, Const) && !isvarargtype(ti) && isa(widenconst(ti), Union)
for ty in uniontypes(ti)
t[i] = ty
_switchtupleunion(t, i - 1, tunion, origt)
_switchtupleunion(𝕃, t, i - 1, tunion, origt)
end
t[i] = origti
else
_switchtupleunion(t, i - 1, tunion, origt)
_switchtupleunion(𝕃, t, i - 1, tunion, origt)
end
end
return tunion
Expand Down
14 changes: 7 additions & 7 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2944,11 +2944,11 @@ end
# issue #28356
# unit test to make sure countunionsplit overflows gracefully
# we don't care what number is returned as long as it's large
@test Core.Compiler.unionsplitcost(Any[Union{Int32, Int64} for i=1:80]) > 100000
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}]) == 2
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32, Int64}, Int8]) == 8
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32}, Int8]) == 6
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32}, Union{Int8, Int16, Int32, Int64}, Int8]) == 6
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int32, Int64} for i=1:80]) > 100000
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}]) == 2
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32, Int64}, Int8]) == 8
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32}, Int8]) == 6
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32}, Union{Int8, Int16, Int32, Int64}, Int8]) == 6

# make sure compiler doesn't hang in union splitting

Expand Down Expand Up @@ -3949,13 +3949,13 @@ end

# argtypes
let
tunion = Core.Compiler.switchtupleunion(Any[Union{Int32,Int64}, Core.Const(nothing)])
tunion = Core.Compiler.switchtupleunion(Core.Compiler.ConstsLattice(), Any[Union{Int32,Int64}, Core.Const(nothing)])
@test length(tunion) == 2
@test Any[Int32, Core.Const(nothing)] in tunion
@test Any[Int64, Core.Const(nothing)] in tunion
end
let
tunion = Core.Compiler.switchtupleunion(Any[Union{Int32,Int64}, Union{Float32,Float64}, Core.Const(nothing)])
tunion = Core.Compiler.switchtupleunion(Core.Compiler.ConstsLattice(), Any[Union{Int32,Int64}, Union{Float32,Float64}, Core.Const(nothing)])
@test length(tunion) == 4
@test Any[Int32, Float32, Core.Const(nothing)] in tunion
@test Any[Int32, Float64, Core.Const(nothing)] in tunion
Expand Down

4 comments on commit 0a9abc1

@aviatesk
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nanosoldier runbenchmarks("inference", vs="@4486bc40b42b350260bd4016297dd3adf2186651")

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your benchmark job has completed - possible performance regressions were detected. A full report can be found here.

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily package evaluation, I will reply here when finished:

@nanosoldier runtests(isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your package evaluation job has completed - possible new issues were detected.
A full report can be found here.

Please sign in to comment.