-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rrule for scatter #297
Conversation
ba00b44
to
ee99e66
Compare
bump |
Is this ready for review? 'cause it is marked as a draft and tests are not passing, so didn't take a look at it yet |
Sorry, there are still about 6 test fails need to be solved. It's not yet ready, but soon. |
Do we manage to add a rrule for a generic |
@CarloLucibello I try to unify all interfaces and extract the generic function. You may check it if it is what you want. |
Sorry, I need some help here. The last 6 test cases are failed, but I cannot identify why they failed. julia> using NNlib
julia> using Zygote
julia> dst = [3 3 4 4 5;
5 5 6 6 7];
julia> src = ones(Int, 2, 3, 4);
julia> idx = [1 2 3 4;
4 2 1 3;
3 5 5 3];
julia> gradient(x -> sum(scatter!(min, dst, x, idx)), src)
([1 1 1; 1 1 1]
[1 1 1; 1 1 1]
[1 1 1; 1 1 1]
[1 1 1; 1 1 1],)
julia> gradient(x -> sum(scatter!(max, dst, x, idx)), src)
([0 0 0; 0 0 0]
[0 0 0; 0 0 0]
[0 0 0; 0 0 0]
[0 0 0; 0 0 0],) I am thinking about if there are some details I didn't figure out in the implementation or something else. |
Looking at the failed test ∂src: Test Failed at /home/runner/work/NNlib.jl/NNlib.jl/test/test_utils.jl:48
Expression: ≈(g_ad, g_fd, atol = atol, rtol = rtol)
Evaluated: [-0.0 -0.4161468365471424 -0.0 -0.6536436208636119; -0.0 -0.4161468365471424 -0.9899924966004454 -0.6536436208636119; -0.0 -0.0 -0.9899924966004454 -0.6536436208636119] ≈
[1.0579138513722123e-15 -0.20807542784028984 1.0579138513722123e-15 -0.6536436208635197; 1.0579138513722123e-15 -0.20807542784028984 -0.9899924966005978 -0.3268202821273808; 1.0579138513722123e-15 1.0579138513722123e-15 -0.9899924966003905 -0.3268202821273808] (atol=1.0e-6, rtol=1.0e-6)
Stacktrace:
[1] gradtest(f::var"#172#184"{typeof(max)}, xs::Matrix{Float64}; atol::Float64, rtol::Float64, fkwargs::NamedTuple{(), Tuple{}}, check_rrule::Bool, check_broadcast::Bool, skip::Bool, broken::Bool)
@ Main ~/work/NNlib.jl/NNlib.jl/test/test_utils.jl:48
[2] gradtest(::Function, ::Matrix{Float64}, ::Vararg{Any, N} where N)
@ Main ~/work/NNlib.jl/NNlib.jl/test/test_utils.jl:18
[3] macro expansion
@ ~/work/NNlib.jl/NNlib.jl/test/scatter.jl:193 [inlined] it seems that some of the entries of the zygote gradient are exactly twice their true values (computed by finite difference) |
Yes, I did see that, but I just cannot figure out the reason why it cause twice. |
any updates on this? |
Well...............I try many experiments, but still cannot fix the bug. |
AppVeyor fails anyway? |
ok, let's merge this and work on min/max later |
No description provided.