From aec71c6468b1d54bb61ab675b74533c8f7767d0e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 4 Nov 2024 18:42:54 -0500 Subject: [PATCH] fix: init hidden state for reactant [skip tests] [skip docs] [skip ci] --- src/utils.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 8de8408ca0..45617fa495 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -12,7 +12,6 @@ using Static: Static, StaticBool, StaticInteger, StaticSymbol using StaticArraysCore: SMatrix, SVector using LuxCore: LuxCore, AbstractLuxLayer -using MLDataDevices: get_device using NNlib: NNlib const CRC = ChainRulesCore @@ -162,11 +161,13 @@ add!!(x::Number, y::Number) = x + y add!!(::Nothing, ::Nothing) = nothing function init_rnn_hidden_state(rng::AbstractRNG, rnn, x::AbstractMatrix) - # TODO: Once we support moving `rng` to the device, we can directly initialize on the - # device - return rnn.init_state(rng, rnn.out_dims, Base.size(x, 2)) |> get_device(x) + y = similar(x, rnn.out_dims, Base.size(x, 2)) + copyto!(y, rnn.init_state(rng, size(y)...)) + return y end +@non_differentiable init_rnn_hidden_state(::Any...) + function init_trainable_rnn_hidden_state(hidden_state::AbstractVector, x::AbstractMatrix) return repeat(hidden_state, 1, Base.size(x, 2)) end