From ab23760c9f17261fd0c55cd0c889a7b124dfbe07 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Thu, 28 Oct 2021 16:26:34 +0900 Subject: [PATCH] inference: form `PartialStruct` for extra type information propagation This commit forms `PartialStruct` whenever there is any type-level refinement available about a field, even if it's not "constant" information. In Julia "definitions" are allowed to be abstract whereas "usages" (i.e. callsites) are often concrete. The basic idea is to allow inference to make more use of such precise callsite type information by encoding it as `PartialStruct`. This may increase optimization possibilities of "unidiomatic" Julia code, which may contain poorly-typed definitions, like this very contrived example: ```julia struct Problem n; s; c; t end function main(args...) prob = Problem(args...) s = 0 for i in 1:prob.n m = mod(i, 3) s += m == 0 ? sin(prob.s) : m == 1 ? cos(prob.c) : tan(prob.t) end return prob, s end main(10000, 1, 2, 3) ``` One of the obvious limitation is that this extra type information can be propagated inter-procedurally only as a const-propagation. I'm not sure this kind of "just a type-level" refinement can often make constant-prop' successful (i.e. shape-up a method body and allow it to be inlined, encoding the extra type information into the generated code), thus I didn't not modify any part of const-prop' heuristics. So the improvements from this change is almost for local analysis, and for very simple inter-procedural calls. --- base/compiler/abstractinterpretation.jl | 28 +++++++++++++------------ base/compiler/typelattice.jl | 1 + test/compiler/inference.jl | 23 ++++++++++++++++++++ test/compiler/irpasses.jl | 23 +++++++------------- 4 files changed, 46 insertions(+), 29 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index b338a80cbbe74a..0f14e3cc76664b 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -192,9 +192,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), else elsetype = tmeet(elsetype, widenconst(new_elsetype)) end - if (slot > 0 || condval !== false) && !(old ⊑ vtype) # essentially vtype ⋤ old + if (slot > 0 || condval !== false) && vtype ⋤ old slot = id - elseif (slot > 0 || condval !== true) && !(old ⊑ elsetype) # essentially elsetype ⋤ old + elseif (slot > 0 || condval !== true) && elsetype ⋤ old slot = id else # reset: no new useful information for this slot vtype = elsetype = Any @@ -1542,22 +1542,23 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), if isconcretetype(t) && !ismutabletype(t) args = Vector{Any}(undef, length(e.args)-1) ats = Vector{Any}(undef, length(e.args)-1) - anyconst = false - allconst = true + local anyrefine = false + local allconst = true for i = 2:length(e.args) at = widenconditional(abstract_eval_value(interp, e.args[i], vtypes, sv)) - if !anyconst - anyconst = has_nontrivial_const_info(at) + if !anyrefine + anyrefine = has_nontrivial_const_info(at) || # constant information + at ⋤ fieldtype(t, i - 1) # just a type-level information, but more precise than the declared type end ats[i-1] = at if at === Bottom t = Bottom - allconst = anyconst = false + anyrefine = allconst = false break elseif at isa Const if !(at.val isa fieldtype(t, i - 1)) t = Bottom - allconst = anyconst = false + anyrefine = allconst = false break end args[i-1] = at.val @@ -1569,7 +1570,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), if t !== Bottom && fieldcount(t) == length(ats) if allconst t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, args, length(args))) - elseif anyconst + elseif anyrefine t = PartialStruct(t, ats) end end @@ -1741,17 +1742,18 @@ function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nslots::Int, s isa(rt, Type) && return rt if isa(rt, PartialStruct) fields = copy(rt.fields) - haveconst = false + local anyrefine = false for i in 1:length(fields) a = fields[i] a = isvarargtype(a) ? a : widenreturn(a, bestguess, nslots, slottypes, changes) - if !haveconst && has_const_info(a) + if !anyrefine # TODO: consider adding && const_prop_profitable(a) here? - haveconst = true + anyrefine = has_const_info(a) || + a ⋤ fieldtype(rt.typ, i) end fields[i] = a end - haveconst && return PartialStruct(rt.typ, fields) + anyrefine && return PartialStruct(rt.typ, fields) end if isa(rt, PartialOpaque) return rt # XXX: this case was missed in #39512 diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 7b4286b3adfdd9..eea23d2e4a5924 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -239,6 +239,7 @@ function ⊑(@nospecialize(a), @nospecialize(b)) return a === b end end +⋤(@nospecialize(a), @nospecialize(b)) = !⊑(b, a) # Check if two lattice elements are partial order equivalent. This is basically # `a ⊑ b && b ⊑ a` but with extra performance optimizations. diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 9d768d4c0d4801..bb21f65820c131 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -3645,3 +3645,26 @@ end # issue #42646 @test only(Base.return_types(getindex, (Array{undef}, Int))) >: Union{} # check that it does not throw + +# form PartialStruct for extra type information propagation +struct FieldTypeRefinement{S,T} + s::S + t::T +end +@test Base.return_types((Int,)) do s + o = FieldTypeRefinement{Any,Int}(s, s) + o.s +end |> only == Int +@test Base.return_types((Int,)) do s + o = FieldTypeRefinement{Int,Any}(s, s) + o.t +end |> only == Int +@test Base.return_types((Int,)) do s + o = FieldTypeRefinement{Any,Any}(s, s) + o.s, o.t +end |> only == Tuple{Int,Int} +@test Base.return_types((Int,)) do a + s1 = Some{Any}(a) + s2 = Some{Any}(s1) + s2.value.value +end |> only == Int diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index bbcd8f2104a278..c2f05ea4db297c 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -426,31 +426,22 @@ let # `getfield_elim_pass!` should work with constant globals end end -let # `typeassert_elim_pass!` +let + # `typeassert` elimination after SROA + # NOTE we can remove this optimization once inference is able to reason about memory-effects src = @eval Module() begin - struct Foo; x; end + mutable struct Foo; x; end code_typed((Int,)) do a x1 = Foo(a) x2 = Foo(x1) - x3 = Foo(x2) - - r1 = (x2.x::Foo).x - r2 = (x2.x::Foo).x::Int - r3 = (x2.x::Foo).x::Integer - r4 = ((x3.x::Foo).x::Foo).x - - return r1, r2, r3, r4 + return typeassert(x2.x, Foo).x end |> only |> first end - # eliminate `typeassert(f2.a, Foo)` - @test all(src.code) do @nospecialize(stmt) + # eliminate `typeassert(x2.x, Foo)` + @test all(src.code) do @nospecialize stmt Meta.isexpr(stmt, :call) || return true ft = Core.Compiler.argextype(stmt.args[1], src, Any[], src.slottypes) return Core.Compiler.widenconst(ft) !== typeof(typeassert) end - # succeeding simple DCE will eliminate `Foo(a)` - @test all(src.code) do @nospecialize(stmt) - return !Meta.isexpr(stmt, :new) - end end