diff --git a/test/common_ops/dropout_tests.jl b/test/common_ops/dropout_tests.jl index 19db98c5..6cf90d5f 100644 --- a/test/common_ops/dropout_tests.jl +++ b/test/common_ops/dropout_tests.jl @@ -105,9 +105,11 @@ end soft_fail = T == Float16 ? Any[AutoFiniteDiff()] : [] skip_backends = length(x_shape) == 5 ? [AutoEnzyme()] : [] + broken_backends = T == Float16 && Sys.iswindows() && length(x_shape) != 5 ? + [AutoEnzyme()] : [] test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail, skip_backends, - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + broken_backends) @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)))