Skip to content

Commit

Permalink
inference×effects: improve the const-prop' heuristics
Browse files Browse the repository at this point in the history
This commit improves the heuristics to judge const-prop' profitability
with new effect property `:const_prop_profitable_args`. This is supposed
to supplement our primary const-prop' heuristic based on inlining cost
and is supposed to be a general fix for type stabilities issues discussed
at e.g. #45952 and #46430 (and eliminating the need for manual
`@constprop :aggressive` clutters in such situations).

The new effect property `:const_prop_profitable_args` tracks call
arguments that can be considered to shape up generated code if
their constant information is available. Currently this commit
exploits the following const-prop' profitabilities:
- `Val(x)`-profitability: as `Val` generally encodes constant information
  into the type domain, it is generally profitable to constant prop' `x`
  if the constructed `Val(x)` is used later (e.g. for dispatch).
  This basically tries to exploit const-prop' profitability in the
  following kind of case:
  ```julia
  kernel(::Val{1}, args...) = ...
  kernel(::Val{2}, args...) = ...

  function profitable1(x::Int, args...)
      kernel(Val(x), args...)
  end
  ```
  This allows the compiler to perform const-prop' for case like #45952
  even if the primary heuristics based on inlining cost gets confused.
- branching-profitability: constant branch condition is generally very
  profitable as it can shape up generated code as well as narrow down
  the return type inference by cutting off the dead branch.
  ```julia
  function profitable2(raise::Bool, args...)
      v = op(args...)
      if v === nothing && raise
          return nothing
      end
      return v
  end
  ```

Currently this commit passes all the test cases and also actually
improves target type stabilities, but doesn't work very ideally as it
seems to be a bit too aggressive (this commit right now strictly
increases the chances of const-propagation). I'd like to further tweak
this heuristic to keep the latency in general cases.
  • Loading branch information
aviatesk committed Jan 10, 2023
1 parent 548aee6 commit 36b9be3
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 31 deletions.
107 changes: 85 additions & 22 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
all_effects = Effects(all_effects; nothrow=false)
end

rettype = from_interprocedural!(𝕃ₚ, rettype, sv, arginfo, conditionals)
(; rt, effects) = from_interprocedural!(𝕃ₚ, rettype, all_effects, sv, arginfo, conditionals)

# Also considering inferring the compilation signature for this method, so
# it is available to the compiler in case it ends up needing it.
Expand All @@ -223,32 +223,32 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
method = match.method
sig = match.spec_types
mi = specialize_method(match; preexisting=true)
if mi !== nothing && !const_prop_methodinstance_heuristic(interp, match, mi, arginfo, sv)
if mi !== nothing && !const_prop_methodinstance_heuristic(interp, mi, arginfo, Effects(), sv)
csig = get_compileable_sig(method, sig, match.sparams)
if csig !== nothing && csig !== sig
abstract_call_method(interp, method, csig, match.sparams, multiple_matches, StmtInfo(false), sv)
end
end
end

if call_result_unused(si) && !(rettype === Bottom)
if call_result_unused(si) && !(rt === Bottom)
add_remark!(interp, sv, "Call result type was widened because the return value is unused")
# We're mainly only here because the optimizer might want this code,
# but we ourselves locally don't typically care about it locally
# (beyond checking if it always throws).
# So avoid adding an edge, since we don't want to bother attempting
# to improve our result even if it does change (to always throw),
# and avoid keeping track of a more complex result type.
rettype = Any
rt = Any
end
add_call_backedges!(interp, rettype, all_effects, edges, matches, atype, sv)
add_call_backedges!(interp, rt, effects, edges, matches, atype, sv)
if !isempty(sv.pclimitations) # remove self, if present
delete!(sv.pclimitations, sv)
for caller in sv.callers_in_cycle
delete!(sv.pclimitations, caller)
end
end
return CallMeta(rettype, all_effects, info)
return CallMeta(rt, effects, info)
end

struct FailedMethodMatch
Expand Down Expand Up @@ -351,15 +351,24 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
end
end

