From 6224ed94c0a307cad0c4b53649e33552bbe1c259 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Fri, 12 Jun 2020 10:23:47 +0200 Subject: [PATCH] support for dims in sum(f,x;dims) --- src/lib/array.jl | 11 ++++++++--- test/gradcheck.jl | 50 +++++++++++++++++++++++++---------------------- 2 files changed, 35 insertions(+), 26 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 056043323..e23b815cb 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -223,9 +223,14 @@ end end end -function _pullback(cx::AContext, ::typeof(sum), f, xs::AbstractArray) - y, back = pullback(cx, ((f, xs) -> sum(f.(xs))), f, xs) - y, ȳ -> (nothing, back(ȳ)...) +_normalize_kws(kws::NamedTuple) = kws +_normalize_kws(kws) = NamedTuple() + +function _pullback(cx::AContext, kwtype, kws, ::typeof(sum), f, xs::AbstractArray) + norm_kws = _normalize_kws(kws) + @assert !haskey(norm_kws, :init) # TODO add init support (julia 1.6) + y, back = pullback(cx, (f, xs) -> sum(f.(xs); norm_kws...), f, xs) + y, ȳ -> (nothing, nothing, nothing, back(ȳ)...) end @adjoint function sum(::typeof(abs2), X::AbstractArray; dims = :) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 03fb477e9..43eab346d 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -103,29 +103,33 @@ end @test gradtest((w, x) -> parent(w)*x, randn(5,5)', randn(5,5)) @test gradtest((w, x) -> parent(w)*x, transpose(randn(5,5)), randn(5,5)) -@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5)) -@test gradtest(x -> sum(abs2, x), randn(4, 3, 2)) -@test gradtest(x -> sum(abs2, x; dims=1), randn(4, 3, 2)) -@test gradtest(x -> sum(x[i] for i in 1:length(x)), randn(10)) -@test gradtest(x -> sum(i->x[i], 1:length(x)), randn(10)) # https://github.com/FluxML/Zygote.jl/issues/231 -@test gradtest(x -> sum((i->x[i]).(1:length(x))), randn(10)) - -# https://github.com/FluxML/Zygote.jl/issues/314 -@test gradient((x,y) -> sum(yi -> yi*x, y), 1, [1,1]) == (2, [1, 1]) -@test gradient((x,y) -> prod(yi -> yi*x, y), 1, [1,1]) == (2, [1, 1]) - -@test gradient((x,y) -> sum(map(yi -> yi*x, y)), 1, [1,1]) == (2, [1, 1]) -@test gradient((x,y) -> prod(map(yi -> yi*x, y)), 1, [1,1]) == (2, [1, 1]) - -@test gradtest(x -> prod(x, dims = (2, 3)), (3,4,5)) -@test gradtest(x -> prod(x), (3,4)) -@test gradient(x -> prod(x), (1,2,3))[1] == (6,3,2) - -@test gradtest(x -> cumsum(x, dims=2), (3,4,5)) -@test gradtest(x -> cumsum(x, dims=1), (3,)) -@test gradtest(x -> cumsum(x), (4,)) -@test gradtest(x -> cumsum(x, dims=3), (5,)) # trivial -@test gradtest(x -> cumsum(x, dims=3), (3,4)) # trivial +@testset "sum, prod, cumsum" begin + @test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5)) + @test gradtest(x -> sum(abs2, x), randn(4, 3, 2)) + @test gradtest(x -> sum(abs2, x; dims=1), randn(4, 3, 2)) + @test gradtest(x -> sum(x[i] for i in 1:length(x)), randn(10)) + @test gradtest(x -> sum(i->x[i], 1:length(x)), randn(10)) # issue #231 + @test gradtest(x -> sum((i->x[i]).(1:length(x))), randn(10)) + @test gradtest(X -> sum(x -> x^2, X), randn(10)) + @test gradtest(X -> sum(sum(x -> x^2, X; dims=1)), randn(10)) # issue #681 + + # https://github.com/FluxML/Zygote.jl/issues/314 + @test gradient((x,y) -> sum(yi -> yi*x, y), 1, [1,1]) == (2, [1, 1]) + @test gradient((x,y) -> prod(yi -> yi*x, y), 1, [1,1]) == (2, [1, 1]) + + @test gradient((x,y) -> sum(map(yi -> yi*x, y)), 1, [1,1]) == (2, [1, 1]) + @test gradient((x,y) -> prod(map(yi -> yi*x, y)), 1, [1,1]) == (2, [1, 1]) + + @test gradtest(x -> prod(x, dims = (2, 3)), (3,4,5)) + @test gradtest(x -> prod(x), (3,4)) + @test gradient(x -> prod(x), (1,2,3))[1] == (6,3,2) + + @test gradtest(x -> cumsum(x, dims=2), (3,4,5)) + @test gradtest(x -> cumsum(x, dims=1), (3,)) + @test gradtest(x -> cumsum(x), (4,)) + @test gradtest(x -> cumsum(x, dims=3), (5,)) # trivial + @test gradtest(x -> cumsum(x, dims=3), (3,4)) # trivial +end @test gradtest(x -> softmax(x).*(1:3), 3) @test gradtest(x -> softmax(x).*(1:3), (3,5))