Skip to content

Commit

Permalink
test: try fixing more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 19, 2024
1 parent 6f58da1 commit 677b2ac
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 121 deletions.
2 changes: 1 addition & 1 deletion lib/LuxLib/src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ using Hwloc: Hwloc
using Static: static, False, True

using ..LuxLib: DISABLE_LOOP_VECTORIZATION
using ..Utils: is_extension_loaded, safe_minimum, unsafe_known, within_enzyme_autodiff
using ..Utils: is_extension_loaded, safe_minimum, within_enzyme_autodiff

const CRC = ChainRulesCore

Expand Down
44 changes: 15 additions & 29 deletions lib/LuxLib/test/common_ops/dropout_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, ongpu, fp64) in MODES
@testset "$T, $x_shape, $dims" for T in (Float16, Float32, Float64),
@testset "$T, $x_shape, $dims" for T in (Float32, Float64),
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)),
dims in (:, 1, (1, 2))

Expand All @@ -26,12 +26,8 @@
__f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, dims)))
@test @inferred(Zygote.gradient(__f, x)) isa Any

__f = let rng = rng, T = T
x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims)))
end
@test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
@test_gradients(sumabs2first,
dropout, rng, x, T(0.5), Val(true), T(2), dims; atol=1.0f-3, rtol=1.0f-3)

y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), dims)

Expand All @@ -49,7 +45,7 @@ end
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, ongpu, fp64) in MODES
@testset "$T: $x_shape" for T in (Float16, Float32, Float64),
@testset "$T: $x_shape" for T in (Float32, Float64),
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1))

!fp64 && T == Float64 && continue
Expand All @@ -75,12 +71,9 @@ end
StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, :)))
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any

__f = let rng = rng, mask = mask, p = T(0.5), invp = T(2)
x -> sum(first(dropout(rng, x, mask, p, Val(true), Val(true), invp, :)))
end
@test_gradients(__f, x; atol=1.0f-3,
rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []))
@test_gradients(sumabs2first,
dropout, rng, x, LuxTestUtils.Constant(mask), T(0.5), Val(true), Val(true),
T(2), :; atol=1.0f-3, rtol=1.0f-3)

@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)))
Expand All @@ -103,14 +96,11 @@ end
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, :)))
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any

__f = let rng = rng, mask = mask, p = T(0.5), invp = T(2)
x -> sum(first(dropout(rng, x, mask, p, Val(true), Val(false), invp, :)))
end

soft_fail = T == Float16 ? Any[AutoFiniteDiff()] : []
skip_backends = length(x_shape) == 5 ? [AutoEnzyme()] : []

@test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail, skip_backends)
@test_gradients(sumabs2first,
dropout, rng, x, LuxTestUtils.Constant(mask),
T(0.5), Val(true), Val(false), T(2), :;
broken_backends=length(x_shape) > 2 ? [AutoEnzyme()] : [],
atol=1.0f-3, rtol=1.0f-3)

@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)))
Expand Down Expand Up @@ -138,7 +128,7 @@ end
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, ongpu, fp64) in MODES
@testset "$T: $x_shape" for T in (Float16, Float32, Float64),
@testset "$T: $x_shape" for T in (Float32, Float64),
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1))

!fp64 && T == Float64 && continue
Expand All @@ -158,12 +148,8 @@ end
__f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true))))
@test @inferred(Zygote.gradient(__f, x)) isa Any

__f = let rng = rng
x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))
end
@test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
@test_gradients(sumabs2first,
alpha_dropout, rng, x, T(0.5), Val(true); atol=1.0f-3, rtol=1.0f-3)

@jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))
@test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any
Expand Down
27 changes: 7 additions & 20 deletions lib/LuxLib/test/normalization/batchnorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ anonact = x -> x^3

is_training(::Val{training}) where {training} = training

sumabs2first(f::F, args...) where {F} = sum(abs2, first(f(args...)))

