Skip to content

Commit

Permalink
Semi-concrete IR interpreter
Browse files Browse the repository at this point in the history
  • Loading branch information
Keno authored and Ian Atol committed Sep 1, 2022
1 parent 71131c9 commit 3557af8
Show file tree
Hide file tree
Showing 18 changed files with 709 additions and 143 deletions.
289 changes: 191 additions & 98 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ include("compiler/stmtinfo.jl")

include("compiler/abstractinterpretation.jl")
include("compiler/typeinfer.jl")
include("compiler/optimize.jl") # TODO: break this up further + extract utilities
include("compiler/optimize.jl")

# required for bootstrap because sort.jl uses extrema
# to decide whether to dispatch to counting sort.
Expand Down
55 changes: 31 additions & 24 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,36 @@ function is_forwardable_argtype(@nospecialize x)
isa(x, PartialOpaque)
end

function va_process_argtypes(given_argtypes::Vector{Any}, mi::MethodInstance,
condargs::Union{Vector{Tuple{Int,Int}}, Nothing}=nothing)
isva = mi.def.isva
nargs = Int(mi.def.nargs)
if isva || isvarargtype(given_argtypes[end])
isva_given_argtypes = Vector{Any}(undef, nargs)
for i = 1:(nargs - isva)
isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
end
if isva
if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end])
last = length(given_argtypes)
else
last = nargs
end
isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[last:end])
# invalidate `Conditional` imposed on varargs
if condargs !== nothing
for (slotid, i) in condargs
if slotid last
isva_given_argtypes[i] = widenconditional(isva_given_argtypes[i])
end
end
end
end
return isva_given_argtypes
end
return given_argtypes
end

# In theory, there could be a `cache` containing a matching `InferenceResult`
# for the provided `linfo` and `given_argtypes`. The purpose of this function is
# to return a valid value for `cache_lookup(linfo, argtypes, cache).argtypes`,
Expand Down Expand Up @@ -56,30 +86,7 @@ function matching_cache_argtypes(
end
given_argtypes[i] = widenconditional(argtype)
end
isva = def.isva
if isva || isvarargtype(given_argtypes[end])
isva_given_argtypes = Vector{Any}(undef, nargs)
for i = 1:(nargs - isva)
isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
end
if isva
if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end])
last = length(given_argtypes)
else
last = nargs
end
isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[last:end])
# invalidate `Conditional` imposed on varargs
if condargs !== nothing
for (slotid, i) in condargs
if slotid last
isva_given_argtypes[i] = widenconditional(isva_given_argtypes[i])
end
end
end
end
given_argtypes = isva_given_argtypes
end
given_argtypes = va_process_argtypes(given_argtypes, linfo, condargs)
@assert length(given_argtypes) == nargs
for i in 1:nargs
given_argtype = given_argtypes[i]
Expand Down
18 changes: 13 additions & 5 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ function in(idx::Int, bsbmp::BitSetBoundedMinPrioritySet)
return idx in bsbmp.elems
end

function append!(bsbmp::BitSetBoundedMinPrioritySet, itr)
for val in itr
push!(bsbmp, val)
end
end

mutable struct InferenceState
#= information about this method instance =#
linfo::MethodInstance
Expand Down Expand Up @@ -209,8 +215,10 @@ Effects(state::InferenceState) = state.ipo_effects
function merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::Effects)
caller.ipo_effects = merge_effects(caller.ipo_effects, effects)
end

merge_effects!(interp::AbstractInterpreter, caller::InferenceState, callee::InferenceState) =
merge_effects!(interp, caller, Effects(callee))
merge_effects!(interp::AbstractInterpreter, caller::IRCode, effects::Effects) = nothing

