From dd6d086a7f0344161bb38138b3faa8295b8f4259 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Tue, 7 Dec 2021 14:16:38 +0900 Subject: [PATCH] optimizer: run SROA multiple times to handle more nested loads --- base/compiler/ssair/passes.jl | 123 ++++++++++++++++++++++++---------- test/compiler/irpasses.jl | 51 +++++++++++--- 2 files changed, 130 insertions(+), 44 deletions(-) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 7d33b37141fef3..f15b01373a373a 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -29,12 +29,14 @@ SSADefUse() = SSADefUse(Int[], Int[], Int[]) compute_live_ins(cfg::CFG, du::SSADefUse) = compute_live_ins(cfg, du.defs, du.uses) -function try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr) - field = stmt.args[3] +try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr) = + try_compute_field(ir, stmt.args[3]) + +function try_compute_field(ir::Union{IncrementalCompact,IRCode}, @nospecialize(field)) # fields are usually literals, handle them manually if isa(field, QuoteNode) field = field.value - elseif isa(field, Int) + elseif isa(field, Int) || isa(field, Symbol) # try to resolve other constants, e.g. global reference else field = argextype(field, ir) @@ -44,8 +46,7 @@ function try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr return nothing end end - isa(field, Union{Int, Symbol}) || return nothing - return field + return isa(field, Union{Int, Symbol}) ? field : nothing end function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr, typ::DataType) @@ -167,7 +168,7 @@ function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospec end function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), - callback = (@nospecialize(pi), @nospecialize(idx)) -> false) + callback = (@nospecialize(x), @nospecialize(idx)) -> false) while true if isa(defssa, OldSSAValue) if already_inserted(compact, defssa) @@ -335,10 +336,29 @@ struct LiftedValue end const LiftedLeaves = IdDict{Any, Union{Nothing,LiftedValue}} +# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining, +# which can be very large sometimes, and program counters in question are often very sparse +const SPCSet = IdSet{Int} + +mutable struct NestedLoads + maybe::Union{Nothing,SPCSet} + NestedLoads() = new(nothing) +end +function record_nested_load!(nested_loads::NestedLoads, pc::Int) + maybe = nested_loads.maybe + maybe === nothing && (maybe = nested_loads.maybe = SPCSet()) + push!(maybe::SPCSet, pc) +end +function is_nested_load(nested_loads::NestedLoads, pc::Int) + maybe = nested_loads.maybe + maybe === nothing && return false + return pc in maybe::SPCSet +end + # try to compute lifted values that can replace `getfield(x, field)` call # where `x` is an immutable struct that are defined at any of `leaves` -function lift_leaves(compact::IncrementalCompact, - @nospecialize(result_t), field::Int, leaves::Vector{Any}) +function lift_leaves!(compact::IncrementalCompact, leaves::Vector{Any}, + @nospecialize(result_t), field::Int, nested_loads::NestedLoads) # For every leaf, the lifted value lifted_leaves = LiftedLeaves() maybe_undef = false @@ -388,11 +408,19 @@ function lift_leaves(compact::IncrementalCompact, ocleaf = simple_walk(compact, ocleaf) end ocdef, _ = walk_to_def(compact, ocleaf) - if isexpr(ocdef, :new_opaque_closure) && isa(field, Int) && 1 ≤ field ≤ length(ocdef.args)-5 + if isexpr(ocdef, :new_opaque_closure) && 1 ≤ field ≤ length(ocdef.args)-5 lift_arg!(compact, leaf, cache_key, ocdef, 5+field, lifted_leaves) continue end return nothing + elseif is_known_call(def, getfield, compact) + if isa(leaf, SSAValue) + struct_typ = unwrap_unionall(widenconst(argextype(def.args[2], compact))) + if ismutabletype(struct_typ) + record_nested_load!(nested_loads, leaf.id) + end + end + return nothing else typ = argextype(leaf, compact) if !isa(typ, Const) @@ -586,7 +614,7 @@ function perform_lifting!(compact::IncrementalCompact, end val = lifted_val.x if isa(val, AnySSAValue) - callback = (@nospecialize(pi), @nospecialize(idx)) -> true + callback = (@nospecialize(x), @nospecialize(idx)) -> true val = simple_walk(compact, val, callback) end push!(new_node.values, val) @@ -617,10 +645,6 @@ function perform_lifting!(compact::IncrementalCompact, return stmt_val # N.B. should never happen end -# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining, -# which can be very large sometimes, and program counters in question are often very sparse -const SPCSet = IdSet{Int} - """ sroa_pass!(ir::IRCode) -> newir::IRCode @@ -639,10 +663,11 @@ its argument). In a case when all usages are fully eliminated, `struct` allocation may also be erased as a result of succeeding dead code elimination. """ -function sroa_pass!(ir::IRCode) +function sroa_pass!(ir::IRCode, optional_opts::Bool = true) compact = IncrementalCompact(ir) defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}() + nested_loads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Immutable` for ((_, idx), stmt) in compact # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement) isa(stmt, Expr) || continue @@ -670,7 +695,7 @@ function sroa_pass!(ir::IRCode) preserved_arg = stmt.args[pidx] isa(preserved_arg, SSAValue) || continue let intermediaries = SPCSet() - callback = function (@nospecialize(pi), @nospecialize(ssa)) + callback = function (@nospecialize(x), @nospecialize(ssa)) push!(intermediaries, ssa.id) return false end @@ -698,7 +723,9 @@ function sroa_pass!(ir::IRCode) if defuses === nothing defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() end - mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse())) + mid, defuse = get!(defuses, defidx) do + SPCSet(), SSADefUse() + end push!(defuse.ccall_preserve_uses, idx) union!(mid, intermediaries) end @@ -708,16 +735,17 @@ function sroa_pass!(ir::IRCode) compact[idx] = form_new_preserves(stmt, preserved, new_preserves) end continue - # TODO: This isn't the best place to put these - elseif is_known_call(stmt, typeassert, compact) - canonicalize_typeassert!(compact, idx, stmt) - continue - elseif is_known_call(stmt, (===), compact) - lift_comparison!(compact, idx, stmt, lifting_cache) - continue - # elseif is_known_call(stmt, isa, compact) - # TODO do a similar optimization as `lift_comparison!` for `===` else + if optional_opts + # TODO: This isn't the best place to put these + if is_known_call(stmt, typeassert, compact) + canonicalize_typeassert!(compact, idx, stmt) + elseif is_known_call(stmt, (===), compact) + lift_comparison!(compact, idx, stmt, lifting_cache) + # elseif is_known_call(stmt, isa, compact) + # TODO do a similar optimization as `lift_comparison!` for `===` + end + end continue end @@ -743,7 +771,7 @@ function sroa_pass!(ir::IRCode) if ismutabletype(struct_typ) isa(val, SSAValue) || continue let intermediaries = SPCSet() - callback = function (@nospecialize(pi), @nospecialize(ssa)) + callback = function (@nospecialize(x), @nospecialize(ssa)) push!(intermediaries, ssa.id) return false end @@ -753,7 +781,9 @@ function sroa_pass!(ir::IRCode) if defuses === nothing defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() end - mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse())) + mid, defuse = get!(defuses, def.id) do + SPCSet(), SSADefUse() + end if is_setfield push!(defuse.defs, idx) else @@ -775,7 +805,7 @@ function sroa_pass!(ir::IRCode) isempty(leaves) && continue result_t = argextype(SSAValue(idx), compact) - lifted_result = lift_leaves(compact, result_t, field, leaves) + lifted_result = lift_leaves!(compact, leaves, result_t, field, nested_loads) lifted_result === nothing && continue lifted_leaves, any_undef = lifted_result @@ -811,21 +841,25 @@ function sroa_pass!(ir::IRCode) used_ssas = copy(compact.used_ssas) simple_dce!(compact, (x::SSAValue) -> used_ssas[x.id] -= 1) ir = complete(compact) - sroa_mutables!(ir, defuses, used_ssas) - return ir + return sroa_mutables!(ir, defuses, used_ssas, nested_loads) else simple_dce!(compact) return complete(compact) end end -function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}) +function sroa_mutables!(ir::IRCode, + defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}, + nested_loads::NestedLoads) # Compute domtree, needed below, now that we have finished compacting the IR. # This needs to be after we iterate through the IR with `IncrementalCompact` # because removing dead blocks can invalidate the domtree. @timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks) - for (idx, (intermediaries, defuse)) in defuses + nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable` + local any_eliminated = false + # NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield` + for (idx, (intermediaries, defuse)) in sort!(collect(defuses); by=first, rev=true) intermediaries = collect(intermediaries) # Check if there are any uses we did not account for. If so, the variable # escapes and we cannot eliminate the allocation. This works, because we're guaranteed @@ -840,7 +874,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse nleaves == nuses_total || continue # Find the type for this allocation defexpr = ir[SSAValue(idx)] - isexpr(defexpr, :new) || continue + isa(defexpr, Expr) || continue + if !isexpr(defexpr, :new) + if is_known_call(defexpr, getfield, ir) + val = defexpr.args[2] + if isa(val, SSAValue) + struct_typ = unwrap_unionall(widenconst(argextype(val, ir))) + if ismutabletype(struct_typ) + record_nested_load!(nested_mloads, idx) + end + end + end + continue + end newidx = idx typ = ir.stmts[newidx][:type] if isa(typ, UnionAll) @@ -910,6 +956,10 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse # Now go through all uses and rewrite them for stmt in du.uses ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt) + if !any_eliminated + any_eliminated |= (is_nested_load(nested_loads, stmt) || + is_nested_load(nested_mloads, stmt)) + end end if !isbitstype(ftyp) if preserve_uses !== nothing @@ -946,6 +996,11 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse @label skip end + if any_eliminated + return sroa_pass!(compact!(ir), false) + else + return ir + end end function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any}) diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 05fb890f91848a..2d1d85e3df97fd 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -90,6 +90,9 @@ iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[ struct ImmutableXYZ; x; y; z; end mutable struct MutableXYZ; x; y; z; end +struct ImmutableOuter{T}; x::T; y::T; z::T; end +mutable struct MutableOuter{T}; x::T; y::T; z::T; end + # should optimize away very basic cases let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) @@ -198,9 +201,8 @@ let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y @test any(isnew, src.code) end -# should include a simple alias analysis -struct ImmutableOuter{T}; x::T; y::T; z::T; end -mutable struct MutableOuter{T}; x::T; y::T; z::T; end +# alias analysis +# -------------- let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) outer = ImmutableOuter(xyz, xyz, xyz) @@ -227,9 +229,11 @@ let src = code_typed1((Any,Any,Any)) do x, y, z x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)] end end - -# FIXME our analysis isn't yet so powerful at this moment: may be unable to handle nested objects well -# OK: mutable(immutable(...)) case +# FIXME? in order to handle nested mutable `getfield` calls, we run SROA iteratively until +# any nested mutable `getfield` calls become no longer eliminatable: +# it's probably not the most efficient option and we may want to introduce some sort of +# alias analysis and eliminates all the loads at once. +# mutable(immutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) t = (xyz,) @@ -260,21 +264,48 @@ let # this is a simple end to end test case, which demonstrates allocation elimi # compiled code for `simple_sroa`, otherwise everything can be folded even without SROA @test @allocated(simple_sroa(s)) == 0 end -# FIXME: immutable(mutable(...)) case +# immutable(mutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) outer = MutableOuter(xyz, xyz, xyz) outer.x.x, outer.y.y, outer.z.z end - @test_broken !any(isnew, src.code) + @test !any(isnew, src.code) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)] + end end -# FIXME: mutable(mutable(...)) case +# mutable(mutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) outer = MutableOuter(xyz, xyz, xyz) outer.x.x, outer.y.y, outer.z.z end - @test_broken !any(isnew, src.code) + @test !any(isnew, src.code) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)] + end +end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + inner = MutableOuter(xyz, xyz, xyz) + outer = MutableOuter(inner, inner, inner) + outer.x.x.x, outer.y.y.y, outer.z.z.z + end + @test !any(isnew, src.code) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)] + end +end +let # NOTE `sroa_mutables!` eliminate from innermost definitions, so that it should be able + # to fully eliminate this insanely nested example + src = code_typed1((Int,)) do x + (Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref((x))))))))))))[][][][][][][][][][] + end + @test !any(isnew, src.code) end # should work nicely with inlining to optimize away a complicated case