Skip to content

Commit

Permalink
add and test zero_tangent
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Sep 26, 2023
1 parent 7f99ce4 commit b3f2a51
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod
export frule_via_ad, rrule_via_ad
# definition helper macros
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
export ProjectTo, canonicalize, unthunk # tangent operations
export ProjectTo, canonicalize, unthunk, zero_tangent # tangent operations
export add!!, is_inplaceable_destination # gradient accumulation operations
export ignore_derivatives, @ignore_derivatives
# tangents
Expand Down
29 changes: 29 additions & 0 deletions src/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,32 @@ arguments.
```
"""
struct NoTangent <: AbstractZero end

"""
zero_tangent(primal)
This returns an appropriate zero tangent suitable for accumulating tangents of the primal.
For mutable composites types this is a structural []`MutableTangent`](@ref)
For `Array`s, it is applied recursively for each element.
For immutable types, this is simply [`ZeroTangent()`](@ref) as accumulation is default out-of-place for contexts where mutation does not apply.
(Where mutation is not to be supported even for mutable types, then [`ZeroTangent()`](@ref) should be used for everything)
!!! warning Exprimental
`zero_tangent`is an experimental feature, and is part of the mutation support featureset.
While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore.
Exactly how it should be used (e.g. is it forward-mode only?)
"""
function zero_tangent end
zero_tangent(::AbstractString) = ZeroTangent()
# zero_tangent(::Number) = zero(x) # TODO: do we want this?
zero_tangent(primal::Array{<:Number}) = zero(primal) # TODO: do we want this?
zero_tangent(primal::Array) = map(zero_tangent, primal)
@generated function zero_tangent(primal)
has_mutable_tangent(primal) || return ZeroTangent() # note this takes care of tuples
zfield_exprs = map(fieldnames(primal)) do fname
fval = Expr(:call, zero_tangent, Expr(:call, getfield, :primal, QuoteNode(fname)))
Expr(:kw, fname, fval)
end
backing_expr = Expr(:tuple, Expr(:parameters, zfield_exprs...))
return :($MutableTangent{$primal}($backing_expr))
end
2 changes: 1 addition & 1 deletion src/tangent_types/structural_tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function StructuralTangent{P}(nt::NamedTuple) where {P}
end
end

has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(T) > 0)
has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(P) > 0)


StructuralTangent{P}(tup::Tuple) where P = Tangent{P,typeof(tup)}(tup)
Expand Down
14 changes: 14 additions & 0 deletions test/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,17 @@
@test isempty(detect_ambiguities(M))
end
end

@testset "zero_tangent" begin
mutable struct MutDemo
x::Float64
end
@test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo}
@test iszero(zero_tangent(MutDemo(1.5)))

@test zero_tangent((;a=1)) isa ZeroTangent

@test zero_tangent([1.0, 2.0]) == [0.0, 0.0]
@test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]]
end

0 comments on commit b3f2a51

Please sign in to comment.