diff --git a/test/compiler.jl b/test/compiler.jl index 6bc172fc7..c97a50f61 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -111,21 +111,35 @@ end end @testset "issue #922" begin - # checks whether getproperty gets accumulated correctly - # instead of defining a test function as in the issue, compare the two pullbacks - function two_svds(X::StridedMatrix{<:Union{Real, Complex}}) - return svd(X).U * svd(X).V' - end + # checks whether getproperty gets accumulated correctly + # instead of defining a test function as in the issue, compare the two pullbacks + function two_svds(X::StridedMatrix{<:Union{Real, Complex}}) + return svd(X).U * svd(X).V' + end - function one_svd(X::StridedMatrix{<:Union{Real, Complex}}) - F = svd(X) - return F.U * F.V' - end + function one_svd(X::StridedMatrix{<:Union{Real, Complex}}) + F = svd(X) + return F.U * F.V' + end - Δoutput = randn(3,2) - X = randn(3,2) + Δoutput = randn(3,2) + X = randn(3,2) + + d_two = Zygote.pullback(two_svds, X)[2](Δoutput) + d_one = Zygote.pullback(one_svd, X)[2](Δoutput) + @test d_one == d_two +end + +# this test fails if adjoint for literal_getproperty is added +# https://github.com/FluxML/Zygote.jl/issues/922#issuecomment-804128905 +@testset "overloaded getproperty" begin + struct MyStruct + a + b + end + Base.getproperty(ms::MyStruct, s::Symbol) = s === :c ? ms.a + ms.b : getfield(ms, s) + sumall(ms::MyStruct) = ms.a + ms.b + ms.c - d_two = Zygote.pullback(two_svds, X)[2](Δoutput) - d_one = Zygote.pullback(one_svd, X)[2](Δoutput) - @test d_one == d_two + ms = MyStruct(1, 2) + @test Zygote.gradient(sumall, ms) == ((a = 2, b = 2),) end