function run_batchnorm_testing(
gen_f, T, sz, training, affine, track_stats, act, aType, mode, ongpu)
epsilon = eps(T)^(5 // 7)
Expand All @@ -43,9 +45,8 @@ function run_batchnorm_testing(
y_simple, nt_simple = batchnorm_fallback(
x, scale, bias, rm, rv, training, act, T(0.9), epsilon)

fp16 = T == Float16
atol = fp16 ? 1.0f-2 : 1.0f-3
rtol = fp16 ? 1.0f-2 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

@test yy_simple atol=atol rtol=rtol
if track_stats
Expand Down Expand Up @@ -84,22 +85,8 @@ function run_batchnorm_testing(
skip_backends = []
act === relu && push!(skip_backends, AutoFiniteDiff())

soft_fail = if fp16
if Sys.iswindows()
[AutoTracker(), AutoFiniteDiff(), AutoReverseDiff(), AutoForwardDiff()]
else
true
end
else
false
end

broken_backends = Sys.iswindows() && fp16 ? [AutoEnzyme()] : []

__f = (args...) -> sum(first(batchnorm(
args..., rm, rv, training, act, T(0.9), epsilon)))
@test_gradients(__f, x, scale, bias; atol, rtol, skip_backends, soft_fail,
broken_backends)
@test_gradients(sumabs2first, batchnorm, x, scale, bias, Constant(rm),
Constant(rv), training, act, T(0.9), epsilon; atol, rtol, skip_backends)
end

if anonact !== act
Expand All @@ -111,7 +98,7 @@ function run_batchnorm_testing(
end

const ALL_TEST_CONFIGS = Iterators.product(
[Float16, Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)),
[Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)),
(Val(true), Val(false)), (true, false), (true, false),
(identity, relu, tanh_fast, sigmoid_fast, anonact))

Expand Down
40 changes: 17 additions & 23 deletions lib/LuxLib/test/normalization/groupnorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ anonact = x -> x^3

is_training(::Val{training}) where {training} = training

sumabs2groupnorm(args...) = sum(abs2, groupnorm(args...))

function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu)
_f = (args...) -> groupnorm(args..., groups, act, epsilon)
_f2 = (args...) -> groupnorm_fallback(args..., groups, act, epsilon)
Expand All @@ -38,25 +40,22 @@ function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu)
y = _f(x, scale, bias)
y_simple = _f2(x, scale, bias)

fp16 = T == Float16
atol = fp16 ? 1.0f-2 : 1.0f-3
rtol = fp16 ? 1.0f-2 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

@test yy_simple atol=atol rtol=rtol

# Check the rrules
if !fp16
∂x, ∂scale, ∂bias = Zygote.gradient(sum _f, x, scale, bias)
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum _f2, x, scale, bias)
if length(sz) == 5 && !ongpu
@test_softfail check_approx(∂x, ∂x_simple; atol, rtol)
else
@test ∂x∂x_simple atol=atol rtol=rtol
end
if affine
@test ∂scale∂scale_simple atol=atol rtol=rtol
@test ∂bias∂bias_simple atol=atol rtol=rtol
end
∂x, ∂scale, ∂bias = Zygote.gradient(sum _f, x, scale, bias)
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum _f2, x, scale, bias)
if length(sz) == 5 && !ongpu
@test_softfail check_approx(∂x, ∂x_simple; atol, rtol)
else
@test ∂x∂x_simple atol=atol rtol=rtol
end
if affine
@test ∂scale∂scale_simple atol=atol rtol=rtol
@test ∂bias∂bias_simple atol=atol rtol=rtol
end

@test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any
Expand All @@ -70,16 +69,11 @@ function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu)
@test y isa aType{T, length(sz)}
@test size(y) == sz

soft_fail = fp16 ? fp16 : [AutoFiniteDiff()]

if affine
__f = (args...) -> sum(groupnorm(args..., groups, act, epsilon))
@test_gradients(__f, x, scale, bias; atol, rtol, soft_fail,
skip_backends=[AutoEnzyme()])
end
@test_gradients(sumabs2groupnorm, x, scale, bias, groups, act, epsilon; atol, rtol,
soft_fail=[AutoFiniteDiff()])
end

const ALL_TEST_CONFIGS = Iterators.product([Float16, Float32, Float64],
const ALL_TEST_CONFIGS = Iterators.product([Float32, Float64],
(
(6, 2),
(4, 6, 2),
Expand Down
38 changes: 14 additions & 24 deletions lib/LuxLib/test/normalization/instancenorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,17 @@ end

anonact = x -> x^3

function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongpu)
_f = (args...) -> first(instancenorm(args..., training, act, epsilon))
sumabs2instancenorm(args...) = sum(abs2, first(instancenorm(args...)))

function run_instancenorm_testing(gen_f, T, sz, training, act, aType)
epsilon = LuxLib.Utils.default_epsilon(T)
x, scale, bias = setup_instancenorm(gen_f, aType, T, sz)

# First test without running stats
y, nt = instancenorm(x, scale, bias, training, act, epsilon)

fp16 = T == Float16
atol = fp16 ? 1.0f-2 : 1.0f-3
rtol = fp16 ? 1.0f-2 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

@test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any
@jet instancenorm(x, scale, bias, training, act, epsilon)
Expand All @@ -37,9 +36,8 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp
@test size(y) == sz

if is_training(training)
__f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon)))
soft_fail = fp16 ? fp16 : [AutoFiniteDiff()]
@test_gradients(__f, x, scale, bias; atol, rtol, soft_fail)
@test_gradients(sumabs2instancenorm, x, scale, bias, training, act, epsilon;
atol, rtol, soft_fail=[AutoFiniteDiff()])
end

# Now test with running stats
Expand All @@ -63,16 +61,13 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp
@test size(y) == sz

if is_training(training)
__f = (args...) -> sum(first(instancenorm(
args..., rm, rv, training, act, T(0.1), epsilon)))
soft_fail = fp16 ? fp16 : [AutoFiniteDiff()]
skip_backends = [AutoEnzyme()]
@test_gradients(__f, x, scale, bias; atol, rtol, soft_fail, skip_backends)
@test_gradients(sumabs2instancenorm, x, scale, bias, Constant(rm), Constant(rv),
training, act, T(0.1), epsilon; atol, rtol, soft_fail=[AutoFiniteDiff()])
end
end

