Skip to content

Commit

Permalink
Allow inlining methods with unmatched type parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Atol committed Apr 22, 2022
1 parent 3cff21e commit c3da8e6
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 73 deletions.
69 changes: 42 additions & 27 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,17 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
boundscheck = :off
end
end
if !validate_sparams(sparam_vals)
sparam_vals = insert_node_here!(compact,
effect_free(NewInstruction(Expr(:call, Core._compute_sparams, item.mi.def, argexprs...), SimpleVector, topline)))
end
# If the iterator already moved on to the next basic block,
# temporarily re-open in again.
local return_value
sig = def.sig
# Special case inlining that maintains the current basic block if there's only one BB in the target
new_new_offset = length(compact.new_new_nodes)
late_fixup_offset = length(compact.late_fixup)
if spec.linear_inline_eligible
#compact[idx] = nothing
inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx)
Expand All @@ -372,7 +378,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
# face of rename_arguments! mutating in place - should figure out
# something better eventually.
inline_compact[idx′] = nothing
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact)
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact)
if isa(stmt′, ReturnNode)
val = stmt′.val
return_value = SSAValue(idx′)
Expand All @@ -383,7 +389,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
end
inline_compact[idx′] = stmt′
end
just_fixup!(inline_compact)
just_fixup!(inline_compact, new_new_offset, late_fixup_offset)
compact.result_idx = inline_compact.result_idx
else
bb_offset, post_bb_id = popfirst!(todo_bbs)
Expand All @@ -397,7 +403,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx)
for ((_, idx′), stmt′) in inline_compact
inline_compact[idx′] = nothing
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact)
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact)
if isa(stmt′, ReturnNode)
if isdefined(stmt′, :val)
val = stmt′.val
Expand Down Expand Up @@ -428,7 +434,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
end
inline_compact[idx′] = stmt′
end
just_fixup!(inline_compact)
just_fixup!(inline_compact, new_new_offset, late_fixup_offset)
compact.result_idx = inline_compact.result_idx
compact.active_result_bb = inline_compact.active_result_bb
if length(pn.edges) == 1
Expand Down Expand Up @@ -896,8 +902,7 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
end
end

# Bail out if any static parameters are left as TypeVar
validate_sparams(match.sparams) || return nothing
#validate_sparams(match.sparams) || return nothing

et = state.et

