Skip to content

Commit

Permalink
optimizer: supports callsite annotations of inlining, fixes #18773
Browse files Browse the repository at this point in the history
Enable `@inline`/`@noinline` annotations on function callsites.
From #40754.

Now `@inline` and `@noinline` can be applied to a code block and then
the compiler will try to (not) inline calls within the block:
```julia
@inline f(...) # The compiler will try to inline `f`

@inline f(...) + g(...) # The compiler will try to inline `f`, `g` and `+`

@inline f(args...) = ... # Of course annotations on a definition is still allowed
```

Here are couple of notes on how those callsite annotations will work:
- callsite annotation always has the precedence over the annotation
  applied to the definition of the called function, whichever we use
  `@inline`/`@noinline`:
  ```julia
  @inline function explicit_inline(args...)
      # body
  end

  let
      @noinline explicit_inline(args...) # this call will not be inlined
  end
  ```
- when callsite annotations are nested, the innermost annotations has
  the precedence
  ```julia
  @noinline let a0, b0 = ...
      a = @inline f(a0)  # the compiler will try to inline this call
      b = notinlined(b0) # the compiler will NOT try to inline this call
      return a, b
  end
  ```
They're both tested and included in documentations.
  • Loading branch information
aviatesk committed Aug 13, 2021
1 parent 6ea0b78 commit 66d549b
Show file tree
Hide file tree
Showing 15 changed files with 312 additions and 67 deletions.
11 changes: 7 additions & 4 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
return nothing
end
mi = mi::MethodInstance
if !force && !const_prop_methodinstance_heuristic(interp, method, mi)
if !force && !const_prop_methodinstance_heuristic(interp, match, mi)
add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic")
return nothing
end
Expand Down Expand Up @@ -692,7 +692,8 @@ end
# This is a heuristic to avoid trying to const prop through complicated functions
# where we would spend a lot of time, but are probably unlikely to get an improved
# result anyway.
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance)
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance)
method = match.method
if method.is_for_opaque_closure
# Not inlining an opaque closure can be very expensive, so be generous
# with the const-prop-ability. It is quite possible that we can't infer
Expand All @@ -710,7 +711,8 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, method
if isdefined(code, :inferred) && !cache_inlineable
cache_inf = code.inferred
if !(cache_inf === nothing)
cache_inlineable = inlining_policy(interp)(cache_inf) !== nothing
# TODO maybe we want to respect callsite `@inline`/`@noinline` annotations here ?
cache_inlineable = inlining_policy(interp)(cache_inf, nothing, match) !== nothing
end
end
if !cache_inlineable
Expand Down Expand Up @@ -1889,7 +1891,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if isa(fname, SlotNumber)
changes = StateUpdate(fname, VarState(Any, false), changes, false)
end
elseif hd === :inbounds || hd === :meta || hd === :loopinfo || hd === :code_coverage_effect
elseif hd === :code_coverage_effect ||
(hd !== :boundscheck && hd !== nothing && is_meta_expr_head(hd)) # :boundscheck can be narrowed to Bool
# these do not generate code
else
t = abstract_eval_statement(interp, stmt, changes, frame)
Expand Down
50 changes: 44 additions & 6 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,20 @@ struct InliningState{S <: Union{EdgeTracker, Nothing}, T, P}
policy::P
end

