diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl index f1a190c21..550f6bb5a 100644 --- a/lib/LuxLib/test/common_ops/activation_tests.jl +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -1,8 +1,17 @@ @testitem "Activation Functions" tags=[:misc] setup=[SharedTestSetup] begin + using Enzyme + rng = StableRNG(1234) apply_act(f::F, x) where {F} = sum(abs2, f.(x)) - apply_act_fast(f::F, x) where {F} = sum(abs2, fast_activation!!(f, copy(x))) + function apply_act_fast(f::F, x) where {F} + if Enzyme.within_autodiff() + y = similar(x) + y .= x + return sum(abs2, fast_activation!!(f, y)) + end + return sum(abs2, fast_activation!!(f, copy(x))) + end apply_act_fast2(f::F, x) where {F} = sum(abs2, fast_activation(f, x)) @testset "$mode" for (mode, aType, ongpu, fp64) in MODES diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 02442b222..3adea7323 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -242,6 +242,8 @@ end @testitem "Bilinear" setup=[SharedTestSetup] tags=[:core_layers] begin rng = StableRNG(12345) + skip_backends = VERSION < v"1.11-" ? [AutoEnzyme()] : [] + @testset "$mode" for (mode, aType, dev, ongpu) in MODES @testset "SkipConnection recombinator" begin d = Dense(2 => 2) @@ -255,7 +257,8 @@ end @test size(layer(x, ps, st)[1]) == (3, 1) @jet layer(x, ps, st) - @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + skip_backends) d = Dense(2 => 2) display(d) @@ -268,7 +271,8 @@ end @test size(layer(x, ps, st)[1]) == (3, 1) @jet layer(x, ps, st) - @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + skip_backends) d = Dense(2 => 3) display(d) @@ -281,7 +285,8 @@ end @test size(layer(x, ps, st)[1]) == (5, 7, 11) @jet layer(x, ps, st) - @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + skip_backends) end @testset "Two-streams zero sum" begin @@ -296,7 +301,8 @@ end @test LuxCore.outputsize(layer, (x, y), rng) == (3,) @jet layer((x, y), ps, st) - @test_gradients(sumabs2first, layer, (x, y), ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, layer, (x, y), ps, st; atol=1.0f-3, rtol=1.0f-3, + skip_backends) end @testset "Inner interactions" begin @@ -307,7 +313,8 @@ end @test size(layer(x, ps, st)[1]) == (3, 1) @jet layer(x, ps, st) - @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + skip_backends) x = randn(Float32, 2, 1) |> aType layer = Bilinear(2 => 3) @@ -316,7 +323,8 @@ end @test size(layer(x, ps, st)[1]) == (3, 1) @jet layer(x, ps, st) - @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3) + @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3, + skip_backends) end end end