Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rrule for scatter #297

Merged
merged 1 commit into from
May 14, 2021
Merged

Add rrule for scatter #297

merged 1 commit into from
May 14, 2021

Conversation

yuehhua
Copy link
Member

@yuehhua yuehhua commented Mar 15, 2021

No description provided.

@yuehhua yuehhua force-pushed the scatter-ad branch 2 times, most recently from ba00b44 to ee99e66 Compare March 22, 2021 12:45
@yuehhua
Copy link
Member Author

yuehhua commented Mar 31, 2021

bump

@CarloLucibello
Copy link
Member

Is this ready for review? 'cause it is marked as a draft and tests are not passing, so didn't take a look at it yet

@yuehhua
Copy link
Member Author

yuehhua commented Mar 31, 2021

Sorry, there are still about 6 test fails need to be solved. It's not yet ready, but soon.

@CarloLucibello
Copy link
Member

Do we manage to add a rrule for a generic op? probably not straightforward, can think about it later

@yuehhua
Copy link
Member Author

yuehhua commented Apr 1, 2021

@CarloLucibello I try to unify all interfaces and extract the generic function. You may check it if it is what you want.

@yuehhua
Copy link
Member Author

yuehhua commented Apr 2, 2021

Sorry, I need some help here. The last 6 test cases are failed, but I cannot identify why they failed.
The gradient is right. Check the following:

julia> using NNlib

julia> using Zygote

julia> dst = [3 3 4 4 5;
                 5 5 6 6 7];

julia> src = ones(Int, 2, 3, 4);

julia> idx = [1 2 3 4;
                    4 2 1 3;
                    3 5 5 3];

julia> gradient(x -> sum(scatter!(min, dst, x, idx)), src)
([1 1 1; 1 1 1]

[1 1 1; 1 1 1]

[1 1 1; 1 1 1]

[1 1 1; 1 1 1],)

julia> gradient(x -> sum(scatter!(max, dst, x, idx)), src)
([0 0 0; 0 0 0]

[0 0 0; 0 0 0]

[0 0 0; 0 0 0]

[0 0 0; 0 0 0],)

I am thinking about if there are some details I didn't figure out in the implementation or something else.

@CarloLucibello
Copy link
Member

CarloLucibello commented Apr 7, 2021

Looking at the failed test

∂src: Test Failed at /home/runner/work/NNlib.jl/NNlib.jl/test/test_utils.jl:48
  Expression: (g_ad, g_fd, atol = atol, rtol = rtol)
   Evaluated: [-0.0 -0.4161468365471424 -0.0 -0.6536436208636119; -0.0 -0.4161468365471424 -0.9899924966004454 -0.6536436208636119; -0.0 -0.0 -0.9899924966004454 -0.6536436208636119]  
[1.0579138513722123e-15 -0.20807542784028984 1.0579138513722123e-15 -0.6536436208635197; 1.0579138513722123e-15 -0.20807542784028984 -0.9899924966005978 -0.3268202821273808; 1.0579138513722123e-15 1.0579138513722123e-15 -0.9899924966003905 -0.3268202821273808] (atol=1.0e-6, rtol=1.0e-6)
Stacktrace:
 [1] gradtest(f::var"#172#184"{typeof(max)}, xs::Matrix{Float64}; atol::Float64, rtol::Float64, fkwargs::NamedTuple{(), Tuple{}}, check_rrule::Bool, check_broadcast::Bool, skip::Bool, broken::Bool)
   @ Main ~/work/NNlib.jl/NNlib.jl/test/test_utils.jl:48
 [2] gradtest(::Function, ::Matrix{Float64}, ::Vararg{Any, N} where N)
   @ Main ~/work/NNlib.jl/NNlib.jl/test/test_utils.jl:18
 [3] macro expansion
   @ ~/work/NNlib.jl/NNlib.jl/test/scatter.jl:193 [inlined]

it seems that some of the entries of the zygote gradient are exactly twice their true values (computed by finite difference)

@yuehhua
Copy link
Member Author

yuehhua commented Apr 8, 2021

Yes, I did see that, but I just cannot figure out the reason why it cause twice.
In the implementation of scatter! on max and min, I compare the the result dst to src, if the values are the same, Δ is put on the place.
To my surprise, the same implementation just cause test fail on ∂src, but not including ∂dst.

@CarloLucibello
Copy link
Member

any updates on this?

@yuehhua
Copy link
Member Author

yuehhua commented May 5, 2021

Well...............I try many experiments, but still cannot fix the bug.
Scatter rrules for max/min still not pass the tests. Maybe let others go first?

@yuehhua yuehhua marked this pull request as ready for review May 10, 2021 05:02
@yuehhua
Copy link
Member Author

yuehhua commented May 10, 2021

AppVeyor fails anyway?

@CarloLucibello
Copy link
Member

ok, let's merge this and work on min/max later

@CarloLucibello CarloLucibello merged commit ed98dd8 into FluxML:master May 14, 2021
@yuehhua yuehhua deleted the scatter-ad branch May 15, 2021 03:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants