From d9f20747d2f9b06150044b7f0125d66b833222d8 Mon Sep 17 00:00:00 2001
From: Shuhei Kadowaki <aviatesk@gmail.com>
Date: Wed, 23 Mar 2022 16:31:51 +0900
Subject: [PATCH] optimizer: refactor SROA test cases

Mostly adapted from #43888.
Should be more robust and cover more cases.
---
 test/compiler/irpasses.jl | 167 +++++++++++++++++++++++++++++++-------
 1 file changed, 139 insertions(+), 28 deletions(-)

diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl
index 045cf833944c2..587d4402f554a 100644
--- a/test/compiler/irpasses.jl
+++ b/test/compiler/irpasses.jl
@@ -2,7 +2,9 @@
 
 using Test
 using Base.Meta
-using Core: PhiNode, SSAValue, GotoNode, PiNode, QuoteNode, ReturnNode, GotoIfNot
+import Core:
+    CodeInfo, Argument, SSAValue, GotoNode, GotoIfNot, PiNode, PhiNode,
+    QuoteNode, ReturnNode
 
 include(normpath(@__DIR__, "irutils.jl"))
 
@@ -12,7 +14,7 @@ include(normpath(@__DIR__, "irutils.jl"))
 ## Test that domsort doesn't mangle single-argument phis (#29262)
 let m = Meta.@lower 1 + 1
     @assert Meta.isexpr(m, :thunk)
-    src = m.args[1]::Core.CodeInfo
+    src = m.args[1]::CodeInfo
     src.code = Any[
         # block 1
         Expr(:call, :opaque),
@@ -47,7 +49,7 @@ end
 # test that we don't stack-overflow in SNCA with large functions.
 let m = Meta.@lower 1 + 1
     @assert Meta.isexpr(m, :thunk)
-    src = m.args[1]::Core.CodeInfo
+    src = m.args[1]::CodeInfo
     code = Any[]
     N = 2^15
     for i in 1:2:N
@@ -73,30 +75,87 @@ end
 # SROA
 # ====
 
+import Core.Compiler: widenconst
+
+is_load_forwarded(src::CodeInfo) = !any(iscall((src, getfield)), src.code)
+is_scalar_replaced(src::CodeInfo) =
+    is_load_forwarded(src) && !any(iscall((src, setfield!)), src.code) && !any(isnew, src.code)
+
+function is_load_forwarded(@nospecialize(T), src::CodeInfo)
+    for i in 1:length(src.code)
+        x = src.code[i]
+        if iscall((src, getfield), x)
+            widenconst(argextype(x.args[1], src)) <: T && return false
+        end
+    end
+    return true
+end
+function is_scalar_replaced(@nospecialize(T), src::CodeInfo)
+    is_load_forwarded(T, src) || return false
+    for i in 1:length(src.code)
+        x = src.code[i]
+        if iscall((src, setfield!), x)
+            widenconst(argextype(x.args[1], src)) <: T && return false
+        elseif isnew(x)
+            widenconst(argextype(SSAValue(i), src)) <: T && return false
+        end
+    end
+    return true
+end
+
 struct ImmutableXYZ; x; y; z; end
 mutable struct MutableXYZ; x; y; z; end
+struct ImmutableOuter{T}; x::T; y::T; z::T; end
+mutable struct MutableOuter{T}; x::T; y::T; z::T; end
+struct ImmutableRef{T}; x::T; end
+Base.getindex(r::ImmutableRef) = r.x
+mutable struct SafeRef{T}; x::T; end
+Base.getindex(s::SafeRef) = getfield(s, 1)
+Base.setindex!(s::SafeRef, x) = setfield!(s, 1, x)
+
+# simple immutability
+# -------------------
 
-# should optimize away very basic cases
 let src = code_typed1((Any,Any,Any)) do x, y, z
         xyz = ImmutableXYZ(x, y, z)
         xyz.x, xyz.y, xyz.z
     end
-    @test !any(isnew, src.code)
+    @test is_scalar_replaced(src)
+    @test any(src.code) do @nospecialize x
+        iscall((src, tuple), x) &&
+        x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)]
+    end
 end
+let src = code_typed1((Any,Any,Any)) do x, y, z
+        xyz = (x, y, z)
+        xyz[1], xyz[2], xyz[3]
+    end
+    @test is_scalar_replaced(src)
+    @test any(src.code) do @nospecialize x
+        iscall((src, tuple), x) &&
+        x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)]
+    end
+end
+
+# simple mutability
+# -----------------
+
 let src = code_typed1((Any,Any,Any)) do x, y, z
         xyz = MutableXYZ(x, y, z)
         xyz.x, xyz.y, xyz.z
     end
