Skip to content

Commit

Permalink
Merge #824
Browse files Browse the repository at this point in the history
824: move NNlib rules out of Zygote r=CarloLucibello a=simeonschaub

Partner to FluxML/NNlib.jl#242. I talked a bit about how to go about releasing this there, would appreciate any feedback and suggestions.

Co-authored-by: Simeon Schaub <[email protected]>
  • Loading branch information
bors[bot] and simeonschaub authored Nov 27, 2020
2 parents 7d4b0be + 3299f7a commit 423383f
Show file tree
Hide file tree
Showing 6 changed files with 2 additions and 201 deletions.
6 changes: 1 addition & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.5.13"
version = "0.6.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -15,7 +14,6 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -25,15 +23,13 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractFFTs = "0.5"
ArrayLayouts = "0.1, 0.2, 0.3, 0.4"
ChainRules = "0.7.16"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10"
ForwardDiff = "0.10"
IRTools = "0.4"
LoopVectorization = "0.8.15"
MacroTools = "0.5"
NNlib = "0.7"
NaNMath = "0.3"
Requires = "0.5, 1.0"
SpecialFunctions = "0.10, 1.0"
Expand Down
2 changes: 0 additions & 2 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module Zygote

using LinearAlgebra, Statistics
using LinearAlgebra: copytri!, AbstractTriangular
using ArrayLayouts: MemoryLayout, AbstractColumnMajor

import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty

Expand Down Expand Up @@ -35,7 +34,6 @@ include("lib/base.jl")
include("lib/array.jl")
include("lib/buffer.jl")
include("lib/broadcast.jl")
include("lib/nnlib.jl")
include("lib/forward.jl")
include("lib/utils.jl")
include("lib/range.jl")
Expand Down
6 changes: 0 additions & 6 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

using Base.Broadcast
using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize
using NNlib

# There's a saying that debugging code is about twice as hard as writing it in
# the first place. So if you're as clever as you can be when writing code, how
Expand Down Expand Up @@ -89,11 +88,6 @@ end

@adjoint broadcasted(::typeof(identity), x::Numeric) = x, Δ -> (nothing, Δ)

@adjoint function broadcasted(::typeof(σ), x::Numeric)
y = σ.(x)
y, ȳ -> (nothing, ȳ .* conj.(y .* (1 .- y)))
end

@adjoint function broadcasted(::typeof(tanh), x::Numeric)
y = tanh.(x)
y, ȳ -> (nothing, ȳ .* conj.(1 .- y.^2))
Expand Down
103 changes: 0 additions & 103 deletions src/lib/nnlib.jl

This file was deleted.

1 change: 0 additions & 1 deletion test/forward/forward.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Zygote, Test
using NNlib: relu

D(f, x) = pushforward(f, x)(1)

Expand Down
85 changes: 1 addition & 84 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using Zygote, NNlib, Test, Random, LinearAlgebra, Statistics, FillArrays,
using Zygote, Test, Random, LinearAlgebra, Statistics, FillArrays,
AbstractFFTs, FFTW, Distances
using Zygote: gradient
using NNlib: conv, ∇conv_data, depthwiseconv, batched_mul
using Base.Broadcast: broadcast_shape
using LoopVectorization: vmap
using Distributed: pmap
Expand Down Expand Up @@ -92,37 +91,10 @@ end
@test gradtest((x, W, b) -> identity.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> identity.(W*x .+ b), (5,3), (2,5), 2)

@test gradtest((x, W, b) -> relu.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> relu.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> selu.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> selu.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> elu.(W*x .+ b, 2), 5, (2,5), 2)
@test gradtest((x, W, b) -> elu.(W*x .+ b, 2), (5,3), (2,5), 2)

