diff --git a/src/functor.jl b/src/functor.jl index d05489104f..9977303cb6 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -36,16 +36,14 @@ Possible values include: """ trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode) +params!(p::Params, x::DenseArray{<:Number}, seen = IdSet()) = Functors.isleaf(x) && push!(p, x) + function params!(p::Params, x, seen = IdSet()) - if x isa AbstractArray{<:Number} && Functors.isleaf(x) - return push!(p, x) - elseif x in seen - nothing - else - push!(seen, x) - for child in trainable(x) - params!(p, child, seen) - end + x in seen && return + + push!(seen, x) + for child in trainable(x) + params!(p, child, seen) end end