Skip to content

Commit

Permalink
fix: init hidden state for reactant
Browse files Browse the repository at this point in the history
[skip tests] [skip docs] [skip ci]
  • Loading branch information
avik-pal committed Nov 8, 2024
1 parent ed0d75c commit 33a4268
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ using Static: Static, StaticBool, StaticInteger, StaticSymbol
using StaticArraysCore: SMatrix, SVector

using LuxCore: LuxCore, AbstractLuxLayer
using MLDataDevices: get_device
using NNlib: NNlib

const CRC = ChainRulesCore
Expand Down Expand Up @@ -162,11 +161,13 @@ add!!(x::Number, y::Number) = x + y
add!!(::Nothing, ::Nothing) = nothing

function init_rnn_hidden_state(rng::AbstractRNG, rnn, x::AbstractMatrix)
# TODO: Once we support moving `rng` to the device, we can directly initialize on the
# device
return rnn.init_state(rng, rnn.out_dims, Base.size(x, 2)) |> get_device(x)
y = similar(x, rnn.out_dims, Base.size(x, 2))
copyto!(y, rnn.init_state(rng, size(y)...))
return y
end

@non_differentiable init_rnn_hidden_state(::Any...)

function init_trainable_rnn_hidden_state(hidden_state::AbstractVector, x::AbstractMatrix)
return repeat(hidden_state, 1, Base.size(x, 2))
end
Expand Down

0 comments on commit 33a4268

Please sign in to comment.