From 07dfe6078ab2ba8861c0f78d80403097b5076a7f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 5 Sep 2022 10:36:59 -0400 Subject: [PATCH] Collapse zeros for `Tuple` & `Ref` tangents (#565) * collapse zeros for Tuple & Ref tangents * v1.15.4 --- Project.toml | 2 +- src/projection.jl | 18 +++++++++++++++--- test/projection.jl | 13 ++++++++----- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 3b97c30ab..ee5cf560d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.15.3" +version = "1.15.4" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/projection.jl b/src/projection.jl index 8eba26353..811802536 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -283,7 +283,11 @@ end (project::ProjectTo{<:Tangent{<:Ref}})(dx::Tangent) = project(Ref(first(backing(dx)))) function (project::ProjectTo{<:Tangent{<:Ref}})(dx::Ref) dy = project.x(dx[]) - return project_type(project)(; x=dy) + if dy isa AbstractZero + return NoTangent() + else + return project_type(project)(; x=dy) + end end # Since this works like a zero-array in broadcasting, it should also accept a number: (project::ProjectTo{<:Tangent{<:Ref}})(dx::Number) = project(Ref(dx)) @@ -321,7 +325,11 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple) end # Here map will fail if the lengths don't match, but gives a much less helpful error: dy = map((f, x) -> f(x), project.elements, dx) - return project_type(project)(dy...) + if all(d -> d isa AbstractZero, dy) + return NoTangent() + else + return project_type(project)(dy...) + end end function (project::ProjectTo{<:Tangent{<:NamedTuple}})(dx::NamedTuple) dy = _project_namedtuple(backing(project), dx) @@ -370,7 +378,11 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray) end dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) - return project_type(project)(dz...) + if all(d -> d isa AbstractZero, dy) + return NoTangent() + else + return project_type(project)(dz...) + end end diff --git a/test/projection.jl b/test/projection.jl index 3e70772ac..d6a8633bb 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -143,16 +143,19 @@ struct NoSuperType end @test ProjectTo(Ref(true)) isa ProjectTo{NoTangent} @test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent} + + @test ProjectTo(Ref(1.0))(Ref(NoTangent())) === NoTangent() # collapse all-zero end @testset "Base: Tuple" begin pt1 = ProjectTo((1.0,)) - @test pt1((1 + im,)) == Tangent{Tuple{Float64}}(1.0,) - @test pt1(pt1((1,))) == pt1(pt1((1,))) # accepts correct Tangent - @test pt1(Tangent{Any}(1)) == pt1((1,)) # accepts Tangent{Any} + @test @inferred(pt1((1 + im,))) == Tangent{Tuple{Float64}}(1.0,) + @test @inferred(pt1(pt1((1,)))) == pt1(pt1((1,))) # accepts correct Tangent + @test @inferred(pt1(Tangent{Any}(1))) == pt1((1,)) # accepts Tangent{Any} @test pt1([1,]) == Tangent{Tuple{Float64}}(1.0,) # accepts Vector - @test pt1(NoTangent()) === NoTangent() - @test pt1(ZeroTangent()) === ZeroTangent() + @test @inferred(pt1(NoTangent())) === NoTangent() + @test @inferred(pt1(ZeroTangent())) === ZeroTangent() + @test @inferred(pt1((NoTangent(),))) === NoTangent() # collapse all-zero @test_throws Exception pt1([1, 2]) # DimensionMismatch, wrong length @test_throws Exception pt1([])