Skip to content

Commit

Permalink
fix: enzyme support for pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 18, 2024
1 parent 272e4f0 commit fe9ac31
Show file tree
Hide file tree
Showing 11 changed files with 33 additions and 32 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ LuxZygoteExt = "Zygote"
ADTypes = "1.10"
Adapt = "4.1"
ArgCheck = "2.3"
ArrayInterface = "7.10"
ArrayInterface = "7.17.1"
CUDA = "5.3.2"
ChainRulesCore = "1.24"
Compat = "4.16"
ComponentArrays = "0.15.18"
ConcreteStructs = "0.2.3"
DispatchDoctor = "0.4.12"
Enzyme = "0.13.14"
Enzyme = "0.13.15"
EnzymeCore = "0.8.5"
FastClosures = "0.3.2"
Flux = "0.14.25"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ ChainRulesCore = "1.24"
ComponentArrays = "0.15.18"
Documenter = "1.4"
DocumenterVitepress = "0.1.3"
Enzyme = "0.13.14"
Enzyme = "0.13.15"
FiniteDiff = "2.23.1"
ForwardDiff = "0.10.36"
Functors = "0.5"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ ChainRulesCore = "1.24"
Compat = "4.16"
CpuId = "0.3"
DispatchDoctor = "0.4.12"
Enzyme = "0.13.14"
Enzyme = "0.13.15"
EnzymeCore = "0.8.5"
FastClosures = "0.3.2"
ForwardDiff = "0.10.36"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ BLISBLAS = "0.1"
BenchmarkTools = "1.5"
ChainRulesCore = "1.24"
ComponentArrays = "0.15.18"
Enzyme = "0.13.14"
Enzyme = "0.13.15"
EnzymeCore = "0.8.5"
ExplicitImports = "1.9.0"
ForwardDiff = "0.10.36"
Expand Down
4 changes: 2 additions & 2 deletions lib/LuxTestUtils/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ MLDataDevices = {path = "../MLDataDevices"}

[compat]
ADTypes = "1.10"
ArrayInterface = "7.9"
ArrayInterface = "7.17.1"
ChainRulesCore = "1.24.0"
ComponentArrays = "0.15.18"
DispatchDoctor = "0.4.12"
Enzyme = "0.13.14"
Enzyme = "0.13.15"
FiniteDiff = "2.23.1"
ForwardDiff = "0.10.36"
Functors = "0.5"
Expand Down
10 changes: 1 addition & 9 deletions lib/LuxTestUtils/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,12 @@ end
function flatten_gradient_computable(f, nt)
if needs_gradient(nt)
x_flat, re = Optimisers.destructure(nt)
_f = x -> f(Functors.fmap(aos_to_soa, re(x)))
_f = x -> f(Functors.fmap(ArrayInterface.aos_to_soa, re(x)))
return _f, x_flat, re
end
return nothing, nothing, nothing
end

# XXX: We can use ArrayInterface after https://github.com/JuliaArrays/ArrayInterface.jl/pull/457
aos_to_soa(x) = x
function aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal, N}) where {N}
y = length(x) > 1 ? reduce(vcat, x) : reduce(vcat, [x[1], x[1]])[1:1]
return reshape(y, size(x))
end
aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal, N}) where {N} = Tracker.collect(x)

function needs_gradient(y)
leaves = Functors.fleaves(y)
isempty(leaves) && return false
Expand Down
1 change: 1 addition & 0 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using ArrayInterface: ArrayInterface
using ChainRulesCore: ChainRulesCore, NoTangent, @thunk
using Compat: @compat
using ConcreteStructs: @concrete
using EnzymeCore: EnzymeRules
using FastClosures: @closure
using Functors: Functors, fmap
using GPUArraysCore: @allowscalar
Expand Down
15 changes: 11 additions & 4 deletions src/layers/pooling.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
abstract type AbstractPoolMode end

CRC.@non_differentiable (::AbstractPoolMode)(::Any...)
(m::AbstractPoolMode)(x) = calculate_pool_dims(m, x)

function calculate_pool_dims end

CRC.@non_differentiable calculate_pool_dims(::Any...)
EnzymeRules.inactive(::typeof(calculate_pool_dims), ::Any...) = true