function default_inlining_policy(@nospecialize(src))
function default_inlining_policy(@nospecialize(src), stmt_flag::Union{Nothing,UInt8}, match::Union{MethodMatch,InferenceResult})
if isa(src, CodeInfo) || isa(src, Vector{UInt8})
src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), src)
src_inlineable = ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
src_inlineable = is_stmt_inline(stmt_flag) || ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
return src_inferred && src_inlineable ? src : nothing
end
if isa(src, OptimizationState) && isdefined(src, :ir)
return src.src.inlineable ? src.ir : nothing
elseif isa(src, OptimizationState) && isdefined(src, :ir)
return (is_stmt_inline(stmt_flag) || src.src.inlineable) ? src.ir : nothing
elseif src === nothing && is_stmt_inline(stmt_flag) && isa(match, MethodMatch)
# when the source isn't available at this moment, try to re-infer and inline it
# HACK in order to avoid cycles here, we disable inlining and makes sure the following inference never comes here
# TODO sort out `AbstractInterpreter` interface to handle this well, and also inference should try to keep the source if the statement will be inlined
interp = NativeInterpreter(; opt_params = OptimizationParams(; inlining = false))
src, rt = typeinf_code(interp, match.method, match.spec_types, match.sparams, true)
return src
end
return nothing
end
Expand Down Expand Up @@ -129,6 +135,10 @@ const SLOT_USEDUNDEF = 32 # slot has uses that might raise UndefVarError
# This statement was marked as @inbounds by the user. If replaced by inlining,
# any contained boundschecks may be removed
const IR_FLAG_INBOUNDS = 0x01
# This statement was marked as @inline by the user
const IR_FLAG_INLINE = 0x01 << 1
# This statement was marked as @noinline by the user
const IR_FLAG_NOINLINE = 0x01 << 2
# This statement may be removed if its result is unused. In particular it must
# thus be both pure and effect free.
const IR_FLAG_EFFECT_FREE = 0x01 << 4
Expand Down Expand Up @@ -174,6 +184,11 @@ function isinlineable(m::Method, me::OptimizationState, params::OptimizationPara
return inlineable
end

is_stmt_inline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_INLINE != 0
is_stmt_inline(::Nothing) = false
is_stmt_noinline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_NOINLINE != 0
is_stmt_noinline(::Nothing) = false # not used for now

# These affect control flow within the function (so may not be removed
# if there is no usage within the function), but don't affect the purity
# of the function as a whole.
Expand Down Expand Up @@ -360,6 +375,7 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, sv::
renumber_ir_elements!(code, changemap, labelmap)

inbounds_depth = 0 # Number of stacked inbounds
inline_flags = BitVector()
meta = Any[]
flags = fill(0x00, length(code))
for i = 1:length(code)
Expand All @@ -374,16 +390,38 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, sv::
inbounds_depth -= 1
end
stmt = nothing
elseif isexpr(stmt, :inline)
if stmt.args[1]::Bool
push!(inline_flags, true)
else
pop!(inline_flags)
end
stmt = nothing
elseif isexpr(stmt, :noinline)
if stmt.args[1]::Bool
push!(inline_flags, false)
else
pop!(inline_flags)
end
stmt = nothing
else
stmt = normalize(stmt, meta)
end
code[i] = stmt
if !(stmt === nothing)
if stmt !== nothing
if inbounds_depth > 0
flags[i] |= IR_FLAG_INBOUNDS
end
if !isempty(inline_flags)
if last(inline_flags)
flags[i] |= IR_FLAG_INLINE
else
flags[i] |= IR_FLAG_NOINLINE
end
end
end
end
@assert isempty(inline_flags) "malformed meta flags"
strip_trailing_junk!(ci, code, stmtinfo, flags)
cfg = compute_basic_blocks(code)
types = Any[]
Expand Down
63 changes: 33 additions & 30 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
argexprs::Vector{Any}, atypes::Vector{Any}, arginfos::Vector{Any},
arg_start::Int, istate::InliningState)

flag = ir.stmts[idx][:flag]
new_argexprs = Any[argexprs[arg_start]]
new_atypes = Any[atypes[arg_start]]
# loop over original arguments and flatten any known iterators
Expand Down Expand Up @@ -659,8 +660,9 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
info = call.info
handled = false
if isa(info, ConstCallInfo)
if maybe_handle_const_call!(ir, state1.id, new_stmt, info, new_sig,
call.rt, istate, false, todo)
if !is_stmt_noinline(flag) && maybe_handle_const_call!(
ir, state1.id, new_stmt, info, new_sig,call.rt, istate, flag, false, todo)

