From 0b14b3e49cfe1e5098bd088f445c1bc52021c74a Mon Sep 17 00:00:00 2001 From: Martin Holters Date: Mon, 30 Aug 2021 22:44:44 +0200 Subject: [PATCH] Fix a precision issue in `abstract_iteration` (#41839) If the first loop exits in the first iteration, the `statetype` is still `Bottom`. In that case, the new `stateordonet` needs to be determined with the two-arg version of `iterate` again. Explicitly test that inference produces a sound (and reasonably precise) result when splatting an iterator (in this case a long range) that allows constant-propagation up to the `MAX_TUPLE_SPLAT` limit. Fixes #41022 Co-authored-by: Jameson Nash --- base/compiler/abstractinterpretation.jl | 35 ++++++++++++++++++------- test/compiler/inference.jl | 17 +++++++++++- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index ac007a5aa77bfe..4c3f7d46fd98b7 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -844,9 +844,11 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n return ret, AbstractIterationInfo(calls) end if Nothing <: stateordonet_widened || length(ret) >= InferenceParams(interp).MAX_TUPLE_SPLAT + stateordonet = stateordonet_widened break end if !isa(stateordonet_widened, DataType) || !(stateordonet_widened <: Tuple) || isvatuple(stateordonet_widened) || length(stateordonet_widened.parameters) != 2 + stateordonet = stateordonet_widened break end nstatetype = getfield_tfunc(stateordonet, Const(2)) @@ -864,27 +866,40 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n end # From here on, we start asking for results on the widened types, rather than # the precise (potentially const) state type - statetype = widenconst(statetype) - valtype = widenconst(valtype) + # statetype and valtype are reinitialized in the first iteration below from the + # (widened) stateordonet, which has not yet been fully analyzed in the loop above + statetype = Bottom + valtype = Bottom + may_have_terminated = Nothing <: stateordonet while valtype !== Any - stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv).rt - stateordonet = widenconst(stateordonet) - nounion = typesubtract(stateordonet, Nothing, 0) - if !isa(nounion, DataType) || !(nounion <: Tuple) || isvatuple(nounion) || length(nounion.parameters) != 2 + nounion = typeintersect(stateordonet, Tuple{Any,Any}) + if nounion !== Union{} && !isa(nounion, DataType) + # nounion is of a type we cannot handle valtype = Any break end - if nounion.parameters[1] <: valtype && nounion.parameters[2] <: statetype + if nounion === Union{} || (nounion.parameters[1] <: valtype && nounion.parameters[2] <: statetype) + # reached a fixpoint or iterator failed/gave invalid answer if typeintersect(stateordonet, Nothing) === Union{} - # Reached a fixpoint, but Nothing is not possible => iterator is infinite or failing - return Any[Bottom], nothing + # ... but cannot terminate + if !may_have_terminated + # ... and cannot have terminated prior to this loop + return Any[Bottom], nothing + else + # iterator may have terminated prior to this loop, but not during it + valtype = Bottom + end end break end valtype = tmerge(valtype, nounion.parameters[1]) statetype = tmerge(statetype, nounion.parameters[2]) + stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv).rt + stateordonet = widenconst(stateordonet) + end + if valtype !== Union{} + push!(ret, Vararg{valtype}) end - push!(ret, Vararg{valtype}) return ret, nothing end diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 4890329f1a3751..c78cd52297581b 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -2877,9 +2877,24 @@ partial_return_2(x) = Val{partial_return_1(x)[2]} @test Base.return_types(partial_return_2, (Int,)) == Any[Type{Val{1}}] -# Precision of abstract_iteration +# Soundness and precision of abstract_iteration +f41839() = (1:100...,) +@test NTuple{100,Int} <: only(Base.return_types(f41839, ())) <: Tuple{Vararg{Int}} f_splat(x) = (x...,) @test Base.return_types(f_splat, (Pair{Int,Int},)) == Any[Tuple{Int, Int}] +@test Base.return_types(f_splat, (UnitRange{Int},)) == Any[Tuple{Vararg{Int}}] +struct Itr41839_1 end # empty or infinite +Base.iterate(::Itr41839_1) = rand(Bool) ? (nothing, nothing) : nothing +Base.iterate(::Itr41839_1, ::Nothing) = (nothing, nothing) +@test Base.return_types(f_splat, (Itr41839_1,)) == Any[Tuple{}] +struct Itr41839_2 end # empty or failing +Base.iterate(::Itr41839_2) = rand(Bool) ? (nothing, nothing) : nothing +Base.iterate(::Itr41839_2, ::Nothing) = error() +@test Base.return_types(f_splat, (Itr41839_2,)) == Any[Tuple{}] +struct Itr41839_3 end +Base.iterate(::Itr41839_3 ) = rand(Bool) ? nothing : (nothing, 1) +Base.iterate(::Itr41839_3 , i) = i < 16 ? (i, i + 1) : nothing +@test only(Base.return_types(f_splat, (Itr41839_3,))) <: Tuple{Vararg{Union{Nothing, Int}}} # issue #32699 f32699(a) = (id = a[1],).id