const ALL_TEST_CONFIGS = Iterators.product(
[Float16, Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)),
[Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)),
(Val(true), Val(false)), (identity, relu, tanh_fast, sigmoid_fast, anonact))

const TEST_BLOCKS = collect(Iterators.partition(
Expand All @@ -87,8 +82,7 @@ end
@testset "$mode" for (mode, aType, ongpu, fp64) in MODES
@testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[1]
!fp64 && T == Float64 && continue
run_instancenorm_testing(
generate_fixed_array, T, sz, training, act, aType, mode, ongpu)
run_instancenorm_testing(generate_fixed_array, T, sz, training, act, aType)
end
end
end
Expand All @@ -98,8 +92,7 @@ end
@testset "$mode" for (mode, aType, ongpu, fp64) in MODES
@testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[2]
!fp64 && T == Float64 && continue
run_instancenorm_testing(
generate_fixed_array, T, sz, training, act, aType, mode, ongpu)
run_instancenorm_testing(generate_fixed_array, T, sz, training, act, aType)
end
end
end
Expand All @@ -109,8 +102,7 @@ end
@testset "$mode" for (mode, aType, ongpu, fp64) in MODES
@testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[3]
!fp64 && T == Float64 && continue
run_instancenorm_testing(
generate_fixed_array, T, sz, training, act, aType, mode, ongpu)
run_instancenorm_testing(generate_fixed_array, T, sz, training, act, aType)
end
end
end
Expand All @@ -120,8 +112,7 @@ end
@testset "$mode" for (mode, aType, ongpu, fp64) in MODES
@testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[4]
!fp64 && T == Float64 && continue
run_instancenorm_testing(
generate_fixed_array, T, sz, training, act, aType, mode, ongpu)
run_instancenorm_testing(generate_fixed_array, T, sz, training, act, aType)
end
end
end
Expand All @@ -131,8 +122,7 @@ end
@testset "$mode" for (mode, aType, ongpu, fp64) in MODES
@testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[5]
!fp64 && T == Float64 && continue
run_instancenorm_testing(
generate_fixed_array, T, sz, training, act, aType, mode, ongpu)
run_instancenorm_testing(generate_fixed_array, T, sz, training, act, aType)
end
end
end
20 changes: 7 additions & 13 deletions lib/LuxLib/test/normalization/layernorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, ongpu
end
end

sumabs2layernorm(args...) = sum(abs2, layernorm(args...))

function run_layernorm_testing_core(
aType, T, x_size, affine_shape, act, dims, x, scale, bias)
epsilon = LuxLib.Utils.default_epsilon(T)
Expand All @@ -51,19 +53,11 @@ function run_layernorm_testing_core(
@test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1)
end

fp16 = T == Float16
atol = fp16 ? 1.0f-2 : 1.0f-3
rtol = fp16 ? 1.0f-2 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

soft_fail = fp16 ? fp16 : [AutoFiniteDiff()]
if affine_shape !== nothing
__f = (args...) -> sum(_f(args...))
@test_gradients(__f, x, scale, bias; atol, rtol, soft_fail,
skip_backends=[AutoEnzyme()])
else
__f = x -> sum(_f(x, scale, bias))
@test_gradients(__f, x; atol, rtol, soft_fail, skip_backends=[AutoEnzyme()])
end
@test_gradients(sumabs2layernorm, x, scale, bias, act, dims, epsilon; atol, rtol,
soft_fail=[AutoFiniteDiff()])

if anonact !== act
lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ))
Expand All @@ -75,7 +69,7 @@ anonact = x -> x^3

const ALL_TEST_CONFIGS = Any[]

for T in (Float16, Float32, Float64),
for T in (Float32, Float64),
x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)),
affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])),
act in (identity, relu, tanh_fast, sigmoid_fast, anonact)
Expand Down
4 changes: 3 additions & 1 deletion lib/LuxLib/test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ function generate_fixed_array(::Type{T}, sz) where {T}
end
generate_fixed_array(::Type{T}, sz::Int) where {T} = T.(collect(1:sz) ./ sz)

export MODES, StableRNG, generate_fixed_array, BACKEND_GROUP
sumabs2first(f::F, args...) where {F} = sum(abs2, first(f(args...)))

export MODES, StableRNG, generate_fixed_array, BACKEND_GROUP, sumabs2first

end
4 changes: 3 additions & 1 deletion lib/LuxTestUtils/src/LuxTestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@ catch err
end

include("test_softfail.jl")
include("utils.jl")
include("autodiff.jl")
include("jet.jl")

include("utils.jl")

export AutoEnzyme, AutoFiniteDiff, AutoTracker, AutoForwardDiff, AutoReverseDiff, AutoZygote
export test_gradients, @test_gradients
export Constant
export @jet, jet_target_modules!
export @test_softfail

Expand Down
Loading

0 comments on commit 677b2ac

Please sign in to comment.