From d728f4abb142732c3f7e6da77a01ef488a2bf5d4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 22 Nov 2024 12:20:39 -0500 Subject: [PATCH] test: use TestExtras in LuxLib testing --- lib/LuxLib/test/Project.toml | 2 + .../test/common_ops/activation_tests.jl | 12 +++--- lib/LuxLib/test/common_ops/bias_act_tests.jl | 38 +++++++------------ lib/LuxLib/test/common_ops/conv_tests.jl | 15 ++++---- lib/LuxLib/test/common_ops/dense_tests.jl | 14 +++---- lib/LuxLib/test/common_ops/dropout_tests.jl | 25 ++++++------ .../test/normalization/batchnorm_tests.jl | 8 ++-- .../test/normalization/groupnorm_tests.jl | 6 +-- .../test/normalization/instancenorm_tests.jl | 12 +++--- .../test/normalization/layernorm_tests.jl | 6 +-- lib/LuxLib/test/shared_testsetup.jl | 2 +- 11 files changed, 62 insertions(+), 78 deletions(-) diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 403bc57fb5..1e1b5c58b4 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -28,6 +28,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -61,5 +62,6 @@ Static = "0.8.4, 1" StaticArrays = "1.9.7" Statistics = "1.10" Test = "1.10" +TestExtras = "0.3.1" Tracker = "0.2.36" Zygote = "0.6.70" diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl index 2789e7d4cb..8a2a56defb 100644 --- a/lib/LuxLib/test/common_ops/activation_tests.jl +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -30,18 +30,18 @@ @test eltype(y2) == T @test eltype(y3) == T - @test @inferred(apply_act(f, x)) isa Any - @test @inferred(apply_act_fast(f, x)) isa Any - @test @inferred(apply_act_fast2(f, x)) isa Any + @constinferred apply_act(f, x) + @constinferred apply_act_fast(f, x) + @constinferred apply_act_fast2(f, x) @jet apply_act_fast(f, x) @jet apply_act_fast2(f, x) - @test @inferred(Zygote.gradient(apply_act, f, x)) isa Any + @constinferred Zygote.gradient(apply_act, f, x) if f !== lisht - @test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any + @constinferred Zygote.gradient(apply_act_fast, f, x) end - @test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any + @constinferred Zygote.gradient(apply_act_fast2, f, x) @test_gradients(apply_act, f, x; atol, rtol) @test_gradients(apply_act_fast, f, x; atol, rtol, skip_backends=[AutoEnzyme()]) diff --git a/lib/LuxLib/test/common_ops/bias_act_tests.jl b/lib/LuxLib/test/common_ops/bias_act_tests.jl index 4e0e51ced4..1e932f3d9a 100644 --- a/lib/LuxLib/test/common_ops/bias_act_tests.jl +++ b/lib/LuxLib/test/common_ops/bias_act_tests.jl @@ -5,12 +5,6 @@ bias_act_loss2(act, x, b) = sum(abs2, bias_activation(act, x, b)) bias_act_loss3(act, x, b) = sum(abs2, bias_activation!!(act, copy(x), b)) - struct __Fix1{F, A} - f::F - act::A - end - (f::__Fix1)(x, b) = f.f(f.act, x, b) - @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "$act, $T, $sz" for act in [ identity, relu, sigmoid, sigmoid_fast, softplus, @@ -27,9 +21,8 @@ y2 = bias_act_loss2(act, x, b) y3 = bias_act_loss3(act, x, b) - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 + atol = 1.0f-3 + rtol = 1.0f-3 @test y1≈y2 atol=atol rtol=rtol @test y1≈y3 atol=atol rtol=rtol @@ -37,28 +30,25 @@ @test eltype(y2) == T @test eltype(y3) == T - @test @inferred(bias_act_loss1(act, x, b)) isa Any - @test @inferred(bias_act_loss2(act, x, b)) isa Any - @test @inferred(bias_act_loss3(act, x, b)) isa Any + @constinferred bias_act_loss1(act, x, b) + @constinferred bias_act_loss2(act, x, b) + @constinferred bias_act_loss3(act, x, b) @jet bias_act_loss2(act, x, b) @jet bias_act_loss3(act, x, b) - if act !== lisht && T != Float16 - @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any - @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any + if act !== lisht + @constinferred Zygote.gradient(bias_act_loss2, act, x, b) + @constinferred Zygote.gradient(bias_act_loss3, act, x, b) end - @test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol, - soft_fail=fp16 ? [AutoFiniteDiff()] : []) - @test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol, - soft_fail=fp16 ? [AutoFiniteDiff()] : []) - @test_gradients(__Fix1(bias_act_loss3, act), x, b; atol, rtol, - soft_fail=fp16 ? [AutoFiniteDiff()] : []) + @test_gradients(bias_act_loss1, act, x, b; atol, rtol) + @test_gradients(bias_act_loss2, act, x, b; atol, rtol) + @test_gradients(bias_act_loss3, act, x, b; atol, rtol) - ∂x1, ∂b1 = Zygote.gradient(__Fix1(bias_act_loss1, act), x, b) - ∂x2, ∂b2 = Zygote.gradient(__Fix1(bias_act_loss2, act), x, b) - ∂x3, ∂b3 = Zygote.gradient(__Fix1(bias_act_loss3, act), x, b) + _, ∂x1, ∂b1 = Zygote.pullback(bias_act_loss1, act, x, b) + _, ∂x2, ∂b2 = Zygote.pullback(bias_act_loss2, act, x, b) + _, ∂x3, ∂b3 = Zygote.pullback(bias_act_loss3, act, x, b) @test ∂x1≈∂x2 atol=atol rtol=rtol @test ∂x1≈∂x3 atol=atol rtol=rtol diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index b58aafcd36..9fff43364c 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -1,5 +1,5 @@ @testsetup module ConvSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, TestExtras expand(_, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) @@ -43,20 +43,19 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, @test eltype(y) == promote_type(Tw, Tx) - @test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any + @constinferred fused_conv_bias_activation(activation, weight, x, bias, cdims) @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) if mode != "amdgpu" && activation !== anonact - @test @inferred(Zygote.gradient( - sumabs2conv, activation, weight, x, bias, cdims - )) isa Any + @constinferred Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims) else try - @inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)) - @test true + @constinferred Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims) catch e e isa ErrorException || rethrow() - @test_broken false + @constinferred_broken Zygote.gradient( + sumabs2conv, activation, weight, x, bias, cdims + ) end end diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index bc4d40e555..6e65b46547 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -1,5 +1,5 @@ @testsetup module DenseSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs, TestExtras anonact = x -> x^3 @@ -27,14 +27,14 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu @test y ≈ y_generic @test eltype(y) == promote_type(Tw, Tx) - @test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any + @constinferred fused_dense_bias_activation(activation, w, x, bias) @jet fused_dense_bias_activation(activation, w, x, bias) atol = 1.0f-3 rtol = 1.0f-3 if activation !== anonact - @test @inferred(Zygote.gradient(sumabs2dense, activation, w, x, bias)) isa Any + @constinferred Zygote.gradient(sumabs2dense, activation, w, x, bias) end skip_backends = [] @@ -117,23 +117,23 @@ end end @testitem "Fused Dense: StaticArrays" tags=[:dense] begin - using StaticArrays, NNlib + using StaticArrays, NNlib, TestExtras x = @SArray rand(2, 4) weight = @SArray rand(3, 2) bias = @SArray rand(3) - @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa SArray + @constinferred fused_dense_bias_activation(relu, weight, x, bias) end @testitem "Fused Dense: CPU No Scalar Indexing" tags=[:dense] begin - using JLArrays, NNlib + using JLArrays, NNlib, TestExtras x = JLArray(rand(Float32, 2, 4)) weight = JLArray(rand(Float32, 3, 2)) bias = JLArray(rand(Float32, 3)) - @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa JLArray + @constinferred fused_dense_bias_activation(relu, weight, x, bias) @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp end diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index 1ec9b4618b..e1de98c7ef 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -10,7 +10,7 @@ x = randn(rng, T, x_shape) |> aType - @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any + @constinferred dropout(rng, x, T(0.5), Val(true), T(2), dims) y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), dims) @@ -21,10 +21,10 @@ @test rng != rng_ @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) - @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any + @constinferred dropout(rng, x, T(0.5), Val(true), T(2), dims) __f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, dims))) - @test @inferred(Zygote.gradient(__f, x)) isa Any + @constinferred Zygote.gradient(__f, x) @test_gradients(sumabs2first, dropout, rng, x, T(0.5), Val(true), T(2), dims; atol=1.0f-3, rtol=1.0f-3) @@ -54,8 +54,7 @@ end mask = rand(T, x_shape) |> aType # Update mask - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)) isa Any + @constinferred dropout(rng, x, mask, T(0.5), Val(true), Val(true), T(2), :) y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), :) @@ -69,7 +68,7 @@ end __f = (x, mask) -> sum(first(dropout( StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, :))) - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any + @constinferred Zygote.gradient(__f, x, mask) @test_gradients(sumabs2first, dropout, rng, x, LuxTestUtils.Constant(mask), T(0.5), Val(true), Val(true), @@ -79,8 +78,7 @@ end rng, x, mask, T(0.5), Val(true), Val(true), T(2), :))) # Try using mask if possible (possible!!) - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)) isa Any + @constinferred dropout(rng, x, mask, T(0.5), Val(true), Val(false), T(2), :) y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), :) @@ -94,7 +92,7 @@ end __f = (x, mask) -> sum(first(dropout( StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, :))) - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any + @constinferred Zygote.gradient(__f, x, mask) @test_gradients(sumabs2first, dropout, rng, x, LuxTestUtils.Constant(mask), @@ -107,8 +105,7 @@ end mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType # Testing Mode - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)) isa Any + @constinferred dropout(rng, x, mask, T(0.5), Val(false), Val(false), T(2), :) y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(false), Val(false), T(2), :) @@ -135,7 +132,7 @@ end x = randn(rng, T, x_shape) |> aType - @test @inferred(alpha_dropout(rng, x, T(0.5), Val(true))) isa Any + @constinferred alpha_dropout(rng, x, T(0.5), Val(true)) y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) @@ -146,13 +143,13 @@ end @test_broken std(y)≈std(x) atol=1.0f-2 rtol=1.0f-2 __f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true)))) - @test @inferred(Zygote.gradient(__f, x)) isa Any + @constinferred Zygote.gradient(__f, x) @test_gradients(sumabs2first, alpha_dropout, rng, x, T(0.5), Val(true); atol=1.0f-3, rtol=1.0f-3) @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) - @test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any + @constinferred alpha_dropout(rng, x, T(0.5), Val(false)) y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 58b6196c1a..d47c542d63 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module BatchNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, TestExtras function setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool) x = gen_f(T, sz) |> aType @@ -69,8 +69,7 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act, end end - @test @inferred(batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa - Any + @constinferred batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) @jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) @test y isa aType{T, length(sz)} @@ -91,8 +90,7 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act, if anonact !== act lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm( x, sc, b, rm, rv, tr, act, ϵ))) - @test @inferred(Zygote.gradient( - lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any + @constinferred Zygote.gradient(lfn, x, scale, bias, rm, rv, training, act, epsilon) end end diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index c103595f99..891c68715b 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module GroupNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs, TestExtras using LuxTestUtils: check_approx function setup_groupnorm(rng, aType, T, sz, affine) @@ -58,12 +58,12 @@ function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) @test ∂bias≈∂bias_simple atol=atol rtol=rtol end - @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any + @constinferred groupnorm(x, scale, bias, groups, act, epsilon) @jet groupnorm(x, scale, bias, groups, act, epsilon) if anonact !== act lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ)) - @test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa Any + @constinferred Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon) end @test y isa aType{T, length(sz)} diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index dd999ff09b..cc8b1e81b6 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module InstanceNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, TestExtras is_training(::Val{training}) where {training} = training @@ -24,12 +24,12 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType) atol = 1.0f-2 rtol = 1.0f-2 - @test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any + @constinferred instancenorm(x, scale, bias, training, act, epsilon) @jet instancenorm(x, scale, bias, training, act, epsilon) if anonact !== act && is_training(training) lfn = (x, sc, b, act, ϵ) -> sum(first(instancenorm(x, sc, b, Val(true), act, ϵ))) - @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, epsilon)) isa Any + @constinferred Zygote.gradient(lfn, x, scale, bias, act, epsilon) end @test y isa aType{T, length(sz)} @@ -46,15 +46,13 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType) y, nt = instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) - @test @inferred(instancenorm( - x, scale, bias, rm, rv, training, act, T(0.1), epsilon)) isa Any + @constinferred instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) @jet instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) if anonact !== act && is_training(training) lfn = (x, sc, b, rm, rv, act, m, ϵ) -> sum(first(instancenorm( x, sc, b, rm, rv, Val(true), act, m, ϵ))) - @test @inferred(Zygote.gradient( - lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon)) isa Any + @constinferred Zygote.gradient(lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon) end @test y isa aType{T, length(sz)} diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index 6b82390a4e..940e95c06b 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module LayerNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics, TestExtras using LuxTestUtils: check_approx function setup_layernorm(gen_f, aType, T, x_size, affine_shape, expand_dims::Bool=true) @@ -40,7 +40,7 @@ function run_layernorm_testing_core( epsilon = LuxLib.Utils.default_epsilon(T) _f = (args...) -> layernorm(args..., act, dims, epsilon) - @test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any + @constinferred layernorm(x, scale, bias, act, dims, epsilon) @jet layernorm(x, scale, bias, act, dims, epsilon) y = _f(x, scale, bias) @@ -61,7 +61,7 @@ function run_layernorm_testing_core( if anonact !== act lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ)) - @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon)) isa Any + @constinferred Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon) end end diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index 77cdab4702..c2072420f9 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -2,7 +2,7 @@ import Reexport: @reexport using LuxLib, MLDataDevices -@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote, NNlib +@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote, NNlib, TestExtras LuxTestUtils.jet_target_modules!(["LuxLib"])