Skip to content

Commit

Permalink
test: use TestExtras in LuxLib testing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 22, 2024
1 parent 9412297 commit d728f4a
Show file tree
Hide file tree
Showing 11 changed files with 62 additions and 78 deletions.
2 changes: 2 additions & 0 deletions lib/LuxLib/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down Expand Up @@ -61,5 +62,6 @@ Static = "0.8.4, 1"
StaticArrays = "1.9.7"
Statistics = "1.10"
Test = "1.10"
TestExtras = "0.3.1"
Tracker = "0.2.36"
Zygote = "0.6.70"
12 changes: 6 additions & 6 deletions lib/LuxLib/test/common_ops/activation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@
@test eltype(y2) == T
@test eltype(y3) == T

@test @inferred(apply_act(f, x)) isa Any
@test @inferred(apply_act_fast(f, x)) isa Any
@test @inferred(apply_act_fast2(f, x)) isa Any
@constinferred apply_act(f, x)
@constinferred apply_act_fast(f, x)
@constinferred apply_act_fast2(f, x)

@jet apply_act_fast(f, x)
@jet apply_act_fast2(f, x)

@test @inferred(Zygote.gradient(apply_act, f, x)) isa Any
@constinferred Zygote.gradient(apply_act, f, x)
if f !== lisht
@test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any
@constinferred Zygote.gradient(apply_act_fast, f, x)
end
@test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any
@constinferred Zygote.gradient(apply_act_fast2, f, x)

@test_gradients(apply_act, f, x; atol, rtol)
@test_gradients(apply_act_fast, f, x; atol, rtol, skip_backends=[AutoEnzyme()])
Expand Down
38 changes: 14 additions & 24 deletions lib/LuxLib/test/common_ops/bias_act_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
bias_act_loss2(act, x, b) = sum(abs2, bias_activation(act, x, b))
bias_act_loss3(act, x, b) = sum(abs2, bias_activation!!(act, copy(x), b))

struct __Fix1{F, A}
f::F
act::A
end
(f::__Fix1)(x, b) = f.f(f.act, x, b)

@testset "$mode" for (mode, aType, ongpu, fp64) in MODES
@testset "$act, $T, $sz" for act in [
identity, relu, sigmoid, sigmoid_fast, softplus,
Expand All @@ -27,38 +21,34 @@
y2 = bias_act_loss2(act, x, b)
y3 = bias_act_loss3(act, x, b)

fp16 = T == Float16
atol = fp16 ? 1.0f-2 : 1.0f-3
rtol = fp16 ? 1.0f-2 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

@test y1y2 atol=atol rtol=rtol
@test y1y3 atol=atol rtol=rtol
@test eltype(y1) == T
@test eltype(y2) == T
@test eltype(y3) == T

@test @inferred(bias_act_loss1(act, x, b)) isa Any
@test @inferred(bias_act_loss2(act, x, b)) isa Any
@test @inferred(bias_act_loss3(act, x, b)) isa Any
@constinferred bias_act_loss1(act, x, b)
@constinferred bias_act_loss2(act, x, b)
@constinferred bias_act_loss3(act, x, b)

@jet bias_act_loss2(act, x, b)
@jet bias_act_loss3(act, x, b)

if act !== lisht && T != Float16
@test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any
@test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any
if act !== lisht
@constinferred Zygote.gradient(bias_act_loss2, act, x, b)
@constinferred Zygote.gradient(bias_act_loss3, act, x, b)
end

@test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol,
soft_fail=fp16 ? [AutoFiniteDiff()] : [])
@test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol,
soft_fail=fp16 ? [AutoFiniteDiff()] : [])
@test_gradients(__Fix1(bias_act_loss3, act), x, b; atol, rtol,
soft_fail=fp16 ? [AutoFiniteDiff()] : [])
@test_gradients(bias_act_loss1, act, x, b; atol, rtol)
@test_gradients(bias_act_loss2, act, x, b; atol, rtol)
@test_gradients(bias_act_loss3, act, x, b; atol, rtol)

∂x1, ∂b1 = Zygote.gradient(__Fix1(bias_act_loss1, act), x, b)
∂x2, ∂b2 = Zygote.gradient(__Fix1(bias_act_loss2, act), x, b)
∂x3, ∂b3 = Zygote.gradient(__Fix1(bias_act_loss3, act), x, b)
_, ∂x1, ∂b1 = Zygote.pullback(bias_act_loss1, act, x, b)
_, ∂x2, ∂b2 = Zygote.pullback(bias_act_loss2, act, x, b)
_, ∂x3, ∂b3 = Zygote.pullback(bias_act_loss3, act, x, b)

@test ∂x1∂x2 atol=atol rtol=rtol
@test ∂x1∂x3 atol=atol rtol=rtol
Expand Down
15 changes: 7 additions & 8 deletions lib/LuxLib/test/common_ops/conv_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module ConvSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, TestExtras

expand(_, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
Expand Down Expand Up @@ -43,20 +43,19 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,

@test eltype(y) == promote_type(Tw, Tx)

@test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any
@constinferred fused_conv_bias_activation(activation, weight, x, bias, cdims)
@jet fused_conv_bias_activation(activation, weight, x, bias, cdims)

if mode != "amdgpu" && activation !== anonact
@test @inferred(Zygote.gradient(
sumabs2conv, activation, weight, x, bias, cdims
)) isa Any
@constinferred Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)
else
try
@inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims))
@test true
@constinferred Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)
catch e
e isa ErrorException || rethrow()
@test_broken false
@constinferred_broken Zygote.gradient(
sumabs2conv, activation, weight, x, bias, cdims
)
end
end

Expand Down
14 changes: 7 additions & 7 deletions lib/LuxLib/test/common_ops/dense_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module DenseSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs, TestExtras

anonact = x -> x^3

Expand Down Expand Up @@ -27,14 +27,14 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu
@test y y_generic
@test eltype(y) == promote_type(Tw, Tx)

@test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any
@constinferred fused_dense_bias_activation(activation, w, x, bias)
@jet fused_dense_bias_activation(activation, w, x, bias)

atol = 1.0f-3
rtol = 1.0f-3

if activation !== anonact
@test @inferred(Zygote.gradient(sumabs2dense, activation, w, x, bias)) isa Any
@constinferred Zygote.gradient(sumabs2dense, activation, w, x, bias)
end

skip_backends = []
Expand Down Expand Up @@ -117,23 +117,23 @@ end
end

@testitem "Fused Dense: StaticArrays" tags=[:dense] begin
using StaticArrays, NNlib
using StaticArrays, NNlib, TestExtras

x = @SArray rand(2, 4)
weight = @SArray rand(3, 2)
bias = @SArray rand(3)

@test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa SArray
@constinferred fused_dense_bias_activation(relu, weight, x, bias)
end

@testitem "Fused Dense: CPU No Scalar Indexing" tags=[:dense] begin
using JLArrays, NNlib
using JLArrays, NNlib, TestExtras

x = JLArray(rand(Float32, 2, 4))
weight = JLArray(rand(Float32, 3, 2))
bias = JLArray(rand(Float32, 3))

@test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa JLArray
@constinferred fused_dense_bias_activation(relu, weight, x, bias)
@test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp
end

Expand Down
25 changes: 11 additions & 14 deletions lib/LuxLib/test/common_ops/dropout_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

x = randn(rng, T, x_shape) |> aType

@test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any
@constinferred dropout(rng, x, T(0.5), Val(true), T(2), dims)

y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), dims)

Expand All @@ -21,10 +21,10 @@
@test rng != rng_

@jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims)))
@test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any
@constinferred dropout(rng, x, T(0.5), Val(true), T(2), dims)

__f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, dims)))
@test @inferred(Zygote.gradient(__f, x)) isa Any
@constinferred Zygote.gradient(__f, x)

@test_gradients(sumabs2first,
dropout, rng, x, T(0.5), Val(true), T(2), dims; atol=1.0f-3, rtol=1.0f-3)
Expand Down Expand Up @@ -54,8 +54,7 @@ end
mask = rand(T, x_shape) |> aType

# Update mask
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)) isa Any
@constinferred dropout(rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)
Expand All @@ -69,7 +68,7 @@ end

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, :)))
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any
@constinferred Zygote.gradient(__f, x, mask)

@test_gradients(sumabs2first,
dropout, rng, x, LuxTestUtils.Constant(mask), T(0.5), Val(true), Val(true),
Expand All @@ -79,8 +78,7 @@ end
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)))

# Try using mask if possible (possible!!)
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)) isa Any
@constinferred dropout(rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)
Expand All @@ -94,7 +92,7 @@ end

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, :)))
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any
@constinferred Zygote.gradient(__f, x, mask)

@test_gradients(sumabs2first,
dropout, rng, x, LuxTestUtils.Constant(mask),
Expand All @@ -107,8 +105,7 @@ end
mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType

# Testing Mode
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)) isa Any
@constinferred dropout(rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)
Expand All @@ -135,7 +132,7 @@ end

x = randn(rng, T, x_shape) |> aType

@test @inferred(alpha_dropout(rng, x, T(0.5), Val(true))) isa Any
@constinferred alpha_dropout(rng, x, T(0.5), Val(true))

y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true))

Expand All @@ -146,13 +143,13 @@ end
@test_broken std(y)std(x) atol=1.0f-2 rtol=1.0f-2

__f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true))))
@test @inferred(Zygote.gradient(__f, x)) isa Any
@constinferred Zygote.gradient(__f, x)

@test_gradients(sumabs2first,
alpha_dropout, rng, x, T(0.5), Val(true); atol=1.0f-3, rtol=1.0f-3)

@jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))
@test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any
@constinferred alpha_dropout(rng, x, T(0.5), Val(false))

y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false))

Expand Down
8 changes: 3 additions & 5 deletions lib/LuxLib/test/normalization/batchnorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module BatchNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, TestExtras

function setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool)
x = gen_f(T, sz) |> aType
Expand Down Expand Up @@ -69,8 +69,7 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act,
end
end

@test @inferred(batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa
Any
@constinferred batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)
@jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)

@test y isa aType{T, length(sz)}
Expand All @@ -91,8 +90,7 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act,
if anonact !== act
lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm(
x, sc, b, rm, rv, tr, act, ϵ)))
@test @inferred(Zygote.gradient(
lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any
@constinferred Zygote.gradient(lfn, x, scale, bias, rm, rv, training, act, epsilon)
end
end

Expand Down
6 changes: 3 additions & 3 deletions lib/LuxLib/test/normalization/groupnorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module GroupNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs, TestExtras
using LuxTestUtils: check_approx

function setup_groupnorm(rng, aType, T, sz, affine)
Expand Down Expand Up @@ -58,12 +58,12 @@ function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu)
@test ∂bias∂bias_simple atol=atol rtol=rtol
end

@test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any
@constinferred groupnorm(x, scale, bias, groups, act, epsilon)
@jet groupnorm(x, scale, bias, groups, act, epsilon)

if anonact !== act
lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ))
@test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa Any
@constinferred Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)
end

@test y isa aType{T, length(sz)}
Expand Down
12 changes: 5 additions & 7 deletions lib/LuxLib/test/normalization/instancenorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module InstanceNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, TestExtras

is_training(::Val{training}) where {training} = training

Expand All @@ -24,12 +24,12 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType)
atol = 1.0f-2
rtol = 1.0f-2

@test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any
@constinferred instancenorm(x, scale, bias, training, act, epsilon)
@jet instancenorm(x, scale, bias, training, act, epsilon)

if anonact !== act && is_training(training)
lfn = (x, sc, b, act, ϵ) -> sum(first(instancenorm(x, sc, b, Val(true), act, ϵ)))
@test @inferred(Zygote.gradient(lfn, x, scale, bias, act, epsilon)) isa Any
@constinferred Zygote.gradient(lfn, x, scale, bias, act, epsilon)
end

@test y isa aType{T, length(sz)}
Expand All @@ -46,15 +46,13 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType)

y, nt = instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon)

@test @inferred(instancenorm(
x, scale, bias, rm, rv, training, act, T(0.1), epsilon)) isa Any
@constinferred instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon)
@jet instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon)

if anonact !== act && is_training(training)
lfn = (x, sc, b, rm, rv, act, m, ϵ) -> sum(first(instancenorm(
x, sc, b, rm, rv, Val(true), act, m, ϵ)))
@test @inferred(Zygote.gradient(
lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon)) isa Any
@constinferred Zygote.gradient(lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon)
end

@test y isa aType{T, length(sz)}
Expand Down
Loading

0 comments on commit d728f4a

Please sign in to comment.