-
-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding GRUv3 support. #1675
Adding GRUv3 support. #1675
Conversation
I added tests to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! Looking forward to it!
I've left a couple of starter comments. Could you please also add CUDA tests?
state0::S | ||
end | ||
|
||
GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs an activation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm following the exact constructor for the GRU currently in Flux line. If we want to add activations here it would make sense to add them for the original GRU and LSTMs for consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK the activations in LSTM/GRU are very specifically chosen. That's why they are currently not options, and we should probably keep that consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth pointing out that TF and JAX-based libraries do allow you to customize the activation. I presume PyTorch doesn't because it lacks a non-CuDNN path for it's GPU RNN backend. That said, this would be better as a separate PR that changes every RNN layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, is it just the output activation or all of them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of them I believe: tensorflow GRU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of them, IIUC. There's a distinction between "activation" functions (by default tanh
) and "gate" functions (by default sigmoid
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For a non-Google implementation, here's MXNet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prior art in Flux: #964
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I agree that we should have the activations in Flux generally across all layers.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) | ||
for a good overview of the internals. | ||
""" | ||
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do the versions differ api wise? Does this need any extra terms?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't. The main difference is the Wh
matrix needs to be split so we can apply the reset vector appropriately. But this doesn't require any extra parameters for the constructor.
there seems to be a lot of code duplication, probably adding a symbol parameter to the type, |
Right. That is definitely the question I had. The conversation was derailed a bit in #1671, as we started discussing CuDNN support. If we were to do the parametric option version we would have to figure out how to do the operation in the line I tagged you in with the current struct layout (it is this line). I'm not sure how that would work tbh without extra unnecessary operations or adding a new parametric type for Wh. |
gx, gh = m.Wi*x, m.Wh*h | ||
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1)) | ||
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2)) | ||
h̃ = tanh.(gate(gx, o, 3) .+ (m.Wh_h̃ * (r .* h)) .+ gate(b, o, 3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CarloLucibello This line.
@DhairyaLGandhi cuda tests added. |
It might make sense to have a function that is a no-op for the regular GRU layer (without changing the struct) and splits the array for v3. |
I remember having issues with views + zygote awhile ago. Is that still the case? (I've not tried since v0.10.x). If we can't use views the split function would create a larger memory footprint, no? I think @darsnack had some thoughts on this in terms of clarity of API. |
Should be the same memory wise |
I think I misunderstood what you suggested. This would be in the constructor, right? Then I agree memory would be the same. We would have to add an extra parametric type to separate Wi and Wh (as they could be different depending on the mode). So the struct would be: struct GRUCell{M, Ai, Ah, V, S}
Wi::Ai
Wh::Ah
b::V
state0::S
end Where M is the mode. My instinct is to put this first to make dispatch very clear, but we could put it elsewhere. |
Repeating what I mentioned from the other thread: We can still reduce code duplication. As an example: function _gru_output(Wi, Wh, x, h)
gx, gh = Wi*x, Wh*h
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
return r, z
end Then use |
Exploring the other extreme for a second, I wonder if we could make |
Maybe I'm misunderstanding. That line of the forward for the GRUv3Cell would then look like h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ (m.Wh_h̃ * (r .* h)) .+ gate(b, o, 3)) How would we turn off |
If I understand it correctly, |
Sorry, that was my mistake! Missed that the v1 also had a middle term and assumed it was another case of https://julialang.zulipchat.com/#narrow/stream/238249-machine-learning/topic/Elman.20RNN.20definition. |
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3), | ||
init(out, out), init_state(out,1)) | ||
|
||
function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to remove the types on the input and parameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe these were introduced as part of #1521. We should tackle them separately for all recurrent cells.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should have a central issue which details all the updates to recurrent cells the discussion in this PR and related issue has mentioned?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm definitely invested in the recurrent architectures for flux, so would like to help. But knowing all the outstanding issues is out of scope for what I can use my time for right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, let's get this through and litigate general changes to the RNN interface in a separate issue.
What needs to be done to push this PR through? |
Updating docs Co-authored-by: Kyle Daruwalla <[email protected]>
Updating docs Co-authored-by: Kyle Daruwalla <[email protected]>
Updating docs Co-authored-by: Kyle Daruwalla <[email protected]>
Updating docs Co-authored-by: Kyle Daruwalla <[email protected]>
Running CI to see what it thinks. Otherwise I think the only other change would be to squash some of those intermediate update commits. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add an entry to NEWS.md?
@darsnack Where should I put the announcement in NEWS.md? |
As a new bullet under 0.12.7 (we're missing a 0.12.6 new entry for some reason). |
Also, how would I squash those commits? I've not had to do that before. |
https://stackoverflow.com/questions/35703556/what-does-it-mean-to-squash-commits-in-git It's just 3 extra commits than necessary, so I don't think you need to bother. In the future, you can go to the "Files" tab of the PR to add review suggestions to a single commit batch (instead of accepting each suggestion as a separate commit). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just gotta wait for CI to pass.
bors r+ |
Build succeeded: |
As per the starting discussion in #1671, we should provide support for variations on the GRU and LSTM cell.
In this PR, I added support for the GRU found in v3 of the original GRU paper. Current support in Flux is for v1 only. Tensorflow supports several variations, with this as one of the variations.
While the feature is added and usable in this PR, this is only a first pass at a design and could use further iterations. Some questions I have:
PR Checklist