-
Notifications
You must be signed in to change notification settings - Fork 1
merge into NNlib and CUDA? #32
Comments
Currently, NNlib doesn't depend on ChainRules, so it may be better to move adjoints definitions to Flux directly. @adjoint function scatter_add!(ys::AbstractArray, us::AbstractArray, xs::AbstractArray)
ys_ = copy(ys)
scatter_add!(ys_, us, xs)
ys_, Δ -> (Δ, gather(Δ, xs), nothing)
end have this (assuming I understood semantics of ∇scatter_add_ys!(Δ, xs) = Δ
∇scatter_add_us!(Δ, xs) = gather(Δ, xs)
function rrule(::typeof(scatter_add!), args...; kwargs...)
ys_ = copy(ys)
scatter_add!(ys_, us, xs)
ys_, Δ -> (∇scatter_add_ys!(Δ, xs), ∇scatter_add_us!(Δ, xs), nothing)
end If this looks good to everyone, I can try it out in NNlib / CUDA / Flux during this or next weekend. |
It makes sense to me. If there is anything I can help, just let me know. |
@dfdx Some gradient function need the intermediate values from the forward pass. If we are going to split the definition, then some backward function would need extra argument to get the values instead of recalculate again. |
The only correction to the comments above is that NNlib's rules are currently being moved to NNlib itself FluxML/NNlib.jl#242, so scatter's rules should go there as well. @yuehhua it would be nice if you could file the PR to NNlib yourself so that you preserve authorship. |
True, and in general some refactoring of forward pass functions may be needed. But in ScatterNNlib all adjoint definitions follow the same, simple to split pattern. All in all, it looks much easier to have separate forward and reverse pass functions and combine them in a pullback than to have only pullback and try to extract forward & reverse passes from it. This is essentially the reason Yota.jl (and perhaps any non-pullback-based library) still doesn't use ChainRules.jl. (I hope it doesn't sound like a selfish argument :)) |
Record current status:
|
All migrations are complete! Thank you everyone. |
Amazing and relentless work, thanks @yuehhua ! |
Hi,
in FluxML/Flux.jl#1431 there was some talk about having the primitives defined here more widely available in the ecosystem. In order to do this, the Zygote and CUDA dependencies should be dropped, because they could be an unnecessary and huge payload for other packages. Therefore, we should have the following steps:
@yuehhua does this plan make sense?
cc @dfdx @jeremiedb @chengchingwen
The text was updated successfully, but these errors were encountered: