Skip to content

Commit

Permalink
optimizer: enhance SROA, handle partially-initialized allocations
Browse files Browse the repository at this point in the history
During adding more test cases for our SROA pass, I found our SROA doesn't
handle allocation sites with uninitialized fields at all.
This commit is based on #42833 and tries to handle such "unsafe" allocations,
if there are safe `setfield!` definitions.

For example, this commit allows the allocation `r = Ref{Int}()` to be
eliminated in the following example (adapted from <https://hackmd.io/bZz8k6SHQQuNUW-Vs7rqfw?view>):
```julia
julia> code_typed() do
           r = Ref{Int}()
           r[] = 42
           b = sin(r[])
           return b
       end |> only
```

This commit comes with a plenty of basic test cases for our SROA pass also.
  • Loading branch information
aviatesk committed Nov 1, 2021
1 parent a121721 commit a2742c9
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 38 deletions.
2 changes: 1 addition & 1 deletion base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ function run_passes(ci::CodeInfo, sv::OptimizationState)
@timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds)
# @timeit "verify 2" verify_ir(ir)
@timeit "compact 2" ir = compact!(ir)
@timeit "SROA" ir = getfield_elim_pass!(ir)
@timeit "SROA" ir = sroa_pass!(ir)
@timeit "ADCE" ir = adce_pass!(ir)
@timeit "type lift" ir = type_lift_pass!(ir)
@timeit "compact 3" ir = compact!(ir)
Expand Down
80 changes: 45 additions & 35 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,22 @@ function compute_value_for_block(ir::IRCode, domtree::DomTree, allblocks::Vector
end

function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use_idx::Int)
# Find the first dominating def
def, stmtblock, curblock = find_def_for_use(ir, domtree, allblocks, du, use_idx)
if def == 0
if !haskey(phinodes, curblock)
# If this happens, we need to search the predecessors for defs. Which
# one doesn't matter - if it did, we'd have had a phinode
return compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, first(ir.cfg.blocks[stmtblock].preds))
end
# The use is the phinode
return phinodes[curblock]
else
return val_for_def_expr(ir, def, fidx)
end
end

# find the first dominating def for the given use
function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, use_idx::Int)
stmtblock = block_for_inst(ir.cfg, use_idx)
curblock = find_curblock(domtree, allblocks, stmtblock)
local def = 0
Expand All @@ -90,17 +105,7 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I
end
end
end
if def == 0
if !haskey(phinodes, curblock)
# If this happens, we need to search the predecessors for defs. Which
# one doesn't matter - if it did, we'd have had a phinode
return compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, first(ir.cfg.blocks[stmtblock].preds))
end
# The use is the phinode
return phinodes[curblock]
else
return val_for_def_expr(ir, def, fidx)
end
return def, stmtblock, curblock
end

function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
Expand Down Expand Up @@ -538,7 +543,7 @@ function perform_lifting!(compact::IncrementalCompact,
end

"""
getfield_elim_pass!(ir::IRCode) -> newir::IRCode
sroa_pass!(ir::IRCode) -> newir::IRCode
`getfield` elimination pass, a.k.a. Scalar Replacements of Aggregates optimization.
Expand All @@ -555,7 +560,7 @@ its argument).
In a case when all usages are fully eliminated, `struct` allocation may also be erased as
a result of dead code elimination.
"""
function getfield_elim_pass!(ir::IRCode)
function sroa_pass!(ir::IRCode)
compact = IncrementalCompact(ir)
defuses = IdDict{Int, Tuple{IdSet{Int}, SSADefUse}}()
lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
Expand Down Expand Up @@ -784,7 +789,6 @@ function getfield_elim_pass!(ir::IRCode)
typ = typ::DataType
# Partition defuses by field
fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)]
ok = true
for use in defuse.uses
stmt = ir[SSAValue(use)]
# We may have discovered above that this use is dead
Expand All @@ -793,47 +797,52 @@ function getfield_elim_pass!(ir::IRCode)
# the use in that case.
stmt === nothing && continue
field = try_compute_fieldidx_stmt(compact, stmt::Expr, typ)
field === nothing && (ok = false; break)
field === nothing && @goto skip
push!(fielddefuse[field].uses, use)
end
ok || continue
for use in defuse.defs
field = try_compute_fieldidx_stmt(compact, ir[SSAValue(use)]::Expr, typ)
field === nothing && (ok = false; break)
field === nothing && @goto skip
push!(fielddefuse[field].defs, use)
end
ok || continue
# Check that the defexpr has defined values for all the fields
# we're accessing. In the future, we may want to relax this,
# but we should come up with semantics for well defined semantics
# for uninitialized fields first.
for (fidx, du) in pairs(fielddefuse)
ndefuse = length(fielddefuse)
blocks = Vector{Tuple{#=phiblocks=# Vector{Int}, #=allblocks=# Vector{Int}}}(undef, ndefuse)
for fidx in 1:ndefuse
du = fielddefuse[fidx]
isempty(du.uses) && continue
push!(du.defs, idx)
ldu = compute_live_ins(ir.cfg, du)
phiblocks = Int[]
if !isempty(ldu.live_in_bbs)
phiblocks = idf(ir.cfg, ldu, domtree)
end
allblocks = sort(vcat(phiblocks, ldu.def_bbs))
blocks[fidx] = phiblocks, allblocks
if fidx + 1 > length(defexpr.args)
ok = false
break
for use in du.uses
def = find_def_for_use(ir, domtree, allblocks, du, use)[1]
(def == 0 || def == idx) && @goto skip
end
end
end
ok || continue
preserve_uses = IdDict{Int, Vector{Any}}((idx=>Any[] for idx in IdSet{Int}(defuse.ccall_preserve_uses)))
# Everything accounted for. Go field by field and perform idf
for (fidx, du) in pairs(fielddefuse)
for fidx in 1:ndefuse
du = fielddefuse[fidx]
ftyp = fieldtype(typ, fidx)
if !isempty(du.uses)
push!(du.defs, idx)
ldu = compute_live_ins(ir.cfg, du)
phiblocks = Int[]
if !isempty(ldu.live_in_bbs)
phiblocks = idf(ir.cfg, ldu, domtree)
end
phiblocks, allblocks = blocks[fidx]
phinodes = IdDict{Int, SSAValue}()
for b in phiblocks
n = PhiNode()
phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts),
NewInstruction(n, ftyp))
end
# Now go through all uses and rewrite them
allblocks = sort(vcat(phiblocks, ldu.def_bbs))
for stmt in du.uses
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
end
Expand All @@ -855,7 +864,6 @@ function getfield_elim_pass!(ir::IRCode)
stmt == idx && continue
ir[SSAValue(stmt)] = nothing
end
continue
end
isempty(defuse.ccall_preserve_uses) && continue
push!(intermediaries, idx)
Expand All @@ -870,6 +878,8 @@ function getfield_elim_pass!(ir::IRCode)
old_preserves..., new_preserves...)
ir[SSAValue(use)] = new_expr
end

@label skip
end

return ir
Expand Down Expand Up @@ -919,14 +929,14 @@ In addition to a simple DCE for unused values and allocations,
this pass also nullifies `typeassert` calls that can be proved to be no-op,
in order to allow LLVM to emit simpler code down the road.
Note that this pass is more effective after SROA optimization (i.e. `getfield_elim_pass!`),
Note that this pass is more effective after SROA optimization (i.e. `sroa_pass!`),
since SROA often allows this pass to:
- eliminate allocation of object whose field references are all replaced with scalar values, and
- nullify `typeassert` call whose first operand has been replaced with a scalar value
(, which may have introduced new type information that inference did not understand)
Also note that currently this pass _needs_ to run after `getfield_elim_pass!`, because
the `typeassert` elimination depends on the transformation within `getfield_elim_pass!`
Also note that currently this pass _needs_ to run after `sroa_pass!`, because
the `typeassert` elimination depends on the transformation within `sroa_pass!`
which redirects references of `typeassert`ed value to the corresponding `PiNode`.
"""
function adce_pass!(ir::IRCode)
Expand Down
184 changes: 182 additions & 2 deletions test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,186 @@ end

# Tests for SROA

import Core.Compiler: argextype, singleton_type
const EMPTY_SPTYPES = Core.Compiler.EMPTY_SLOTTYPES
function iscall((src, f)::Tuple{Core.CodeInfo,Function}, @nospecialize(x))
return iscall(x) do @nospecialize x
singleton_type(argextype(x, src, EMPTY_SPTYPES)) === f
end
end
iscall(pred, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1])

struct ImmutableXYZ; x; y; z; end
mutable struct MutableXYZ; x; y; z; end

# should optimize away very basic cases
let src = code_typed((Any,Any,Any)) do x, y, z
xyz = ImmutableXYZ(x, y, z)
xyz.x, xyz.y, xyz.z
end |> only |> first
@test !any(src.code) do @nospecialize x
Meta.isexpr(x, :new)
end
end
let src = code_typed((Any,Any,Any)) do x, y, z
xyz = MutableXYZ(x, y, z)
xyz.x, xyz.y, xyz.z
end |> only |> first
@test !any(src.code) do @nospecialize x
Meta.isexpr(x, :new)
end
end

# should handle simple mutabilities
let src = code_typed((Any,Any,Any)) do x, y, z
xyz = MutableXYZ(x, y, z)
xyz.y = 42
xyz.x, xyz.y, xyz.z
end |> only |> first
@test !any(src.code) do @nospecialize x
Meta.isexpr(x, :new)
end
@test any(src.code) do @nospecialize x
iscall((src, tuple), x) &&
x.args[2:end] == Any[#=x=# Core.Argument(2), 42, #=x=# Core.Argument(4)]
end
end
let src = code_typed((Any,Any,Any)) do x, y, z
xyz = MutableXYZ(x, y, z)
xyz.x, xyz.z = xyz.z, xyz.x
xyz.x, xyz.y, xyz.z
end |> only |> first
@test !any(src.code) do @nospecialize x
Meta.isexpr(x, :new)
end
@test any(src.code) do @nospecialize x
iscall((src, tuple), x) &&
x.args[2:end] == Any[#=z=# Core.Argument(4), #=y=# Core.Argument(3), #=x=# Core.Argument(2)]
end
end
# FIXME currently SROA requires "safe" allocation site, i.e. without any uninitialized fields
# circumvent uninitialized fields as far as there is a solid `setfield!` definition
let src = code_typed() do
r = Ref{Any}()
r[] = 42
return r[]
end |> only |> first
@test !any(src.code) do @nospecialize x
Meta.isexpr(x, :new)
end
end
let src = code_typed((Bool,)) do cond
r = Ref{Any}()
if cond
r[] = 42
return r[]
else
r[] = 32
return r[]
end
end |> only |> first
@test !any(src.code) do @nospecialize x
Meta.isexpr(x, :new)
end
end
let src = code_typed((Bool,)) do cond
r = Ref{Any}()
if cond
r[] = 42
else
r[] = 32
end
return r[]
end |> only |> first
@test_broken !any(src.code) do @nospecialize x
Meta.isexpr(x, :new)
end
end
let
src = code_typed((Bool,)) do cond
r = Ref{Any}()
if cond
r[] = 42
end
return r[]
end |> only |> first
@test any(src.code) do @nospecialize x
Meta.isexpr(x, :new)
end
end

# should include alias analysis to some extent
struct ImmutableOuter{T}; x::T; y::T; z::T; end
mutable struct MutableOuter{T}; x::T; y::T; z::T; end
let src = code_typed((Any,Any,Any)) do x, y, z
xyz = ImmutableXYZ(x, y, z)
outer = ImmutableOuter(xyz, xyz, xyz)
outer.x.x, outer.y.y, outer.z.z
end |> only |> first
@test !any(src.code) do @nospecialize x
Meta.isexpr(x, :new)
end
@test any(src.code) do @nospecialize x
iscall((src, tuple), x) &&
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
end
end
let src = code_typed((Any,Any,Any)) do x, y, z
xyz = ImmutableXYZ(x, y, z)
# #42831 forms ::PartialStruct(ImmutableOuter{Any}, Any[ImmutableXYZ, ImmutableXYZ, ImmutableXYZ])
# so the succeeding `getproperty`s are type stable and inlined
outer = ImmutableOuter{Any}(xyz, xyz, xyz)
outer.x.x, outer.y.y, outer.z.z
end |> only |> first
@test !any(src.code) do @nospecialize x
Meta.isexpr(x, :new)
end
@test any(src.code) do @nospecialize x
iscall((src, tuple), x) &&
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
end
end
# currently our SROA can't handle nested mutable objects
let src = code_typed((Any,Any,Any)) do x, y, z
xyz = MutableXYZ(x, y, z)
outer = ImmutableOuter(xyz, xyz, xyz)
outer.x.x, outer.y.y, outer.z.z
end |> only |> first
@test_broken !any(src.code) do @nospecialize x
Meta.isexpr(x, :new)
end
end
let src = code_typed((Any,Any,Any)) do x, y, z
xyz = ImmutableXYZ(x, y, z)
outer = MutableOuter(xyz, xyz, xyz)
outer.x.x, outer.y.y, outer.z.z
end |> only |> first
@test_broken !any(src.code) do @nospecialize x
Meta.isexpr(x, :new)
end
end

# should work nicely with inlining to optimize away a complicated case
# adapted from http://wiki.luajit.org/Allocation-Sinking-Optimization#implementation%5B
struct Point
x::Float64
y::Float64
end
#=@inline=# add(a::Point, b::Point) = Point(a.x + b.x, a.y + b.y)
function compute()
a = Point(1.5, 2.5)
b = Point(2.25, 4.75)
for i in 0:(100000000-1)
a = add(add(a, b), b)
end
a.x, a.y
end
let src = first(only(code_typed(compute, ())))
@test !any(src.code) do @nospecialize x
Meta.isexpr(x, :new)
end
end

mutable struct Foo30594; x::Float64; end
Base.copy(x::Foo30594) = Foo30594(x.x)
function add!(p::Foo30594, off::Foo30594)
Expand Down Expand Up @@ -180,7 +360,7 @@ let m = Meta.@lower 1 + 1
src.ssaflags = fill(Int32(0), nstmts)
ir = Core.Compiler.inflate_ir(src, Any[], Any[Any, Any])
@test Core.Compiler.verify_ir(ir) === nothing
ir = @test_nowarn Core.Compiler.getfield_elim_pass!(ir)
ir = @test_nowarn Core.Compiler.sroa_pass!(ir)
@test Core.Compiler.verify_ir(ir) === nothing
end

Expand Down Expand Up @@ -384,7 +564,7 @@ exc39508 = ErrorException("expected")
end
@test test39508() === exc39508

let # `getfield_elim_pass!` should work with constant globals
let # `sroa_pass!` should work with constant globals
# immutable pass
src = @eval Module() begin
const REF_FLD = :x
Expand Down

0 comments on commit a2742c9

Please sign in to comment.