From e0c7cdf91bb87fbdd4c4ae381719e939038d9bf1 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Tue, 30 Nov 2021 21:38:22 +0900 Subject: [PATCH] optimizer: run SROA multiple times to handle more nested loads --- base/compiler/ssair/passes.jl | 230 ++++++++++++++++++++-------------- test/compiler/irpasses.jl | 44 +++++-- 2 files changed, 171 insertions(+), 103 deletions(-) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index c7004b3eb5c63..5fd38b8d63929 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -323,50 +323,39 @@ function is_getfield_captures(@nospecialize(def), compact::IncrementalCompact) return oc ⊑ Core.OpaqueClosure end +struct LiftedValue + x + LiftedValue(@nospecialize x) = new(x) +end +const LiftedValues = IdDict{Any, Union{Nothing,LiftedValue}} + +# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining, +# which can be very large sometimes, and program counters in question are often very sparse +const SPCSet = IdSet{Int} + +mutable struct MayAlloc{S} + maybe::Union{Nothing,S} + MayAlloc{S}() where S = new{S}(nothing) +end +allocated!(x::MayAlloc{S}) where S = (x′ = x.maybe; x′ === nothing ? @noinline(x.maybe = S()) : x′) +push!(x::MayAlloc{<:IdSet}, @nospecialize(item)) = (push!(allocated!(x), item); x) +in(@nospecialize(item), x::MayAlloc{<:IdSet}) = (x′ = x.maybe; x′ !== nothing && item in x′) +get!(default::Callable, d::MayAlloc{<:IdDict}, @nospecialize(key)) = get!(default, allocated!(d), key) +setindex!(d::MayAlloc{<:IdDict}, @nospecialize(val), @nospecialize(key)) = setindex!(allocated!(d), val, key) + # try to compute lifted values that can replace `getfield(x, field)` call # where `x` is an immutable struct that are defined at any of `leaves` -function lift_leaves(compact::IncrementalCompact, - @nospecialize(result_t), field::Int, leaves::Vector{Any}) +function lift_leaves!(compact::IncrementalCompact, leaves::Vector{Any}, + @nospecialize(result_t), field::Int, nested_loads::MayAlloc{SPCSet}) # For every leaf, the lifted value - lifted_leaves = IdDict{Any, Union{Nothing,LiftedValue}}() - maybe_undef = false + lifted_leaves = MayAlloc{LiftedValues}() + local maybe_undef = false for leaf in leaves - leaf_key = leaf + cache_key = leaf if isa(leaf, AnySSAValue) - function lift_arg(ref::Core.Compiler.UseRef) - lifted = ref[] - if is_old(compact, leaf) && isa(lifted, SSAValue) - lifted = OldSSAValue(lifted.id) - end - if isa(lifted, GlobalRef) || isa(lifted, Expr) - lifted = insert_node!(compact, leaf, effect_free(NewInstruction(lifted, compact_exprtype(compact, lifted)))) - ref[] = lifted - (isa(leaf, SSAValue) && (leaf.id < compact.result_idx)) && push!(compact.late_fixup, leaf.id) - end - lifted_leaves[leaf_key] = LiftedValue(lifted) - nothing - end - function walk_leaf(@nospecialize(leaf)) - if isa(leaf, OldSSAValue) && already_inserted(compact, leaf) - leaf = compact.ssa_rename[leaf.id] - if isa(leaf, AnySSAValue) - leaf = simple_walk(compact, leaf) - end - if isa(leaf, AnySSAValue) - def = compact[leaf] - else - def = leaf - end - elseif isa(leaf, AnySSAValue) - def = compact[leaf] - else - def = leaf - end - return Pair{Any, Any}(def, leaf) - end - (def, leaf) = walk_leaf(leaf) - if is_tuple_call(compact, def) && 1 <= field < length(def.args) - lift_arg(UseRef(def, 1 + field)) + (def, leaf) = walk_to_def(compact, leaf) + if is_tuple_call(compact, def) && 1 ≤ field < length(def.args) + lift_arg!(compact, leaf, cache_key, def, 1+field, lifted_leaves) continue elseif isexpr(def, :new) typ = widenconst(types(compact)[leaf]) @@ -375,7 +364,7 @@ function lift_leaves(compact::IncrementalCompact, end (isa(typ, DataType) && !isabstracttype(typ)) || return nothing @assert !ismutabletype(typ) - if length(def.args) < 1 + field + if length(def.args) < 1+field if field > fieldcount(typ) return nothing end @@ -384,7 +373,7 @@ function lift_leaves(compact::IncrementalCompact, # On this branch, this will be a guaranteed UndefRefError. # We use the regular undef mechanic to lift this to a boolean slot maybe_undef = true - lifted_leaves[leaf_key] = nothing + lifted_leaves[cache_key] = nothing continue end return nothing @@ -398,16 +387,7 @@ function lift_leaves(compact::IncrementalCompact, end compact[leaf] = def end - lifted = def.args[1+field] - if is_old(compact, leaf) && isa(lifted, SSAValue) - lifted = OldSSAValue(lifted.id) - end - if isa(lifted, GlobalRef) || isa(lifted, Expr) - lifted = insert_node!(compact, leaf, effect_free(NewInstruction(lifted, compact_exprtype(compact, lifted)))) - def.args[1+field] = lifted - (isa(leaf, SSAValue) && (leaf.id < compact.result_idx)) && push!(compact.late_fixup, leaf.id) - end - lifted_leaves[leaf_key] = LiftedValue(lifted) + lift_arg!(compact, leaf, cache_key, def, 1+field, lifted_leaves) continue elseif is_getfield_captures(def, compact) # Walk to new_opaque_closure @@ -415,12 +395,20 @@ function lift_leaves(compact::IncrementalCompact, if isa(ocleaf, AnySSAValue) ocleaf = simple_walk(compact, ocleaf) end - ocdef, _ = walk_leaf(ocleaf) - if isexpr(ocdef, :new_opaque_closure) && isa(field, Int) && 1 <= field <= length(ocdef.args)-5 - lift_arg(UseRef(ocdef, 5 + field)) + ocdef, _ = walk_to_def(compact, ocleaf) + if isexpr(ocdef, :new_opaque_closure) && 1 <= field <= length(ocdef.args)-5 + lift_arg!(compact, leaf, cache_key, ocdef, 5+field, lifted_leaves) continue end return nothing + elseif isa(def, Expr) && is_known_call(def, getfield, compact) + if isa(leaf, SSAValue) + struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, def.args[2]))) + if ismutabletype(struct_typ) + push!(nested_loads, leaf.id) + end + end + return nothing else typ = compact_exprtype(compact, leaf) if !isa(typ, Const) @@ -445,18 +433,57 @@ function lift_leaves(compact::IncrementalCompact, else return nothing end - elseif isa(leaf, Union{Argument, Expr}) + elseif isa(leaf, Argument) || isa(leaf, Expr) return nothing end ismutable(leaf) && return nothing isdefined(leaf, field) || return nothing val = getfield(leaf, field) is_inlineable_constant(val) || return nothing - lifted_leaves[leaf_key] = LiftedValue(quoted(val)) + lifted_leaves[cache_key] = LiftedValue(quoted(val)) end + lifted_leaves = lifted_leaves.maybe + lifted_leaves === nothing && return nothing return lifted_leaves, maybe_undef end +function walk_to_def(compact::IncrementalCompact, @nospecialize(leaf)) + if isa(leaf, OldSSAValue) && already_inserted(compact, leaf) + leaf = compact.ssa_rename[leaf.id] + if isa(leaf, AnySSAValue) + leaf = simple_walk(compact, leaf) + end + if isa(leaf, AnySSAValue) + def = compact[leaf] + else + def = leaf + end + elseif isa(leaf, AnySSAValue) + def = compact[leaf] + else + def = leaf + end + return Pair{Any, Any}(def, leaf) +end + +function lift_arg!( + compact::IncrementalCompact, @nospecialize(leaf), @nospecialize(cache_key), + stmt::Expr, argidx::Int, lifted_leaves::MayAlloc{LiftedValues}) + lifted = stmt.args[argidx] + if is_old(compact, leaf) && isa(lifted, SSAValue) + lifted = OldSSAValue(lifted.id) + end + if isa(lifted, GlobalRef) || isa(lifted, Expr) + lifted = insert_node!(compact, leaf, effect_free(NewInstruction(lifted, compact_exprtype(compact, lifted)))) + stmt.args[argidx] = lifted + if isa(leaf, SSAValue) && leaf.id < compact.result_idx + push!(compact.late_fixup, leaf.id) + end + end + lifted_leaves[cache_key] = LiftedValue(lifted) + nothing +end + make_MaybeUndef(@nospecialize(typ)) = isa(typ, MaybeUndef) ? typ : MaybeUndef(typ) """ @@ -500,22 +527,21 @@ function lift_comparison!(compact::IncrementalCompact, length(leaves) ≤ 1 && return # bail out if we don't have multiple leaves # Let's check if we evaluate the comparison for each one of the leaves - lifted_leaves = nothing + lifted_leaves = MayAlloc{LiftedValues}() for leaf in leaves r = egal_tfunc(compact_exprtype(compact, leaf), cmp) if isa(r, Const) - if lifted_leaves === nothing - lifted_leaves = IdDict{Any, Union{Nothing,LiftedValue}}() - end lifted_leaves[leaf] = LiftedValue(r.val) else return # TODO In some cases it might be profitable to hoist the === here end end + lifted_leaves = lifted_leaves.maybe + lifted_leaves === nothing && return # should never happen lifted_val = perform_lifting!(compact, visited_phinodes, cmp, lifting_cache, Bool, - lifted_leaves::IdDict{Any, Union{Nothing,LiftedValue}}, val)::LiftedValue + lifted_leaves, val)::LiftedValue compact[idx] = lifted_val.x end @@ -532,15 +558,10 @@ function is_old(compact, @nospecialize(old_node_ssa)) !already_inserted(compact, old_node_ssa) end -struct LiftedValue - x - LiftedValue(@nospecialize x) = new(x) -end - function perform_lifting!(compact::IncrementalCompact, visited_phinodes::Vector{AnySSAValue}, @nospecialize(cache_key), lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue}, - @nospecialize(result_t), lifted_leaves::IdDict{Any, Union{Nothing,LiftedValue}}, @nospecialize(stmt_val)) + @nospecialize(result_t), lifted_leaves::LiftedValues, @nospecialize(stmt_val)) reverse_mapping = IdDict{AnySSAValue, Int}(ssa => id for (id, ssa) in enumerate(visited_phinodes)) # Insert PhiNodes @@ -613,10 +634,6 @@ function perform_lifting!(compact::IncrementalCompact, return stmt_val # N.B. should never happen end -# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining, -# which can be very large sometimes, and program counters in question are often very sparse -const SPCSet = IdSet{Int} - """ sroa_pass!(ir::IRCode) -> newir::IRCode @@ -635,10 +652,11 @@ its argument). In a case when all usages are fully eliminated, `struct` allocation may also be erased as a result of succeeding dead code elimination. """ -function sroa_pass!(ir::IRCode) +function sroa_pass!(ir::IRCode, optional_opts::Bool = true) compact = IncrementalCompact(ir) - defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations + defuses = MayAlloc{IdDict{Int, Tuple{SPCSet, SSADefUse}}}() # tracks mutable-related information lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}() + nested_loads = MayAlloc{SPCSet}() # tracks nested `getfield(getfield(...)::Mutable, ...)::Immutable` for ((_, idx), stmt) in compact # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement) isa(stmt, Expr) || continue @@ -690,10 +708,9 @@ function sroa_pass!(ir::IRCode) else continue end - if defuses === nothing - defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() + mid, defuse = get!(defuses, defidx) do + SPCSet(), SSADefUse() end - mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse())) push!(defuse.ccall_preserve_uses, idx) union!(mid, intermediaries) end @@ -706,16 +723,17 @@ function sroa_pass!(ir::IRCode) compact[idx] = new_expr end continue - # TODO: This isn't the best place to put these - elseif is_known_call(stmt, typeassert, compact) - canonicalize_typeassert!(compact, idx, stmt) - continue - elseif is_known_call(stmt, (===), compact) - lift_comparison!(compact, idx, stmt, lifting_cache) - continue - # elseif is_known_call(stmt, isa, compact) - # TODO do a similar optimization as `lift_comparison!` for `===` else + if optional_opts + # TODO: This isn't the best place to put these + if is_known_call(stmt, typeassert, compact) + canonicalize_typeassert!(compact, idx, stmt) + elseif is_known_call(stmt, (===), compact) + lift_comparison!(compact, idx, stmt, lifting_cache) + # elseif is_known_call(stmt, isa, compact) + # TODO do a similar optimization as `lift_comparison!` for `===` + end + end continue end @@ -748,10 +766,9 @@ function sroa_pass!(ir::IRCode) def = simple_walk(compact, val, callback) # Mutable stuff here isa(def, SSAValue) || continue - if defuses === nothing - defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() + mid, defuse = get!(defuses, def.id) do + SPCSet(), SSADefUse() end - mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse())) if is_setfield push!(defuse.defs, idx) else @@ -773,7 +790,7 @@ function sroa_pass!(ir::IRCode) isempty(leaves) && continue result_t = compact_exprtype(compact, SSAValue(idx)) - lifted_result = lift_leaves(compact, result_t, field, leaves) + lifted_result = lift_leaves!(compact, leaves, result_t, field, nested_loads) lifted_result === nothing && continue lifted_leaves, any_undef = lifted_result @@ -800,6 +817,7 @@ function sroa_pass!(ir::IRCode) end non_dce_finish!(compact) + defuses = defuses.maybe if defuses !== nothing # now go through analyzed mutable structs and see which ones we can eliminate # NOTE copy the use count here, because `simple_dce!` may modify it and we need it @@ -809,20 +827,23 @@ function sroa_pass!(ir::IRCode) used_ssas = copy(compact.used_ssas) simple_dce!(compact, (x::SSAValue) -> used_ssas[x.id] -= 1) ir = complete(compact) - sroa_mutables!(ir, defuses, used_ssas) - return ir + return sroa_mutables!(ir, defuses, used_ssas, nested_loads) else simple_dce!(compact) return complete(compact) end end -function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}) +function sroa_mutables!(ir::IRCode, + defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}, + nested_loads::MayAlloc{SPCSet}) # Compute domtree, needed below, now that we have finished compacting the IR. # This needs to be after we iterate through the IR with `IncrementalCompact` # because removing dead blocks can invalidate the domtree. @timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks) + nested_mloads = MayAlloc{SPCSet}() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable` + local any_eliminated = any_meliminated = false for (idx, (intermediaries, defuse)) in defuses intermediaries = collect(intermediaries) # Check if there are any uses we did not account for. If so, the variable @@ -838,7 +859,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse nleaves == nuses_total || continue # Find the type for this allocation defexpr = ir[SSAValue(idx)] - isexpr(defexpr, :new) || continue + isa(defexpr, Expr) || continue + if !isexpr(defexpr, :new) + if is_known_call(defexpr, getfield, ir) + val = defexpr.args[2] + if isa(val, SSAValue) + struct_typ = unwrap_unionall(widenconst(argextype(val, ir))) + if ismutabletype(struct_typ) + push!(nested_mloads, idx) + end + end + end + continue + end newidx = idx typ = ir.stmts[newidx][:type] if isa(typ, UnionAll) @@ -902,6 +935,12 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse # Now go through all uses and rewrite them for stmt in du.uses ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt) + if !any_eliminated + any_eliminated |= stmt in nested_loads + end + if !any_meliminated + any_meliminated |= stmt in nested_mloads + end end if !isbitstype(ftyp) if preserve_uses !== nothing @@ -940,6 +979,11 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse @label skip end + if any_eliminated || any_meliminated + return sroa_pass!(compact!(ir), false) + else + return ir + end end """ diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 2151d938b525f..b4d46ebf4258a 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -90,6 +90,9 @@ iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[ struct ImmutableXYZ; x; y; z; end mutable struct MutableXYZ; x; y; z; end +struct ImmutableOuter{T}; x::T; y::T; z::T; end +mutable struct MutableOuter{T}; x::T; y::T; z::T; end + # should optimize away very basic cases let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) @@ -198,9 +201,8 @@ let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y @test any(isnew, src.code) end -# should include a simple alias analysis -struct ImmutableOuter{T}; x::T; y::T; z::T; end -mutable struct MutableOuter{T}; x::T; y::T; z::T; end +# alias analysis +# -------------- let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) outer = ImmutableOuter(xyz, xyz, xyz) @@ -227,9 +229,11 @@ let src = code_typed1((Any,Any,Any)) do x, y, z x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)] end end - -# FIXME our analysis isn't yet so powerful at this moment: may be unable to handle nested objects well -# OK: mutable(immutable(...)) case +# FIXME? in order to handle nested mutable `getfield` calls, we run SROA iteratively until +# any nested mutable `getfield` calls become no longer eliminatable: +# it's probably not the most efficient option and we may want to introduce some sort of +# alias analysis and eliminates all the loads at once. +# mutable(immutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) t = (xyz,) @@ -260,21 +264,41 @@ let # this is a simple end to end test case, which demonstrates allocation elimi # compiled code for `simple_sroa`, otherwise everything can be folded even without SROA @test @allocated(simple_sroa(s)) == 0 end -# FIXME: immutable(mutable(...)) case +# immutable(mutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) outer = MutableOuter(xyz, xyz, xyz) outer.x.x, outer.y.y, outer.z.z end - @test_broken !any(isnew, src.code) + @test !any(isnew, src.code) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)] + end end -# FIXME: mutable(mutable(...)) case +# mutable(mutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) outer = MutableOuter(xyz, xyz, xyz) outer.x.x, outer.y.y, outer.z.z end - @test_broken !any(isnew, src.code) + @test !any(isnew, src.code) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)] + end +end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + inner = MutableOuter(xyz, xyz, xyz) + outer = MutableOuter(inner, inner, inner) + outer.x.x.x, outer.y.y.y, outer.z.z.z + end + @test !any(isnew, src.code) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)] + end end # should work nicely with inlining to optimize away a complicated case