Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perform inference using optimizer-derived type information #56687

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Compiler/src/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,7 @@ mutable struct IRInterpretationState
callstack #::Vector{AbsIntState}
frameid::Int
parentid::Int
new_call_inferred::Bool

function IRInterpretationState(interp::AbstractInterpreter,
spec_info::SpecInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
Expand All @@ -850,7 +851,7 @@ mutable struct IRInterpretationState
edges = Any[]
callstack = AbsIntState[]
return new(spec_info, ir, mi, WorldWithRange(world, valid_worlds), curridx, argtypes_refined, ir.sptypes, tpdum,
ssa_refined, lazyreachability, tasks, edges, callstack, 0, 0)
ssa_refined, lazyreachability, tasks, edges, callstack, 0, 0, #=new_call_inferred=#false)
end
end

Expand Down
116 changes: 95 additions & 21 deletions Compiler/src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,48 @@ struct InliningState{Interp<:AbstractInterpreter}
edges::Vector{Any}
world::UInt
interp::Interp
opt_cache::IdDict{MethodInstance,CodeInstance}
end
function InliningState(sv::InferenceState, interp::AbstractInterpreter)
return InliningState(sv.edges, frame_world(sv), interp)
function InliningState(sv::InferenceState, interp::AbstractInterpreter,
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
return InliningState(sv.edges, frame_world(sv), interp, opt_cache)
end
function InliningState(interp::AbstractInterpreter)
return InliningState(Any[], get_inference_world(interp), interp)
function InliningState(interp::AbstractInterpreter,
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
return InliningState(Any[], get_inference_world(interp), interp, opt_cache)
end

struct OptimizerCache{CodeCache}
wvc::WorldView{CodeCache}
owner
opt_cache::IdDict{MethodInstance,CodeInstance}
function OptimizerCache(
wvc::WorldView{CodeCache},
owner,
opt_cache::IdDict{MethodInstance,CodeInstance}) where CodeCache
@nospecialize owner
new{CodeCache}(wvc, owner, opt_cache)
end
end
function get((; wvc, owner, opt_cache)::OptimizerCache, mi::MethodInstance, default)
if haskey(opt_cache, mi)
codeinst = opt_cache[mi]
if (codeinst.min_world ≤ wvc.worlds.min_world &&
wvc.worlds.max_world ≤ codeinst.max_world &&
codeinst.owner === owner)
@assert isdefined(codeinst, :inferred) && codeinst.inferred === nothing
return codeinst
end
end
return get(wvc, mi, default)
end

# get `code_cache(::AbstractInterpreter)` from `state::InliningState`
code_cache(state::InliningState) = WorldView(code_cache(state.interp), state.world)
function code_cache(state::InliningState)
cache = WorldView(code_cache(state.interp), state.world)
owner = cache_owner(state.interp)
return OptimizerCache(cache, owner, state.opt_cache)
end

mutable struct OptimizationState{Interp<:AbstractInterpreter}
linfo::MethodInstance
Expand All @@ -168,13 +200,15 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter}
bb_vartables::Vector{Union{Nothing,VarTable}}
insert_coverage::Bool
end
function OptimizationState(sv::InferenceState, interp::AbstractInterpreter)
inlining = InliningState(sv, interp)
function OptimizationState(sv::InferenceState, interp::AbstractInterpreter,
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
inlining = InliningState(sv, interp, opt_cache)
return OptimizationState(sv.linfo, sv.src, nothing, sv.stmt_info, sv.mod,
sv.sptypes, sv.slottypes, inlining, sv.cfg,
sv.unreachable, sv.bb_vartables, sv.insert_coverage)
end
function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractInterpreter)
function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractInterpreter,
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
# prepare src for running optimization passes if it isn't already
nssavalues = src.ssavaluetypes
if nssavalues isa Int
Expand All @@ -194,7 +228,7 @@ function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractIn
mod = isa(def, Method) ? def.module : def
# Allow using the global MI cache, but don't track edges.
# This method is mostly used for unit testing the optimizer
inlining = InliningState(interp)
inlining = InliningState(interp, opt_cache)
cfg = compute_basic_blocks(src.code)
unreachable = BitSet()
bb_vartables = Union{VarTable,Nothing}[]
Expand Down Expand Up @@ -999,7 +1033,7 @@ end

# run the optimization work
function optimize(interp::AbstractInterpreter, opt::OptimizationState, caller::InferenceResult)
@timeit "optimizer" ir = run_passes_ipo_safe(opt.src, opt)
@timeit "optimizer" ir = run_passes_ipo_safe(interp, opt, caller)
ipo_dataflow_analysis!(interp, opt, ir, caller)
return finish(interp, opt, ir, caller)
end
Expand All @@ -1019,27 +1053,25 @@ matchpass(optimize_until::Int, stage, _) = optimize_until == stage
matchpass(optimize_until::String, _, name) = optimize_until == name
matchpass(::Nothing, _, _) = false

function run_passes_ipo_safe(
ci::CodeInfo,
sv::OptimizationState,
optimize_until = nothing, # run all passes by default
)
function run_passes_ipo_safe(interp::AbstractInterpreter, sv::OptimizationState, result::InferenceResult;
optimize_until = nothing) # run all passes by default
ci = sv.src
__stage__ = 0 # used by @pass
# NOTE: The pass name MUST be unique for `optimize_until::String` to work
@pass "convert" ir = convert_to_ircode(ci, sv)
@pass "slot2reg" ir = slot2reg(ir, ci, sv)
# TODO: Domsorting can produce an updated domtree - no need to recompute here
@pass "compact 1" ir = compact!(ir)
@pass "Inlining" ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
# @timeit "verify 2" verify_ir(ir)
@pass "compact 2" ir = compact!(ir)
@pass "SROA" ir = sroa_pass!(ir, sv.inlining)
@pass "ADCE" (ir, made_changes) = adce_pass!(ir, sv.inlining)
if made_changes
@pass "compact 3" ir = compact!(ir, true)
end
@pass "ADCE" ir, changed = adce_pass!(ir, sv.inlining)
@pass "compact 3" changed && (
ir = compact!(ir, true))
@pass "optinf" optinf_worthwhile(ir) && (
ir = optinf!(ir, interp, sv, result))
if is_asserts()
@timeit "verify 3" begin
@timeit "verify" begin
verify_ir(ir, true, false, optimizer_lattice(sv.inlining.interp), sv.linfo)
verify_linetable(ir.debuginfo, length(ir.stmts))
end
Expand All @@ -1048,6 +1080,48 @@ function run_passes_ipo_safe(
return ir
end

# If the optimizer derives new type information (as implied by `IR_FLAG_REFINED`),
# and this new type information is available for the arguments of a call expression,
# further optimizations may be possible by performing irinterp on the optimized IR.
function optinf_worthwhile(ir::IRCode)
@assert isempty(ir.new_nodes) "expected compacted IRCode"
for i = 1:length(ir.stmts)
inst = ir[SSAValue(i)]
if has_flag(inst, IR_FLAG_REFINED)
if isexpr(inst[:stmt], :call)
return true
end
end
end
return false
end

function optinf!(ir::IRCode, interp::AbstractInterpreter, sv::OptimizationState, result::InferenceResult)
ci = sv.src
spec_info = SpecInfo(ci)
world = get_inference_world(interp)
min_world, max_world = first(result.valid_worlds), last(result.valid_worlds)
irsv = IRInterpretationState(interp, spec_info, ir, result.linfo, ir.argtypes,
world, min_world, max_world)
rt, (nothrow, noub) = ir_abstract_constant_propagation(interp, irsv)
if irsv.new_call_inferred
ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
ir = compact!(ir)
effects = result.effects
if nothrow
effects = Effects(effects; nothrow=true)
end
if noub
effects = Effects(effects; noub=ALWAYS_TRUE)
end
result.effects = effects
result.exc_result = refine_exception_type(result.exc_result, effects)
⋤ = strictneqpartialorder(ipo_lattice(interp))
result.result = rt ⋤ result.result ? rt : result.result
end
return ir
end

function strip_trailing_junk!(code::Vector{Any}, ssavaluetypes::Vector{Any}, ssaflags::Vector, debuginfo::DebugInfoStream, cfg::CFG, info::Vector{CallInfo})
# Remove `nothing`s at the end, we don't handle them well
# (we expect the last instruction to be a terminator)
Expand Down
4 changes: 2 additions & 2 deletions Compiler/src/ssair/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ using Base: # Base definitions
isempty, length, max, min, missing, println, push!, pushfirst!,
!, !==, &, *, +, -, :, <, <<, >, |, , , , , , , ,
using ..Compiler: # Compiler specific definitions
AbstractLattice, Compiler, IRCode, IR_FLAG_NOTHROW,
@show, AbstractLattice, Compiler, IRCode, IR_FLAG_NOTHROW,
argextype, fieldcount_noerror, has_flag, intrinsic_nothrow, is_meta_expr_head,
is_identity_free_argtype, isexpr, setfield!_nothrow, singleton_type, try_compute_field,
try_compute_fieldidx, widenconst
try_compute_fieldidx, widenconst,

function include(x::String)
if !isdefined(Base, :end_base_include)
Expand Down
3 changes: 2 additions & 1 deletion Compiler/src/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1709,7 +1709,8 @@ function reprocess_phi_node!(𝕃ₒ::AbstractLattice, compact::IncrementalCompa

# There's only one predecessor left - just replace it
v = phi.values[1]
if !⊑(𝕃ₒ, compact[compact.ssa_rename[old_idx]][:type], argextype(v, compact))
⋤ = strictneqpartialorder(𝕃ₒ)
if argextype(v, compact) ⋤ compact[compact.ssa_rename[old_idx]][:type]
v = Refined(v)
end
compact.ssa_rename[old_idx] = v
Expand Down
8 changes: 6 additions & 2 deletions Compiler/src/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, sstate::St
call = abstract_call(interp, arginfo, si, irsv)::Future
Future{Any}(call, interp, irsv) do call, interp, irsv
irsv.ir.stmts[irsv.curridx][:info] = call.info
irsv.new_call_inferred |= true
nothing
end
return call
Expand Down Expand Up @@ -204,7 +205,8 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
# Handled at the very end
return false
elseif isa(stmt, PiNode)
rt = tmeet(typeinf_lattice(interp), argextype(stmt.val, ir), widenconst(stmt.typ))
= join(typeinf_lattice(interp))
rt = argextype(stmt.val, ir) widenconst(stmt.typ)
elseif stmt === nothing
return false
elseif isa(stmt, GlobalRef)
Expand All @@ -226,7 +228,9 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
inst[:stmt] = quoted(rt.val)
end
return true
elseif !(typeinf_lattice(interp), inst[:type], rt)
end
= strictneqpartialorder(typeinf_lattice(interp))
if rt inst[:type]
inst[:type] = rt
return true
end
Expand Down
36 changes: 26 additions & 10 deletions Compiler/src/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -989,9 +989,10 @@ function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr,
lifted_leaves === nothing && return

result_t = Union{}
= join(𝕃ₒ)
for v in values(lifted_leaves)
v === nothing && return
result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact))
result_t = result_t argextype(v.val, compact)
end

(lifted_val, nest) = perform_lifting!(compact,
Expand All @@ -1001,8 +1002,12 @@ function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr,
compact[idx] = lifted_val === nothing ? nothing : Expr(:call, GlobalRef(Core, :tuple), lifted_val.val)
finish_phi_nest!(compact, nest)
if lifted_val !== nothing
if !(𝕃ₒ, compact[SSAValue(idx)][:type], tuple_tfunc(𝕃ₒ, Any[result_t]))
add_flag!(compact[SSAValue(idx)], IR_FLAG_REFINED)
stmttype = tuple_tfunc(𝕃ₒ, Any[result_t])
inst = compact[SSAValue(idx)]
= strictneqpartialorder(𝕃ₒ)
if stmttype inst[:type]
inst[:type] = stmttype
add_flag!(inst, IR_FLAG_REFINED)
end
end

Expand Down Expand Up @@ -1440,19 +1445,23 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
lifted_leaves, any_undef = lifted_result

result_t = Union{}
= join(𝕃ₒ)
for v in values(lifted_leaves)
v === nothing && continue
result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact))
result_t = result_t argextype(v.val, compact)
end

(lifted_val, nest) = perform_lifting!(compact,
visited_philikes, field, result_t, lifted_leaves, val, lazydomtree)

should_delete_node = false
line = compact[SSAValue(idx)][:line]
if lifted_val !== nothing && !(𝕃ₒ, compact[SSAValue(idx)][:type], result_t)
inst = compact[SSAValue(idx)]
line = inst[:line]
= strictneqpartialorder(𝕃ₒ)
if lifted_val !== nothing && result_t inst[:type]
compact[idx] = lifted_val === nothing ? nothing : lifted_val.val
add_flag!(compact[SSAValue(idx)], IR_FLAG_REFINED)
inst[:type] = result_t
add_flag!(inst, IR_FLAG_REFINED)
elseif lifted_val === nothing || isa(lifted_val.val, AnySSAValue)
# Save some work in a later compaction, by inserting this into the renamer now,
# but only do this if we didn't set the REFINED flag, to save work for irinterp
Expand Down Expand Up @@ -1855,9 +1864,15 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int,Tuple{SPCSet,SSADefUse}}
for use in du.uses
if use.kind === :getfield
inst = ir[SSAValue(use.idx)]
inst[:stmt] = compute_value_for_use(ir, domtree, allblocks,
newvalue = compute_value_for_use(ir, domtree, allblocks,
du, phinodes, fidx, use.idx)
add_flag!(inst, IR_FLAG_REFINED)
inst[:stmt] = newvalue
newvaluetyp = argextype(newvalue, ir)
= strictneqpartialorder(𝕃ₒ)
if newvaluetyp inst[:type]
inst[:type] = newvaluetyp
add_flag!(inst, IR_FLAG_REFINED)
end
elseif use.kind === :isdefined
continue # already rewritten if possible
elseif use.kind === :nopreserve
Expand All @@ -1878,11 +1893,12 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int,Tuple{SPCSet,SSADefUse}}
for b in phiblocks
n = ir[phinodes[b]][:stmt]::PhiNode
result_t = Bottom
= join(𝕃ₒ)
for p in ir.cfg.blocks[b].preds
push!(n.edges, p)
v = compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, p)
push!(n.values, v)
result_t = tmerge(𝕃ₒ, result_t, argextype(v, ir))
result_t = result_t argextype(v, ir)
end
ir[phinodes[b]][:type] = result_t
end
Expand Down
Loading