Skip to content

Commit

Permalink
test: workaround Enzyme warning
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 19, 2024
1 parent 677b2ac commit d586e10
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
8 changes: 5 additions & 3 deletions lib/LuxLib/test/common_ops/activation_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
@testitem "Activation Functions" tags=[:misc] setup=[SharedTestSetup] begin
using Enzyme

rng = StableRNG(1234)

apply_act(f::F, x) where {F} = sum(abs2, f.(x))
Expand Down Expand Up @@ -41,9 +43,9 @@
end
@test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any

@test_gradients(Base.Fix1(apply_act, f), x; atol, rtol)
@test_gradients(Base.Fix1(apply_act_fast, f), x; atol, rtol)
@test_gradients(Base.Fix1(apply_act_fast2, f), x; atol, rtol)
@test_gradients(apply_act, f, x; atol, rtol)
@test_gradients(apply_act_fast, f, x; atol, rtol, skip_backends=[AutoEnzyme()])
@test_gradients(apply_act_fast2, f, x; atol, rtol)

∂x1 = Zygote.gradient(apply_act, f, x)[2]
∂x2 = Zygote.gradient(apply_act_fast, f, x)[2]
Expand Down
20 changes: 14 additions & 6 deletions test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ end
@testitem "Bilinear" setup=[SharedTestSetup] tags=[:core_layers] begin
rng = StableRNG(12345)

skip_backends = VERSION < v"1.11-" ? [AutoEnzyme()] : []

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
@testset "SkipConnection recombinator" begin
d = Dense(2 => 2)
Expand All @@ -255,7 +257,8 @@ end

@test size(layer(x, ps, st)[1]) == (3, 1)
@jet layer(x, ps, st)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends)

d = Dense(2 => 2)
display(d)
Expand All @@ -268,7 +271,8 @@ end

@test size(layer(x, ps, st)[1]) == (3, 1)
@jet layer(x, ps, st)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends)

d = Dense(2 => 3)
display(d)
Expand All @@ -281,7 +285,8 @@ end

@test size(layer(x, ps, st)[1]) == (5, 7, 11)
@jet layer(x, ps, st)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends)
end

@testset "Two-streams zero sum" begin
Expand All @@ -296,7 +301,8 @@ end

@test LuxCore.outputsize(layer, (x, y), rng) == (3,)
@jet layer((x, y), ps, st)
@test_gradients(sumabs2first, layer, (x, y), ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, layer, (x, y), ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends)
end

@testset "Inner interactions" begin
Expand All @@ -307,7 +313,8 @@ end

@test size(layer(x, ps, st)[1]) == (3, 1)
@jet layer(x, ps, st)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends)

x = randn(Float32, 2, 1) |> aType
layer = Bilinear(2 => 3)
Expand All @@ -316,7 +323,8 @@ end

@test size(layer(x, ps, st)[1]) == (3, 1)
@jet layer(x, ps, st)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends)
end
end
end
Expand Down

0 comments on commit d586e10

Please sign in to comment.