Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: fixes cache lookup with extended lattice elements #53953

Merged
merged 2 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1239,7 +1239,7 @@ const_prop_result(inf_result::InferenceResult) =
ConstCallResults(inf_result.result, inf_result.exc_result, ConstPropResult(inf_result),
inf_result.ipo_effects, inf_result.linfo)

# return cached constant analysis result
# return cached result of constant analysis
return_cached_result(::AbstractInterpreter, inf_result::InferenceResult, ::AbsIntState) =
const_prop_result(inf_result)

Expand All @@ -1248,7 +1248,16 @@ function const_prop_call(interp::AbstractInterpreter,
concrete_eval_result::Union{Nothing, ConstCallResults}=nothing)
inf_cache = get_inference_cache(interp)
𝕃ᵢ = typeinf_lattice(interp)
inf_result = cache_lookup(𝕃ᵢ, mi, arginfo.argtypes, inf_cache)
argtypes = has_conditional(𝕃ᵢ, sv) ? ConditionalArgtypes(arginfo, sv) : SimpleArgtypes(arginfo.argtypes)
# use `cache_argtypes` that has been constructed for fresh regular inference if available
volatile_inf_result = result.volatile_inf_result
if volatile_inf_result !== nothing
cache_argtypes = volatile_inf_result.inf_result.argtypes
else
cache_argtypes = matching_cache_argtypes(𝕃ᵢ, mi)
end
argtypes = matching_cache_argtypes(𝕃ᵢ, mi, argtypes, cache_argtypes)
inf_result = cache_lookup(𝕃ᵢ, mi, argtypes, inf_cache)
if inf_result !== nothing
# found the cache for this constant prop'
if inf_result.result === nothing
Expand All @@ -1258,13 +1267,18 @@ function const_prop_call(interp::AbstractInterpreter,
@assert inf_result.linfo === mi "MethodInstance for cached inference result does not match"
return return_cached_result(interp, inf_result, sv)
end
# perform fresh constant prop'
argtypes = has_conditional(𝕃ᵢ, sv) ? ConditionalArgtypes(arginfo, sv) : SimpleArgtypes(arginfo.argtypes)
inf_result = InferenceResult(mi, argtypes, typeinf_lattice(interp))
if !any(inf_result.overridden_by_const)
overridden_by_const = falses(length(argtypes))
for i = 1:length(argtypes)
if argtypes[i] !== cache_argtypes[i]
overridden_by_const[i] = true
end
end
if !any(overridden_by_const)
add_remark!(interp, sv, "[constprop] Could not handle constant info in matching_cache_argtypes")
return nothing
end
# perform fresh constant prop'
inf_result = InferenceResult(mi, argtypes, overridden_by_const)
frame = InferenceState(inf_result, #=cache_mode=#:local, interp)
if frame === nothing
add_remark!(interp, sv, "[constprop] Could not retrieve the source")
Expand All @@ -1286,26 +1300,19 @@ end

# TODO implement MustAlias forwarding

struct ConditionalArgtypes <: ForwardableArgtypes
struct ConditionalArgtypes
arginfo::ArgInfo
sv::InferenceState
end

"""
matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
conditional_argtypes::ConditionalArgtypes)

The implementation is able to forward `Conditional` of `conditional_argtypes`,
as well as the other general extended lattice information.
"""
function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
conditional_argtypes::ConditionalArgtypes)
conditional_argtypes::ConditionalArgtypes,
cache_argtypes::Vector{Any})
(; arginfo, sv) = conditional_argtypes
(; fargs, argtypes) = arginfo
given_argtypes = Vector{Any}(undef, length(argtypes))
def = mi.def::Method
nargs = Int(def.nargs)
cache_argtypes, overridden_by_const = matching_cache_argtypes(𝕃, mi)
local condargs = nothing
for i in 1:length(argtypes)
argtype = argtypes[i]
Expand Down Expand Up @@ -1348,7 +1355,7 @@ function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
else
given_argtypes = va_process_argtypes(𝕃, given_argtypes, mi)
end
return pick_const_args!(𝕃, cache_argtypes, overridden_by_const, given_argtypes)
return pick_const_args!(𝕃, given_argtypes, cache_argtypes)
end