# tests for https://github.com/FluxML/Zygote.jl/issues/758
@test gradient(xs -> sum(selu.(xs)), [1_000, 10_000]) == ([1.0507009873554805, 1.0507009873554805],)
@test gradient(x -> selu(x), 1_000) == (1.0507009873554805,)
@test gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],)
@test gradient(x -> elu(x, 2), 1_000) == (1.,)
@test gradient(x -> elu(x, 2), -1) == (2*exp(-1),)
@test gradcheck(x->sum(selu.(x)),[100., 1_000.])
@test gradcheck(x->sum(elu.(x, 3.5)),[100., 1_000.])
@test gradcheck(x->sum(elu.(x, 3.5)),[1_000., 10_000.]) # for elu the tests are passing but for selu not, interesting
# numerical instability even for the linear part of such function, see:
# julia> ngradient(x->sum(selu.(x)),[1_000., 10_000.])
# ([1.0506591796875, 1.0506591796875],)
# julia> gradient(x->sum(selu.(x)),[1_000., 10_000.])
# ([1.0507009873554805, 1.0507009873554805],)
@test_broken gradcheck(x->sum(selu.(x)),[1_000., 10_000.])

@test gradtest((x, W, b) -> tanh.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> tanh.(W*x .+ b), (5,3), (2,5), 2)

@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)

@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
@test gradtest((w, x) -> Adjoint(w)*x, randn(10, 2), randn(10))
@test gradtest((w, x) -> transpose(w)*x, randn(5,5), randn(5,5))
Expand Down Expand Up @@ -163,13 +135,6 @@ end
@test gradtest(x -> cumsum(x, dims=3), (3,4)) # trivial
end

@test gradtest(x -> softmax(x).*(1:3), 3)
@test gradtest(x -> softmax(x).*(1:3), (3,5))
@test gradtest(x -> softmax(x, dims=2).*(1:3), (3,5))
@test gradtest(x -> logsoftmax(x).*(1:3), 3)
@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
@test gradtest(x -> logsoftmax(x, dims=2).*(1:3), (3,5))

@test gradtest(x -> x', rand(5))

@test gradtest(det, (4, 4))
Expand Down Expand Up @@ -235,49 +200,6 @@ end
@test gradient(xs -> sum(inv, [x^2 for x in xs]), ones(2)) == ([-2, -2],)
end

@testset "conv: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(repeat([5], spatial_rank)..., 3, 2)
w = rand(repeat([3], spatial_rank)..., 3, 3)
cdims = DenseConvDims(x, w)
@test gradtest((x, w) -> conv(x, w, cdims), x, w)
@test gradtest((x, w) -> sum(conv(x, w, cdims)), x, w) # https://github.com/FluxML/Flux.jl/issues/1055

y = conv(x, w, cdims)
@test gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w)
if spatial_rank == 3
@test_broken gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)
else
@test gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)
end

dcdims = DepthwiseConvDims(x, w)
@test gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w)

y = depthwiseconv(x, w, dcdims)
@test gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w)
if spatial_rank == 3
@test_broken gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
else
@test gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
end
end

@testset "pooling: spatial_rank=$spatial_rank" for spatial_rank in (1, 2)
x = rand(repeat([10], spatial_rank)..., 3, 2)
pdims = PoolDims(x, 2)
@test gradtest(x -> maxpool(x, pdims), x)
@test gradtest(x -> meanpool(x, pdims), x)
@test gradtest(x -> sum(maxpool(x, pdims)), x)
@test gradtest(x -> sum(meanpool(x, pdims)), x)

#https://github.com/FluxML/NNlib.jl/issues/188
k = ntuple(_ -> 2, spatial_rank) # Kernel size of pool in ntuple format
@test gradtest(x -> maxpool(x, k), x)
@test gradtest(x -> meanpool(x, k), x)
@test gradtest(x -> sum(maxpool(x, k)), x)
@test gradtest(x -> sum(meanpool(x, k)), x)
end

@test gradtest(x -> reverse(x), rand(17))
@test gradtest(x -> reverse(x, 8), rand(17))
@test gradtest(x -> reverse(x, 8, 13), rand(17))
Expand Down Expand Up @@ -523,11 +445,6 @@ end
@test first(back(randn(rng, M, P))) isa Vector
end
end

@testset "batched matrix multiplication" begin
B = 3
@test gradtest(batched_mul, randn(rng, M, P, B), randn(rng, P, Q, B))
end
end

@testset "backsolve" begin
Expand Down

0 comments on commit 423383f

Please sign in to comment.