diff --git a/src/utils.jl b/src/utils.jl index 8de8408ca0..45617fa495 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -12,7 +12,6 @@ using Static: Static, StaticBool, StaticInteger, StaticSymbol using StaticArraysCore: SMatrix, SVector using LuxCore: LuxCore, AbstractLuxLayer -using MLDataDevices: get_device using NNlib: NNlib const CRC = ChainRulesCore @@ -162,11 +161,13 @@ add!!(x::Number, y::Number) = x + y add!!(::Nothing, ::Nothing) = nothing function init_rnn_hidden_state(rng::AbstractRNG, rnn, x::AbstractMatrix) - # TODO: Once we support moving `rng` to the device, we can directly initialize on the - # device - return rnn.init_state(rng, rnn.out_dims, Base.size(x, 2)) |> get_device(x) + y = similar(x, rnn.out_dims, Base.size(x, 2)) + copyto!(y, rnn.init_state(rng, size(y)...)) + return y end +@non_differentiable init_rnn_hidden_state(::Any...) + function init_trainable_rnn_hidden_state(hidden_state::AbstractVector, x::AbstractMatrix) return repeat(hidden_state, 1, Base.size(x, 2)) end