# This is only for use with `Conditional`.
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,14 @@ include("compiler/ssair/ir.jl")
include("compiler/ssair/tarjan.jl")

include("compiler/abstractlattice.jl")
include("compiler/stmtinfo.jl")
include("compiler/inferenceresult.jl")
include("compiler/inferencestate.jl")

include("compiler/typeutils.jl")
include("compiler/typelimits.jl")
include("compiler/typelattice.jl")
include("compiler/tfuncs.jl")
include("compiler/stmtinfo.jl")

include("compiler/abstractinterpretation.jl")
include("compiler/typeinfer.jl")
Expand Down
110 changes: 33 additions & 77 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
@@ -1,63 +1,30 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

"""
matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance) ->
(cache_argtypes::Vector{Any}, overridden_by_const::BitVector)

Returns argument types `cache_argtypes::Vector{Any}` for `mi` that are in the native
Julia type domain. `overridden_by_const::BitVector` is all `false` meaning that
there is no additional extended lattice information there.

matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance, argtypes::ForwardableArgtypes) ->
(cache_argtypes::Vector{Any}, overridden_by_const::BitVector)

Returns cache-correct extended lattice argument types `cache_argtypes::Vector{Any}`
for `mi` given some `argtypes` accompanied by `overridden_by_const::BitVector`
that marks which argument contains additional extended lattice information.

In theory, there could be a `cache` containing a matching `InferenceResult`
for the provided `mi` and `given_argtypes`. The purpose of this function is
to return a valid value for `cache_lookup(𝕃, mi, argtypes, cache).argtypes`,
so that we can construct cache-correct `InferenceResult`s in the first place.
"""
function matching_cache_argtypes end

function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance)
method = isa(mi.def, Method) ? mi.def::Method : nothing
cache_argtypes = most_general_argtypes(method, mi.specTypes)
overridden_by_const = falses(length(cache_argtypes))
return cache_argtypes, overridden_by_const
(; def, specTypes) = mi
return most_general_argtypes(isa(def, Method) ? def : nothing, specTypes)
end

struct SimpleArgtypes <: ForwardableArgtypes
struct SimpleArgtypes
argtypes::Vector{Any}
end

"""
matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance, argtypes::SimpleArgtypes)

The implementation for `argtypes` with general extended lattice information.
This is supposed to be used for debugging and testing or external `AbstractInterpreter`
usages and in general `matching_cache_argtypes(::MethodInstance, ::ConditionalArgtypes)`
is more preferred it can forward `Conditional` information.
"""
function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance, simple_argtypes::SimpleArgtypes)
function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
simple_argtypes::SimpleArgtypes,
cache_argtypes::Vector{Any})
(; argtypes) = simple_argtypes
given_argtypes = Vector{Any}(undef, length(argtypes))
for i = 1:length(argtypes)
given_argtypes[i] = widenslotwrapper(argtypes[i])
end
given_argtypes = va_process_argtypes(𝕃, given_argtypes, mi)
return pick_const_args(𝕃, mi, given_argtypes)
return pick_const_args!(𝕃, given_argtypes, cache_argtypes)
end

function pick_const_args(𝕃::AbstractLattice, mi::MethodInstance, given_argtypes::Vector{Any})
cache_argtypes, overridden_by_const = matching_cache_argtypes(𝕃, mi)
return pick_const_args!(𝕃, cache_argtypes, overridden_by_const, given_argtypes)
end

