Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: form PartialStruct for extra type information propagation #42831

Merged
merged 3 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,9 @@ function from_interconditional(@nospecialize(typ), (; fargs, argtypes)::ArgInfo,
else
elsetype = tmeet(elsetype, widenconst(new_elsetype))
end
if (slot > 0 || condval !== false) && !(old ⊑ vtype) # essentially vtype ⋤ old
if (slot > 0 || condval !== false) && vtype ⋤ old
slot = id
elseif (slot > 0 || condval !== true) && !(old ⊑ elsetype) # essentially elsetype ⋤ old
elseif (slot > 0 || condval !== true) && elsetype ⋤ old
slot = id
else # reset: no new useful information for this slot
vtype = elsetype = Any
Expand Down Expand Up @@ -1598,36 +1598,35 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
elseif ehead === :new
t = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))[1]
if isconcretetype(t) && !ismutabletype(t)
args = Vector{Any}(undef, length(e.args)-1)
ats = Vector{Any}(undef, length(e.args)-1)
anyconst = false
allconst = true
nargs = length(e.args) - 1
ats = Vector{Any}(undef, nargs)
local anyrefine = false
local allconst = true
for i = 2:length(e.args)
at = widenconditional(abstract_eval_value(interp, e.args[i], vtypes, sv))
if !anyconst
anyconst = has_nontrivial_const_info(at)
end
ats[i-1] = at
ft = fieldtype(t, i-1)
at = tmeet(at, ft)
if at === Bottom
t = Bottom
allconst = anyconst = false
break
elseif at isa Const
if !(at.val isa fieldtype(t, i - 1))
t = Bottom
allconst = anyconst = false
break
end
args[i-1] = at.val
else
@goto t_computed
elseif !isa(at, Const)
allconst = false
end
if !anyrefine
anyrefine = has_nontrivial_const_info(at) || # constant information
at ⋤ ft # just a type-level information, but more precise than the declared type
end
ats[i-1] = at
end
# For now, don't allow partially initialized Const/PartialStruct
if t !== Bottom && fieldcount(t) == length(ats)
if fieldcount(t) == nargs
if allconst
t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, args, length(args)))
elseif anyconst
argvals = Vector{Any}(undef, nargs)
for j in 1:nargs
argvals[j] = (ats[j]::Const).val
end
t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, argvals, nargs))
elseif anyrefine
t = PartialStruct(t, ats)
end
end
Expand All @@ -1638,7 +1637,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
at = abstract_eval_value(interp, e.args[2], vtypes, sv)
n = fieldcount(t)
if isa(at, Const) && isa(at.val, Tuple) && n == length(at.val::Tuple) &&
let t = t; _all(i->getfield(at.val::Tuple, i) isa fieldtype(t, i), 1:n); end
let t = t, at = at; _all(i->getfield(at.val::Tuple, i) isa fieldtype(t, i), 1:n); end
t = Const(ccall(:jl_new_structt, Any, (Any, Any), t, at.val))
elseif isa(at, PartialStruct) && at ⊑ Tuple && n == length(at.fields::Vector{Any}) &&
let t = t, at = at; _all(i->(at.fields::Vector{Any})[i] ⊑ fieldtype(t, i), 1:n); end
Expand Down Expand Up @@ -1718,6 +1717,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
else
t = abstract_eval_value_expr(interp, e, vtypes, sv)
end
@label t_computed
@assert !isa(t, TypeVar) "unhandled TypeVar"
if isa(t, DataType) && isdefined(t, :instance)
# replace singleton types with their equivalent Const object
Expand Down Expand Up @@ -1801,17 +1801,18 @@ function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nslots::Int, s
isa(rt, Type) && return rt
if isa(rt, PartialStruct)
fields = copy(rt.fields)
haveconst = false
local anyrefine = false
for i in 1:length(fields)
a = fields[i]
a = isvarargtype(a) ? a : widenreturn(a, bestguess, nslots, slottypes, changes)
if !haveconst && has_const_info(a)
if !anyrefine
# TODO: consider adding && const_prop_profitable(a) here?
haveconst = true
anyrefine = has_const_info(a) ||
a ⊏ fieldtype(rt.typ, i)
end
fields[i] = a
end
haveconst && return PartialStruct(rt.typ, fields)
anyrefine && return PartialStruct(rt.typ, fields)
end
if isa(rt, PartialOpaque)
return rt # XXX: this case was missed in #39512
Expand Down
23 changes: 22 additions & 1 deletion base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ function maybe_extract_const_bool(c::AnyConditional)
end
maybe_extract_const_bool(@nospecialize c) = nothing

