Skip to content

Commit

Permalink
collapse zeros for Tuple & Ref tangents
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jul 16, 2022
1 parent f2e3ac5 commit 971eba9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
18 changes: 15 additions & 3 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,11 @@ end
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Tangent) = project(Ref(first(backing(dx))))
function (project::ProjectTo{<:Tangent{<:Ref}})(dx::Ref)
dy = project.x(dx[])
return project_type(project)(; x=dy)
if dy isa AbstractZero
return NoTangent()
else
return project_type(project)(; x=dy)
end
end
# Since this works like a zero-array in broadcasting, it should also accept a number:
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Number) = project(Ref(dx))
Expand Down Expand Up @@ -321,7 +325,11 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple)
end
# Here map will fail if the lengths don't match, but gives a much less helpful error:
dy = map((f, x) -> f(x), project.elements, dx)
return project_type(project)(dy...)
if all(d -> d isa AbstractZero, dy)
return NoTangent()
else
return project_type(project)(dy...)
end
end
function (project::ProjectTo{<:Tangent{<:NamedTuple}})(dx::NamedTuple)
dy = _project_namedtuple(backing(project), dx)
Expand Down Expand Up @@ -370,7 +378,11 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray)
end
dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray
dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements))
return project_type(project)(dz...)
if all(d -> d isa AbstractZero, dy)
return NoTangent()
else
return project_type(project)(dz...)
end
end


Expand Down
13 changes: 8 additions & 5 deletions test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,19 @@ struct NoSuperType end

@test ProjectTo(Ref(true)) isa ProjectTo{NoTangent}
@test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent}

@test ProjectTo(Ref(1.0))(Ref(NoTangent())) === NoTangent() # collapse all-zero
end

@testset "Base: Tuple" begin
pt1 = ProjectTo((1.0,))
@test pt1((1 + im,)) == Tangent{Tuple{Float64}}(1.0,)
@test pt1(pt1((1,))) == pt1(pt1((1,))) # accepts correct Tangent
@test pt1(Tangent{Any}(1)) == pt1((1,)) # accepts Tangent{Any}
@test @inferred(pt1((1 + im,))) == Tangent{Tuple{Float64}}(1.0,)
@test @inferred(pt1(pt1((1,)))) == pt1(pt1((1,))) # accepts correct Tangent
@test @inferred(pt1(Tangent{Any}(1))) == pt1((1,)) # accepts Tangent{Any}
@test pt1([1,]) == Tangent{Tuple{Float64}}(1.0,) # accepts Vector
@test pt1(NoTangent()) === NoTangent()
@test pt1(ZeroTangent()) === ZeroTangent()
@test @inferred(pt1(NoTangent())) === NoTangent()
@test @inferred(pt1(ZeroTangent())) === ZeroTangent()
@test @inferred(pt1((NoTangent(),))) === NoTangent() # collapse all-zero

@test_throws Exception pt1([1, 2]) # DimensionMismatch, wrong length
@test_throws Exception pt1([])
Expand Down

0 comments on commit 971eba9

Please sign in to comment.