From d8cfa86ece9bc7345a31805cb4c4717efeb04a2b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 20 Feb 2022 16:05:54 -0500 Subject: [PATCH 1/2] let RNN layers accept in => out --- docs/src/models/recurrence.md | 9 ++++-- src/deprecations.jl | 5 ++++ src/layers/recurrent.jl | 52 +++++++++++++++++------------------ test/layers/recurrent.jl | 14 +++++----- 4 files changed, 44 insertions(+), 36 deletions(-) diff --git a/docs/src/models/recurrence.md b/docs/src/models/recurrence.md index ee296dc5d9..d92f6472c2 100644 --- a/docs/src/models/recurrence.md +++ b/docs/src/models/recurrence.md @@ -65,8 +65,11 @@ The `Recur` wrapper stores the state between runs in the `m.state` field. If we use the `RNN(2, 5)` constructor – as opposed to `RNNCell` – you'll see that it's simply a wrapped cell. ```julia -julia> RNN(2, 5) -Recur(RNNCell(2, 5, tanh)) +julia> RNN(2, 5) # or equivalently RNN(2 => 5) +Recur( + RNNCell(2 => 5, tanh), # 45 parameters +) # Total: 4 trainable arrays, 45 parameters, + # plus 1 non-trainable, 5 parameters, summarysize 412 bytes. ``` Equivalent to the `RNN` stateful constructor, `LSTM` and `GRU` are also available. @@ -74,7 +77,7 @@ Equivalent to the `RNN` stateful constructor, `LSTM` and `GRU` are also availabl Using these tools, we can now build the model shown in the above diagram with: ```julia -m = Chain(RNN(2, 5), Dense(5 => 1)) +m = Chain(RNN(2 => 5), Dense(5 => 1)) ``` In this example, each output has only one component. diff --git a/src/deprecations.jl b/src/deprecations.jl index 479709eab7..e258f41897 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -27,3 +27,8 @@ Bilinear(in1::Integer, in2::Integer, out::Integer, σ = identity; kw...) = Bilinear((in1, in2) => out, σ; kw...) Embedding(in::Integer, out::Integer; kw...) = Embedding(in => out; kw...) +RNNCell(in::Integer, out::Integer, σ = tanh; kw...) = RNNCell(in => out, σ; kw...) +LSTMCell(in::Integer, out::Integer; kw...) = LSTMCell(in => out; kw...) + +GRUCell(in::Integer, out::Integer; kw...) = GRUCell(in => out; kw...) +GRUv3Cell(in::Integer, out::Integer; kw...) = GRUv3Cell(in => out; kw...) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 6e10a75cef..9929c65a32 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -100,7 +100,7 @@ struct RNNCell{F,A,V,S} state0::S end -RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) = +RNNCell((in, out)::Pair, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) = RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1)) function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T} @@ -113,26 +113,26 @@ end @functor RNNCell function Base.show(io::IO, l::RNNCell) - print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)) + print(io, "RNNCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)) l.σ == identity || print(io, ", ", l.σ) print(io, ")") end """ - RNN(in::Integer, out::Integer, σ = tanh) + RNN(in => out, σ = tanh) The most basic recurrent layer; essentially acts as a `Dense` layer, but with the output fed back into the input each time step. -The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. +The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. This constructor is syntactic sugar for `Recur(RNNCell(a...))`, and so RNNs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below. # Examples ```jldoctest -julia> r = RNN(3, 5) +julia> r = RNN(3 => 5) Recur( - RNNCell(3, 5, tanh), # 50 parameters + RNNCell(3 => 5, tanh), # 50 parameters ) # Total: 4 trainable arrays, 50 parameters, # plus 1 non-trainable, 5 parameters, summarysize 432 bytes. @@ -150,9 +150,9 @@ julia> r(rand(Float32, 3, 10)) |> size # batch size of 10 Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. See the following example: ```julia - julia> r = RNN(3, 5) + julia> r = RNN(3 => 5) Recur( - RNNCell(3, 5, tanh), # 50 parameters + RNNCell(3 => 5, tanh), # 50 parameters ) # Total: 4 trainable arrays, 50 parameters, # plus 1 non-trainable, 5 parameters, summarysize 432 bytes. @@ -187,7 +187,7 @@ struct LSTMCell{A,V,S} state0::S end -function LSTMCell(in::Integer, out::Integer; +function LSTMCell((in, out)::Pair{<:Integer, <:Integer}; init = glorot_uniform, initb = zeros32, init_state = zeros32) @@ -208,15 +208,15 @@ end @functor LSTMCell Base.show(io::IO, l::LSTMCell) = - print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")") + print(io, "LSTMCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷4, ")") """ - LSTM(in::Integer, out::Integer) + LSTM(in => out) [Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory) recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. -The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. +The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. This constructor is syntactic sugar for `Recur(LSTMCell(a...))`, and so LSTMs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below. @@ -225,9 +225,9 @@ for a good overview of the internals. # Examples ```jldoctest -julia> l = LSTM(3, 5) +julia> l = LSTM(3 => 5) Recur( - LSTMCell(3, 5), # 190 parameters + LSTMCell(3 => 5), # 190 parameters ) # Total: 5 trainable arrays, 190 parameters, # plus 2 non-trainable, 10 parameters, summarysize 1.062 KiB. @@ -261,7 +261,7 @@ struct GRUCell{A,V,S} state0::S end -GRUCell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) = +GRUCell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) = GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1)) function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T} @@ -276,16 +276,16 @@ end @functor GRUCell Base.show(io::IO, l::GRUCell) = - print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")") + print(io, "GRUCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")") """ - GRU(in::Integer, out::Integer) + GRU(in => out) [Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. This implements the variant proposed in v1 of the referenced paper. -The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. +The integer arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. This constructor is syntactic sugar for `Recur(GRUCell(a...))`, and so GRUs are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below. @@ -294,9 +294,9 @@ for a good overview of the internals. # Examples ```jldoctest -julia> g = GRU(3, 5) +julia> g = GRU(3 => 5) Recur( - GRUCell(3, 5), # 140 parameters + GRUCell(3 => 5), # 140 parameters ) # Total: 4 trainable arrays, 140 parameters, # plus 1 non-trainable, 5 parameters, summarysize 792 bytes. @@ -325,7 +325,7 @@ struct GRUv3Cell{A,V,S} state0::S end -GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) = +GRUv3Cell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) = GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3), init(out, out), init_state(out,1)) @@ -341,16 +341,16 @@ end @functor GRUv3Cell Base.show(io::IO, l::GRUv3Cell) = - print(io, "GRUv3Cell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")") + print(io, "GRUv3Cell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")") """ - GRUv3(in::Integer, out::Integer) + GRUv3(in => out) [Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. This implements the variant proposed in v3 of the referenced paper. -The parameters `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. +The arguments `in` and `out` describe the size of the feature vectors passed as input and as output. That is, it accepts a vector of length `in` or a batch of vectors represented as a `in x B` matrix and outputs a vector of length `out` or a batch of vectors of size `out x B`. This constructor is syntactic sugar for `Recur(GRUv3Cell(a...))`, and so GRUv3s are stateful. Note that the state shape can change depending on the inputs, and so it is good to `reset!` the model between inference calls if the batch size changes. See the examples below. @@ -359,9 +359,9 @@ for a good overview of the internals. # Examples ```jldoctest -julia> g = GRUv3(3, 5) +julia> g = GRUv3(3 => 5) Recur( - GRUv3Cell(3, 5), # 140 parameters + GRUv3Cell(3 => 5), # 140 parameters ) # Total: 5 trainable arrays, 140 parameters, # plus 1 non-trainable, 5 parameters, summarysize 848 bytes. diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index f10a2449c9..d5695aeba1 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -2,7 +2,7 @@ @testset "BPTT-1D" begin seq = [rand(Float32, 2) for i = 1:3] for r ∈ [RNN,] - rnn = r(2, 3) + rnn = r(2 => 3) Flux.reset!(rnn) grads_seq = gradient(Flux.params(rnn)) do sum(rnn.(seq)[3]) @@ -24,7 +24,7 @@ end @testset "BPTT-2D" begin seq = [rand(Float32, (2, 1)) for i = 1:3] for r ∈ [RNN,] - rnn = r(2, 3) + rnn = r(2 => 3) Flux.reset!(rnn) grads_seq = gradient(Flux.params(rnn)) do sum(rnn.(seq)[3]) @@ -44,7 +44,7 @@ end @testset "BPTT-3D" begin seq = rand(Float32, (2, 1, 3)) - rnn = RNN(2, 3) + rnn = RNN(2 => 3) Flux.reset!(rnn) grads_seq = gradient(Flux.params(rnn)) do sum(rnn(seq)[:, :, 3]) @@ -70,9 +70,9 @@ end @testset "RNN-shapes" begin @testset for R in [RNN, GRU, LSTM, GRUv3] - m1 = R(3, 5) - m2 = R(3, 5) - m3 = R(3, 5) + m1 = R(3 => 5) + m2 = R(3 => 5) + m3 = R(3, 5) # leave one to test the silently deprecated "," not "=>" notation x1 = rand(Float32, 3) x2 = rand(Float32, 3, 1) x3 = rand(Float32, 3, 1, 2) @@ -90,7 +90,7 @@ end @testset "RNN-input-state-eltypes" begin @testset for R in [RNN, GRU, LSTM, GRUv3] - m = R(3, 5) + m = R(3 => 5) x = rand(Float64, 3, 1) Flux.reset!(m) @test_throws MethodError m(x) From 25b4c63388b5aff824faab2d94d7701c0eb0e94f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 20 Feb 2022 19:36:14 -0500 Subject: [PATCH 2/2] make all signatures agree --- src/layers/recurrent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 9929c65a32..02d9b07089 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -187,7 +187,7 @@ struct LSTMCell{A,V,S} state0::S end -function LSTMCell((in, out)::Pair{<:Integer, <:Integer}; +function LSTMCell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32)