Skip to content

Commit

Permalink
Merge #955
Browse files Browse the repository at this point in the history
955: fix adjoint for sum r=DhairyaLGandhi a=simeonschaub

addresses the second part of #897

Co-authored-by: Simeon Schaub <[email protected]>
  • Loading branch information
bors[bot] and simeonschaub authored Apr 30, 2021
2 parents 11d3c2d + 148ebeb commit 0f7e606
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
11 changes: 3 additions & 8 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,9 @@ end
sum(xs, dims = dims), Δ -> (nothing,)
end

_normalize_kws(kws::NamedTuple) = kws
_normalize_kws(kws) = NamedTuple()

function _pullback(cx::AContext, kwtype, kws, ::typeof(sum), f, xs::AbstractArray)
norm_kws = _normalize_kws(kws)
@assert !haskey(norm_kws, :init) # TODO add init support (julia 1.6)
y, back = pullback(cx, (f, xs) -> sum(f.(xs); norm_kws...), f, xs)
y, ȳ -> (nothing, nothing, nothing, back(ȳ)...)
@adjoint function sum(f, xs::AbstractArray; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end

@adjoint function sum(::typeof(abs2), X::AbstractArray; dims = :)
Expand Down
4 changes: 4 additions & 0 deletions test/lib/array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
using LinearAlgebra

# issue 897
@test gradient(x -> sum(sin, Diagonal(x)), ones(2)) == ([0.5403023058681398, 0.5403023058681398],)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ end
@testset "lib" begin
include("lib/number.jl")
include("lib/lib.jl")
include("lib/array.jl")
end

@testset "Features" begin
Expand Down

0 comments on commit 0f7e606

Please sign in to comment.