struct InterproceduralResult
rt
effects::Effects
InterproceduralResult(@nospecialize(rt), effects::Effects) = new(rt, effects)
end

"""
from_interprocedural!(𝕃ₚ::AbstractLattice, rt, sv::InferenceState, arginfo::ArgInfo, maybecondinfo) -> newrt
from_interprocedural!(𝕃ₚ::AbstractLattice, rt, effects::Effects,
sv::InferenceState, arginfo::ArgInfo, maybecondinfo) -> InterproceduralResult
Converts inter-procedural return type `rt` into a local lattice element `newrt`,
that is appropriate in the context of current local analysis frame `sv`, especially:
Converts extended lattice element `rt` and `effects::Effects` that represent inferred
return type and method call effects into new lattice ement and `Effects` that are
appropriate in the context of current local analysis frame `sv`, especially:
- unwraps `rt::LimitedAccuracy` and collects its limitations into the current frame `sv`
- converts boolean `rt` to new boolean `newrt` in a way `newrt` can propagate extra conditional
refinement information, e.g. translating `rt::InterConditional` into `newrt::Conditional`
that holds a type constraint information about a variable in `sv`
- recomputes `effects.const_prop_profitable_args` so that they are imposed on call arguments of `sv`
This function _should_ be used wherever we propagate results returned from
`abstract_call_method` or `abstract_call_method_with_const_args`.
Expand All @@ -371,7 +380,8 @@ In such cases `maybecondinfo` should be either of:
When we deal with multiple `MethodMatch`es, it's better to precompute `maybecondinfo` by
`tmerge`ing argument signature type of each method call.
"""
function from_interprocedural!(𝕃ₚ::AbstractLattice, @nospecialize(rt), sv::InferenceState, arginfo::ArgInfo, @nospecialize(maybecondinfo))
function from_interprocedural!(𝕃ₚ::AbstractLattice, @nospecialize(rt), effects::Effects,
sv::InferenceState, arginfo::ArgInfo, @nospecialize(maybecondinfo))
rt = collect_limitations!(rt, sv)
if isa(rt, InterMustAlias)
rt = from_intermustalias(rt, arginfo)
Expand All @@ -383,7 +393,23 @@ function from_interprocedural!(𝕃ₚ::AbstractLattice, @nospecialize(rt), sv::
end
end
@assert !(rt isa InterConditional || rt isa InterMustAlias) "invalid lattice element returned from inter-procedural context"
return rt
if effects.const_prop_profitable_args !== NO_PROFITABLE_ARGS
argsbits = 0x00
fargs = arginfo.fargs
if fargs !== nothing
for i = 1:length(fargs)
if is_const_prop_profitable_arg(effects, i)
arg = fargs[i]
if is_call_argument(arg, sv) && 1 slot_id(arg) 8
argsbits |= 0x01 << (slot_id(arg)-1)
end
end
end
end
const_prop_profitable_args = ConstPropProfitableArgs(argsbits)
effects = Effects(effects; const_prop_profitable_args)
end
return InterproceduralResult(rt, effects)
end

function collect_limitations!(@nospecialize(typ), sv::InferenceState)
Expand Down Expand Up @@ -1018,9 +1044,8 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
end
end
# try constant prop'
inf_cache = get_inference_cache(interp)
𝕃ᵢ = typeinf_lattice(interp)
inf_result = cache_lookup(𝕃ᵢ, mi, arginfo.argtypes, inf_cache)
inf_result = cache_lookup(𝕃ᵢ, mi, arginfo.argtypes, get_inference_cache(interp))
if inf_result === nothing
# if there might be a cycle, check to make sure we don't end up
# calling ourselves here.
Expand Down Expand Up @@ -1087,7 +1112,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter,
return nothing
end
mi = mi::MethodInstance
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, arginfo, sv)
if !force && !const_prop_methodinstance_heuristic(interp, mi, arginfo, result.effects, sv)
add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic")
return nothing
end
Expand Down Expand Up @@ -1239,8 +1264,8 @@ end
# where we would spend a lot of time, but are probably unlikely to get an improved
# result anyway.
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter,
match::MethodMatch, mi::MethodInstance, arginfo::ArgInfo, sv::InferenceState)
method = match.method
mi::MethodInstance, arginfo::ArgInfo, effects::Effects, sv::InferenceState)
method = mi.def::Method
if method.is_for_opaque_closure
# Not inlining an opaque closure can be very expensive, so be generous
# with the const-prop-ability. It is quite possible that we can't infer
Expand All @@ -1264,6 +1289,8 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter,
elseif is_stmt_noinline(flag)
# this call won't be inlined, thus this constant-prop' will most likely be unfruitful
return false
elseif any_const_prop_profitable_args(effects, arginfo.argtypes)
return true
else
code = get(code_cache(interp), mi, nothing)
if isdefined(code, :inferred)
Expand All @@ -1272,7 +1299,6 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter,
else
inferred = code.inferred
end
# TODO propagate a specific `CallInfo` that conveys information about this call
if inlining_policy(interp, inferred, NoCallInfo(), IR_FLAG_NULL, mi, arginfo.argtypes) !== nothing
return true
end
Expand All @@ -1282,6 +1308,21 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter,
return false # the cache isn't inlineable, so this constant-prop' will most likely be unfruitful
end

