diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl index f1a190c219..2789e7d4cb 100644 --- a/lib/LuxLib/test/common_ops/activation_tests.jl +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -1,4 +1,6 @@ @testitem "Activation Functions" tags=[:misc] setup=[SharedTestSetup] begin + using Enzyme + rng = StableRNG(1234) apply_act(f::F, x) where {F} = sum(abs2, f.(x)) @@ -41,9 +43,9 @@ end @test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any - @test_gradients(Base.Fix1(apply_act, f), x; atol, rtol) - @test_gradients(Base.Fix1(apply_act_fast, f), x; atol, rtol) - @test_gradients(Base.Fix1(apply_act_fast2, f), x; atol, rtol) + @test_gradients(apply_act, f, x; atol, rtol) + @test_gradients(apply_act_fast, f, x; atol, rtol, skip_backends=[AutoEnzyme()]) + @test_gradients(apply_act_fast2, f, x; atol, rtol) ∂x1 = Zygote.gradient(apply_act, f, x)[2] ∂x2 = Zygote.gradient(apply_act_fast, f, x)[2] diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 02442b2226..3adea7323d 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -242,6 +242,8 @@ end @testitem "Bilinear" setup=[SharedTestSetup] tags=[:core_layers] begin rng = StableRNG(12345) + skip_backends = VERSION < v"1.11-" ? [AutoEnzyme()] : [] + @testset "$mode" for (mode, aType, dev, ongpu) in MODES @testset "SkipConnection recombinator" begin d = Dense(2 => 2) @@ -255,7 +257,8 @@ end @test size(layer(x, ps, st)[1]) == (3, 1) @jet layer(x, ps, st) - @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + skip_backends) d = Dense(2 => 2) display(d) @@ -268,7 +271,8 @@ end @test size(layer(x, ps, st)[1]) == (3, 1) @jet layer(x, ps, st) - @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + skip_backends) d = Dense(2 => 3) display(d) @@ -281,7 +285,8 @@ end @test size(layer(x, ps, st)[1]) == (5, 7, 11) @jet layer(x, ps, st) - @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + skip_backends) end @testset "Two-streams zero sum" begin @@ -296,7 +301,8 @@ end @test LuxCore.outputsize(layer, (x, y), rng) == (3,) @jet layer((x, y), ps, st) - @test_gradients(sumabs2first, layer, (x, y), ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, layer, (x, y), ps, st; atol=1.0f-3, rtol=1.0f-3, + skip_backends) end @testset "Inner interactions" begin @@ -307,7 +313,8 @@ end @test size(layer(x, ps, st)[1]) == (3, 1) @jet layer(x, ps, st) - @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + skip_backends) x = randn(Float32, 2, 1) |> aType layer = Bilinear(2 => 3) @@ -316,7 +323,8 @@ end @test size(layer(x, ps, st)[1]) == (3, 1) @jet layer(x, ps, st) - @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + skip_backends) end end end