From c713708ada1bc79a2649431760da3fd21ba0e7f8 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Thu, 28 Oct 2021 16:26:34 +0900 Subject: [PATCH] inference: form `PartialStruct` for extra type information propagation This commit forms `PartialStruct` whenever there is any type-level refinement available about a field, even if it's not "constant" information. In Julia "definitions" are allowed to be abstract whereas "usages" (i.e. callsites) are often concrete. The basic idea is to allow inference to make more use of such precise callsite type information by encoding it as `PartialStruct`. This may increase optimization possibilities of "unidiomatic" Julia code, which may contain poorly-typed definitions, like this very contrived example: ```julia struct Problem n; s; c; t end function main(args...) prob = Problem(args...) s = 0 for i in 1:prob.n m = mod(i, 3) s += m == 0 ? sin(prob.s) : m == 1 ? cos(prob.c) : tan(prob.t) end return prob, s end main(10000, 1, 2, 3) ``` One of the obvious limitation is that this extra type information can be propagated inter-procedurally only as a const-propagation. I'm not sure this kind of "just a type-level" refinement can often make constant-prop' successful (i.e. shape-up a method body and allow it to be inlined, encoding the extra type information into the generated code), thus I didn't not modify any part of const-prop' heuristics. So the improvements from this change is almost for local analysis, and for very simple inter-procedural calls. --- base/compiler/abstractinterpretation.jl | 30 ++++++++++++++++--------- base/compiler/typelattice.jl | 1 + test/compiler/inference.jl | 23 +++++++++++++++++++ 3 files changed, 43 insertions(+), 11 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index b338a80cbbe74a..4335ecdfb82e4f 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -1542,22 +1542,26 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), if isconcretetype(t) && !ismutabletype(t) args = Vector{Any}(undef, length(e.args)-1) ats = Vector{Any}(undef, length(e.args)-1) - anyconst = false - allconst = true + local anyconst = anyrefine = false + local allconst = true for i = 2:length(e.args) at = widenconditional(abstract_eval_value(interp, e.args[i], vtypes, sv)) if !anyconst - anyconst = has_nontrivial_const_info(at) + if has_nontrivial_const_info(at) + anyconst = true + elseif !anyrefine + anyrefine = at ⋤ fieldtype(t, i - 1) + end end ats[i-1] = at if at === Bottom t = Bottom - allconst = anyconst = false + anyconst = anyrefine = allconst = false break elseif at isa Const if !(at.val isa fieldtype(t, i - 1)) t = Bottom - allconst = anyconst = false + anyconst = anyrefine = allconst = false break end args[i-1] = at.val @@ -1569,7 +1573,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), if t !== Bottom && fieldcount(t) == length(ats) if allconst t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, args, length(args))) - elseif anyconst + elseif anyconst || anyrefine t = PartialStruct(t, ats) end end @@ -1741,17 +1745,21 @@ function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nslots::Int, s isa(rt, Type) && return rt if isa(rt, PartialStruct) fields = copy(rt.fields) - haveconst = false + anyconst = anyrefine = false for i in 1:length(fields) a = fields[i] a = isvarargtype(a) ? a : widenreturn(a, bestguess, nslots, slottypes, changes) - if !haveconst && has_const_info(a) - # TODO: consider adding && const_prop_profitable(a) here? - haveconst = true + if !anyconst + if has_const_info(a) + # TODO: consider adding && const_prop_profitable(a) here? + anyconst = true + elseif !anyrefine + anyrefine = a ⋤ fieldtype(rt.typ, i) + end end fields[i] = a end - haveconst && return PartialStruct(rt.typ, fields) + (anyconst || anyrefine) && return PartialStruct(rt.typ, fields) end if isa(rt, PartialOpaque) return rt # XXX: this case was missed in #39512 diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 7b4286b3adfdd9..eea23d2e4a5924 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -239,6 +239,7 @@ function ⊑(@nospecialize(a), @nospecialize(b)) return a === b end end +⋤(@nospecialize(a), @nospecialize(b)) = !⊑(b, a) # Check if two lattice elements are partial order equivalent. This is basically # `a ⊑ b && b ⊑ a` but with extra performance optimizations. diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 9d768d4c0d4801..bb21f65820c131 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -3645,3 +3645,26 @@ end # issue #42646 @test only(Base.return_types(getindex, (Array{undef}, Int))) >: Union{} # check that it does not throw + +# form PartialStruct for extra type information propagation +struct FieldTypeRefinement{S,T} + s::S + t::T +end +@test Base.return_types((Int,)) do s + o = FieldTypeRefinement{Any,Int}(s, s) + o.s +end |> only == Int +@test Base.return_types((Int,)) do s + o = FieldTypeRefinement{Int,Any}(s, s) + o.t +end |> only == Int +@test Base.return_types((Int,)) do s + o = FieldTypeRefinement{Any,Any}(s, s) + o.s, o.t +end |> only == Tuple{Int,Int} +@test Base.return_types((Int,)) do a + s1 = Some{Any}(a) + s2 = Some{Any}(s1) + s2.value.value +end |> only == Int