Skip to content

Commit

Permalink
fix: more enzyme support
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 18, 2024
1 parent fe9ac31 commit b279050
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 13 deletions.
12 changes: 10 additions & 2 deletions lib/LuxLib/src/impl/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ function apply_batchnorm_scale_bias_act_cpu!(
if size(y, 1) == 1
apply_batchnorm_scale_bias_act_2d_serial_cpu!(y, γ′, β′, x, σ)
else
apply_batchnorm_scale_bias_act_3d_threaded_cpu!(y, γ′, β′, x, σ)
if Utils.within_enzyme_autodiff()
apply_batchnorm_scale_bias_act_3d_serial_cpu!(y, γ′, β′, x, σ)
else
apply_batchnorm_scale_bias_act_3d_threaded_cpu!(y, γ′, β′, x, σ)
end
end
end

Expand Down Expand Up @@ -160,7 +164,11 @@ function apply_batchnorm_scale_bias_cpu!(y::AbstractArray{yT, 3}, γ′::Abstrac
if size(y, 1) == 1
apply_batchnorm_scale_bias_2d_serial_cpu!(y, γ′, β′, x)
else
apply_batchnorm_scale_bias_3d_threaded_cpu!(y, γ′, β′, x)
if Utils.within_enzyme_autodiff()
apply_batchnorm_scale_bias_3d_serial_cpu!(y, γ′, β′, x)
else
apply_batchnorm_scale_bias_3d_threaded_cpu!(y, γ′, β′, x)
end
end
end

Expand Down
5 changes: 2 additions & 3 deletions lib/LuxLib/src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ using Hwloc: Hwloc
using Static: static, False, True

using ..LuxLib: DISABLE_LOOP_VECTORIZATION
using ..Utils: is_extension_loaded, safe_minimum, unsafe_known
using ..Utils: is_extension_loaded, safe_minimum, unsafe_known, within_enzyme_autodiff

const CRC = ChainRulesCore

Expand Down Expand Up @@ -136,8 +136,7 @@ CRC.@non_differentiable explicit_blas_loaded()
use_octavian() = False()
else
function use_octavian()
unsafe_known(is_extension_loaded(Val(:Enzyme))) && EnzymeCore.within_autodiff() &&
return False()
within_enzyme_autodiff() && return False()
return is_extension_loaded(Val(:Octavian)) & is_x86_64() &
(INTEL_HARDWARE | AMD_RYZEN_HARDWARE)
end
Expand Down
8 changes: 6 additions & 2 deletions lib/LuxLib/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,11 @@ within_autodiff(::AbstractArray{<:ForwardDiff.Dual}) = True()

CRC.rrule(::typeof(within_autodiff), x) = True(), _ -> (∂∅, ∂∅)

function within_enzyme_autodiff()
unsafe_known(is_extension_loaded(Val(:Enzyme))) && return EnzymeCore.within_autodiff()
return false
end

static_training_mode(::Nothing, args...) = within_autodiff_vararg(args...)

function static_training_mode(
Expand Down Expand Up @@ -330,8 +335,7 @@ CRC.@non_differentiable static_training_mode_check(::Any...)
else
@inline function can_loopvec_args(args...)
# Avoid loop vectorization inside Enzyme autodiff calls
unsafe_known(is_extension_loaded(Val(:Enzyme))) && EnzymeCore.within_autodiff() &&
return false
within_enzyme_autodiff() && return false
return can_loopvec_args_check(is_extension_loaded(Val(:LoopVectorization)), args...)
end
end
Expand Down
3 changes: 0 additions & 3 deletions test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,6 @@ end

broken_backends = Any[AutoTracker()]
umode == :nearest || push!(broken_backends, AutoReverseDiff())
if VERSION < v"1.11-"
push!(broken_backends, AutoEnzyme())
end
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends)
end
Expand Down
8 changes: 5 additions & 3 deletions test/layers/normalize_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@

@jet m(x, ps, st)
@test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3,
rtol=1.0f-3, skip_backends=[AutoFiniteDiff(), AutoEnzyme()])
rtol=1.0f-3, skip_backends=[AutoFiniteDiff()],
broken_backends=[AutoEnzyme()])

@testset for affine in (true, false)
m = BatchNorm(2; affine, track_stats=false)
Expand All @@ -54,7 +55,7 @@

@jet m(x, ps, Lux.testmode(st))
@test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3,
rtol=1.0f-3, skip_backends=[AutoFiniteDiff(), AutoEnzyme()])
rtol=1.0f-3, skip_backends=[AutoFiniteDiff()])

# with activation function
m = BatchNorm(2, sigmoid; affine)
Expand All @@ -68,7 +69,8 @@
sigmoid.((x .- st_.running_mean) ./ sqrt.(st_.running_var .+ m.epsilon))
@jet m(x, ps, Lux.testmode(st))
@test_gradients(sumabs2first, m, x, ps, st; atol=1.0f-3,
rtol=1.0f-3, skip_backends=[AutoFiniteDiff(), AutoEnzyme()])
rtol=1.0f-3, skip_backends=[AutoFiniteDiff()],
broken_backends=[AutoEnzyme()])

m = BatchNorm(32; affine)
x = randn(Float32, 416, 416, 32, 1) |> aType
Expand Down

0 comments on commit b279050

Please sign in to comment.