# check if constant information is available on any call argument that has been analyzed as
# const-prop' profitable
function any_const_prop_profitable_args(effects::Effects, argtypes::Vector{Any})
if effects.const_prop_profitable_args === NO_PROFITABLE_ARGS
return false
end
for i in 1:length(argtypes)
ai = widenconditional(argtypes[i])
if isa(ai, Const) && is_const_prop_profitable_arg(effects, i)
return true
end
end
return false
end

# This is only for use with `Conditional`.
# In general, usage of this is wrong.
ssa_def_slot(@nospecialize(arg), sv::IRCode) = nothing
Expand Down Expand Up @@ -1921,11 +1962,10 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
(; rt, effects, const_result, edge) = const_call_result
end
end
rt = from_interprocedural!(𝕃ₚ, rt, sv, arginfo, sig)
effects = Effects(effects; nonoverlayed=!overlayed)
info = InvokeCallInfo(match, const_result)
(; rt, effects) = from_interprocedural!(𝕃ₚ, rt, effects, sv, arginfo, sig)
edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge)
return CallMeta(rt, effects, info)
return CallMeta(rt, effects, InvokeCallInfo(match, const_result))
end

function invoke_rewrite(xs::Vector{Any})
Expand Down Expand Up @@ -2040,6 +2080,20 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
val = _pure_eval_call(f, arginfo)
return CallMeta(val === nothing ? Type : val, EFFECTS_TOTAL, MethodResultPure())
end
elseif la == 2 && istoptype(f, :Val)
# `Val` generally encodes constant information into the type domain, so there is
# generally a high profitability for constant propagation if the argument of the
# `Val` constructor is a call argument
fargs = arginfo.fargs
if fargs !== nothing
arg = arginfo.fargs[2]
if is_call_argument(arg, sv) && !isempty(sv.ssavalue_uses[sv.currpc])
if 1 slot_id(arg) 8
const_prop_profitable_args = ConstPropProfitableArgs(0x01 << (slot_id(arg)-1))
merge_effects!(interp, sv, Effects(EFFECTS_TOTAL; const_prop_profitable_args))
end
end
end
end
atype = argtypes_to_type(argtypes)
return abstract_call_gf_by_type(interp, f, arginfo, si, atype, sv, max_methods)
Expand Down Expand Up @@ -2073,7 +2127,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
effects = Effects(effects; nothrow=false)
end
end
rt = from_interprocedural!(𝕃ₚ, rt, sv, arginfo, match.spec_types)
(; rt, effects) = from_interprocedural!(𝕃ₚ, rt, effects, sv, arginfo, match.spec_types)
info = OpaqueClosureCallInfo(match, const_result)
edge !== nothing && add_backedge!(sv, edge)
return CallMeta(rt, effects, info)
Expand Down Expand Up @@ -2496,7 +2550,7 @@ function abstract_eval_foreigncall(interp::AbstractInterpreter, e::Expr, vtypes:
override.terminates_globally ? true : effects.terminates,
override.notaskstate ? true : effects.notaskstate,
override.inaccessiblememonly ? ALWAYS_TRUE : effects.inaccessiblememonly,
effects.nonoverlayed)
effects.nonoverlayed, effects.const_prop_profitable_args)
end
return RTEffects(t, effects)
end
Expand Down Expand Up @@ -2865,6 +2919,15 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
@goto branch
elseif isa(stmt, GotoIfNot)
condx = stmt.cond
if is_call_argument(condx, frame)
# if this condition object is a call argument, there will be a high
# profitability for constant-propagating it, since it can shape up
# the generated code by cutting off the dead branch entirely
if 1 slot_id(condx) 8
const_prop_profitable_args = ConstPropProfitableArgs(0x01 << (slot_id(condx)-1))
merge_effects!(interp, frame, Effects(EFFECTS_TOTAL; const_prop_profitable_args))
end
end
condt = abstract_eval_value(interp, condx, currstate, frame)
if condt === Bottom
ssavaluetypes[currpc] = Bottom
Expand Down
34 changes: 26 additions & 8 deletions base/compiler/effects.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
struct ConstPropProfitableArgs
argsbits::UInt8
end