function pick_const_args!(𝕃::AbstractLattice, cache_argtypes::Vector{Any}, overridden_by_const::BitVector, given_argtypes::Vector{Any})
for i = 1:length(given_argtypes)
function pick_const_args!(𝕃::AbstractLattice, given_argtypes::Vector{Any}, cache_argtypes::Vector{Any})
nargtypes = length(given_argtypes)
@assert nargtypes == length(cache_argtypes) #= == nargs =# "invalid `given_argtypes` for `mi`"
for i = 1:nargtypes
given_argtype = given_argtypes[i]
cache_argtype = cache_argtypes[i]
if !is_argtype_match(𝕃, given_argtype, cache_argtype, false)
Expand All @@ -66,13 +33,13 @@ function pick_const_args!(𝕃::AbstractLattice, cache_argtypes::Vector{Any}, ov
!⊏(𝕃, given_argtype, cache_argtype))
# if the type information of this `PartialStruct` is less strict than
# declared method signature, narrow it down using `tmeet`
given_argtype = tmeet(𝕃, given_argtype, cache_argtype)
given_argtypes[i] = tmeet(𝕃, given_argtype, cache_argtype)
end
cache_argtypes[i] = given_argtype
overridden_by_const[i] = true
else
given_argtypes[i] = cache_argtype
end
end
return cache_argtypes, overridden_by_const
return given_argtypes
end

function is_argtype_match(𝕃::AbstractLattice,
Expand All @@ -89,9 +56,9 @@ end
va_process_argtypes(𝕃::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance) =
va_process_argtypes(Returns(nothing), 𝕃, given_argtypes, mi)
function va_process_argtypes(@specialize(va_handler!), 𝕃::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance)
def = mi.def
isva = isa(def, Method) ? def.isva : false
nargs = isa(def, Method) ? Int(def.nargs) : length(mi.specTypes.parameters)
def = mi.def::Method
isva = def.isva
nargs = Int(def.nargs)
if isva || isvarargtype(given_argtypes[end])
isva_given_argtypes = Vector{Any}(undef, nargs)
for i = 1:(nargs-isva)
Expand All @@ -112,14 +79,11 @@ function va_process_argtypes(@specialize(va_handler!), 𝕃::AbstractLattice, gi
return given_argtypes
end

function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(specTypes),
withfirst::Bool = true)
function most_general_argtypes(method::Union{Method,Nothing}, @nospecialize(specTypes))
toplevel = method === nothing
isva = !toplevel && method.isva
mi_argtypes = Any[(unwrap_unionall(specTypes)::DataType).parameters...]
nargs::Int = toplevel ? 0 : method.nargs
# For opaque closure, the closure environment is processed elsewhere
withfirst || (nargs -= 1)
cache_argtypes = Vector{Any}(undef, nargs)
# First, if we're dealing with a varargs method, then we set the last element of `args`
# to the appropriate `Tuple` type or `PartialStruct` instance.
Expand Down Expand Up @@ -162,17 +126,16 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
cache_argtypes[nargs] = vargtype
nargs -= 1
end
# Now, we propagate type info from `linfo_argtypes` into `cache_argtypes`, improving some
# Now, we propagate type info from `mi_argtypes` into `cache_argtypes`, improving some
# type info as we go (where possible). Note that if we're dealing with a varargs method,
# we already handled the last element of `cache_argtypes` (and decremented `nargs` so that
# we don't overwrite the result of that work here).
if mi_argtypes_length > 0
n = mi_argtypes_length > nargs ? nargs : mi_argtypes_length
tail_index = n
tail_index = nargtypes = min(mi_argtypes_length, nargs)
local lastatype
for i = 1:n
for i = 1:nargtypes
atyp = mi_argtypes[i]
if i == n && isvarargtype(atyp)
if i == nargtypes && isvarargtype(atyp)
atyp = unwrapva(atyp)
tail_index -= 1
end
Expand All @@ -185,16 +148,16 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
else
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
end
i == n && (lastatype = atyp)
i == nargtypes && (lastatype = atyp)
cache_argtypes[i] = atyp
end
for i = (tail_index + 1):nargs
for i = (tail_index+1):nargs
cache_argtypes[i] = lastatype
end
else
@assert nargs == 0 "invalid specialization of method" # wrong number of arguments
end
cache_argtypes
return cache_argtypes
end

# eliminate free `TypeVar`s in order to make the life much easier down the road:
Expand All @@ -213,22 +176,15 @@ end
function cache_lookup(𝕃::AbstractLattice, mi::MethodInstance, given_argtypes::Vector{Any},
cache::Vector{InferenceResult})
method = mi.def::Method
nargs = Int(method.nargs)
method.isva && (nargs -= 1)
length(given_argtypes) ≥ nargs || return nothing
nargtypes = length(given_argtypes)
@assert nargtypes == Int(method.nargs) "invalid `given_argtypes` for `mi`"
for cached_result in cache
cached_result.linfo === mi || continue
cached_result.linfo === mi || @goto next_cache
cache_argtypes = cached_result.argtypes
cache_overridden_by_const = cached_result.overridden_by_const
for i in 1:nargs
if !is_argtype_match(𝕃, widenmustalias(given_argtypes[i]),
cache_argtypes[i], cache_overridden_by_const[i])
@goto next_cache
end
end
if method.isva
if !is_argtype_match(𝕃, tuple_tfunc(𝕃, given_argtypes[(nargs + 1):end]),
cache_argtypes[end], cache_overridden_by_const[end])
@assert length(cache_argtypes) == nargtypes "invalid `cache_argtypes` for `mi`"
cache_overridden_by_const = cached_result.overridden_by_const::BitVector
for i in 1:nargtypes
if !is_argtype_match(𝕃, given_argtypes[i], cache_argtypes[i], cache_overridden_by_const[i])
@goto next_cache
end
end
Expand Down
5 changes: 4 additions & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,10 @@ end
frame_parent(sv::InferenceState) = sv.parent::Union{Nothing,AbsIntState}
frame_parent(sv::IRInterpretationState) = sv.parent::Union{Nothing,AbsIntState}

is_constproped(sv::InferenceState) = any(sv.result.overridden_by_const)
function is_constproped(sv::InferenceState)
(;overridden_by_const) = sv.result
return overridden_by_const !== nothing
end
is_constproped(::IRInterpretationState) = true

is_cached(sv::InferenceState) = !iszero(sv.cache_mode & CACHE_MODE_GLOBAL)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ the original `ci::CodeInfo` are modified.
"""
function inflate_ir!(ci::CodeInfo, mi::MethodInstance)
sptypes = sptypes_from_meth_instance(mi)
argtypes, _ = matching_cache_argtypes(fallback_lattice, mi)
argtypes = matching_cache_argtypes(fallback_lattice, mi)
return inflate_ir!(ci, sptypes, argtypes)
end
function inflate_ir!(ci::CodeInfo, sptypes::Vector{VarState}, argtypes::Vector{Any})
Expand Down
8 changes: 3 additions & 5 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ struct EdgeCallResult
end
end

# return cached regular inference result
# return cached result of regular inference
function return_cached_result(::AbstractInterpreter, codeinst::CodeInstance, caller::AbsIntState)
rt = cached_return_type(codeinst)
effects = ipo_effects(codeinst)
Expand Down Expand Up @@ -869,10 +869,8 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
effects = isinferred ? frame.result.ipo_effects : adjust_effects(Effects(), method) # effects are adjusted already within `finish` for ipo_effects
exc_bestguess = refine_exception_type(frame.exc_bestguess, effects)
# propagate newly inferred source to the inliner, allowing efficient inlining w/o deserialization:
# note that this result is cached globally exclusively, we can use this local result destructively
volatile_inf_result = (isinferred && (force_inline ||
src_inlining_policy(interp, result.src, NoCallInfo(), IR_FLAG_NULL))) ?
VolatileInferenceResult(result) : nothing
# note that this result is cached globally exclusively, so we can use this local result destructively
volatile_inf_result = isinferred ? VolatileInferenceResult(result) : nothing
return EdgeCallResult(frame.bestguess, exc_bestguess, edge, effects, volatile_inf_result)
elseif frame === true
# unresolvable cycle
Expand Down
Loading