Expand Down Expand Up @@ -1104,7 +1109,7 @@ function inline_invoke!(
argtypes = invoke_rewrite(sig.argtypes)
if isa(result, ConstPropResult)
(; mi) = item = InliningTodo(result.result, argtypes)
validate_sparams(mi.sparam_vals) || return nothing
# validate_sparams(mi.sparam_vals) || return nothing
if argtypes_to_type(argtypes) <: mi.def.sig
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
handle_single_case!(ir, idx, stmt, item, todo, state.params, true)
Expand Down Expand Up @@ -1327,7 +1332,7 @@ function handle_const_prop_result!(
(; mi) = item = InliningTodo(result.result, argtypes)
spec_types = mi.specTypes
allow_abstract || isdispatchtuple(spec_types) || return false
validate_sparams(mi.sparam_vals) || return false
#validate_sparams(mi.sparam_vals) || return false
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
item === nothing && return false
push!(cases, InliningCase(spec_types, item))
Expand Down Expand Up @@ -1365,7 +1370,6 @@ function handle_const_opaque_closure_call!(
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
item = InliningTodo(result.result, sig.argtypes)
isdispatchtuple(item.mi.specTypes) || return
validate_sparams(item.mi.sparam_vals) || return
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
handle_single_case!(ir, idx, stmt, item, todo, state.params)
return nothing
Expand Down Expand Up @@ -1545,38 +1549,49 @@ function late_inline_special_case!(
end

function ssa_substitute!(idx::Int, @nospecialize(val), arg_replacements::Vector{Any},
@nospecialize(spsig), spvals::SimpleVector,
@nospecialize(spsig), spvals::Union{SimpleVector, SSAValue},
linetable_offset::Int32, boundscheck::Symbol, compact::IncrementalCompact)
compact.result[idx][:flag] &= ~IR_FLAG_INBOUNDS
compact.result[idx][:line] += linetable_offset
return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck)
return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck, compact, idx)
end

function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
@nospecialize(spsig), spvals::SimpleVector, boundscheck::Symbol)
@nospecialize(spsig), spvals::Union{SimpleVector, SSAValue}, boundscheck::Symbol,
compact::IncrementalCompact, idx::Int)
if isa(val, Argument)
return arg_replacements[val.n]
end
if isa(val, Expr)
e = val::Expr
head = e.head
if head === :static_parameter
return quoted(spvals[e.args[1]::Int])
if isa(spvals, SimpleVector)
return quoted(spvals[e.args[1]::Int])
else
ret = insert_node!(compact, SSAValue(idx),
effect_free(NewInstruction(Expr(:call, Core._svec_ref, false, spvals, e.args[1]), Any)))
return ret
end
elseif head === :cfunction
@assert !isa(spsig, UnionAll) || !isempty(spvals)
e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals)
e.args[4] = svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
for argt in e.args[4]::SimpleVector ]...)
if isa(spvals, SimpleVector)
@assert !isa(spsig, UnionAll) || !isempty(spvals)
e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals)
e.args[4] = svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
for argt in e.args[4]::SimpleVector ]...)
end
elseif head === :foreigncall
@assert !isa(spsig, UnionAll) || !isempty(spvals)
for i = 1:length(e.args)
if i == 2
e.args[2] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[2], spsig, spvals)
elseif i == 3
e.args[3] = svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
for argt in e.args[3]::SimpleVector ]...)
if isa(spvals, SimpleVector)
@assert !isa(spsig, UnionAll) || !isempty(spvals)
for i = 1:length(e.args)
if i == 2
e.args[2] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[2], spsig, spvals)
elseif i == 3
e.args[3] = svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
for argt in e.args[3]::SimpleVector ]...)
end
end
end
elseif head === :boundscheck
Expand All @@ -1591,7 +1606,7 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
end
urs = userefs(val)
for op in urs
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck)
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck, compact, idx)
end
return urs[]
end
115 changes: 77 additions & 38 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -631,16 +631,13 @@ mutable struct IncrementalCompact
perm = my_sortperm(Int[code.new_nodes.info[i].pos for i in 1:length(code.new_nodes)])
new_len = length(code.stmts) + length(code.new_nodes)
ssa_rename = Any[SSAValue(i) for i = 1:new_len]
new_new_used_ssas = Vector{Int}()
late_fixup = Vector{Int}()
bb_rename = Vector{Int}()
new_new_nodes = NewNodeStream()
pending_nodes = NewNodeStream()
pending_perm = Int[]
return new(code, parent.result,
parent.result_bbs, ssa_rename, bb_rename, bb_rename, parent.used_ssas,
late_fixup, perm, 1,
new_new_nodes, new_new_used_ssas, pending_nodes, pending_perm,
parent.late_fixup, perm, 1,
parent.new_new_nodes, parent.new_new_used_ssas, pending_nodes, pending_perm,
1, result_offset, parent.active_result_bb, false, false, false)
end
end
Expand Down Expand Up @@ -1469,62 +1466,104 @@ function maybe_erase_unused!(
return false
end

function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{Any})
struct FixedNode
node::Any
needs_fixup::Bool
FixedNode(@nospecialize(node), needs_fixup::Bool) = new(node, needs_fixup)
end

function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{Any}, reify_new_nodes::Bool)
values = Vector{Any}(undef, length(old_values))
needs_fixup = false
for i = 1:length(old_values)
isassigned(old_values, i) || continue
val = old_values[i]
if isa(val, Union{OldSSAValue, NewSSAValue})
val = fixup_node(compact, val)
if isa(val, OldSSAValue)
val = compact.ssa_rename[val.id]
if isa(val, SSAValue)
compact.used_ssas[val.id] += 1
end
elseif isa(val, NewSSAValue)
if reify_new_nodes
val = SSAValue(length(compact.result) + val.id)
else
needs_fixup = true
end
end
values[i] = val
end
values
return FixedNode(values, needs_fixup)
end

