Skip to content

Commit

Permalink
debug further
Browse files Browse the repository at this point in the history
  • Loading branch information
simeonschaub committed Dec 21, 2020
1 parent 5df14d2 commit a205b17
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions test/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ function gradcheck(f, xs...; rtol = 1e-6, atol = 1e-6)
#return all(isapprox.(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol))
for (grad_zygote, grad_finite_difference) in zip(grad_zygote, grad_finite_difference)
@test isapprox(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol)
if !isapprox(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol)
display(grad_zygote - grad_finite_difference)
println()
end
end
end

Expand All @@ -23,11 +27,11 @@ gradtest((x, W, b) -> elu.(W*x .+ b, 2), 5, (2,5), 2)
gradtest((x, W, b) -> elu.(W*x .+ b, 2), (5,3), (2,5), 2)

# tests for https://github.com/FluxML/Zygote.jl/issues/758
gradient(xs -> sum(selu.(xs)), [1_000, 10_000]) == ([1.0507009873554805, 1.0507009873554805],)
gradient(x -> selu(x), 1_000) == (1.0507009873554805,)
gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],)
gradient(x -> elu(x, 2), 1_000) == (1.,)
gradient(x -> elu(x, 2), -1) == (2*exp(-1),)
@test gradient(xs -> sum(selu.(xs)), [1_000, 10_000])[1] [1.0507009873554805, 1.0507009873554805] rtol=1e-8
@test gradient(x -> selu(x), 1_000) == (1.0507009873554805,)
@test gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],)
@test gradient(x -> elu(x, 2), 1_000) == (1.,)
@test gradient(x -> elu(x, 2), -1) == (2*exp(-1),)
gradcheck(x->sum(selu.(x)),[100., 1_000.])
gradcheck(x->sum(elu.(x, 3.5)),[100., 1_000.])
gradcheck(x->sum(elu.(x, 3.5)),[1_000., 10_000.]) # for elu the tests are passing but for selu not, interesting
Expand Down

0 comments on commit a205b17

Please sign in to comment.