Skip to content

Commit

Permalink
Collapse zeros for Tuple & Ref tangents (JuliaDiff#565)
Browse files Browse the repository at this point in the history
* collapse zeros for Tuple & Ref tangents

* v1.15.4
  • Loading branch information
mcabbott authored Sep 5, 2022
1 parent 4ce6418 commit 07dfe60
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.15.3"
version = "1.15.4"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
18 changes: 15 additions & 3 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,11 @@ end
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Tangent) = project(Ref(first(backing(dx))))
function (project::ProjectTo{<:Tangent{<:Ref}})(dx::Ref)
dy = project.x(dx[])
return project_type(project)(; x=dy)
if dy isa AbstractZero
return NoTangent()
else
return project_type(project)(; x=dy)
end
end
# Since this works like a zero-array in broadcasting, it should also accept a number:
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Number) = project(Ref(dx))
Expand Down Expand Up @@ -321,7 +325,11 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple)
end
# Here map will fail if the lengths don't match, but gives a much less helpful error:
dy = map((f, x) -> f(x), project.elements, dx)
return project_type(project)(dy...)
if all(d -> d isa AbstractZero, dy)
return NoTangent()
else
return project_type(project)(dy...)
end
end
function (project::ProjectTo{<:Tangent{<:NamedTuple}})(dx::NamedTuple)
dy = _project_namedtuple(backing(project), dx)
Expand Down Expand Up @@ -370,7 +378,11 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray)
end
dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray
dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements))
return project_type(project)(dz...)
if all(d -> d isa AbstractZero, dy)
return NoTangent()
else
return project_type(project)(dz...)
end
end


Expand Down
13 changes: 8 additions & 5 deletions test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,19 @@ struct NoSuperType end

@test ProjectTo(Ref(true)) isa ProjectTo{NoTangent}
@test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent}

@test ProjectTo(Ref(1.0))(Ref(NoTangent())) === NoTangent() # collapse all-zero
end

@testset "Base: Tuple" begin
pt1 = ProjectTo((1.0,))
@test pt1((1 + im,)) == Tangent{Tuple{Float64}}(1.0,)
@test pt1(pt1((1,))) == pt1(pt1((1,))) # accepts correct Tangent
@test pt1(Tangent{Any}(1)) == pt1((1,)) # accepts Tangent{Any}
@test @inferred(pt1((1 + im,))) == Tangent{Tuple{Float64}}(1.0,)
@test @inferred(pt1(pt1((1,)))) == pt1(pt1((1,))) # accepts correct Tangent
@test @inferred(pt1(Tangent{Any}(1))) == pt1((1,)) # accepts Tangent{Any}
@test pt1([1,]) == Tangent{Tuple{Float64}}(1.0,) # accepts Vector
@test pt1(NoTangent()) === NoTangent()
@test pt1(ZeroTangent()) === ZeroTangent()
@test @inferred(pt1(NoTangent())) === NoTangent()
@test @inferred(pt1(ZeroTangent())) === ZeroTangent()
@test @inferred(pt1((NoTangent(),))) === NoTangent() # collapse all-zero

@test_throws Exception pt1([1, 2]) # DimensionMismatch, wrong length
@test_throws Exception pt1([])
Expand Down

0 comments on commit 07dfe60

Please sign in to comment.