Skip to content

Commit

Permalink
test: cleanup conv tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 19, 2024
1 parent 1752587 commit 0eefb89
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 22 deletions.
2 changes: 1 addition & 1 deletion lib/LuxLib/test/common_ops/activation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
@testset "$mode" for (mode, aType, ongpu, fp64) in MODES
@testset "$f: $T" for f in [identity, relu, sigmoid, sigmoid_fast, softplus,
logsigmoid, gelu, swish, lisht, tanh, tanh_fast],
T in [Float16, Float32, Float64]
T in [Float32, Float64]

!fp64 && T == Float64 && continue

Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/test/common_ops/bias_act_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@testset "$act, $T, $sz" for act in [
identity, relu, sigmoid, sigmoid_fast, softplus,
logsigmoid, gelu, swish, lisht, tanh, tanh_fast],
T in [Float16, Float32, Float64],
T in [Float32, Float64],
sz in [(2, 2, 3, 4), (4, 5)]

!fp64 && T == Float64 && continue
Expand Down
30 changes: 14 additions & 16 deletions lib/LuxLib/test/common_ops/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ end

calc_padding(pad, ::NTuple{N}, dilation, stride) where {N} = expand(Val(2 * N), pad)

sumabs2conv(args...) = sum(abs2, fused_conv_bias_activation(args...))

function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,
hasbias, groups, Tw, Tx, aType, mode, ongpu)
weight = convfilter(gen_f, Tw, kernel, 4 => 8; groups) |> aType
Expand All @@ -28,9 +30,8 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,

generic_testing = !(mode == "amdgpu" && (Tx == Float64 || Tw == Float64))

fp16 = Tx == Float16 || Tw == Float16
atol = fp16 ? 1.0f-1 : 1.0f-3
rtol = fp16 ? 1.0f-1 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

if generic_testing
y_generic = LuxLib.Impl.conv(x, weight, cdims)
Expand All @@ -45,36 +46,33 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,
@test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any
@jet fused_conv_bias_activation(activation, weight, x, bias, cdims)

__f = (σ, w, x, b, cdims) -> sum(abs2, fused_conv_bias_activation(σ, w, x, b, cdims))

if mode != "amdgpu" && activation !== anonact && !fp16
@test @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) isa Any
if mode != "amdgpu" && activation !== anonact
@test @inferred(Zygote.gradient(
sumabs2conv, activation, weight, x, bias, cdims
)) isa Any
else
try
@inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims))
@inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims))
@test true
catch e
e isa ErrorException || rethrow()
@test_broken false
end
end

__f_grad = let activation = activation, cdims = cdims
(w, x, b) -> __f(activation, w, x, b, cdims)
end

skip_backends = Any[AutoEnzyme()]
skip_backends = []
mp = Tx != Tw
mp && push!(skip_backends, AutoReverseDiff())
((mp && ongpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64))) &&
push!(skip_backends, AutoTracker())
@test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends, soft_fail=fp16)

@test_gradients(sumabs2conv, activation, weight, x, bias, cdims; atol, rtol,
skip_backends)
end

anonact = x -> gelu(x)

const ELTYPES = [(Float16, Float16), (Float32, Float16), (Float32, Float32),
(Float32, Float64), (Float64, Float64)]
const ELTYPES = [(Float32, Float32), (Float32, Float64), (Float64, Float64)]
const ACTIVATIONS = [
identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, swish, anonact]

Expand Down
6 changes: 3 additions & 3 deletions test/enzyme_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ const MODELS_LIST = Any[
(Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)),
(StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)),
(Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)),
(Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)),
(Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)),
(Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)),
(Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)),
(Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)),
Expand All @@ -71,9 +69,11 @@ if VERSION < v"1.11-"
# Only fails on CI
push!(
MODELS_LIST, Any[
(Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)),
(Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)),
(Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)),
(Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2))
(Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)),
(Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)),
]
)
end
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ const RETESTITEMS_NWORKER_THREADS = parse(

ReTestItems.runtests(Lux;
tags=(tag == "all" ? nothing : [Symbol(tag)]), testitem_timeout=2400,
nworkers, nworker_threads=RETESTITEMS_NWORKER_THREADS, retries=2,
nworkers, nworker_threads=RETESTITEMS_NWORKER_THREADS, retries=2
)
end
end
Expand Down

0 comments on commit 0eefb89

Please sign in to comment.