function ⊑(@nospecialize(a), @nospecialize(b))
"""
a ⊑ b -> Bool

The non-strict partial order over the type inference lattice.
"""
@nospecialize(a) ⊑ @nospecialize(b) = begin
if isa(b, LimitedAccuracy)
if !isa(a, LimitedAccuracy)
return false
Expand Down Expand Up @@ -232,6 +237,22 @@ function ⊑(@nospecialize(a), @nospecialize(b))
end
end

"""
a ⊏ b -> Bool

The strict partial order over the type inference lattice.
This is defined as the irreflexive kernel of `⊑`.
"""
@nospecialize(a) ⊏ @nospecialize(b) = a ⊑ b && !⊑(b, a)

"""
a ⋤ b -> Bool

This order could be used as a slightly more efficient version of the strict order `⊏`,
where we can safely assume `a ⊑ b` holds.
"""
@nospecialize(a) ⋤ @nospecialize(b) = !⊑(b, a)

# Check if two lattice elements are partial order equivalent. This is basically
# `a ⊑ b && b ⊑ a` but with extra performance optimizations.
function is_lattice_equal(@nospecialize(a), @nospecialize(b))
Expand Down
23 changes: 23 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3669,3 +3669,26 @@ end

# issue #42646
@test only(Base.return_types(getindex, (Array{undef}, Int))) >: Union{} # check that it does not throw

# form PartialStruct for extra type information propagation
struct FieldTypeRefinement{S,T}
s::S
t::T
end
@test Base.return_types((Int,)) do s
o = FieldTypeRefinement{Any,Int}(s, s)
o.s
end |> only == Int
@test Base.return_types((Int,)) do s
o = FieldTypeRefinement{Int,Any}(s, s)
o.t
end |> only == Int
@test Base.return_types((Int,)) do s
o = FieldTypeRefinement{Any,Any}(s, s)
o.s, o.t
end |> only == Tuple{Int,Int}
@test Base.return_types((Int,)) do a
s1 = Some{Any}(a)
s2 = Some{Any}(s1)
s2.value.value
end |> only == Int
23 changes: 7 additions & 16 deletions test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,31 +426,22 @@ let # `getfield_elim_pass!` should work with constant globals
end
end

let # `typeassert_elim_pass!`
let
# `typeassert` elimination after SROA
# NOTE we can remove this optimization once inference is able to reason about memory-effects
src = @eval Module() begin
struct Foo; x; end
mutable struct Foo; x; end

code_typed((Int,)) do a
x1 = Foo(a)
x2 = Foo(x1)
x3 = Foo(x2)

r1 = (x2.x::Foo).x
r2 = (x2.x::Foo).x::Int
r3 = (x2.x::Foo).x::Integer
r4 = ((x3.x::Foo).x::Foo).x

return r1, r2, r3, r4
return typeassert(x2.x, Foo).x
end |> only |> first
end
# eliminate `typeassert(f2.a, Foo)`
@test all(src.code) do @nospecialize(stmt)
# eliminate `typeassert(x2.x, Foo)`
@test all(src.code) do @nospecialize stmt
Meta.isexpr(stmt, :call) || return true
ft = Core.Compiler.argextype(stmt.args[1], src, Any[], src.slottypes)
return Core.Compiler.widenconst(ft) !== typeof(typeassert)
end
# succeeding simple DCE will eliminate `Foo(a)`
@test all(src.code) do @nospecialize(stmt)
return !Meta.isexpr(stmt, :new)
end
end