diff --git a/Project.toml b/Project.toml index fc0f7b63ab..177b58e27e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.4.8" +version = "0.4.9" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -10,6 +10,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/src/autodiff.jl b/src/autodiff.jl index 080c16c4c4..ccf635819f 100644 --- a/src/autodiff.jl +++ b/src/autodiff.jl @@ -3,6 +3,7 @@ ChainRulesCore.@non_differentiable replicate(::Any) ChainRulesCore.@non_differentiable update_statistics(::Any, ::Any, ::Any, ::Any, ::Any, ::Any, ::Any) ChainRulesCore.@non_differentiable generate_dropout_mask(::Any, ::Any, ::Any, ::Any) +ChainRulesCore.@non_differentiable _get_reshape_dims(::Any, ::Any) ChainRulesCore.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any) ChainRulesCore.@non_differentiable glorot_normal(::Any...) ChainRulesCore.@non_differentiable glorot_uniform(::Any...) @@ -31,7 +32,63 @@ function ChainRulesCore.rrule(::typeof(dropout), rng::AbstractRNG, x::AbstractAr return (y, mask, rng), dropout_pullback end -# Activation Rrules +# Utilities + +function ChainRulesCore.rrule(::typeof(_reshape_into_proper_shape), ::Nothing, y) + function _reshape_into_proper_shape_pullback(dx) + return NoTangent(), NoTangent(), NoTangent() + end + return nothing, _reshape_into_proper_shape_pullback +end + +function ChainRulesCore.rrule(::typeof(_reshape_into_proper_shape), x, y) + res = _reshape_into_proper_shape(x, y) + function _reshape_into_proper_shape_pullback(dx) + return NoTangent(), reshape(dx, size(x)), NoTangent() + end + return res, _reshape_into_proper_shape_pullback +end + +function ChainRulesCore.rrule(::typeof(merge), nt1::NamedTuple{F1}, + nt2::NamedTuple{F2}) where {F1, F2} + y = merge(nt1, nt2) + function merge_pullback(dy) + dnt1 = NamedTuple((f1 => (f1 in F2 ? NoTangent() : getproperty(dy, f1)) + for f1 in F1)) + dnt2 = NamedTuple((f2 => getproperty(dy, f2) for f2 in F2)) + return (NoTangent(), dnt1, dnt2) + end + return y, merge_pullback +end + +function ChainRulesCore.rrule(::typeof(vec), x::AbstractMatrix) + y = vec(x) + vec_pullback(dy) = NoTangent(), reshape(dy, size(x)) + return y, vec_pullback +end + +function ChainRulesCore.rrule(::typeof(convert), T::DataType, x::AbstractMatrix) + y = convert(T, x) + function convert_pullback(dy) + if dy isa NoTangent || dy isa ZeroTangent + dx = dy + else + dx = convert(typeof(x), dy) + end + return NoTangent(), NoTangent(), dx + end + return y, convert_pullback +end + +function ChainRulesCore.rrule(::typeof(collect), v::Vector) + y = collect(v) + function collect_pullback(dy) + return NoTangent(), dy + end + return y, collect_pullback +end + +# Activation rrules function ChainRulesCore.rrule(::typeof(applyactivation), f::cudnnValidActivationTypes, x::CuArray{T}) where {T <: CUDNNFloat} mode = getCUDNNActivationMode(f) diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index e265af9553..ec53a5856d 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -68,6 +68,10 @@ Use [`Lux.testmode`](@ref) during inference. m = Chain(Dense(784 => 64), BatchNorm(64, relu), Dense(64 => 10), BatchNorm(10)) ``` +!!! warning + + Passing a batch size of 1, during training will result in NaNs. + See also [`GroupNorm`](@ref) """ struct BatchNorm{affine, track_stats, F1, F2, F3, N} <: @@ -90,9 +94,13 @@ function BatchNorm(chs::Int, activation=identity; init_bias=zeros32, init_scale= end function initialparameters(rng::AbstractRNG, l::BatchNorm{affine}) where {affine} - return affine ? (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) : - NamedTuple() + if affine + return (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) + else + return (scale=nothing, bias=nothing) + end end + function initialstates(rng::AbstractRNG, l::BatchNorm{affine, track_stats}) where {affine, track_stats} return if track_stats @@ -109,9 +117,6 @@ function statelength(l::BatchNorm{affine, track_stats}) where {affine, track_sta end function (BN::BatchNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N} - @assert size(x, N - 1) == BN.chs - @assert !istraining(st)||size(x, N) > 1 "During `training`, `BatchNorm` can't handle Batch Size == 1" - x_normalized, xmean, xvar = normalization(x, st.running_mean, st.running_var, ps.scale, ps.bias, BN.activation, collect([1:(N - 2); N]), st.training, @@ -278,8 +283,11 @@ function GroupNorm(chs::Integer, groups::Integer, activation=identity; init_bias end function initialparameters(rng::AbstractRNG, l::GroupNorm{affine}) where {affine} - return affine ? (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) : - NamedTuple() + if affine + return (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) + else + return (scale=nothing, bias=nothing) + end end function initialstates(rng::AbstractRNG, @@ -300,9 +308,6 @@ end function (GN::GroupNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N} sz = size(x) - @assert N > 2 - @assert sz[N - 1] == GN.chs - x_ = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ GN.groups, GN.groups, sz[N]) x_normalized, xmean, xvar = normalization(x_, st.running_mean, st.running_var, ps.scale, diff --git a/src/nnlib.jl b/src/nnlib.jl index 1c887ec2bb..e0aa664f46 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -33,8 +33,8 @@ Performs BatchNorm/GroupNorm/InstanceNorm based on input configuration running_var::Union{Nothing, AbstractVector{T}}, scale::Union{Nothing, AbstractVector{T}}, bias::Union{Nothing, AbstractVector{T}}, activation, - reduce_dims, t::Val, momentum::T=T(0.1), epsilon::T=T(1e-5); - kwargs...) where {T, N} + reduce_dims, t::Val, momentum::T=T(0.1), + epsilon::T=T(1e-5)) where {T, N} running_mean_reshaped = _reshape_into_proper_shape(running_mean, x) running_var_reshaped = _reshape_into_proper_shape(running_var, x) scale_reshaped = _reshape_into_proper_shape(scale, x) @@ -44,15 +44,16 @@ Performs BatchNorm/GroupNorm/InstanceNorm based on input configuration scale_reshaped, bias_reshaped, activation, reduce_dims, t, momentum, - epsilon; kwargs...) + epsilon) return x_norm, _safe_vec(running_mean_), _safe_vec(running_var_) end @generated function normalization_forward(x::AbstractArray{T, N}, running_mean::RM, running_var::RV, scale::S, bias::B, activation::A, reduce_dims, ::Val{training}, - momentum::T=T(0.1f0), epsilon::T=T(1.0f-5); - kwargs...) where {RM, RV, S, B, T, N, A, training} + momentum::T=T(0.1f0), + epsilon::T=T(1.0f-5)) where {RM, RV, S, B, T, N, + A, training} calls = [] if !training if RM == Nothing @@ -79,16 +80,16 @@ end expr = if S != Nothing if A == typeof(identity) - :(result = @. scale * (x - batchmean) / sqrt(batchvar + epsilon) + bias) + :(result = scale .* (x .- batchmean) ./ sqrt.(batchvar .+ epsilon) .+ bias) else - :(result = @. activation(scale * (x - batchmean) / sqrt(batchvar + epsilon) + - bias)) + :(result = activation.(scale .* (x .- batchmean) ./ + sqrt.(batchvar .+ epsilon) .+ bias)) end else if A == typeof(identity) - :(result = @. (x - batchmean) / sqrt(batchvar + epsilon)) + :(result = (x .- batchmean) ./ sqrt.(batchvar .+ epsilon)) else - :(result = @. activation((x - batchmean) / sqrt(batchvar + epsilon))) + :(result = activation.((x .- batchmean) ./ sqrt.(batchvar .+ epsilon))) end end push!(calls, expr) @@ -115,7 +116,7 @@ end @inline function generate_dropout_mask(rng::AbstractRNG, x, p, q; dims=:) realfptype = float(real(eltype(x))) y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims))) - @. y = _dropout_kernel(y, p, q) + y .= _dropout_kernel.(y, p, q) return y end diff --git a/test/autodiff.jl b/test/autodiff.jl index c921dde374..ed32ac2f7f 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -30,3 +30,16 @@ end @test gs_x_1 == gs_x_2 end + +@testset "_reshape_into_proper_shape" begin + x = randn(rng, Float32, 3, 2) + y = randn(rng, Float32, 2, 2, 6, 2) + + @test size(Lux._reshape_into_proper_shape(x, y)) == (1, 1, 6, 1) + @inferred Lux._reshape_into_proper_shape(x, y) + + gs_1 = Zygote.gradient(x -> sum(Lux._reshape_into_proper_shape(x, y)), x)[1] + gs_2 = Zygote.gradient(x -> sum(reshape(x, (1, 1, 6, 1))), x)[1] + + @test gs_1 == gs_2 +end diff --git a/test/layers/normalize.jl b/test/layers/normalize.jl index 1f94cfa2c4..73f4cf3a99 100644 --- a/test/layers/normalize.jl +++ b/test/layers/normalize.jl @@ -6,65 +6,72 @@ rng = Random.default_rng() Random.seed!(rng, 0) @testset "BatchNorm" begin - let m = BatchNorm(2), x = [1.0f0 3.0f0 5.0f0 - 2.0f0 4.0f0 6.0f0] - println(m) - ps, st = Lux.setup(rng, m) - - @test Lux.parameterlength(m) == Lux.parameterlength(ps) - @test Lux.statelength(m) == Lux.statelength(st) - - @test ps.bias == [0, 0] # init_bias(2) - @test ps.scale == [1, 1] # init_scale(2) - - y, st_ = pullback(m, x, ps, st)[1] - @test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol=1.0e-5) - # julia> x - # 2×3 Array{Float64,2}: - # 1.0 3.0 5.0 - # 2.0 4.0 6.0 - # - # mean of batch will be - # (1. + 3. + 5.) / 3 = 3 - # (2. + 4. + 6.) / 3 = 4 - # - # ∴ update rule with momentum: - # .1 * 3 + 0 = .3 - # .1 * 4 + 0 = .4 - @test st_.running_mean ≈ reshape([0.3, 0.4], 2, 1) - - # julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] - # 2×1 Array{Float64,2}: - # 1.3 - # 1.3 - @test st_.running_var ≈ - 0.1 .* var(x; dims=2, corrected=false) .* (3 / 2) .+ 0.9 .* [1.0, 1.0] - - st_ = Lux.testmode(st_) - x′ = m(x, ps, st_)[1] - @test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol=1.0e-5) - - @inferred m(x, ps, st) - - run_JET_tests(m, x, ps, st) - - test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) - end - - let m = BatchNorm(2; track_stats=false), x = [1.0f0 3.0f0 5.0f0; 2.0f0 4.0f0 6.0f0] + m = BatchNorm(2) + x = [1.0f0 3.0f0 5.0f0 + 2.0f0 4.0f0 6.0f0] + println(m) + ps, st = Lux.setup(rng, m) + + @test Lux.parameterlength(m) == Lux.parameterlength(ps) + @test Lux.statelength(m) == Lux.statelength(st) + + @test ps.bias == [0, 0] # init_bias(2) + @test ps.scale == [1, 1] # init_scale(2) + + y, st_ = pullback(m, x, ps, st)[1] + @test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol=1.0e-5) + # julia> x + # 2×3 Array{Float64,2}: + # 1.0 3.0 5.0 + # 2.0 4.0 6.0 + + # mean of batch will be + # (1. + 3. + 5.) / 3 = 3 + # (2. + 4. + 6.) / 3 = 4 + + # ∴ update rule with momentum: + # .1 * 3 + 0 = .3 + # .1 * 4 + 0 = .4 + @test st_.running_mean ≈ reshape([0.3, 0.4], 2, 1) + + # julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + # 2×1 Array{Float64,2}: + # 1.3 + # 1.3 + @test st_.running_var ≈ + 0.1 .* var(x; dims=2, corrected=false) .* (3 / 2) .+ 0.9 .* [1.0, 1.0] + + st_ = Lux.testmode(st_) + x_ = m(x, ps, st_)[1] + @test isapprox(x_[1], (1 .- 0.3) / sqrt(1.3), atol=1.0e-5) + + @inferred m(x, ps, st) + + run_JET_tests(m, x, ps, st) + + test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; atol=1.0f-3, + rtol=1.0f-3) + + for affine in (true, false) + m = BatchNorm(2; affine, track_stats=false) + x = [1.0f0 3.0f0 5.0f0; 2.0f0 4.0f0 6.0f0] println(m) ps, st = Lux.setup(rng, m) @inferred m(x, ps, st) run_JET_tests(m, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) - end - - # with activation function - let m = BatchNorm(2, sigmoid), x = [1.0f0 3.0f0 5.0f0 - 2.0f0 4.0f0 6.0f0] + if affine + test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; + atol=1.0f-3, rtol=1.0f-3) + else + test_gradient_correctness_fdm(x -> sum(first(m(x, ps, st))), x; atol=1.0f-3, + rtol=1.0f-3) + end + + # with activation function + m = BatchNorm(2, sigmoid; affine) + x = [1.0f0 3.0f0 5.0f0 + 2.0f0 4.0f0 6.0f0] println(m) ps, st = Lux.setup(rng, m) st = Lux.testmode(st) @@ -75,18 +82,16 @@ Random.seed!(rng, 0) @inferred m(x, ps, st) run_JET_tests(m, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) - end - - let m = BatchNorm(2), x = reshape(Float32.(1:6), 3, 2, 1) - println(m) - ps, st = Lux.setup(rng, m) - st = Lux.trainmode(st) - @test_throws AssertionError m(x, ps, st)[1] - end + if affine + test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; + atol=1.0f-3, rtol=1.0f-3) + else + test_gradient_correctness_fdm(x -> sum(first(m(x, ps, st))), x; atol=1.0f-3, + rtol=1.0f-3) + end - let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1) + m = BatchNorm(32; affine) + x = randn(Float32, 416, 416, 32, 1) println(m) ps, st = Lux.setup(rng, m) st = Lux.testmode(st) @@ -101,76 +106,82 @@ end # begin tests squeeze(x) = dropdims(x; dims=tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions - let m = GroupNorm(4, 2; track_stats=true), - sizes = (3, 4, 2), - x = reshape(collect(1:prod(sizes)), sizes) - - println(m) - x = Float32.(x) - ps, st = Lux.setup(rng, m) - @test Lux.parameterlength(m) == Lux.parameterlength(ps) - @test Lux.statelength(m) == Lux.statelength(st) - @test ps.bias == [0, 0, 0, 0] # init_bias(32) - @test ps.scale == [1, 1, 1, 1] # init_scale(32) - - y, st_ = pullback(m, x, ps, st)[1] - - # julia> x - # [:, :, 1] = - # 1.0 4.0 7.0 10.0 - # 2.0 5.0 8.0 11.0 - # 3.0 6.0 9.0 12.0 - # - # [:, :, 2] = - # 13.0 16.0 19.0 22.0 - # 14.0 17.0 20.0 23.0 - # 15.0 18.0 21.0 24.0 - # - # mean will be - # (1. + 2. + 3. + 4. + 5. + 6.) / 6 = 3.5 - # (7. + 8. + 9. + 10. + 11. + 12.) / 6 = 9.5 - # - # (13. + 14. + 15. + 16. + 17. + 18.) / 6 = 15.5 - # (19. + 20. + 21. + 22. + 23. + 24.) / 6 = 21.5 - # - # mean = - # 3.5 15.5 - # 9.5 21.5 - # - # ∴ update rule with momentum: - # (1. - .1) * 0 + .1 * (3.5 + 15.5) / 2 = 0.95 - # (1. - .1) * 0 + .1 * (9.5 + 21.5) / 2 = 1.55 - @test st_.running_mean ≈ [0.95, 1.55] - n = prod(size(x)) ÷ m.groups ÷ size(x)[end] - corr = n / (n - 1) - z = reshape(x, 3, 2, 2, 2) - variance = var(z; dims=(1, 2), corrected=false) - @test st_.running_var ≈ 0.1 * corr * vec(mean(variance; dims=4)) .+ 0.9 * 1 - - st__ = Lux.testmode(st_) - y, st__ = m(x, ps, st__) - out = (z .- reshape(st_.running_mean, 1, 1, 2, 1)) ./ - sqrt.(reshape(st_.running_var, 1, 1, 2, 1) .+ 1.0f-5) - @test y≈reshape(out, size(x)) atol=1.0e-5 - - @inferred m(x, ps, st) - run_JET_tests(m, x, ps, st) - test_gradient_correctness_fdm(ps -> sum(first(m(x, ps, st))), ps; atol=1.0f-3, - rtol=1.0f-3) - end - - let m = GroupNorm(2, 2; track_stats=false), x = randn(rng, Float32, 3, 2, 1) + m = GroupNorm(4, 2; track_stats=true) + sizes = (3, 4, 2) + x = reshape(collect(1:prod(sizes)), sizes) + + println(m) + x = Float32.(x) + ps, st = Lux.setup(rng, m) + @test Lux.parameterlength(m) == Lux.parameterlength(ps) + @test Lux.statelength(m) == Lux.statelength(st) + @test ps.bias == [0, 0, 0, 0] # init_bias(32) + @test ps.scale == [1, 1, 1, 1] # init_scale(32) + + y, st_ = pullback(m, x, ps, st)[1] + + # julia> x + # [:, :, 1] = + # 1.0 4.0 7.0 10.0 + # 2.0 5.0 8.0 11.0 + # 3.0 6.0 9.0 12.0 + # + # [:, :, 2] = + # 13.0 16.0 19.0 22.0 + # 14.0 17.0 20.0 23.0 + # 15.0 18.0 21.0 24.0 + # + # mean will be + # (1. + 2. + 3. + 4. + 5. + 6.) / 6 = 3.5 + # (7. + 8. + 9. + 10. + 11. + 12.) / 6 = 9.5 + # + # (13. + 14. + 15. + 16. + 17. + 18.) / 6 = 15.5 + # (19. + 20. + 21. + 22. + 23. + 24.) / 6 = 21.5 + # + # mean = + # 3.5 15.5 + # 9.5 21.5 + # + # ∴ update rule with momentum: + # (1. - .1) * 0 + .1 * (3.5 + 15.5) / 2 = 0.95 + # (1. - .1) * 0 + .1 * (9.5 + 21.5) / 2 = 1.55 + @test st_.running_mean ≈ [0.95, 1.55] + n = prod(size(x)) ÷ m.groups ÷ size(x)[end] + corr = n / (n - 1) + z = reshape(x, 3, 2, 2, 2) + variance = var(z; dims=(1, 2), corrected=false) + @test st_.running_var ≈ 0.1 * corr * vec(mean(variance; dims=4)) .+ 0.9 * 1 + + st__ = Lux.testmode(st_) + y, st__ = m(x, ps, st__) + out = (z .- reshape(st_.running_mean, 1, 1, 2, 1)) ./ + sqrt.(reshape(st_.running_var, 1, 1, 2, 1) .+ 1.0f-5) + @test y≈reshape(out, size(x)) atol=1.0e-5 + + @inferred m(x, ps, st) + run_JET_tests(m, x, ps, st) + test_gradient_correctness_fdm(ps -> sum(first(m(x, ps, st))), ps; atol=1.0f-3, + rtol=1.0f-3) + + for affine in (true, false) + m = GroupNorm(2, 2; affine, track_stats=false) + x = randn(rng, Float32, 3, 2, 1) println(m) ps, st = Lux.setup(rng, m) @inferred m(x, ps, st) run_JET_tests(m, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) - end - - # with activation function - let m = GroupNorm(2, 2, sigmoid), x = randn(rng, Float32, 3, 2, 1) + if affine + test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; + atol=1.0f-3, rtol=1.0f-3) + else + test_gradient_correctness_fdm(x -> sum(first(m(x, ps, st))), x; atol=1.0f-3, + rtol=1.0f-3) + end + + # with activation function + m = GroupNorm(2, 2, sigmoid; affine) + x = randn(rng, Float32, 3, 2, 1) println(m) ps, st = Lux.setup(rng, m) st = Lux.testmode(st) @@ -179,11 +190,16 @@ end @inferred m(x, ps, st) run_JET_tests(m, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) - end + if affine + test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; + atol=1.0f-3, rtol=1.0f-3) + else + test_gradient_correctness_fdm(x -> sum(first(m(x, ps, st))), x; atol=1.0f-3, + rtol=1.0f-3) + end - let m = GroupNorm(32, 16), x = randn(Float32, 416, 416, 32, 1) + m = GroupNorm(32, 16; affine) + x = randn(rng, Float32, 416, 416, 32, 1) println(m) ps, st = Lux.setup(rng, m) st = Lux.testmode(st) diff --git a/test/test_utils.jl b/test/test_utils.jl index 856646df58..b4695584c6 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,4 +1,4 @@ -using FiniteDifferences, JET, Lux, Random, Zygote +using FiniteDifferences, JET, Lux, Random, Test, Zygote function Base.isapprox(nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; kwargs...) where {fields}