Skip to content

Commit

Permalink
inference: don't put globally-cached results into inference local cac…
Browse files Browse the repository at this point in the history
…he (#51888)

Now, we fill the inference's local cache exclusively with locally-cached
results. Given that sources of globally-cached results will be populated
from the global cache when needed (for inlining), there's no need for
them to waste memory in the local cache.

@nanosoldier `runbenchmarks("inference", vs=":master")`
  • Loading branch information
aviatesk authored and pull[bot] committed Nov 2, 2023
1 parent 51f95ff commit f4dc5c1
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 44 deletions.
2 changes: 1 addition & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,7 @@ function const_prop_call(interp::AbstractInterpreter,
add_remark!(interp, sv, "[constprop] Could not handle constant info in matching_cache_argtypes")
return nothing
end
frame = InferenceState(inf_result, #=cache=#:local, interp)
frame = InferenceState(inf_result, #=cache_mode=#:local, interp)
if frame === nothing
add_remark!(interp, sv, "[constprop] Could not retrieve the source")
return nothing # this is probably a bad generated function (unsound), but just ignore it
Expand Down
15 changes: 7 additions & 8 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,15 +240,15 @@ mutable struct InferenceState
# Whether to restrict inference of abstract call sites to avoid excessive work
# Set by default for toplevel frame.
restrict_abstract_call_sites::Bool
cached::Bool # TODO move this to InferenceResult?
cache_mode::Symbol # TODO move this to InferenceResult?
insert_coverage::Bool

# The interpreter that created this inference state. Not looked at by
# NativeInterpreter. But other interpreters may use this to detect cycles
interp::AbstractInterpreter

# src is assumed to be a newly-allocated CodeInfo, that can be modified in-place to contain intermediate results
function InferenceState(result::InferenceResult, src::CodeInfo, cache::Symbol,
function InferenceState(result::InferenceResult, src::CodeInfo, cache_mode::Symbol,
interp::AbstractInterpreter)
linfo = result.linfo
world = get_world_counter(interp)
Expand Down Expand Up @@ -303,19 +303,18 @@ mutable struct InferenceState
end

restrict_abstract_call_sites = isa(def, Module)
@assert cache === :no || cache === :local || cache === :global
cached = cache === :global
@assert cache_mode === :no || cache_mode === :local || cache_mode === :global

# some more setups
InferenceParams(interp).unoptimize_throw_blocks && mark_throw_blocks!(src, handler_at)
cache !== :no && push!(get_inference_cache(interp), result)
cache_mode === :local && push!(get_inference_cache(interp), result)

return new(
linfo, world, mod, sptypes, slottypes, src, cfg, method_info,
currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
result, unreachable, valid_worlds, bestguess, ipo_effects,
restrict_abstract_call_sites, cached, insert_coverage,
restrict_abstract_call_sites, cache_mode, insert_coverage,
interp)
end
end
Expand Down Expand Up @@ -667,7 +666,7 @@ end
function print_callstack(sv::InferenceState)
while sv !== nothing
print(sv.linfo)
!sv.cached && print(" [uncached]")
sv.cache_mode === :global || print(" [uncached]")
println()
for cycle in sv.callers_in_cycle
print(' ', cycle.linfo)
Expand Down Expand Up @@ -765,7 +764,7 @@ frame_parent(sv::IRInterpretationState) = sv.parent::Union{Nothing,AbsIntState}
is_constproped(sv::InferenceState) = any(sv.result.overridden_by_const)
is_constproped(::IRInterpretationState) = true

is_cached(sv::InferenceState) = sv.cached
is_cached(sv::InferenceState) = sv.cache_mode === :global
is_cached(::IRInterpretationState) = false

method_info(sv::InferenceState) = sv.method_info
Expand Down
58 changes: 30 additions & 28 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ const track_newly_inferred = RefValue{Bool}(false)
const newly_inferred = CodeInstance[]

# build (and start inferring) the inference frame for the top-level MethodInstance
function typeinf(interp::AbstractInterpreter, result::InferenceResult, cache::Symbol)
frame = InferenceState(result, cache, interp)
function typeinf(interp::AbstractInterpreter, result::InferenceResult, cache_mode::Symbol)
frame = InferenceState(result, cache_mode, interp)
frame === nothing && return false
cache === :global && lock_mi_inference(interp, result.linfo)
cache_mode === :global && lock_mi_inference(interp, result.linfo)
return typeinf(interp, frame)
end

Expand Down Expand Up @@ -239,7 +239,7 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
# but that is what !must_be_codeinf permits
# This is hopefully unreachable when must_be_codeinf is true
end
return
return nothing
end

function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
Expand Down Expand Up @@ -267,11 +267,14 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
end
for caller in frames
finish!(caller.interp, caller)
if caller.cached
if caller.cache_mode === :global
cache_result!(caller.interp, caller.result)
end
# n.b. We do not drop result.src here, even though that wastes memory while it is still in the local cache
# since the user might have requested call-site inlining of it.
# Drop result.src here since otherwise it can waste memory.
# N.B. If the `cache_mode === :local`, the inliner may request to use it later.
if caller.cache_mode !== :local
caller.result.src = nothing
end
end
empty!(frames)
return true
Expand Down Expand Up @@ -381,23 +384,23 @@ function cache_result!(interp::AbstractInterpreter, result::InferenceResult)
end
# check if the existing linfo metadata is also sufficient to describe the current inference result
# to decide if it is worth caching this
linfo = result.linfo
already_inferred = already_inferred_quick_test(interp, linfo)
if !already_inferred && haskey(WorldView(code_cache(interp), valid_worlds), linfo)
mi = result.linfo
already_inferred = already_inferred_quick_test(interp, mi)
if !already_inferred && haskey(WorldView(code_cache(interp), valid_worlds), mi)
already_inferred = true
end

# TODO: also don't store inferred code if we've previously decided to interpret this function
if !already_inferred
code_cache(interp)[linfo] = ci = CodeInstance(interp, result, valid_worlds)
code_cache(interp)[mi] = ci = CodeInstance(interp, result, valid_worlds)
if track_newly_inferred[]
m = linfo.def
m = mi.def
if isa(m, Method) && m.module != Core
ccall(:jl_push_newly_inferred, Cvoid, (Any,), ci)
end
end
end
unlock_mi_inference(interp, linfo)
unlock_mi_inference(interp, mi)
nothing
end

Expand Down Expand Up @@ -543,7 +546,7 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
# a parent may be cached still, but not this intermediate work:
# we can throw everything else away now
me.result.src = nothing
me.cached = false
me.cache_mode = :no
set_inlineable!(me.src, false)
unlock_mi_inference(interp, me.linfo)
elseif limited_src
Expand All @@ -555,16 +558,14 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
# annotate fulltree with type information,
# either because we are the outermost code, or we might use this later
type_annotate!(interp, me)
doopt = (me.cached || me.parent !== nothing)
doopt = (me.cache_mode !== :no || me.parent !== nothing)
# Disable the optimizer if we've already determined that there's nothing for
# it to do.
if may_discard_trees(interp) && is_result_constabi_eligible(me.result)
doopt = false
end
if doopt && may_optimize(interp)
me.result.src = OptimizationState(me, interp)
else
me.result.src = me.src::CodeInfo # stash a convenience copy of the code (e.g. for reflection)
end
end
validate_code_in_debug_mode(me.linfo, me.src, "inferred")
Expand Down Expand Up @@ -814,15 +815,15 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
# we already inferred this edge before and decided to discard the inferred code,
# nevertheless we re-infer it here again and keep it around in the local cache
# since the inliner will request to use it later
cache = :local
cache_mode = :local
else
rt = cached_return_type(code)
effects = ipo_effects(code)
update_valid_age!(caller, WorldRange(min_world(code), max_world(code)))
return EdgeCallResult(rt, mi, effects)
end
else
cache = :global # cache edge targets by default
cache_mode = :global # cache edge targets by default
end
if ccall(:jl_get_module_infer, Cint, (Any,), method.module) == 0 && !generating_output(#=incremental=#false)
add_remark!(interp, caller, "Inference is disabled for the target module")
Expand All @@ -839,7 +840,7 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
# completely new
lock_mi_inference(interp, mi)
result = InferenceResult(mi, typeinf_lattice(interp))
frame = InferenceState(result, cache, interp) # always use the cache for edge targets
frame = InferenceState(result, cache_mode, interp) # always use the cache for edge targets
if frame === nothing
add_remark!(interp, caller, "Failed to retrieve source")
# can't get the source for this, so we know nothing
Expand Down Expand Up @@ -970,7 +971,8 @@ typeinf_frame(interp::AbstractInterpreter, method::Method, @nospecialize(atype),
function typeinf_frame(interp::AbstractInterpreter, mi::MethodInstance, run_optimizer::Bool)
start_time = ccall(:jl_typeinf_timing_begin, UInt64, ())
result = InferenceResult(mi, typeinf_lattice(interp))
frame = InferenceState(result, run_optimizer ? :global : :no, interp)
cache_mode = run_optimizer ? :global : :no
frame = InferenceState(result, cache_mode, interp)
frame === nothing && return nothing
typeinf(interp, frame)
ccall(:jl_typeinf_timing_end, Cvoid, (UInt64,), start_time)
Expand Down Expand Up @@ -1010,7 +1012,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
end
lock_mi_inference(interp, mi)
result = InferenceResult(mi, typeinf_lattice(interp))
frame = InferenceState(result, #=cache=#:global, interp)
frame = InferenceState(result, #=cache_mode=#:global, interp)
frame === nothing && return nothing
typeinf(interp, frame)
ccall(:jl_typeinf_timing_end, Cvoid, (UInt64,), start_time)
Expand Down Expand Up @@ -1043,18 +1045,18 @@ end

# This is a bridge for the C code calling `jl_typeinf_func()`
typeinf_ext_toplevel(mi::MethodInstance, world::UInt) = typeinf_ext_toplevel(NativeInterpreter(world), mi)
function typeinf_ext_toplevel(interp::AbstractInterpreter, linfo::MethodInstance)
if isa(linfo.def, Method)
function typeinf_ext_toplevel(interp::AbstractInterpreter, mi::MethodInstance)
if isa(mi.def, Method)
# method lambda - infer this specialization via the method cache
src = typeinf_ext(interp, linfo)
src = typeinf_ext(interp, mi)
else
src = linfo.uninferred::CodeInfo
src = mi.uninferred::CodeInfo
if !src.inferred
# toplevel lambda - infer directly
start_time = ccall(:jl_typeinf_timing_begin, UInt64, ())
if !src.inferred
result = InferenceResult(linfo, typeinf_lattice(interp))
frame = InferenceState(result, src, #=cache=#:global, interp)
result = InferenceResult(mi, typeinf_lattice(interp))
frame = InferenceState(result, src, #=cache_mode=#:global, interp)
typeinf(interp, frame)
@assert is_inferred(frame) # TODO: deal with this better
src = frame.src
Expand Down
4 changes: 2 additions & 2 deletions stdlib/REPL/src/REPLCompletions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,10 @@ CC.bail_out_toplevel_call(::REPLInterpreter, ::CC.InferenceLoopState, ::CC.Infer
# `REPLInterpreter` is specifically used by `repl_eval_ex`, where all top-level frames are
# `repl_frame` always. However, this assumption wouldn't stand if `REPLInterpreter` were to
# be employed, for instance, by `typeinf_ext_toplevel`.
is_repl_frame(sv::CC.InferenceState) = sv.linfo.def isa Module && !sv.cached
is_repl_frame(sv::CC.InferenceState) = sv.linfo.def isa Module && sv.cache_mode === :no

function is_call_graph_uncached(sv::CC.InferenceState)
sv.cached && return false
sv.cache_mode === :global && return false
parent = sv.parent
parent === nothing && return true
return is_call_graph_uncached(parent::CC.InferenceState)
Expand Down
6 changes: 1 addition & 5 deletions test/compiler/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -437,18 +437,14 @@ custom_lookup_context(x::Int) = custom_lookup_target(true, x)
const CONST_INVOKE_INTERP_WORLD = Base.get_world_counter()
const CONST_INVOKE_INTERP = ConstInvokeInterp(; world=CONST_INVOKE_INTERP_WORLD)
function custom_lookup(mi::MethodInstance, min_world::UInt, max_world::UInt)
local matched_mi = nothing
for inf_result in CONST_INVOKE_INTERP.inf_cache
if inf_result.linfo === mi
if CC.any(inf_result.overridden_by_const)
return CodeInstance(CONST_INVOKE_INTERP, inf_result, inf_result.valid_worlds)
elseif matched_mi === nothing
matched_mi = inf_result.linfo
end
end
end
matched_mi === nothing && return nothing
return CONST_INVOKE_INTERP.code_cache.dict[matched_mi]
return CONST_INVOKE_INTERP.code_cache.dict[mi]
end

let # generate cache
Expand Down

0 comments on commit f4dc5c1

Please sign in to comment.