handled = true
else
info = info.call
Expand All @@ -671,7 +673,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
MethodMatchInfo[info] : info.matches
# See if we can inline this call to `iterate`
analyze_single_call!(ir, todo, state1.id, new_stmt,
new_sig, call.rt, info, istate)
new_sig, call.rt, info, istate, flag)
end
if i != length(thisarginfo.each)
valT = getfield_tfunc(call.rt, Const(1))
Expand Down Expand Up @@ -719,16 +721,16 @@ function compileable_specialization(et::Union{EdgeTracker, Nothing}, (; linfo)::
return mi
end

function resolve_todo(todo::InliningTodo, state::InliningState)
spec = todo.spec::DelayedInliningSpec
function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
(; match) = todo.spec::DelayedInliningSpec

#XXX: update_valid_age!(min_valid[1], max_valid[1], sv)
isconst, src = false, nothing
if isa(spec.match, InferenceResult)
let inferred_src = spec.match.src
if isa(match, InferenceResult)
let inferred_src = match.src
if isa(inferred_src, Const)
if !is_inlineable_constant(inferred_src.val)
return compileable_specialization(state.et, spec.match)
return compileable_specialization(state.et, match)
end
isconst, src = true, quoted(inferred_src.val)
else
Expand Down Expand Up @@ -756,12 +758,10 @@ function resolve_todo(todo::InliningTodo, state::InliningState)
return ConstantCase(src)
end

if src !== nothing
src = state.policy(src)
end
src = state.policy(src, flag, match)

if src === nothing
return compileable_specialization(et, spec.match)
return compileable_specialization(et, match)
end

if isa(src, IRCode)
Expand All @@ -772,9 +772,9 @@ function resolve_todo(todo::InliningTodo, state::InliningState)
return InliningTodo(todo.mi, src)
end

function resolve_todo(todo::UnionSplit, state::InliningState)
function resolve_todo(todo::UnionSplit, state::InliningState, flag::UInt8)
UnionSplit(todo.fully_covered, todo.atype,
Pair{Any,Any}[sig=>resolve_todo(item, state) for (sig, item) in todo.cases])
Pair{Any,Any}[sig=>resolve_todo(item, state, flag) for (sig, item) in todo.cases])
end

function validate_sparams(sparams::SimpleVector)
Expand All @@ -785,7 +785,7 @@ function validate_sparams(sparams::SimpleVector)
end

function analyze_method!(match::MethodMatch, atypes::Vector{Any},
state::InliningState, @nospecialize(stmttyp))
state::InliningState, @nospecialize(stmttyp), flag::UInt8)
method = match.method
methsig = method.sig

Expand All @@ -805,7 +805,7 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any},

et = state.et

if !state.params.inlining
if !state.params.inlining || is_stmt_noinline(flag)
return compileable_specialization(et, match)
end

Expand All @@ -819,7 +819,7 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any},
# If we don't have caches here, delay resolving this MethodInstance
# until the batch inlining step (or an external post-processing pass)
state.mi_cache === nothing && return todo
return resolve_todo(todo, state)
return resolve_todo(todo, state, flag)
end

function InliningTodo(mi::MethodInstance, ir::IRCode)
Expand Down Expand Up @@ -1044,7 +1044,7 @@ is_builtin(s::Signature) =
s.ft Builtin

function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result)::InvokeCallInfo,
state::InliningState, todo::Vector{Pair{Int, Any}})
state::InliningState, todo::Vector{Pair{Int, Any}}, flag::UInt8)
stmt = ir.stmts[idx][:inst]
calltype = ir.stmts[idx][:type]

Expand All @@ -1058,17 +1058,17 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result):
atypes = atypes[4:end]
pushfirst!(atypes, atype0)

if isa(result, InferenceResult)
if isa(result, InferenceResult) && !is_stmt_noinline(flag)
(; mi) = item = InliningTodo(result, atypes, calltype)
validate_sparams(mi.sparam_vals) || return nothing
if argtypes_to_type(atypes) <: mi.def.sig
state.mi_cache !== nothing && (item = resolve_todo(item, state))
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
handle_single_case!(ir, stmt, idx, item, true, todo)
return nothing
end
end

result = analyze_method!(match, atypes, state, calltype)
result = analyze_method!(match, atypes, state, calltype, flag)
handle_single_case!(ir, stmt, idx, result, true, todo)
return nothing
end
Expand Down Expand Up @@ -1163,7 +1163,7 @@ end

