From fb5e71559aafba1977910460d38f6d63e3821b2d Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Tue, 7 Dec 2021 14:16:38 +0900 Subject: [PATCH] optimizer: run SROA multiple times to handle more nested loads --- base/compiler/ssair/passes.jl | 105 ++++++++++++++++++++++++++-------- test/compiler/irpasses.jl | 44 ++++++++++---- 2 files changed, 114 insertions(+), 35 deletions(-) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 08d3e59a8e3f7..5f4efac520aba 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -329,13 +329,32 @@ struct LiftedValue end const LiftedLeaves = IdDict{Any, Union{Nothing,LiftedValue}} +# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining, +# which can be very large sometimes, and program counters in question are often very sparse +const SPCSet = IdSet{Int} + +mutable struct NestedLoads + maybe::Union{Nothing,SPCSet} + NestedLoads() = new(nothing) +end +function record_nested_load!(nested_loads::NestedLoads, pc::Int) + maybe = nested_loads.maybe + maybe === nothing && (maybe = nested_loads.maybe = SPCSet()) + push!(maybe::SPCSet, pc) +end +function is_nested_load(nested_loads::NestedLoads, pc::Int) + maybe = nested_loads.maybe + maybe === nothing && return false + return pc in maybe::SPCSet +end + # try to compute lifted values that can replace `getfield(x, field)` call # where `x` is an immutable struct that are defined at any of `leaves` -function lift_leaves(compact::IncrementalCompact, - @nospecialize(result_t), field::Int, leaves::Vector{Any}) +function lift_leaves!(compact::IncrementalCompact, leaves::Vector{Any}, + @nospecialize(result_t), field::Int, nested_loads::NestedLoads) # For every leaf, the lifted value lifted_leaves = LiftedLeaves() - maybe_undef = false + local maybe_undef = false for leaf in leaves cache_key = leaf if isa(leaf, AnySSAValue) @@ -382,11 +401,19 @@ function lift_leaves(compact::IncrementalCompact, ocleaf = simple_walk(compact, ocleaf) end ocdef, _ = walk_to_def(compact, ocleaf) - if isexpr(ocdef, :new_opaque_closure) && isa(field, Int) && 1 ≤ field ≤ length(ocdef.args)-5 + if isexpr(ocdef, :new_opaque_closure) && 1 ≤ field ≤ length(ocdef.args)-5 lift_arg!(compact, leaf, cache_key, ocdef, 5+field, lifted_leaves) continue end return nothing + elseif isa(def, Expr) && is_known_call(def, getfield, compact) + if isa(leaf, SSAValue) + struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, def.args[2]))) + if ismutabletype(struct_typ) + record_nested_load!(nested_loads, leaf.id) + end + end + return nothing else typ = compact_exprtype(compact, leaf) if !isa(typ, Const) @@ -611,10 +638,6 @@ function perform_lifting!(compact::IncrementalCompact, return stmt_val # N.B. should never happen end -# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining, -# which can be very large sometimes, and program counters in question are often very sparse -const SPCSet = IdSet{Int} - """ sroa_pass!(ir::IRCode) -> newir::IRCode @@ -633,10 +656,11 @@ its argument). In a case when all usages are fully eliminated, `struct` allocation may also be erased as a result of succeeding dead code elimination. """ -function sroa_pass!(ir::IRCode) +function sroa_pass!(ir::IRCode, optional_opts::Bool = true) compact = IncrementalCompact(ir) defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}() + nested_loads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Immutable` for ((_, idx), stmt) in compact # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement) isa(stmt, Expr) || continue @@ -691,7 +715,9 @@ function sroa_pass!(ir::IRCode) if defuses === nothing defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() end - mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse())) + mid, defuse = get!(defuses, defidx) do + SPCSet(), SSADefUse() + end push!(defuse.ccall_preserve_uses, idx) union!(mid, intermediaries) end @@ -704,16 +730,17 @@ function sroa_pass!(ir::IRCode) compact[idx] = new_expr end continue - # TODO: This isn't the best place to put these - elseif is_known_call(stmt, typeassert, compact) - canonicalize_typeassert!(compact, idx, stmt) - continue - elseif is_known_call(stmt, (===), compact) - lift_comparison!(compact, idx, stmt, lifting_cache) - continue - # elseif is_known_call(stmt, isa, compact) - # TODO do a similar optimization as `lift_comparison!` for `===` else + if optional_opts + # TODO: This isn't the best place to put these + if is_known_call(stmt, typeassert, compact) + canonicalize_typeassert!(compact, idx, stmt) + elseif is_known_call(stmt, (===), compact) + lift_comparison!(compact, idx, stmt, lifting_cache) + # elseif is_known_call(stmt, isa, compact) + # TODO do a similar optimization as `lift_comparison!` for `===` + end + end continue end @@ -749,7 +776,9 @@ function sroa_pass!(ir::IRCode) if defuses === nothing defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() end - mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse())) + mid, defuse = get!(defuses, def.id) do + SPCSet(), SSADefUse() + end if is_setfield push!(defuse.defs, idx) else @@ -771,7 +800,7 @@ function sroa_pass!(ir::IRCode) isempty(leaves) && continue result_t = compact_exprtype(compact, SSAValue(idx)) - lifted_result = lift_leaves(compact, result_t, field, leaves) + lifted_result = lift_leaves!(compact, leaves, result_t, field, nested_loads) lifted_result === nothing && continue lifted_leaves, any_undef = lifted_result @@ -807,20 +836,23 @@ function sroa_pass!(ir::IRCode) used_ssas = copy(compact.used_ssas) simple_dce!(compact, (x::SSAValue) -> used_ssas[x.id] -= 1) ir = complete(compact) - sroa_mutables!(ir, defuses, used_ssas) - return ir + return sroa_mutables!(ir, defuses, used_ssas, nested_loads) else simple_dce!(compact) return complete(compact) end end -function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}) +function sroa_mutables!(ir::IRCode, + defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}, + nested_loads::NestedLoads) # Compute domtree, needed below, now that we have finished compacting the IR. # This needs to be after we iterate through the IR with `IncrementalCompact` # because removing dead blocks can invalidate the domtree. @timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks) + nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable` + local any_eliminated = any_meliminated = false for (idx, (intermediaries, defuse)) in defuses intermediaries = collect(intermediaries) # Check if there are any uses we did not account for. If so, the variable @@ -836,7 +868,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse nleaves == nuses_total || continue # Find the type for this allocation defexpr = ir[SSAValue(idx)] - isexpr(defexpr, :new) || continue + isa(defexpr, Expr) || continue + if !isexpr(defexpr, :new) + if is_known_call(defexpr, getfield, ir) + val = defexpr.args[2] + if isa(val, SSAValue) + struct_typ = unwrap_unionall(widenconst(argextype(val, ir))) + if ismutabletype(struct_typ) + record_nested_load!(nested_mloads, idx) + end + end + end + continue + end newidx = idx typ = ir.stmts[newidx][:type] if isa(typ, UnionAll) @@ -900,6 +944,12 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse # Now go through all uses and rewrite them for stmt in du.uses ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt) + if !any_eliminated + any_eliminated |= is_nested_load(nested_loads, stmt) + end + if !any_meliminated + any_meliminated |= is_nested_load(nested_mloads, stmt) + end end if !isbitstype(ftyp) if preserve_uses !== nothing @@ -938,6 +988,11 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse @label skip end + if any_eliminated || any_meliminated + return sroa_pass!(compact!(ir), false) + else + return ir + end end """ diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 2151d938b525f..b4d46ebf4258a 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -90,6 +90,9 @@ iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[ struct ImmutableXYZ; x; y; z; end mutable struct MutableXYZ; x; y; z; end +struct ImmutableOuter{T}; x::T; y::T; z::T; end +mutable struct MutableOuter{T}; x::T; y::T; z::T; end + # should optimize away very basic cases let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) @@ -198,9 +201,8 @@ let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y @test any(isnew, src.code) end -# should include a simple alias analysis -struct ImmutableOuter{T}; x::T; y::T; z::T; end -mutable struct MutableOuter{T}; x::T; y::T; z::T; end +# alias analysis +# -------------- let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) outer = ImmutableOuter(xyz, xyz, xyz) @@ -227,9 +229,11 @@ let src = code_typed1((Any,Any,Any)) do x, y, z x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)] end end - -# FIXME our analysis isn't yet so powerful at this moment: may be unable to handle nested objects well -# OK: mutable(immutable(...)) case +# FIXME? in order to handle nested mutable `getfield` calls, we run SROA iteratively until +# any nested mutable `getfield` calls become no longer eliminatable: +# it's probably not the most efficient option and we may want to introduce some sort of +# alias analysis and eliminates all the loads at once. +# mutable(immutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) t = (xyz,) @@ -260,21 +264,41 @@ let # this is a simple end to end test case, which demonstrates allocation elimi # compiled code for `simple_sroa`, otherwise everything can be folded even without SROA @test @allocated(simple_sroa(s)) == 0 end -# FIXME: immutable(mutable(...)) case +# immutable(mutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) outer = MutableOuter(xyz, xyz, xyz) outer.x.x, outer.y.y, outer.z.z end - @test_broken !any(isnew, src.code) + @test !any(isnew, src.code) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)] + end end -# FIXME: mutable(mutable(...)) case +# mutable(mutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) outer = MutableOuter(xyz, xyz, xyz) outer.x.x, outer.y.y, outer.z.z end - @test_broken !any(isnew, src.code) + @test !any(isnew, src.code) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)] + end +end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + inner = MutableOuter(xyz, xyz, xyz) + outer = MutableOuter(inner, inner, inner) + outer.x.x.x, outer.y.y.y, outer.z.z.z + end + @test !any(isnew, src.code) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)] + end end # should work nicely with inlining to optimize away a complicated case