Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make RNN layers accept in => out #1886

Merged
merged 2 commits into from
Feb 21, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions docs/src/models/recurrence.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,19 @@ The `Recur` wrapper stores the state between runs in the `m.state` field.
If we use the `RNN(2, 5)` constructor – as opposed to `RNNCell` – you'll see that it's simply a wrapped cell.

```julia
julia> RNN(2, 5)
Recur(RNNCell(2, 5, tanh))
julia> RNN(2, 5) # or equivalently RNN(2 => 5)
Recur(
RNNCell(2 => 5, tanh), # 45 parameters
) # Total: 4 trainable arrays, 45 parameters,
# plus 1 non-trainable, 5 parameters, summarysize 412 bytes.
```

Equivalent to the `RNN` stateful constructor, `LSTM` and `GRU` are also available.

Using these tools, we can now build the model shown in the above diagram with:

```julia
m = Chain(RNN(2, 5), Dense(5 => 1))
m = Chain(RNN(2 => 5), Dense(5 => 1))
```
In this example, each output has only one component.

Expand Down
5 changes: 5 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,8 @@ Bilinear(in1::Integer, in2::Integer, out::Integer, σ = identity; kw...) =
Bilinear((in1, in2) => out, σ; kw...)
Embedding(in::Integer, out::Integer; kw...) = Embedding(in => out; kw...)

RNNCell(in::Integer, out::Integer, σ = tanh; kw...) = RNNCell(in => out, σ; kw...)
LSTMCell(in::Integer, out::Integer; kw...) = LSTMCell(in => out; kw...)

GRUCell(in::Integer, out::Integer; kw...) = GRUCell(in => out; kw...)
GRUv3Cell(in::Integer, out::Integer; kw...) = GRUv3Cell(in => out; kw...)
52 changes: 26 additions & 26 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ struct RNNCell{F,A,V,S}
state0::S
end

RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
RNNCell((in, out)::Pair, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))

function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T}
Expand All @@ -113,26 +113,26 @@ end
@functor RNNCell