function fixup_node(compact::IncrementalCompact, @nospecialize(stmt))
function fixup_node(compact::IncrementalCompact, @nospecialize(stmt), reify_new_nodes::Bool)
if isa(stmt, PhiNode)
return PhiNode(stmt.edges, fixup_phinode_values!(compact, stmt.values))
(;node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes)
return FixedNode(PhiNode(stmt.edges, node), needs_fixup)
elseif isa(stmt, PhiCNode)
return PhiCNode(fixup_phinode_values!(compact, stmt.values))
(;node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes)
return FixedNode(PhiCNode(node), needs_fixup)
elseif isa(stmt, NewSSAValue)
return SSAValue(length(compact.result) + stmt.id)
elseif isa(stmt, OldSSAValue)
val = compact.ssa_rename[stmt.id]
if isa(val, SSAValue)
# If `val.id` is greater than the length of `compact.result` or
# `compact.used_ssas`, this SSA value is in `new_new_nodes`, so
# don't count the use
compact.used_ssas[val.id] += 1
if reify_new_nodes
return FixedNode(SSAValue(length(compact.result) + stmt.id), false)
else
return FixedNode(stmt, true)
end
return val
elseif isa(stmt, OldSSAValue)
return FixedNode(compact.ssa_rename[stmt.id], false)
else
urs = userefs(stmt)
needs_fixup = false
for ur in urs
val = ur[]
if isa(val, Union{NewSSAValue, OldSSAValue})
ur[] = fixup_node(compact, val)
if isa(val, NewSSAValue)
if reify_new_nodes
val = SSAValue(length(compact.result) + val.id)
else
needs_fixup = true
end
elseif isa(val, OldSSAValue)
val = compact.ssa_rename[val.id]
end
if isa(val, SSAValue) && val.id <= length(compact.used_ssas)
# If `val.id` is greater than the length of `compact.result` or
# `compact.used_ssas`, this SSA value is in `new_new_nodes`, so
# don't count the use
compact.used_ssas[val.id] += 1
end
ur[] = val
end
return urs[]
return FixedNode(urs[], needs_fixup)
end
end

function just_fixup!(compact::IncrementalCompact)
resize!(compact.used_ssas, length(compact.result))
append!(compact.used_ssas, compact.new_new_used_ssas)
empty!(compact.new_new_used_ssas)
for idx in compact.late_fixup
function just_fixup!(compact::IncrementalCompact, new_new_nodes_offset::Union{Int, Nothing} = nothing, late_fixup_offset::Union{Int, Nothing}=nothing)
if new_new_nodes_offset === late_fixup_offset === nothing # only do this appending in non_dce_finish!
resize!(compact.used_ssas, length(compact.result))
append!(compact.used_ssas, compact.new_new_used_ssas)
empty!(compact.new_new_used_ssas)
end
off = late_fixup_offset === nothing ? 1 : (late_fixup_offset+1)
set_off = off
for i in off:length(compact.late_fixup)
idx = compact.late_fixup[i]
stmt = compact.result[idx][:inst]
new_stmt = fixup_node(compact, stmt)
(stmt === new_stmt) || (compact.result[idx][:inst] = new_stmt)
end
for idx in 1:length(compact.new_new_nodes)
node = compact.new_new_nodes.stmts[idx]
stmt = node[:inst]
new_stmt = fixup_node(compact, stmt)
if new_stmt !== stmt
node[:inst] = new_stmt
(;node, needs_fixup) = fixup_node(compact, stmt, late_fixup_offset === nothing)
(stmt === node) || (compact.result[idx][:inst] = node)
if needs_fixup
compact.late_fixup[set_off] = idx
set_off += 1
end
end
if late_fixup_offset !== nothing
resize!(compact.late_fixup, set_off-1)
end
off = new_new_nodes_offset === nothing ? 1 : (new_new_nodes_offset+1)
for idx in off:length(compact.new_new_nodes)
new_node = compact.new_new_nodes.stmts[idx]
stmt = new_node[:inst]
(;node) = fixup_node(compact, stmt, late_fixup_offset === nothing)
if node !== stmt
new_node[:inst] = node
end
end
end
Expand Down
Loading

0 comments on commit c3da8e6

Please sign in to comment.