-    @test !any(isnew, src.code)
+    @test is_scalar_replaced(src)
+    @test any(src.code) do @nospecialize x
+        iscall((src, tuple), x) &&
+        x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)]
+    end
 end
-
-# should handle simple mutabilities
 let src = code_typed1((Any,Any,Any)) do x, y, z
         xyz = MutableXYZ(x, y, z)
         xyz.y = 42
         xyz.x, xyz.y, xyz.z
     end
-    @test !any(isnew, src.code)
+    @test is_scalar_replaced(src)
     @test any(src.code) do @nospecialize x
         iscall((src, tuple), x) &&
         x.args[2:end] == Any[#=x=# Core.Argument(2), 42, #=x=# Core.Argument(4)]
@@ -107,19 +166,23 @@ let src = code_typed1((Any,Any,Any)) do x, y, z
         xyz.x, xyz.z = xyz.z, xyz.x
         xyz.x, xyz.y, xyz.z
     end
-    @test !any(isnew, src.code)
+    @test is_scalar_replaced(src)
     @test any(src.code) do @nospecialize x
         iscall((src, tuple), x) &&
         x.args[2:end] == Any[#=z=# Core.Argument(4), #=y=# Core.Argument(3), #=x=# Core.Argument(2)]
     end
 end
-# circumvent uninitialized fields as far as there is a solid `setfield!` definition
+
+# uninitialized fields
+# --------------------
+
+# safe cases
 let src = code_typed1() do
         r = Ref{Any}()
         r[] = 42
         return r[]
     end
-    @test !any(isnew, src.code)
+    @test is_scalar_replaced(src)
 end
 let src = code_typed1((Bool,)) do cond
         r = Ref{Any}()
@@ -131,7 +194,7 @@ let src = code_typed1((Bool,)) do cond
             return r[]
         end
     end
-    @test !any(isnew, src.code)
+    @test is_scalar_replaced(src)
 end
 let src = code_typed1((Bool,)) do cond
         r = Ref{Any}()
@@ -142,7 +205,7 @@ let src = code_typed1((Bool,)) do cond
         end
         return r[]
     end
-    @test !any(isnew, src.code)
+    @test is_scalar_replaced(src)
 end
 let src = code_typed1((Bool,Bool,Any,Any,Any)) do c1, c2, x, y, z
         r = Ref{Any}()
@@ -157,7 +220,16 @@ let src = code_typed1((Bool,Bool,Any,Any,Any)) do c1, c2, x, y, z
         end
         return r[]
     end
-    @test !any(isnew, src.code)
+    @test is_scalar_replaced(src)
+end
+
+# unsafe cases
+let src = code_typed1() do
+        r = Ref{Any}()
+        return r[]
+    end
+    @test count(isnew, src.code) == 1
+    @test count(iscall((src, getfield)), src.code) == 1
 end
 let src = code_typed1((Bool,)) do cond
         r = Ref{Any}()
@@ -167,7 +239,9 @@ let src = code_typed1((Bool,)) do cond
         return r[]
     end
     # N.B. `r` should be allocated since `cond` might be `false` and then it will be thrown
-    @test any(isnew, src.code)
+    @test count(isnew, src.code) == 1
+    @test count(iscall((src, setfield!)), src.code) == 1
+    @test count(iscall((src, getfield)), src.code) == 1
 end
 let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y
         r = Ref{Any}()
@@ -181,12 +255,16 @@ let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y
         return r[]
     end
     # N.B. `r` should be allocated since `c2` might be `false` and then it will be thrown
-    @test any(isnew, src.code)
+    @test count(isnew, src.code) == 1
+    @test count(iscall((src, setfield!)), src.code) == 2
+    @test count(iscall((src, getfield)), src.code) == 1
 end
 
-# should include a simple alias analysis
-struct ImmutableOuter{T}; x::T; y::T; z::T; end
-mutable struct MutableOuter{T}; x::T; y::T; z::T; end
+# aliased load forwarding
+# -----------------------
+# TODO fix broken examples with EscapeAnalysis
+
+# OK: immutable(immutable(...)) case
 let src = code_typed1((Any,Any,Any)) do x, y, z
         xyz = ImmutableXYZ(x, y, z)
         outer = ImmutableOuter(xyz, xyz, xyz)
@@ -214,22 +292,21 @@ let src = code_typed1((Any,Any,Any)) do x, y, z
     end
 end
 
-# FIXME our analysis isn't yet so powerful at this moment: may be unable to handle nested objects well
-# OK: mutable(immutable(...)) case
+# OK (mostly): immutable(mutable(...)) case
 let src = code_typed1((Any,Any,Any)) do x, y, z
         xyz = MutableXYZ(x, y, z)
         t   = (xyz,)
         v = t[1].x
         v, v, v
     end
-    @test !any(isnew, src.code)
+    @test is_scalar_replaced(src)
 end
 let src = code_typed1((Any,Any,Any)) do x, y, z
         xyz = MutableXYZ(x, y, z)
         outer = ImmutableOuter(xyz, xyz, xyz)
         outer.x.x, outer.y.y, outer.z.z
     end
-    @test !any(isnew, src.code)
+    @test is_scalar_replaced(src)
     @test any(src.code) do @nospecialize x
         iscall((src, tuple), x) &&
         x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
@@ -240,12 +317,27 @@ let # this is a simple end to end test case, which demonstrates allocation elimi
     # NOTE this test case isn't so robust and might be subject to future changes of the broadcasting implementation,
     # in that case you don't really need to stick to keeping this test case around
     simple_sroa(s) = broadcast(identity, Ref(s))
+    let src = code_typed1(simple_sroa, (String,))
+        @test is_scalar_replaced(src)
+    end
     s = Base.inferencebarrier("julia")::String
     simple_sroa(s)
     # NOTE don't hard-code `"julia"` in `@allocated` clause and make sure to execute the
     # compiled code for `simple_sroa`, otherwise everything can be folded even without SROA
     @test @allocated(simple_sroa(s)) == 0
 end
+let # FIXME: some nested example
+    src = code_typed1((Int,)) do x
+        Ref(Ref(x))[][]
+    end
+    @test_broken is_scalar_replaced(src)
+
+    src = code_typed1((Int,)) do x
+        Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref((x)))))))))))[][][][][][][][][][]
+    end
+    @test_broken is_scalar_replaced(src)
+end
+
 # FIXME: immutable(mutable(...)) case
 let src = code_typed1((Any,Any,Any)) do x, y, z
         xyz = ImmutableXYZ(x, y, z)