function Base.show(io::IO, l::RNNCell)
print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1))
print(io, "RNNCell(", size(l.Wi, 2), " => ", size(l.Wi, 1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end

"""
RNN(in::Integer, out::Integer, σ = tanh)
RNN(in => out, σ = tanh)

The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
output fed back into the input each time step.

The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.

This constructor is syntactic sugar for `Recur(RNNCell(a...))`, and so RNNs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.

# Examples
```jldoctest
julia> r = RNN(3, 5)
julia> r = RNN(3 => 5)
Recur(
RNNCell(3, 5, tanh), # 50 parameters
RNNCell(3 => 5, tanh), # 50 parameters
) # Total: 4 trainable arrays, 50 parameters,
# plus 1 non-trainable, 5 parameters, summarysize 432 bytes.

Expand All @@ -150,9 +150,9 @@ julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the following example:

```julia
julia> r = RNN(3, 5)
julia> r = RNN(3 => 5)
Recur(
RNNCell(3, 5, tanh), # 50 parameters
RNNCell(3 => 5, tanh), # 50 parameters
) # Total: 4 trainable arrays, 50 parameters,
# plus 1 non-trainable, 5 parameters, summarysize 432 bytes.

Expand Down Expand Up @@ -187,7 +187,7 @@ struct LSTMCell{A,V,S}
state0::S
end

function LSTMCell(in::Integer, out::Integer;
function LSTMCell((in, out)::Pair{<:Integer, <:Integer};
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
init = glorot_uniform,
initb = zeros32,
init_state = zeros32)
Expand All @@ -208,15 +208,15 @@ end
@functor LSTMCell

Base.show(io::IO, l::LSTMCell) =
print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")")
print(io, "LSTMCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷4, ")")

"""
LSTM(in::Integer, out::Integer)
LSTM(in => out)

[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences.

The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.

This constructor is syntactic sugar for `Recur(LSTMCell(a...))`, and so LSTMs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.

Expand All @@ -225,9 +225,9 @@ for a good overview of the internals.

# Examples
```jldoctest
julia> l = LSTM(3, 5)
julia> l = LSTM(3 => 5)
Recur(
LSTMCell(3, 5), # 190 parameters
LSTMCell(3 => 5), # 190 parameters
) # Total: 5 trainable arrays, 190 parameters,
# plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB.

Expand Down Expand Up @@ -261,7 +261,7 @@ struct GRUCell{A,V,S}
state0::S
end

GRUCell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
GRUCell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))

function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
Expand All @@ -276,16 +276,16 @@ end
@functor GRUCell

Base.show(io::IO, l::GRUCell) =
print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
print(io, "GRUCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")")

"""
GRU(in::Integer, out::Integer)
GRU(in => out)

[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an
RNN but generally exhibits a longer memory span over sequences. This implements
the variant proposed in v1 of the referenced paper.

The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
The integer arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.

This constructor is syntactic sugar for `Recur(GRUCell(a...))`, and so GRUs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.

Expand All @@ -294,9 +294,9 @@ for a good overview of the internals.

# Examples
```jldoctest
julia> g = GRU(3, 5)
julia> g = GRU(3 => 5)
Recur(
GRUCell(3, 5), # 140 parameters
GRUCell(3 => 5), # 140 parameters
) # Total: 4 trainable arrays, 140 parameters,
# plus 1 non-trainable, 5 parameters, summarysize 792 bytes.

Expand Down Expand Up @@ -325,7 +325,7 @@ struct GRUv3Cell{A,V,S}
state0::S
end

GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
GRUv3Cell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
init(out, out), init_state(out,1))

Expand All @@ -341,16 +341,16 @@ end
@functor GRUv3Cell

Base.show(io::IO, l::GRUv3Cell) =
print(io, "GRUv3Cell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")
print(io, "GRUv3Cell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")")

"""
GRUv3(in::Integer, out::Integer)
GRUv3(in => out)

[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an
RNN but generally exhibits a longer memory span over sequences. This implements
the variant proposed in v3 of the referenced paper.

The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.
The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`.

This constructor is syntactic sugar for `Recur(GRUv3Cell(a...))`, and so GRUv3s are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below.

Expand All @@ -359,9 +359,9 @@ for a good overview of the internals.

# Examples
```jldoctest
julia> g = GRUv3(3, 5)
julia> g = GRUv3(3 => 5)
Recur(
GRUv3Cell(3, 5), # 140 parameters
GRUv3Cell(3 => 5), # 140 parameters
) # Total: 5 trainable arrays, 140 parameters,
# plus 1 non-trainable, 5 parameters, summarysize 848 bytes.

Expand Down
14 changes: 7 additions & 7 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
@testset "BPTT-1D" begin
seq = [rand(Float32, 2) for i = 1:3]
for r ∈ [RNN,]
rnn = r(2, 3)
rnn = r(2 => 3)
Flux.reset!(rnn)
grads_seq = gradient(Flux.params(rnn)) do
sum(rnn.(seq)[3])
Expand All @@ -24,7 +24,7 @@ end
@testset "BPTT-2D" begin
seq = [rand(Float32, (2, 1)) for i = 1:3]
for r ∈ [RNN,]
rnn = r(2, 3)
rnn = r(2 => 3)
Flux.reset!(rnn)
grads_seq = gradient(Flux.params(rnn)) do
sum(rnn.(seq)[3])
Expand All @@ -44,7 +44,7 @@ end

@testset "BPTT-3D" begin
seq = rand(Float32, (2, 1, 3))
rnn = RNN(2, 3)
rnn = RNN(2 => 3)
Flux.reset!(rnn)
grads_seq = gradient(Flux.params(rnn)) do
sum(rnn(seq)[:, :, 3])
Expand All @@ -70,9 +70,9 @@ end

@testset "RNN-shapes" begin
@testset for R in [RNN, GRU, LSTM, GRUv3]
m1 = R(3, 5)
m2 = R(3, 5)
m3 = R(3, 5)
m1 = R(3 => 5)
m2 = R(3 => 5)
m3 = R(3, 5) # leave one to test the silently deprecated "," not "=>" notation
x1 = rand(Float32, 3)
x2 = rand(Float32, 3, 1)
x3 = rand(Float32, 3, 1, 2)
Expand All @@ -90,7 +90,7 @@ end

@testset "RNN-input-state-eltypes" begin
@testset for R in [RNN, GRU, LSTM, GRUv3]
m = R(3, 5)
m = R(3 => 5)
x = rand(Float64, 3, 1)
Flux.reset!(m)
@test_throws MethodError m(x)
Expand Down