Skip to content

Commit

Permalink
propagate VolatileInferenceResult(::InferenceResult)
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Nov 2, 2023
1 parent c513ca1 commit 2019fc3
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 109 deletions.
37 changes: 13 additions & 24 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
splitsigs = switchtupleunion(sig)
for sig_n in splitsigs
result = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, si, sv)
(; rt, edge, effects, inferred_src) = result
(; rt, edge, effects, volatile_inf_result) = result
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
this_arginfo = ArgInfo(fargs, this_argtypes)
const_call_result = abstract_call_method_with_const_args(interp,
result, f, this_arginfo, si, match, sv)
const_result = nothing
const_result = volatile_inf_result
if const_call_result !== nothing
if const_call_result.rt ₚ rt
rt = const_call_result.rt
Expand All @@ -90,7 +90,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
this_rt = widenwrappedconditional(this_rt)
else
result = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, si, sv)
(; rt, edge, effects, inferred_src) = result
(; rt, edge, effects, volatile_inf_result) = result
this_conditional = ignorelimited(rt)
this_rt = widenwrappedconditional(rt)
# try constant propagation with argtypes for this match
Expand All @@ -99,7 +99,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
this_arginfo = ArgInfo(fargs, this_argtypes)
const_call_result = abstract_call_method_with_const_args(interp,
result, f, this_arginfo, si, match, sv)
const_result = nothing
const_result = volatile_inf_result
if const_call_result !== nothing
this_const_conditional = ignorelimited(const_call_result.rt)
this_const_rt = widenwrappedconditional(const_call_result.rt)
Expand All @@ -119,11 +119,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
const_results = fill!(Vector{Union{Nothing,ConstResult}}(undef, napplicable), nothing)
end
const_results[i] = const_result
elseif inferred_src !== nothing
if const_results === nothing
const_results = fill!(Vector{Union{Nothing,ConstResult}}(undef, napplicable), nothing)
end
const_results[i] = InferredResult(inferred_src)
end
edge === nothing || push!(edges, edge)
end
Expand Down Expand Up @@ -626,7 +621,7 @@ function abstract_call_method(interp::AbstractInterpreter,
sparams = recomputed[2]::SimpleVector
end

(; rt, edge, effects, inferred_src) = typeinf_edge(interp, method, sig, sparams, sv)
(; rt, edge, effects, volatile_inf_result) = typeinf_edge(interp, method, sig, sparams, sv)

if edge === nothing
edgecycle = edgelimited = true
Expand All @@ -650,7 +645,7 @@ function abstract_call_method(interp::AbstractInterpreter,
end
end

return MethodCallResult(rt, edgecycle, edgelimited, edge, effects, inferred_src)
return MethodCallResult(rt, edgecycle, edgelimited, edge, effects, volatile_inf_result)
end

function edge_matches_sv(interp::AbstractInterpreter, frame::AbsIntState,
Expand Down Expand Up @@ -753,14 +748,14 @@ struct MethodCallResult
edgelimited::Bool
edge::Union{Nothing,MethodInstance}
effects::Effects
inferred_src::Union{Nothing,CodeInfo}
volatile_inf_result::Union{Nothing,VolatileInferenceResult}
function MethodCallResult(@nospecialize(rt),
edgecycle::Bool,
edgelimited::Bool,
edge::Union{Nothing,MethodInstance},
effects::Effects,
inferred_src::Union{Nothing,CodeInfo}=nothing)
return new(rt, edgecycle, edgelimited, edge, effects, inferred_src)
volatile_inf_result::Union{Nothing,VolatileInferenceResult}=nothing)
return new(rt, edgecycle, edgelimited, edge, effects, volatile_inf_result)
end
end

Expand Down Expand Up @@ -1952,7 +1947,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
tienv = ccall(:jl_type_intersection_with_env, Any, (Any, Any), nargtype, method.sig)::SimpleVector
ti = tienv[1]; env = tienv[2]::SimpleVector
result = abstract_call_method(interp, method, ti, env, false, si, sv)
(; rt, edge, effects, inferred_src) = result
(; rt, edge, effects, volatile_inf_result) = result
match = MethodMatch(ti, env, method, argtype <: method.sig)
res = nothing
sig = match.spec_types
Expand All @@ -1969,15 +1964,12 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
invokecall = InvokeCall(types, lookupsig)
const_call_result = abstract_call_method_with_const_args(interp,
result, f, arginfo, si, match, sv, invokecall)
const_result = nothing
const_result = volatile_inf_result
if const_call_result !== nothing
if (𝕃ₚ, const_call_result.rt, rt)
(; rt, effects, const_result, edge) = const_call_result
end
end
if const_result === nothing && inferred_src !== nothing
const_result = InferredResult(inferred_src)
end
rt = from_interprocedural!(interp, rt, sv, arginfo, sig)
info = InvokeCallInfo(match, const_result)
edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge)
Expand Down Expand Up @@ -2101,13 +2093,13 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
closure::PartialOpaque, arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, check::Bool=true)
sig = argtypes_to_type(arginfo.argtypes)
result = abstract_call_method(interp, closure.source::Method, sig, Core.svec(), false, si, sv)
(; rt, edge, effects, inferred_src) = result
(; rt, edge, effects, volatile_inf_result) = result
tt = closure.typ
sigT = (unwrap_unionall(tt)::DataType).parameters[1]
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
𝕃ₚ = ipo_lattice(interp)
= (𝕃ₚ)
const_result = nothing
const_result = volatile_inf_result
if !result.edgecycle
const_call_result = abstract_call_method_with_const_args(interp, result,
nothing, arginfo, si, match, sv)
Expand All @@ -2125,9 +2117,6 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
effects = Effects(effects; nothrow=false)
end
end
if const_result === nothing && inferred_src !== nothing
const_result = InferredResult(inferred_src)
end
rt = from_interprocedural!(interp, rt, sv, arginfo, match.spec_types)
info = OpaqueClosureCallInfo(match, const_result)
edge !== nothing && add_backedge!(sv, edge)
Expand Down
9 changes: 6 additions & 3 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,10 @@ to enable flow-sensitive analysis.
"""
const VarTable = Vector{VarState}

const CACHE_MODE_NULL = 0x00
const CACHE_MODE_GLOBAL = 0x01 << 0
const CACHE_MODE_LOCAL = 0x01 << 1
const CACHE_MODE_NULL = 0x00 # not cached, without optimization
const CACHE_MODE_GLOBAL = 0x01 << 0 # cached globally, optimization allowed
const CACHE_MODE_LOCAL = 0x01 << 1 # cached locally, optimization allowed
const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization allowed

mutable struct InferenceState
#= information about this method instance =#
Expand Down Expand Up @@ -467,6 +468,8 @@ function convert_cache_mode(cache_mode::Symbol)
return CACHE_MODE_GLOBAL
elseif cache_mode === :local
return CACHE_MODE_LOCAL
elseif cache_mode === :volatile
return CACHE_MODE_VOLATILE
elseif cache_mode === :no
return CACHE_MODE_NULL
end
Expand Down
16 changes: 1 addition & 15 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,11 @@ is_source_inferred(@nospecialize src::MaybeCompressed) =

function inlining_policy(interp::AbstractInterpreter,
@nospecialize(src), @nospecialize(info::CallInfo), stmt_flag::UInt32, mi::MethodInstance,
argtypes::Vector{Any})
_::Vector{Any})
if isa(src, MaybeCompressed)
is_source_inferred(src) || return nothing
src_inlineable = is_stmt_inline(stmt_flag) || is_inlineable(src)
return src_inlineable ? src : nothing
elseif src === nothing && is_stmt_inline(stmt_flag)
# if this statement is forced to be inlined, make an additional effort to find the
# inferred source in the local cache
# we still won't find a source for recursive call because the "single-level" inlining
# seems to be more trouble and complex than it's worth
inf_result = cache_lookup(optimizer_lattice(interp), mi, argtypes, get_inference_cache(interp))
inf_result === nothing && return nothing
src = inf_result.src
if isa(src, CodeInfo)
src_inferred = is_source_inferred(src)
return src_inferred ? src : nothing
else
return nothing
end
elseif isa(src, IRCode)
return src
elseif isa(src, SemiConcreteResult)
Expand Down
89 changes: 42 additions & 47 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -842,52 +842,57 @@ function compileable_specialization(match::MethodMatch, effects::Effects,
return compileable_specialization(mi, effects, et, info; compilesig_invokes)
end

struct CachedResult
struct InferredResult
src::Any
effects::Effects
CachedResult(@nospecialize(src), effects::Effects) = new(src, effects)
InferredResult(@nospecialize(src), effects::Effects) = new(src, effects)
end
@inline function get_cached_result(state::InliningState, mi::MethodInstance)
code = get(code_cache(state), mi, nothing)
if code isa CodeInstance
if use_const_api(code)
# in this case function can be inlined to a constant
return ConstantCase(quoted(code.rettype_const))
else
src = @atomic :monotonic code.inferred
end
src = @atomic :monotonic code.inferred
effects = decode_effects(code.ipo_purity_bits)
return CachedResult(src, effects)
return InferredResult(src, effects)
end
return CachedResult(nothing, Effects())
return InferredResult(nothing, Effects())
end
@inline function get_local_result(inf_result::InferenceResult)
effects = inf_result.ipo_effects
if is_foldable_nothrow(effects)
res = inf_result.result
if isa(res, Const) && is_inlineable_constant(res.val)
# use constant calling convention
return ConstantCase(quoted(res.val))
end
end
return InferredResult(inf_result.src, effects)
end

# the general resolver for usual and const-prop'ed calls
function resolve_todo(mi::MethodInstance, result::Union{MethodMatch,InferenceResult},
function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult,VolatileInferenceResult},
argtypes::Vector{Any}, @nospecialize(info::CallInfo), flag::UInt32, state::InliningState;
invokesig::Union{Nothing,Vector{Any}}=nothing,
inferred_result::Union{Nothing,InferredResult}=nothing)
invokesig::Union{Nothing,Vector{Any}}=nothing)
et = InliningEdgeTracker(state, invokesig)

preserve_local_sources = true
if isa(result, InferenceResult)
src = result.src
effects = result.ipo_effects
if is_foldable_nothrow(effects)
res = result.result
if isa(res, Const) && is_inlineable_constant(res.val)
# use constant calling convention
add_inlining_backedge!(et, mi)
return ConstantCase(quoted(res.val))
end
end
inferred_result = get_local_result(result)
elseif isa(result, VolatileInferenceResult)
inferred_result = get_local_result(result.inf_result)
# volatile inference result can be inlined destructively
preserve_local_sources = OptimizationParams(state.interp).preserve_local_sources
else
cached_result = get_cached_result(state, mi)
if cached_result isa ConstantCase
add_inlining_backedge!(et, mi)
return cached_result
end
(; src, effects) = cached_result
inferred_result = get_cached_result(state, mi)
end
if inferred_result isa ConstantCase
add_inlining_backedge!(et, mi)
return inferred_result
end
(; src, effects) = inferred_result

# the duplicated check might have been done already within `analyze_method!`, but still
# we need it here too since we may come here directly using a constant-prop' result
Expand All @@ -901,17 +906,7 @@ function resolve_todo(mi::MethodInstance, result::Union{MethodMatch,InferenceRes
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)

add_inlining_backedge!(et, mi)

if src isa String && inferred_result !== nothing
# if the inferred source for this globally-cached method is available,
# use it destructively as it will never be used again
src = inferred_result.inferred_src
preserve_local_sources = OptimizationParams(state.interp).preserve_local_sources
else
preserve_local_sources = true
end
ir = retrieve_ir_for_inlining(mi, src, preserve_local_sources)

return InliningTodo(mi, ir, effects)
end

Expand Down Expand Up @@ -957,7 +952,7 @@ end
function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
@nospecialize(info::CallInfo), flag::UInt32, state::InliningState;
allow_typevars::Bool, invokesig::Union{Nothing,Vector{Any}}=nothing,
inferred_result::Union{Nothing,InferredResult}=nothing)
volatile_inf_result::Union{Nothing,VolatileInferenceResult}=nothing)
method = match.method
spec_types = match.spec_types

Expand Down Expand Up @@ -986,7 +981,7 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
# Get the specialization for this method signature
# (later we will decide what to do with it)
mi = specialize_method(match)
return resolve_todo(mi, match, argtypes, info, flag, state; invokesig, inferred_result)
return resolve_todo(mi, volatile_inf_result, argtypes, info, flag, state; invokesig)
end

function retrieve_ir_for_inlining(mi::MethodInstance, src::String, ::Bool=true)
Expand Down Expand Up @@ -1226,8 +1221,8 @@ function handle_invoke_call!(todo::Vector{Pair{Int,Any}},
return nothing
end
end
inferred_result = result isa InferredResult ? result : nothing
item = analyze_method!(match, argtypes, info, flag, state; allow_typevars=false, invokesig, inferred_result)
volatile_inf_result = result isa VolatileInferenceResult ? result : nothing
item = analyze_method!(match, argtypes, info, flag, state; allow_typevars=false, invokesig, volatile_inf_result)
end
end
handle_single_case!(todo, ir, idx, stmt, item, true)
Expand Down Expand Up @@ -1367,8 +1362,8 @@ function handle_any_const_result!(cases::Vector{InliningCase},
if isa(result, ConstPropResult)
return handle_const_prop_result!(cases, result, argtypes, info, flag, state; allow_abstract, allow_typevars)
else
@assert result === nothing || result isa InferredResult
return handle_match!(cases, match, argtypes, info, flag, state; allow_abstract, allow_typevars, inferred_result = result)
@assert result === nothing || result isa VolatileInferenceResult
return handle_match!(cases, match, argtypes, info, flag, state; allow_abstract, allow_typevars, volatile_inf_result = result)
end
end

Expand Down Expand Up @@ -1499,14 +1494,14 @@ end
function handle_match!(cases::Vector{InliningCase},
match::MethodMatch, argtypes::Vector{Any}, @nospecialize(info::CallInfo), flag::UInt32,
state::InliningState;
allow_abstract::Bool, allow_typevars::Bool, inferred_result::Union{Nothing,InferredResult})
allow_abstract::Bool, allow_typevars::Bool, volatile_inf_result::Union{Nothing,VolatileInferenceResult})
spec_types = match.spec_types
allow_abstract || isdispatchtuple(spec_types) || return false
# We may see duplicated dispatch signatures here when a signature gets widened
# during abstract interpretation: for the purpose of inlining, we can just skip
# processing this dispatch candidate (unless unmatched type parameters are present)
!allow_typevars && any(case::InliningCase->case.sig === spec_types, cases) && return true
item = analyze_method!(match, argtypes, info, flag, state; allow_typevars, inferred_result)
item = analyze_method!(match, argtypes, info, flag, state; allow_typevars, volatile_inf_result)
item === nothing && return false
push!(cases, InliningCase(spec_types, item))
return true
Expand Down Expand Up @@ -1613,8 +1608,9 @@ function handle_opaque_closure_call!(todo::Vector{Pair{Int,Any}},
if isa(result, SemiConcreteResult)
item = semiconcrete_result_item(result, info, flag, state)
else
@assert result === nothing || result isa InferredResult
item = analyze_method!(info.match, sig.argtypes, info, flag, state; allow_typevars=false, inferred_result=result)
@assert result === nothing || result isa VolatileInferenceResult
volatile_inf_result = result
item = analyze_method!(info.match, sig.argtypes, info, flag, state; allow_typevars=false, volatile_inf_result)
end
end
handle_single_case!(todo, ir, idx, stmt, item)
Expand All @@ -1639,8 +1635,7 @@ function handle_modifyfield!_call!(ir::IRCode, idx::Int, stmt::Expr, info::Modif
end

function handle_finalizer_call!(ir::IRCode, idx::Int, stmt::Expr, info::FinalizerInfo,
state::InliningState)

state::InliningState)
# Finalizers don't return values, so if their execution is not observable,
# we can just not register them
if is_removable_if_unused(info.effects)
Expand Down
8 changes: 6 additions & 2 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,12 @@ struct SemiConcreteResult <: ConstResult
effects::Effects
end

struct InferredResult <: ConstResult
inferred_src::CodeInfo
# XXX Technically this does not represent a result of constant inference, but rather that of
# regular edge inference. It might be more appropriate to rename `ConstResult` and
# `ConstCallInfo` to better reflect the fact that they represent either of local or
# volatile inference result.
struct VolatileInferenceResult <: ConstResult
inf_result::InferenceResult
end

"""
Expand Down
Loading

0 comments on commit 2019fc3

Please sign in to comment.