"""
effects::Effects
Expand Down Expand Up @@ -63,6 +67,7 @@ struct Effects
notaskstate::Bool
inaccessiblememonly::UInt8
nonoverlayed::Bool
const_prop_profitable_args::ConstPropProfitableArgs
noinbounds::Bool
function Effects(
consistent::UInt8,
Expand All @@ -72,6 +77,7 @@ struct Effects
notaskstate::Bool,
inaccessiblememonly::UInt8,
nonoverlayed::Bool,
const_prop_profitable_args::ConstPropProfitableArgs = NO_PROFITABLE_ARGS,
noinbounds::Bool = true)
return new(
consistent,
Expand All @@ -81,6 +87,7 @@ struct Effects
notaskstate,
inaccessiblememonly,
nonoverlayed,
const_prop_profitable_args,
noinbounds)
end
end
Expand All @@ -98,6 +105,9 @@ const EFFECT_FREE_IF_INACCESSIBLEMEMONLY = 0x01 << 1
# :inaccessiblememonly bits
const INACCESSIBLEMEM_OR_ARGMEMONLY = 0x01 << 1

# :const_prop_profitable_args bits
const NO_PROFITABLE_ARGS = ConstPropProfitableArgs(0x00)

const EFFECTS_TOTAL = Effects(ALWAYS_TRUE, ALWAYS_TRUE, true, true, true, ALWAYS_TRUE, true)
const EFFECTS_THROWS = Effects(ALWAYS_TRUE, ALWAYS_TRUE, false, true, true, ALWAYS_TRUE, true)
const EFFECTS_UNKNOWN = Effects(ALWAYS_FALSE, ALWAYS_FALSE, false, false, false, ALWAYS_FALSE, true) # unknown mostly, but it's not overlayed at least (e.g. it's not a call)
Expand All @@ -111,6 +121,7 @@ function Effects(e::Effects = EFFECTS_UNKNOWN′;
notaskstate::Bool = e.notaskstate,
inaccessiblememonly::UInt8 = e.inaccessiblememonly,
nonoverlayed::Bool = e.nonoverlayed,
const_prop_profitable_args::ConstPropProfitableArgs = e.const_prop_profitable_args,
noinbounds::Bool = e.noinbounds)
return Effects(
consistent,
Expand All @@ -120,6 +131,7 @@ function Effects(e::Effects = EFFECTS_UNKNOWN′;
notaskstate,
inaccessiblememonly,
nonoverlayed,
const_prop_profitable_args,
noinbounds)
end

Expand All @@ -132,6 +144,7 @@ function merge_effects(old::Effects, new::Effects)
merge_effectbits(old.notaskstate, new.notaskstate),
merge_effectbits(old.inaccessiblememonly, new.inaccessiblememonly),
merge_effectbits(old.nonoverlayed, new.nonoverlayed),
merge_effectbits(old.const_prop_profitable_args, new.const_prop_profitable_args),
merge_effectbits(old.noinbounds, new.noinbounds))
end

Expand All @@ -142,6 +155,7 @@ function merge_effectbits(old::UInt8, new::UInt8)
return old | new
end
merge_effectbits(old::Bool, new::Bool) = old & new
merge_effectbits(old::ConstPropProfitableArgs, new::ConstPropProfitableArgs) = ConstPropProfitableArgs(old.argsbits | new.argsbits)

is_consistent(effects::Effects) = effects.consistent === ALWAYS_TRUE
is_effect_free(effects::Effects) = effects.effect_free === ALWAYS_TRUE
Expand Down Expand Up @@ -177,14 +191,17 @@ is_effect_free_if_inaccessiblememonly(effects::Effects) = !iszero(effects.effect

is_inaccessiblemem_or_argmemonly(effects::Effects) = effects.inaccessiblememonly === INACCESSIBLEMEM_OR_ARGMEMONLY

is_const_prop_profitable_arg(effects::Effects, arg::Int) = !iszero(effects.const_prop_profitable_args.argsbits & (0x01 << (arg-1)))

function encode_effects(e::Effects)
return ((e.consistent % UInt32) << 0) |
((e.effect_free % UInt32) << 3) |
((e.nothrow % UInt32) << 5) |
((e.terminates % UInt32) << 6) |
((e.notaskstate % UInt32) << 7) |
((e.inaccessiblememonly % UInt32) << 8) |
((e.nonoverlayed % UInt32) << 10)
return ((e.consistent % UInt32) << 0) |
((e.effect_free % UInt32) << 3) |
((e.nothrow % UInt32) << 5) |
((e.terminates % UInt32) << 6) |
((e.notaskstate % UInt32) << 7) |
((e.inaccessiblememonly % UInt32) << 8) |
((e.nonoverlayed % UInt32) << 10) |
((e.const_prop_profitable_args.argsbits % UInt32) << 11)
end

function decode_effects(e::UInt32)
Expand All @@ -195,7 +212,8 @@ function decode_effects(e::UInt32)
_Bool((e >> 6) & 0x01),
_Bool((e >> 7) & 0x01),
UInt8((e >> 8) & 0x03),
_Bool((e >> 10) & 0x01))
_Bool((e >> 10) & 0x01),
ConstPropProfitableArgs(UInt8((e >> 11) & 0x7f)))
end

struct EffectsOverride
Expand Down
2 changes: 2 additions & 0 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -531,3 +531,5 @@ function narguments(sv::InferenceState)
nargs = length(sv.result.argtypes) - isva
return nargs
end
is_call_argument(@nospecialize(x), sv::InferenceState) =
isa(x, SlotNumber) && slot_id(x) narguments(sv)
12 changes: 11 additions & 1 deletion base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,23 @@ _topmod(m::Module) = ccall(:jl_base_relative_to, Any, (Any,), m)::Module

function istopfunction(@nospecialize(f), name::Symbol)
tn = typeof(f).name
if tn.mt.name === name
mn = tn.mt.name
if mn === name
top = _topmod(tn.module)
return isdefined(top, name) && isconst(top, name) && f === getglobal(top, name)
end
return false
end

function istoptype(@nospecialize(T), name::Symbol)
t = unwrap_unionall(T)
if isa(t, DataType) && t.name.name === name
top = _topmod(t.name.module)
return isdefined(top, name) && isconst(top, name) && T === getglobal(top, name)
end
return false
end

#######
# AST #
#######
Expand Down
7 changes: 7 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1128,3 +1128,10 @@ end
import Base.Broadcast: BroadcastStyle, DefaultArrayStyle
@test Base.infer_effects(BroadcastStyle, (DefaultArrayStyle{1},DefaultArrayStyle{2},)) |>
Core.Compiler.is_foldable

function f44330(x; isreal=true)
y = similar(x)
y .= x
isreal ? real(y) : y
end
@inferred f44330(randn(ComplexF64, 1))
Loading

0 comments on commit 36b9be3

Please sign in to comment.