Skip to content

Commit

Permalink
test: workaround Enzyme warning
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 19, 2024
1 parent 677b2ac commit 7a0140e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
11 changes: 10 additions & 1 deletion lib/LuxLib/test/common_ops/activation_tests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
@testitem "Activation Functions" tags=[:misc] setup=[SharedTestSetup] begin
using Enzyme

rng = StableRNG(1234)

apply_act(f::F, x) where {F} = sum(abs2, f.(x))
apply_act_fast(f::F, x) where {F} = sum(abs2, fast_activation!!(f, copy(x)))
function apply_act_fast(f::F, x) where {F}
if Enzyme.within_autodiff()
y = similar(x)
y .= x
return sum(abs2, fast_activation!!(f, y))
end
return sum(abs2, fast_activation!!(f, copy(x)))
end
apply_act_fast2(f::F, x) where {F} = sum(abs2, fast_activation(f, x))

@testset "$mode" for (mode, aType, ongpu, fp64) in MODES
Expand Down
20 changes: 14 additions & 6 deletions test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ end
@testitem "Bilinear" setup=[SharedTestSetup] tags=[:core_layers] begin
rng = StableRNG(12345)

skip_backends = VERSION < v"1.11-" ? [AutoEnzyme()] : []

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
@testset "SkipConnection recombinator" begin
d = Dense(2 => 2)
Expand All @@ -255,7 +257,8 @@ end

@test size(layer(x, ps, st)[1]) == (3, 1)
@jet layer(x, ps, st)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends)

d = Dense(2 => 2)
display(d)
Expand All @@ -268,7 +271,8 @@ end

@test size(layer(x, ps, st)[1]) == (3, 1)
@jet layer(x, ps, st)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends)

d = Dense(2 => 3)
display(d)
Expand All @@ -281,7 +285,8 @@ end

@test size(layer(x, ps, st)[1]) == (5, 7, 11)
@jet layer(x, ps, st)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends)
end

@testset "Two-streams zero sum" begin
Expand All @@ -296,7 +301,8 @@ end

@test LuxCore.outputsize(layer, (x, y), rng) == (3,)
@jet layer((x, y), ps, st)
@test_gradients(sumabs2first, layer, (x, y), ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, layer, (x, y), ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends)
end

@testset "Inner interactions" begin
Expand All @@ -307,7 +313,8 @@ end

@test size(layer(x, ps, st)[1]) == (3, 1)
@jet layer(x, ps, st)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends)

x = randn(Float32, 2, 1) |> aType
layer = Bilinear(2 => 3)
Expand All @@ -316,7 +323,8 @@ end

@test size(layer(x, ps, st)[1]) == (3, 1)
@jet layer(x, ps, st)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
skip_backends)
end
end
end
Expand Down

0 comments on commit 7a0140e

Please sign in to comment.