Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Semi-concrete IR interpreter #44803

Merged
merged 1 commit into from
Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 191 additions & 98 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ include("compiler/stmtinfo.jl")

include("compiler/abstractinterpretation.jl")
include("compiler/typeinfer.jl")
include("compiler/optimize.jl") # TODO: break this up further + extract utilities
include("compiler/optimize.jl")

# required for bootstrap because sort.jl uses extrema
# to decide whether to dispatch to counting sort.
Expand Down
55 changes: 31 additions & 24 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,36 @@ function is_forwardable_argtype(@nospecialize x)
isa(x, PartialOpaque)
end

function va_process_argtypes(given_argtypes::Vector{Any}, mi::MethodInstance,
ianatol marked this conversation as resolved.
Show resolved Hide resolved
condargs::Union{Vector{Tuple{Int,Int}}, Nothing}=nothing)
isva = mi.def.isva
nargs = Int(mi.def.nargs)
if isva || isvarargtype(given_argtypes[end])
isva_given_argtypes = Vector{Any}(undef, nargs)
for i = 1:(nargs - isva)
isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
end
if isva
if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end])
last = length(given_argtypes)
else
last = nargs
end
isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[last:end])
# invalidate `Conditional` imposed on varargs
if condargs !== nothing
for (slotid, i) in condargs
if slotid ≥ last
isva_given_argtypes[i] = widenconditional(isva_given_argtypes[i])
end
end
end
end
return isva_given_argtypes
end
return given_argtypes
end

# In theory, there could be a `cache` containing a matching `InferenceResult`
# for the provided `linfo` and `given_argtypes`. The purpose of this function is
# to return a valid value for `cache_lookup(linfo, argtypes, cache).argtypes`,
Expand Down Expand Up @@ -56,30 +86,7 @@ function matching_cache_argtypes(
end
given_argtypes[i] = widenconditional(argtype)
end
isva = def.isva
if isva || isvarargtype(given_argtypes[end])
isva_given_argtypes = Vector{Any}(undef, nargs)
for i = 1:(nargs - isva)
isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
end
if isva
if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end])
last = length(given_argtypes)
else
last = nargs
end
isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[last:end])
# invalidate `Conditional` imposed on varargs
if condargs !== nothing
for (slotid, i) in condargs
if slotid ≥ last
isva_given_argtypes[i] = widenconditional(isva_given_argtypes[i])
end
end
end
end
given_argtypes = isva_given_argtypes
end
given_argtypes = va_process_argtypes(given_argtypes, linfo, condargs)
@assert length(given_argtypes) == nargs
for i in 1:nargs
given_argtype = given_argtypes[i]
Expand Down
18 changes: 13 additions & 5 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ function in(idx::Int, bsbmp::BitSetBoundedMinPrioritySet)
return idx in bsbmp.elems
end

function append!(bsbmp::BitSetBoundedMinPrioritySet, itr)
for val in itr
push!(bsbmp, val)
end
end

mutable struct InferenceState
#= information about this method instance =#
linfo::MethodInstance
Expand Down Expand Up @@ -209,8 +215,10 @@ Effects(state::InferenceState) = state.ipo_effects
function merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::Effects)
caller.ipo_effects = merge_effects(caller.ipo_effects, effects)
end

merge_effects!(interp::AbstractInterpreter, caller::InferenceState, callee::InferenceState) =
merge_effects!(interp, caller, Effects(callee))
merge_effects!(interp::AbstractInterpreter, caller::IRCode, effects::Effects) = nothing

is_effect_overridden(sv::InferenceState, effect::Symbol) = is_effect_overridden(sv.linfo, effect)
function is_effect_overridden(linfo::MethodInstance, effect::Symbol)
Expand All @@ -226,15 +234,15 @@ function InferenceResult(
return _InferenceResult(linfo, arginfo)
end

add_remark!(::AbstractInterpreter, sv::InferenceState, remark) = return
add_remark!(::AbstractInterpreter, sv::Union{InferenceState, IRCode}, remark) = return

function bail_out_toplevel_call(::AbstractInterpreter, @nospecialize(callsig), sv::InferenceState)
return sv.restrict_abstract_call_sites && !isdispatchtuple(callsig)
function bail_out_toplevel_call(::AbstractInterpreter, @nospecialize(callsig), sv::Union{InferenceState, IRCode})
return isa(sv, InferenceState) && sv.restrict_abstract_call_sites && !isdispatchtuple(callsig)
end
function bail_out_call(::AbstractInterpreter, @nospecialize(rt), sv::InferenceState)
function bail_out_call(::AbstractInterpreter, @nospecialize(rt), sv::Union{InferenceState, IRCode})
return rt === Any
end
function bail_out_apply(::AbstractInterpreter, @nospecialize(rt), sv::InferenceState)
function bail_out_apply(::AbstractInterpreter, @nospecialize(rt), sv::Union{InferenceState, IRCode})
return rt === Any
end

Expand Down
3 changes: 2 additions & 1 deletion base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ import Core.Compiler: # Core.Compiler specific definitions
isbitstype, isexpr, is_meta_expr_head, println, widenconst, argextype, singleton_type,
fieldcount_noerror, try_compute_field, try_compute_fieldidx, hasintersect, ⊑,
intrinsic_nothrow, array_builtin_common_typecheck, arrayset_typecheck,
setfield!_nothrow, alloc_array_ndims, check_effect_free!
setfield!_nothrow, alloc_array_ndims, stmt_effect_free, check_effect_free!,
SemiConcreteResult

include(x) = _TOP_MOD.include(@__MODULE__, x)
if _TOP_MOD === Core.Compiler
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/EscapeAnalysis/interprocedural.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Core.Compiler:
call_sig, argtypes_to_type, is_builtin, is_return_type, istopfunction, validate_sparams,
specialize_method, invoke_rewrite

const Linfo = Union{MethodInstance,InferenceResult}
const Linfo = Union{MethodInstance,InferenceResult,SemiConcreteResult}
struct CallInfo
linfos::Vector{Linfo}
nothrow::Bool
Expand Down
1 change: 1 addition & 0 deletions base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ include("compiler/ssair/verify.jl")
include("compiler/ssair/legacy.jl")
include("compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl")
include("compiler/ssair/passes.jl")
include("compiler/ssair/irinterp.jl")
11 changes: 11 additions & 0 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,8 @@ function compute_inlining_cases(info::ConstCallInfo,
push!(cases, InliningCase(result.mi.specTypes, case))
elseif isa(result, ConstPropResult)
handled_all_cases &= handle_const_prop_result!(result, argtypes, flag, state, cases, #=allow_abstract=#true)
elseif isa(result, SemiConcreteResult)
handled_all_cases &= handle_semi_concrete_result!(result, cases, #=allow_abstract=#true)
else
@assert result === nothing
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#false)
Expand Down Expand Up @@ -1434,6 +1436,15 @@ function handle_const_prop_result!(
return true
end

function handle_semi_concrete_result!(result::SemiConcreteResult, cases::Vector{InliningCase}, allow_abstract::Bool = false)
mi = result.mi
spec_types = mi.specTypes
allow_abstract || isdispatchtuple(spec_types) || return false
validate_sparams(mi.sparam_vals) || return false
push!(cases, InliningCase(spec_types, InliningTodo(mi, result.ir, result.effects)))
return true
end

function concrete_result_item(result::ConcreteResult, state::InliningState)
if !isdefined(result, :result) || !is_inlineable_constant(result.result)
case = compileable_specialization(state.et, result.mi, result.effects)
Expand Down
11 changes: 9 additions & 2 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1050,15 +1050,22 @@ function renumber_ssa2!(@nospecialize(stmt), ssanums::Vector{Any}, used_ssas::Ve
end

# Used in inlining before we start compacting - Only works at the CFG level
function kill_edge!(bbs::Vector{BasicBlock}, from::Int, to::Int)
function kill_edge!(bbs::Vector{BasicBlock}, from::Int, to::Int, callback=nothing)
preds, succs = bbs[to].preds, bbs[from].succs
deleteat!(preds, findfirst(x->x === from, preds)::Int)
deleteat!(succs, findfirst(x->x === to, succs)::Int)
if length(preds) == 0
for succ in copy(bbs[to].succs)
kill_edge!(bbs, to, succ)
kill_edge!(bbs, to, succ, callback)
end
end
if callback !== nothing
callback(from, to)
end
end

function kill_edge!(ir::IRCode, from::Int, to::Int, callback=nothing)
kill_edge!(ir.cfg.blocks, from, to, callback)
end

# N.B.: from and to are non-renamed indices
Expand Down
Loading