Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding GRUv3 support. #1675

Merged
merged 10 commits into from
Aug 2, 2021
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Flux Release Notes

## v0.12.7
* Added support for [`GRUv3`](https://github.com/FluxML/Flux.jl/pull/1675)

## v0.12.5
* Added option to configure [`groups`](https://github.com/FluxML/Flux.jl/pull/1531) in `Conv`.

Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient

export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
RNN, LSTM, GRU,
RNN, LSTM, GRU, GRUv3,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
Expand Down
61 changes: 56 additions & 5 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,15 @@ end

# GRU

function _gru_output(Wi, Wh, b, x, h)
mkschleg marked this conversation as resolved.
Show resolved Hide resolved
o = size(h, 1)
gx, gh = Wi*x, Wh*h
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))

return gx, gh, r, z
end

struct GRUCell{A,V,S}
Wi::A
Wh::A
Expand All @@ -195,9 +204,7 @@ GRUCell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =

function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
b, o = m.b, size(h, 1)
gx, gh = m.Wi*x, m.Wh*h
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h)
h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
h′ = (1 .- z) .* h̃ .+ z .* h
sz = size(x)
Expand All @@ -212,8 +219,9 @@ Base.show(io::IO, l::GRUCell) =
"""
GRU(in::Integer, out::Integer)

[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078) layer. Behaves like an
RNN but generally exhibits a longer memory span over sequences.
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an
RNN but generally exhibits a longer memory span over sequences. This implements
the variant proposed in v1 of the referenced paper.

See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
Expand All @@ -233,6 +241,49 @@ function Base.getproperty(m::GRUCell, sym::Symbol)
end
end


# GRU v3

struct GRUv3Cell{A,V,S}
Wi::A
Wh::A
b::V
Wh_h̃::A
state0::S
end

GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs an activation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm following the exact constructor for the GRU currently in Flux line. If we want to add activations here it would make sense to add them for the original GRU and LSTMs for consistency.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK the activations in LSTM/GRU are very specifically chosen. That's why they are currently not options, and we should probably keep that consistent.

Copy link
Member

@ToucheSir ToucheSir Jul 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth pointing out that TF and JAX-based libraries do allow you to customize the activation. I presume PyTorch doesn't because it lacks a non-CuDNN path for it's GPU RNN backend. That said, this would be better as a separate PR that changes every RNN layer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, is it just the output activation or all of them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of them I believe: tensorflow GRU

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of them, IIUC. There's a distinction between "activation" functions (by default tanh) and "gate" functions (by default sigmoid).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a non-Google implementation, here's MXNet.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prior art in Flux: #964

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree that we should have the activations in Flux generally across all layers.

GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
init(out, out), init_state(out,1))

function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to remove the types on the input and parameters?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe these were introduced as part of #1521. We should tackle them separately for all recurrent cells.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should have a central issue which details all the updates to recurrent cells the discussion in this PR and related issue has mentioned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm definitely invested in the recurrent architectures for flux, so would like to help. But knowing all the outstanding issues is out of scope for what I can use my time for right now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's get this through and litigate general changes to the RNN interface in a separate issue.

b, o = m.b, size(h, 1)
gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h)
h̃ = tanh.(gate(gx, o, 3) .+ (m.Wh_h̃ * (r .* h)) .+ gate(b, o, 3))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CarloLucibello This line.

h′ = (1 .- z) .* h̃ .+ z .* h
sz = size(x)
return h′, reshape(h′, :, sz[2:end]...)
end

@functor GRUv3Cell

Base.show(io::IO, l::GRUv3Cell) =
print(io, "GRUv3Cell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")

"""
GRUv3(in::Integer, out::Integer)

[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an
RNN but generally exhibits a longer memory span over sequences. This implements
the variant proposed in v3 of the referenced paper.

See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do the versions differ api wise? Does this need any extra terms?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't. The main difference is the Wh matrix needs to be split so we can apply the reset vector appropriately. But this doesn't require any extra parameters for the constructor.

Recur(m::GRUv3Cell) = Recur(m, m.state0)


@adjoint function Broadcast.broadcasted(f::Recur, args...)
Zygote.∇map(__context__, f, args...)
end
4 changes: 2 additions & 2 deletions test/cuda/curnn.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Flux, CUDA, Test

@testset for R in [RNN, GRU, LSTM]
@testset for R in [RNN, GRU, LSTM, GRUv3]
m = R(10, 5) |> gpu
x = gpu(rand(10))
(m̄,) = gradient(m -> sum(m(x)), m)
Expand All @@ -12,7 +12,7 @@ using Flux, CUDA, Test
end

@testset "RNN" begin
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
@testset for R in [RNN, GRU, LSTM, GRUv3], batch_size in (1, 5)
rnn = R(10, 5)
curnn = fmap(gpu, rnn)

Expand Down
6 changes: 3 additions & 3 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ end
end

@testset "RNN-shapes" begin
@testset for R in [RNN, GRU, LSTM]
@testset for R in [RNN, GRU, LSTM, GRUv3]
m1 = R(3, 5)
m2 = R(3, 5)
x1 = rand(Float32, 3)
Expand All @@ -58,10 +58,10 @@ end
end

@testset "RNN-input-state-eltypes" begin
@testset for R in [RNN, GRU, LSTM]
@testset for R in [RNN, GRU, LSTM, GRUv3]
m = R(3, 5)
x = rand(Float64, 3, 1)
Flux.reset!(m)
@test_throws MethodError m(x)
end
end
end