@concrete struct GenericPoolMode <: AbstractPoolMode
kernel_size <: Tuple{Vararg{IntegerType}}
Expand All @@ -9,17 +14,19 @@ CRC.@non_differentiable (::AbstractPoolMode)(::Any...)
dilation <: Tuple{Vararg{IntegerType}}
end

(m::GenericPoolMode)(x) = PoolDims(x, m.kernel_size; padding=m.pad, m.stride, m.dilation)
function calculate_pool_dims(m::GenericPoolMode, x)
return PoolDims(x, m.kernel_size; padding=m.pad, m.stride, m.dilation)
end

struct GlobalPoolMode <: AbstractPoolMode end

(::GlobalPoolMode)(x) = PoolDims(x, size(x)[1:(end - 2)])
calculate_pool_dims(::GlobalPoolMode, x) = PoolDims(x, size(x)[1:(end - 2)])

@concrete struct AdaptivePoolMode <: AbstractPoolMode
out_size <: Tuple{Vararg{IntegerType}}
end

function (m::AdaptivePoolMode)(x)
function calculate_pool_dims(m::AdaptivePoolMode, x)
in_size = size(x)[1:(end - 2)]
stride = in_size m.out_size
kernel_size = in_size .- (m.out_size .- 1) .* stride
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ ChainRulesCore = "1.24"
ComponentArrays = "0.15.18"
DispatchDoctor = "0.4.12"
Documenter = "1.4"
Enzyme = "0.13.14"
Enzyme = "0.13.15"
ExplicitImports = "1.9.0"
ForwardDiff = "0.10.36"
Functors = "0.5"
Expand Down
14 changes: 11 additions & 3 deletions test/enzyme_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ generic_loss_function(model, x, ps, st) = sum(abs2, first(model(x, ps, st)))
function compute_enzyme_gradient(model, x, ps, st)
dx = Enzyme.make_zero(x)
dps = Enzyme.make_zero(ps)
Enzyme.autodiff(Reverse, generic_loss_function, Active, Const(model),
Duplicated(x, dx), Duplicated(ps, dps), Const(st))
Enzyme.autodiff(
Enzyme.set_runtime_activity(Reverse),
generic_loss_function, Active, Const(model),
Duplicated(x, dx), Duplicated(ps, dps), Const(st)
)
return dx, dps
end

Expand Down Expand Up @@ -40,7 +43,8 @@ const MODELS_LIST = [
(Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), rand(Float32, 5, 5, 2, 2)),
(Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)),
(Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)),
(Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)),
# XXX: https://github.com/EnzymeAD/Enzyme.jl/issues/2105
# (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)),
# XXX: https://github.com/LuxDL/Lux.jl/issues/1024
# (Bilinear((2, 2) => 3), randn(Float32, 2, 3)),
(SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)),
Expand Down Expand Up @@ -83,6 +87,8 @@ end
ongpu && continue

@testset "[$(i)] $(nameof(typeof(model)))" for (i, (model, x)) in enumerate(MODELS_LIST)
display(model)

ps, st = Lux.setup(rng, model) |> dev
x = x |> aType

Expand All @@ -107,6 +113,8 @@ end
ongpu && continue

@testset "[$(i)] $(nameof(typeof(model)))" for (i, (model, x)) in enumerate(MODELS_LIST)
display(model)

ps, st = Lux.setup(rng, model)
ps = ComponentArray(ps)
st = st |> dev
Expand Down
9 changes: 1 addition & 8 deletions test/layers/pooling_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@
@test_gradients(sumabs2first, layer, y, ps, st; atol=1.0f-3, rtol=1.0f-3,
broken_backends)

broken_backends2 = broken_backends
if VERSION v"1.11-"
push!(broken_backends2, AutoEnzyme())
elseif ltype == :LPPool
push!(broken_backends2, AutoEnzyme())
end

layer = getfield(Lux, global_ltype)()
display(layer)
ps, st = Lux.setup(rng, layer) |> dev
Expand All @@ -54,7 +47,7 @@
@test layer(x, ps, st)[1] == nnlib_op[ltype](x, PoolDims(x, size(x)[1:2]))
@jet layer(x, ps, st)
@test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3,
rtol=1.0f-3, broken_backends=broken_backends2)
rtol=1.0f-3, broken_backends)

layer = getfield(Lux, ltype)((2, 2))
display(layer)
Expand Down

0 comments on commit fe9ac31

Please sign in to comment.