diff --git a/src/projection.jl b/src/projection.jl index 8eba26353..811802536 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -283,7 +283,11 @@ end (project::ProjectTo{<:Tangent{<:Ref}})(dx::Tangent) = project(Ref(first(backing(dx)))) function (project::ProjectTo{<:Tangent{<:Ref}})(dx::Ref) dy = project.x(dx[]) - return project_type(project)(; x=dy) + if dy isa AbstractZero + return NoTangent() + else + return project_type(project)(; x=dy) + end end # Since this works like a zero-array in broadcasting, it should also accept a number: (project::ProjectTo{<:Tangent{<:Ref}})(dx::Number) = project(Ref(dx)) @@ -321,7 +325,11 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple) end # Here map will fail if the lengths don't match, but gives a much less helpful error: dy = map((f, x) -> f(x), project.elements, dx) - return project_type(project)(dy...) + if all(d -> d isa AbstractZero, dy) + return NoTangent() + else + return project_type(project)(dy...) + end end function (project::ProjectTo{<:Tangent{<:NamedTuple}})(dx::NamedTuple) dy = _project_namedtuple(backing(project), dx) @@ -370,7 +378,11 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray) end dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) - return project_type(project)(dz...) + if all(d -> d isa AbstractZero, dy) + return NoTangent() + else + return project_type(project)(dz...) + end end diff --git a/test/projection.jl b/test/projection.jl index 3e70772ac..d6a8633bb 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -143,16 +143,19 @@ struct NoSuperType end @test ProjectTo(Ref(true)) isa ProjectTo{NoTangent} @test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent} + + @test ProjectTo(Ref(1.0))(Ref(NoTangent())) === NoTangent() # collapse all-zero end @testset "Base: Tuple" begin pt1 = ProjectTo((1.0,)) - @test pt1((1 + im,)) == Tangent{Tuple{Float64}}(1.0,) - @test pt1(pt1((1,))) == pt1(pt1((1,))) # accepts correct Tangent - @test pt1(Tangent{Any}(1)) == pt1((1,)) # accepts Tangent{Any} + @test @inferred(pt1((1 + im,))) == Tangent{Tuple{Float64}}(1.0,) + @test @inferred(pt1(pt1((1,)))) == pt1(pt1((1,))) # accepts correct Tangent + @test @inferred(pt1(Tangent{Any}(1))) == pt1((1,)) # accepts Tangent{Any} @test pt1([1,]) == Tangent{Tuple{Float64}}(1.0,) # accepts Vector - @test pt1(NoTangent()) === NoTangent() - @test pt1(ZeroTangent()) === ZeroTangent() + @test @inferred(pt1(NoTangent())) === NoTangent() + @test @inferred(pt1(ZeroTangent())) === ZeroTangent() + @test @inferred(pt1((NoTangent(),))) === NoTangent() # collapse all-zero @test_throws Exception pt1([1, 2]) # DimensionMismatch, wrong length @test_throws Exception pt1([])