function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt),
sig::Signature, @nospecialize(calltype), infos::Vector{MethodMatchInfo},
state::InliningState)
state::InliningState, flag::UInt8)
cases = Pair{Any, Any}[]
signature_union = Union{}
only_method = nothing # keep track of whether there is one matching method
Expand Down Expand Up @@ -1197,7 +1197,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
fully_covered = false
continue
end
case = analyze_method!(match, sig.atypes, state, calltype)
case = analyze_method!(match, sig.atypes, state, calltype, flag)
if case === nothing
fully_covered = false
continue
Expand All @@ -1224,7 +1224,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
match = meth[1]
end
fully_covered = true
case = analyze_method!(match, sig.atypes, state, calltype)
case = analyze_method!(match, sig.atypes, state, calltype, flag)
case === nothing && return
push!(cases, Pair{Any,Any}(match.spec_types, case))
end
Expand All @@ -1246,7 +1246,7 @@ end

function maybe_handle_const_call!(ir::IRCode, idx::Int, stmt::Expr,
info::ConstCallInfo, sig::Signature, @nospecialize(calltype),
state::InliningState,
state::InliningState, flag::UInt8,
isinvoke::Bool, todo::Vector{Pair{Int, Any}})
# when multiple matches are found, bail out and later inliner will union-split this signature
# TODO effectively use multiple constant analysis results here
Expand All @@ -1258,7 +1258,7 @@ function maybe_handle_const_call!(ir::IRCode, idx::Int, stmt::Expr,
validate_sparams(mi.sparam_vals) || return true
mthd_sig = mi.def.sig
mistypes = mi.specTypes
state.mi_cache !== nothing && (item = resolve_todo(item, state))
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
if sig.atype <: mthd_sig
handle_single_case!(ir, stmt, idx, item, isinvoke, todo)
return true
Expand Down Expand Up @@ -1296,6 +1296,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
info = info.info
end

flag = ir.stmts[idx][:flag]

# Inference determined this couldn't be analyzed. Don't question it.
if info === false
continue
Expand All @@ -1305,23 +1307,24 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
# it'll have performed a specialized analysis for just this case. Use its
# result.
if isa(info, ConstCallInfo)
if maybe_handle_const_call!(ir, idx, stmt, info, sig, calltype, state, sig.f === Core.invoke, todo)
if !is_stmt_noinline(flag) && maybe_handle_const_call!(
ir, idx, stmt, info, sig, calltype, state, flag, sig.f === Core.invoke, todo)
continue
else
info = info.call
end
end

if isa(info, OpaqueClosureCallInfo)
result = analyze_method!(info.match, sig.atypes, state, calltype)
result = analyze_method!(info.match, sig.atypes, state, calltype, flag)
handle_single_case!(ir, stmt, idx, result, false, todo)
continue
end

# Handle invoke
if sig.f === Core.invoke
if isa(info, InvokeCallInfo)
inline_invoke!(ir, idx, sig, info, state, todo)
inline_invoke!(ir, idx, sig, info, state, todo, flag)
end
continue
end
Expand All @@ -1335,7 +1338,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
continue
end

analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, state)
analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, state, flag)
end
todo
end
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, linfo::MethodInsta
nslots = length(ci.slotflags)
resize!(ci.slottypes::Vector{Any}, nslots)
resize!(ci.slotnames, nslots)
return ccall(:jl_compress_ir, Any, (Any, Any), def, ci)
return ccall(:jl_compress_ir, Vector{UInt8}, (Any, Any), def, ci)
else
return ci
end
Expand Down
5 changes: 3 additions & 2 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ end

# Meta expression head, these generally can't be deleted even when they are
# in a dead branch but can be ignored when analyzing uses/liveness.
is_meta_expr_head(head::Symbol) = (head === :inbounds || head === :boundscheck || head === :meta || head === :loopinfo)
is_meta_expr_head(head::Symbol) = (head === :inbounds || head === :boundscheck || head === :meta ||
head === :loopinfo || head === :inline || head === :noinline)

sym_isless(a::Symbol, b::Symbol) = ccall(:strcmp, Int32, (Ptr{UInt8}, Ptr{UInt8}), a, b) < 0

Expand Down Expand Up @@ -188,7 +189,7 @@ function specialize_method(method::Method, @nospecialize(atypes), sparams::Simpl
if preexisting
# check cached specializations
# for an existing result stored there
return ccall(:jl_specializations_lookup, Any, (Any, Any), method, atypes)
return ccall(:jl_specializations_lookup, Any, (Any, Any), method, atypes)::Union{Nothing,MethodInstance}
end
return ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), method, atypes, sparams)
end
Expand Down
Loading

0 comments on commit 66d549b

Please sign in to comment.