Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding GRUv3 support. #1675

Merged
merged 10 commits into from
Aug 2, 2021
Merged

Adding GRUv3 support. #1675

merged 10 commits into from
Aug 2, 2021

Conversation

mkschleg
Copy link
Contributor

@mkschleg mkschleg commented Jul 23, 2021

As per the starting discussion in #1671, we should provide support for variations on the GRU and LSTM cell.

In this PR, I added support for the GRU found in v3 of the original GRU paper. Current support in Flux is for v1 only. Tensorflow supports several variations, with this as one of the variations.

While the feature is added and usable in this PR, this is only a first pass at a design and could use further iterations. Some questions I have:

  • Should we have new types for each variation of these cells? (another possibility is through parametric options)
  • Should we have a shared constructor similar to Tensorflow/Pytorch? (it might make sense to rename the current GRU to GRUv1 if we want to do this).

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable
  • API changes require approval from a committer (different from the author, if applicable)

@mkschleg
Copy link
Contributor Author

I added tests to test/layers/recurrent.jl. I'm not sure there are more tests I should add? Also added doc strings for the new method and updated the doc string for the original version to clarify it is v1 of the arxiv paper.

Copy link
Member

@DhairyaLGandhi DhairyaLGandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! Looking forward to it!

I've left a couple of starter comments. Could you please also add CUDA tests?

state0::S
end

GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs an activation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm following the exact constructor for the GRU currently in Flux line. If we want to add activations here it would make sense to add them for the original GRU and LSTMs for consistency.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK the activations in LSTM/GRU are very specifically chosen. That's why they are currently not options, and we should probably keep that consistent.

Copy link
Member

@ToucheSir ToucheSir Jul 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth pointing out that TF and JAX-based libraries do allow you to customize the activation. I presume PyTorch doesn't because it lacks a non-CuDNN path for it's GPU RNN backend. That said, this would be better as a separate PR that changes every RNN layer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, is it just the output activation or all of them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of them I believe: tensorflow GRU

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of them, IIUC. There's a distinction between "activation" functions (by default tanh) and "gate" functions (by default sigmoid).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a non-Google implementation, here's MXNet.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prior art in Flux: #964

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree that we should have the activations in Flux generally across all layers.

See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do the versions differ api wise? Does this need any extra terms?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't. The main difference is the Wh matrix needs to be split so we can apply the reset vector appropriately. But this doesn't require any extra parameters for the constructor.

@CarloLucibello
Copy link
Member

there seems to be a lot of code duplication, probably adding a symbol parameter to the type, GRU{..., Mode}, and only dispatch the call method on that seems a better alternative

@mkschleg
Copy link
Contributor Author

mkschleg commented Jul 23, 2021

Right. That is definitely the question I had. The conversation was derailed a bit in #1671, as we started discussing CuDNN support.

If we were to do the parametric option version we would have to figure out how to do the operation in the line I tagged you in with the current struct layout (it is this line). I'm not sure how that would work tbh without extra unnecessary operations or adding a new parametric type for Wh.

gx, gh = m.Wi*x, m.Wh*h
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
h̃ = tanh.(gate(gx, o, 3) .+ (m.Wh_h̃ * (r .* h)) .+ gate(b, o, 3))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CarloLucibello This line.

@mkschleg
Copy link
Contributor Author

@DhairyaLGandhi cuda tests added.

@DhairyaLGandhi
Copy link
Member

It might make sense to have a function that is a no-op for the regular GRU layer (without changing the struct) and splits the array for v3.

@mkschleg
Copy link
Contributor Author

I remember having issues with views + zygote awhile ago. Is that still the case? (I've not tried since v0.10.x). If we can't use views the split function would create a larger memory footprint, no? I think @darsnack had some thoughts on this in terms of clarity of API.

@DhairyaLGandhi
Copy link
Member

Should be the same memory wise

@mkschleg
Copy link
Contributor Author

I think I misunderstood what you suggested. This would be in the constructor, right? Then I agree memory would be the same.

We would have to add an extra parametric type to separate Wi and Wh (as they could be different depending on the mode).

So the struct would be:

struct GRUCell{M, Ai, Ah, V, S}
  Wi::Ai
  Wh::Ah
  b::V
  state0::S
end

Where M is the mode. My instinct is to put this first to make dispatch very clear, but we could put it elsewhere.

@darsnack
Copy link
Member

Repeating what I mentioned from the other thread: GRU(...; mode = :v3)/GRU{:v3}(...) are just as verbose as GRUv3(...). At first glance, the forward passes seem quite different for the line @mkschleg highlighted. Different forward pass == different types makes sense to me here. Having conditional forward passes based on a parameter is messier IMO, and I don't there is a function argument we can pass in that runs the two alternate passes (w/o this function being complex).

We can still reduce code duplication. As an example:

function _gru_output(Wi, Wh, x, h)
   gx, gh = Wi*x, Wh*h
   r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
   z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))

   return r, z