is_effect_overridden(sv::InferenceState, effect::Symbol) = is_effect_overridden(sv.linfo, effect)
function is_effect_overridden(linfo::MethodInstance, effect::Symbol)
Expand All @@ -226,15 +234,15 @@ function InferenceResult(
return _InferenceResult(linfo, arginfo)
end

add_remark!(::AbstractInterpreter, sv::InferenceState, remark) = return
add_remark!(::AbstractInterpreter, sv::Union{InferenceState, IRCode}, remark) = return

function bail_out_toplevel_call(::AbstractInterpreter, @nospecialize(callsig), sv::InferenceState)
return sv.restrict_abstract_call_sites && !isdispatchtuple(callsig)
function bail_out_toplevel_call(::AbstractInterpreter, @nospecialize(callsig), sv::Union{InferenceState, IRCode})
return isa(sv, InferenceState) && sv.restrict_abstract_call_sites && !isdispatchtuple(callsig)
end
function bail_out_call(::AbstractInterpreter, @nospecialize(rt), sv::InferenceState)
function bail_out_call(::AbstractInterpreter, @nospecialize(rt), sv::Union{InferenceState, IRCode})
return rt === Any
end
function bail_out_apply(::AbstractInterpreter, @nospecialize(rt), sv::InferenceState)
function bail_out_apply(::AbstractInterpreter, @nospecialize(rt), sv::Union{InferenceState, IRCode})
return rt === Any
end

Expand Down
3 changes: 2 additions & 1 deletion base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ import Core.Compiler: # Core.Compiler specific definitions
isbitstype, isexpr, is_meta_expr_head, println, widenconst, argextype, singleton_type,
fieldcount_noerror, try_compute_field, try_compute_fieldidx, hasintersect, ,
intrinsic_nothrow, array_builtin_common_typecheck, arrayset_typecheck,
setfield!_nothrow, alloc_array_ndims, check_effect_free!
setfield!_nothrow, alloc_array_ndims, stmt_effect_free, check_effect_free!,
SemiConcreteResult

include(x) = _TOP_MOD.include(@__MODULE__, x)
if _TOP_MOD === Core.Compiler
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/EscapeAnalysis/interprocedural.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Core.Compiler:
call_sig, argtypes_to_type, is_builtin, is_return_type, istopfunction, validate_sparams,
specialize_method, invoke_rewrite

const Linfo = Union{MethodInstance,InferenceResult}
const Linfo = Union{MethodInstance,InferenceResult,SemiConcreteResult}
struct CallInfo
linfos::Vector{Linfo}
nothrow::Bool
Expand Down
1 change: 1 addition & 0 deletions base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ include("compiler/ssair/verify.jl")
include("compiler/ssair/legacy.jl")
include("compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl")
include("compiler/ssair/passes.jl")
include("compiler/ssair/irinterp.jl")
11 changes: 11 additions & 0 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,8 @@ function compute_inlining_cases(info::ConstCallInfo,
push!(cases, InliningCase(result.mi.specTypes, case))
elseif isa(result, ConstPropResult)
handled_all_cases &= handle_const_prop_result!(result, argtypes, flag, state, cases, #=allow_abstract=#true)
elseif isa(result, SemiConcreteResult)
handled_all_cases &= handle_semi_concrete_result!(result, cases, #=allow_abstract=#true)
else
@assert result === nothing
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#false)
Expand Down Expand Up @@ -1434,6 +1436,15 @@ function handle_const_prop_result!(
return true
end

function handle_semi_concrete_result!(result::SemiConcreteResult, cases::Vector{InliningCase}, allow_abstract::Bool = false)
mi = result.mi
spec_types = mi.specTypes
allow_abstract || isdispatchtuple(spec_types) || return false
validate_sparams(mi.sparam_vals) || return false
push!(cases, InliningCase(spec_types, InliningTodo(mi, result.ir, result.effects)))
return true
end

function concrete_result_item(result::ConcreteResult, state::InliningState)
if !isdefined(result, :result) || !is_inlineable_constant(result.result)
case = compileable_specialization(state.et, result.mi, result.effects)
Expand Down
11 changes: 9 additions & 2 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1050,15 +1050,22 @@ function renumber_ssa2!(@nospecialize(stmt), ssanums::Vector{Any}, used_ssas::Ve
end

# Used in inlining before we start compacting - Only works at the CFG level
function kill_edge!(bbs::Vector{BasicBlock}, from::Int, to::Int)
function kill_edge!(bbs::Vector{BasicBlock}, from::Int, to::Int, callback=nothing)
preds, succs = bbs[to].preds, bbs[from].succs
deleteat!(preds, findfirst(x->x === from, preds)::Int)
deleteat!(succs, findfirst(x->x === to, succs)::Int)
if length(preds) == 0
for succ in copy(bbs[to].succs)
kill_edge!(bbs, to, succ)
kill_edge!(bbs, to, succ, callback)
end
end
if callback !== nothing
callback(from, to)
end
end

function kill_edge!(ir::IRCode, from::Int, to::Int, callback=nothing)
kill_edge!(ir.cfg.blocks, from, to, callback)
end

# N.B.: from and to are non-renamed indices
Expand Down
Loading

0 comments on commit 3557af8

Please sign in to comment.