From b279050ac909cfc9443f23e50155029d5192a9d7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 18 Nov 2024 14:59:44 -0500 Subject: [PATCH] fix: more enzyme support --- lib/LuxLib/src/impl/batchnorm.jl | 12 ++++++++++-- lib/LuxLib/src/traits.jl | 5 ++--- lib/LuxLib/src/utils.jl | 8 ++++++-- test/layers/conv_tests.jl | 3 --- test/layers/normalize_tests.jl | 8 +++++--- 5 files changed, 23 insertions(+), 13 deletions(-) diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl index b15490f1fb..13c1fccfc1 100644 --- a/lib/LuxLib/src/impl/batchnorm.jl +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -115,7 +115,11 @@ function apply_batchnorm_scale_bias_act_cpu!( if size(y, 1) == 1 apply_batchnorm_scale_bias_act_2d_serial_cpu!(y, γ′, β′, x, σ) else - apply_batchnorm_scale_bias_act_3d_threaded_cpu!(y, γ′, β′, x, σ) + if Utils.within_enzyme_autodiff() + apply_batchnorm_scale_bias_act_3d_serial_cpu!(y, γ′, β′, x, σ) + else + apply_batchnorm_scale_bias_act_3d_threaded_cpu!(y, γ′, β′, x, σ) + end end end @@ -160,7 +164,11 @@ function apply_batchnorm_scale_bias_cpu!(y::AbstractArray{yT, 3}, γ′::Abstrac if size(y, 1) == 1 apply_batchnorm_scale_bias_2d_serial_cpu!(y, γ′, β′, x) else - apply_batchnorm_scale_bias_3d_threaded_cpu!(y, γ′, β′, x) + if Utils.within_enzyme_autodiff() + apply_batchnorm_scale_bias_3d_serial_cpu!(y, γ′, β′, x) + else + apply_batchnorm_scale_bias_3d_threaded_cpu!(y, γ′, β′, x) + end end end diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index a9164f2c4d..6e7ead343f 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -82,7 +82,7 @@ using Hwloc: Hwloc using Static: static, False, True using ..LuxLib: DISABLE_LOOP_VECTORIZATION -using ..Utils: is_extension_loaded, safe_minimum, unsafe_known +using ..Utils: is_extension_loaded, safe_minimum, unsafe_known, within_enzyme_autodiff const CRC = ChainRulesCore @@ -136,8 +136,7 @@ CRC.@non_differentiable explicit_blas_loaded() use_octavian() = False() else function use_octavian() - unsafe_known(is_extension_loaded(Val(:Enzyme))) && EnzymeCore.within_autodiff() && - return False() + within_enzyme_autodiff() && return False() return is_extension_loaded(Val(:Octavian)) & is_x86_64() & (INTEL_HARDWARE | AMD_RYZEN_HARDWARE) end diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl index 14748d67f8..1ef926b93a 100644 --- a/lib/LuxLib/src/utils.jl +++ b/lib/LuxLib/src/utils.jl @@ -284,6 +284,11 @@ within_autodiff(::AbstractArray{<:ForwardDiff.Dual}) = True() CRC.rrule(::typeof(within_autodiff), x) = True(), _ -> (∂∅, ∂∅) +function within_enzyme_autodiff() + unsafe_known(is_extension_loaded(Val(:Enzyme))) && return EnzymeCore.within_autodiff() + return false +end + static_training_mode(::Nothing, args...) = within_autodiff_vararg(args...) function static_training_mode( @@ -330,8 +335,7 @@ CRC.@non_differentiable static_training_mode_check(::Any...) else @inline function can_loopvec_args(args...) # Avoid loop vectorization inside Enzyme autodiff calls - unsafe_known(is_extension_loaded(Val(:Enzyme))) && EnzymeCore.within_autodiff() && - return false + within_enzyme_autodiff() && return false return can_loopvec_args_check(is_extension_loaded(Val(:LoopVectorization)), args...) end end diff --git a/test/layers/conv_tests.jl b/test/layers/conv_tests.jl index 6a98c9e63c..1f36b8f7a3 100644 --- a/test/layers/conv_tests.jl +++ b/test/layers/conv_tests.jl @@ -263,9 +263,6 @@ end broken_backends = Any[AutoTracker()] umode == :nearest || push!(broken_backends, AutoReverseDiff()) - if VERSION < v"1.11-" - push!(broken_backends, AutoEnzyme()) - end @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3, broken_backends) end diff --git a/test/layers/normalize_tests.jl b/test/layers/normalize_tests.jl index 8ea3af96d4..9ee90b85e5 100644 --- a/test/layers/normalize_tests.jl +++ b/test/layers/normalize_tests.jl @@ -44,7 +44,8 @@ @jet m(x, ps, st) @test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3, - rtol=1.0f-3, skip_backends=[AutoFiniteDiff(), AutoEnzyme()]) + rtol=1.0f-3, skip_backends=[AutoFiniteDiff()], + broken_backends=[AutoEnzyme()]) @testset for affine in (true, false) m = BatchNorm(2; affine, track_stats=false) @@ -54,7 +55,7 @@ @jet m(x, ps, Lux.testmode(st)) @test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3, - rtol=1.0f-3, skip_backends=[AutoFiniteDiff(), AutoEnzyme()]) + rtol=1.0f-3, skip_backends=[AutoFiniteDiff()]) # with activation function m = BatchNorm(2, sigmoid; affine) @@ -68,7 +69,8 @@ sigmoid.((x .- st_.running_mean) ./ sqrt.(st_.running_var .+ m.epsilon)) @jet m(x, ps, Lux.testmode(st)) @test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3, - rtol=1.0f-3, skip_backends=[AutoFiniteDiff(), AutoEnzyme()]) + rtol=1.0f-3, skip_backends=[AutoFiniteDiff()], + broken_backends=[AutoEnzyme()]) m = BatchNorm(32; affine) x = randn(Float32, 416, 416, 32, 1) |> aType