end

Then use _gru_output in all the variants. IMO it makes sense to reduce duplication via the API when it results in a intuitive interpretation of what's going on under the hood. I don't think that's the case here.

@ToucheSir
Copy link
Member

Exploring the other extreme for a second, I wonder if we could make GRUv3Cell the default GRU cell and set Wh_h̃ to Zeros in the GRUCell constructor. In theory, this would reduce .+ (m.Wh_h̃ * (r .* h)) to a no-op and allow us to use one codepath.

@mkschleg
Copy link
Contributor Author

mkschleg commented Jul 23, 2021

Maybe I'm misunderstanding. That line of the forward for the GRUv3Cell would then look like

= tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ (m.Wh_h̃ * (r .* h)) .+ gate(b, o, 3))

How would we turn off r .* gate(gh, o, 3) for the v1 version?

@darsnack
Copy link
Member

If I understand it correctly, Wh in v1 == [Wh Wh_h̃] in v3. So always adopting the gate(Wh * h, o, 1) .+ Wh_h̃ * (r .* h)) means that Wh takes on a slightly different interpretation each time. In v1, Wh is plain Wh and Wh_h̃ == Zeros(). In v3, we would initialize Wh to be the first two slices and Wh_h̃ to be the last.

@ToucheSir
Copy link
Member

Sorry, that was my mistake! Missed that the v1 also had a middle term and assumed it was another case of https://julialang.zulipchat.com/#narrow/stream/238249-machine-learning/topic/Elman.20RNN.20definition.

GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
init(out, out), init_state(out,1))

function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to remove the types on the input and parameters?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe these were introduced as part of #1521. We should tackle them separately for all recurrent cells.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should have a central issue which details all the updates to recurrent cells the discussion in this PR and related issue has mentioned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm definitely invested in the recurrent architectures for flux, so would like to help. But knowing all the outstanding issues is out of scope for what I can use my time for right now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's get this through and litigate general changes to the RNN interface in a separate issue.

@mkschleg
Copy link
Contributor Author

What needs to be done to push this PR through?

src/layers/recurrent.jl Outdated Show resolved Hide resolved
src/layers/recurrent.jl Outdated Show resolved Hide resolved
src/layers/recurrent.jl Outdated Show resolved Hide resolved
src/layers/recurrent.jl Outdated Show resolved Hide resolved
src/layers/recurrent.jl Outdated Show resolved Hide resolved
src/layers/recurrent.jl Outdated Show resolved Hide resolved
mkschleg and others added 6 commits July 31, 2021 12:10
@ToucheSir
Copy link
Member

Running CI to see what it thinks. Otherwise I think the only other change would be to squash some of those intermediate update commits.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an entry to NEWS.md?

@mkschleg
Copy link
Contributor Author

mkschleg commented Aug 2, 2021

@darsnack Where should I put the announcement in NEWS.md?

@darsnack
Copy link
Member

darsnack commented Aug 2, 2021

As a new bullet under 0.12.7 (we're missing a 0.12.6 new entry for some reason).

@mkschleg
Copy link
Contributor Author

mkschleg commented Aug 2, 2021

Also, how would I squash those commits? I've not had to do that before.

@darsnack
Copy link
Member

darsnack commented Aug 2, 2021

https://stackoverflow.com/questions/35703556/what-does-it-mean-to-squash-commits-in-git

It's just 3 extra commits than necessary, so I don't think you need to bother. In the future, you can go to the "Files" tab of the PR to add review suggestions to a single commit batch (instead of accepting each suggestion as a separate commit).

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just gotta wait for CI to pass.

@darsnack
Copy link
Member

darsnack commented Aug 2, 2021

bors r+

@bors
Copy link
Contributor

bors bot commented Aug 2, 2021

Build succeeded:

@bors bors bot merged commit 5d2a955 into FluxML:master Aug 2, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants