Skip to content

Commit

Permalink
optimizer: run SROA multiple times to handle more nested loads
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Dec 7, 2021
1 parent 30fe8cc commit fb5e715
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 35 deletions.
105 changes: 80 additions & 25 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,32 @@ struct LiftedValue
end
const LiftedLeaves = IdDict{Any, Union{Nothing,LiftedValue}}

# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
# which can be very large sometimes, and program counters in question are often very sparse
const SPCSet = IdSet{Int}

mutable struct NestedLoads
maybe::Union{Nothing,SPCSet}
NestedLoads() = new(nothing)
end
function record_nested_load!(nested_loads::NestedLoads, pc::Int)
maybe = nested_loads.maybe
maybe === nothing && (maybe = nested_loads.maybe = SPCSet())
push!(maybe::SPCSet, pc)
end
function is_nested_load(nested_loads::NestedLoads, pc::Int)
maybe = nested_loads.maybe
maybe === nothing && return false
return pc in maybe::SPCSet
end

# try to compute lifted values that can replace `getfield(x, field)` call
# where `x` is an immutable struct that are defined at any of `leaves`
function lift_leaves(compact::IncrementalCompact,
@nospecialize(result_t), field::Int, leaves::Vector{Any})
function lift_leaves!(compact::IncrementalCompact, leaves::Vector{Any},
@nospecialize(result_t), field::Int, nested_loads::NestedLoads)
# For every leaf, the lifted value
lifted_leaves = LiftedLeaves()
maybe_undef = false
local maybe_undef = false
for leaf in leaves
cache_key = leaf
if isa(leaf, AnySSAValue)
Expand Down Expand Up @@ -382,11 +401,19 @@ function lift_leaves(compact::IncrementalCompact,
ocleaf = simple_walk(compact, ocleaf)
end
ocdef, _ = walk_to_def(compact, ocleaf)
if isexpr(ocdef, :new_opaque_closure) && isa(field, Int) && 1 field length(ocdef.args)-5
if isexpr(ocdef, :new_opaque_closure) && 1 field length(ocdef.args)-5
lift_arg!(compact, leaf, cache_key, ocdef, 5+field, lifted_leaves)
continue
end
return nothing
elseif isa(def, Expr) && is_known_call(def, getfield, compact)
if isa(leaf, SSAValue)
struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, def.args[2])))
if ismutabletype(struct_typ)
record_nested_load!(nested_loads, leaf.id)
end
end
return nothing
else
typ = compact_exprtype(compact, leaf)
if !isa(typ, Const)
Expand Down Expand Up @@ -611,10 +638,6 @@ function perform_lifting!(compact::IncrementalCompact,
return stmt_val # N.B. should never happen
end

# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
# which can be very large sometimes, and program counters in question are often very sparse
const SPCSet = IdSet{Int}

"""
sroa_pass!(ir::IRCode) -> newir::IRCode
Expand All @@ -633,10 +656,11 @@ its argument).
In a case when all usages are fully eliminated, `struct` allocation may also be erased as
a result of succeeding dead code elimination.
"""
function sroa_pass!(ir::IRCode)
function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
compact = IncrementalCompact(ir)
defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
nested_loads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Immutable`
for ((_, idx), stmt) in compact
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
isa(stmt, Expr) || continue
Expand Down Expand Up @@ -691,7 +715,9 @@ function sroa_pass!(ir::IRCode)
if defuses === nothing
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
end
mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse()))
mid, defuse = get!(defuses, defidx) do
SPCSet(), SSADefUse()
end
push!(defuse.ccall_preserve_uses, idx)
union!(mid, intermediaries)
end
Expand All @@ -704,16 +730,17 @@ function sroa_pass!(ir::IRCode)
compact[idx] = new_expr
end
continue
# TODO: This isn't the best place to put these
elseif is_known_call(stmt, typeassert, compact)
canonicalize_typeassert!(compact, idx, stmt)
continue
elseif is_known_call(stmt, (===), compact)
lift_comparison!(compact, idx, stmt, lifting_cache)
continue
# elseif is_known_call(stmt, isa, compact)
# TODO do a similar optimization as `lift_comparison!` for `===`
else
if optional_opts
# TODO: This isn't the best place to put these
if is_known_call(stmt, typeassert, compact)
canonicalize_typeassert!(compact, idx, stmt)
elseif is_known_call(stmt, (===), compact)
lift_comparison!(compact, idx, stmt, lifting_cache)
# elseif is_known_call(stmt, isa, compact)
# TODO do a similar optimization as `lift_comparison!` for `===`
end
end
continue
end

Expand Down Expand Up @@ -749,7 +776,9 @@ function sroa_pass!(ir::IRCode)
if defuses === nothing
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
end
mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse()))
mid, defuse = get!(defuses, def.id) do
SPCSet(), SSADefUse()
end
if is_setfield
push!(defuse.defs, idx)
else
Expand All @@ -771,7 +800,7 @@ function sroa_pass!(ir::IRCode)
isempty(leaves) && continue

result_t = compact_exprtype(compact, SSAValue(idx))
lifted_result = lift_leaves(compact, result_t, field, leaves)
lifted_result = lift_leaves!(compact, leaves, result_t, field, nested_loads)
lifted_result === nothing && continue
lifted_leaves, any_undef = lifted_result

Expand Down Expand Up @@ -807,20 +836,23 @@ function sroa_pass!(ir::IRCode)
used_ssas = copy(compact.used_ssas)
simple_dce!(compact, (x::SSAValue) -> used_ssas[x.id] -= 1)
ir = complete(compact)
sroa_mutables!(ir, defuses, used_ssas)
return ir
return sroa_mutables!(ir, defuses, used_ssas, nested_loads)
else
simple_dce!(compact)
return complete(compact)
end
end

function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int})
function sroa_mutables!(ir::IRCode,
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int},
nested_loads::NestedLoads)
# Compute domtree, needed below, now that we have finished compacting the IR.
# This needs to be after we iterate through the IR with `IncrementalCompact`
# because removing dead blocks can invalidate the domtree.
@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)

nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
local any_eliminated = any_meliminated = false
for (idx, (intermediaries, defuse)) in defuses
intermediaries = collect(intermediaries)
# Check if there are any uses we did not account for. If so, the variable
Expand All @@ -836,7 +868,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
nleaves == nuses_total || continue
# Find the type for this allocation
defexpr = ir[SSAValue(idx)]
isexpr(defexpr, :new) || continue
isa(defexpr, Expr) || continue
if !isexpr(defexpr, :new)
if is_known_call(defexpr, getfield, ir)
val = defexpr.args[2]
if isa(val, SSAValue)
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
if ismutabletype(struct_typ)
record_nested_load!(nested_mloads, idx)
end
end
end
continue
end
newidx = idx
typ = ir.stmts[newidx][:type]
if isa(typ, UnionAll)
Expand Down Expand Up @@ -900,6 +944,12 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
# Now go through all uses and rewrite them
for stmt in du.uses
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
if !any_eliminated
any_eliminated |= is_nested_load(nested_loads, stmt)
end
if !any_meliminated
any_meliminated |= is_nested_load(nested_mloads, stmt)
end
end
if !isbitstype(ftyp)
if preserve_uses !== nothing
Expand Down Expand Up @@ -938,6 +988,11 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse

@label skip
end
if any_eliminated || any_meliminated
return sroa_pass!(compact!(ir), false)
else
return ir
end
end

"""
Expand Down
44 changes: 34 additions & 10 deletions test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[
struct ImmutableXYZ; x; y; z; end
mutable struct MutableXYZ; x; y; z; end

struct ImmutableOuter{T}; x::T; y::T; z::T; end
mutable struct MutableOuter{T}; x::T; y::T; z::T; end

# should optimize away very basic cases
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = ImmutableXYZ(x, y, z)
Expand Down Expand Up @@ -198,9 +201,8 @@ let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y
@test any(isnew, src.code)
end

# should include a simple alias analysis
struct ImmutableOuter{T}; x::T; y::T; z::T; end
mutable struct MutableOuter{T}; x::T; y::T; z::T; end
# alias analysis
# --------------
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = ImmutableXYZ(x, y, z)
outer = ImmutableOuter(xyz, xyz, xyz)
Expand All @@ -227,9 +229,11 @@ let src = code_typed1((Any,Any,Any)) do x, y, z
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
end
end

# FIXME our analysis isn't yet so powerful at this moment: may be unable to handle nested objects well
# OK: mutable(immutable(...)) case
# FIXME? in order to handle nested mutable `getfield` calls, we run SROA iteratively until
# any nested mutable `getfield` calls become no longer eliminatable:
# it's probably not the most efficient option and we may want to introduce some sort of
# alias analysis and eliminates all the loads at once.
# mutable(immutable(...)) case
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = MutableXYZ(x, y, z)
t = (xyz,)
Expand Down Expand Up @@ -260,21 +264,41 @@ let # this is a simple end to end test case, which demonstrates allocation elimi
# compiled code for `simple_sroa`, otherwise everything can be folded even without SROA
@test @allocated(simple_sroa(s)) == 0
end
# FIXME: immutable(mutable(...)) case
# immutable(mutable(...)) case
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = ImmutableXYZ(x, y, z)
outer = MutableOuter(xyz, xyz, xyz)
outer.x.x, outer.y.y, outer.z.z
end
@test_broken !any(isnew, src.code)
@test !any(isnew, src.code)
@test any(src.code) do @nospecialize x
iscall((src, tuple), x) &&
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
end
end
# FIXME: mutable(mutable(...)) case
# mutable(mutable(...)) case
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = MutableXYZ(x, y, z)
outer = MutableOuter(xyz, xyz, xyz)
outer.x.x, outer.y.y, outer.z.z
end
@test_broken !any(isnew, src.code)
@test !any(isnew, src.code)
@test any(src.code) do @nospecialize x
iscall((src, tuple), x) &&
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
end
end
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = MutableXYZ(x, y, z)
inner = MutableOuter(xyz, xyz, xyz)
outer = MutableOuter(inner, inner, inner)
outer.x.x.x, outer.y.y.y, outer.z.z.z
end
@test !any(isnew, src.code)
@test any(src.code) do @nospecialize x
iscall((src, tuple), x) &&
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
end
end

# should work nicely with inlining to optimize away a complicated case
Expand Down

0 comments on commit fb5e715

Please sign in to comment.