@@ -314,6 +406,25 @@ let src = code_typed1(compute_points)
     @test !any(isnew, src.code)
 end
 
+# preserve elimination
+# --------------------
+
+let src = code_typed1((String,)) do s
+        ccall(:some_ccall, Cint, (Ptr{String},), Ref(s))
+    end
+    @test count(isnew, src.code) == 0
+end
+
+# if the mutable struct is directly used, we shouldn't eliminate it
+let src = code_typed1() do
+        a = MutableXYZ(-512275808,882558299,-2133022131)
+        b = Int32(42)
+        ccall(:some_ccall, Cvoid, (MutableXYZ, Int32), a, b)
+        return a.x
+    end
+    @test count(isnew, src.code) == 1
+end
+
 # comparison lifting
 # ==================
 
@@ -454,7 +565,7 @@ end
 # A SSAValue after the compaction line
 let m = Meta.@lower 1 + 1
     @assert Meta.isexpr(m, :thunk)
-    src = m.args[1]::Core.CodeInfo
+    src = m.args[1]::CodeInfo
     src.code = Any[
         # block 1
         nothing,
@@ -517,7 +628,7 @@ end
 let m = Meta.@lower 1 + 1
     # Test that CFG simplify combines redundant basic blocks
     @assert Meta.isexpr(m, :thunk)
-    src = m.args[1]::Core.CodeInfo
+    src = m.args[1]::CodeInfo
     src.code = Any[
         Core.Compiler.GotoNode(2),
         Core.Compiler.GotoNode(3),
@@ -542,7 +653,7 @@ end
 let m = Meta.@lower 1 + 1
     # Test that CFG simplify doesn't mess up when chaining past return blocks
     @assert Meta.isexpr(m, :thunk)
-    src = m.args[1]::Core.CodeInfo
+    src = m.args[1]::CodeInfo
     src.code = Any[
         Core.Compiler.GotoIfNot(Core.Compiler.Argument(2), 3),
         Core.Compiler.GotoNode(4),
@@ -572,7 +683,7 @@ let m = Meta.@lower 1 + 1
     # Test that CFG simplify doesn't try to merge every block in a loop into
     # its predecessor
     @assert Meta.isexpr(m, :thunk)
-    src = m.args[1]::Core.CodeInfo
+    src = m.args[1]::CodeInfo
     src.code = Any[
         # Block 1
         Core.Compiler.GotoNode(2),