From fe9ac311d450982d466107d0b4810d5225fd7071 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 18 Nov 2024 13:55:26 -0500 Subject: [PATCH] fix: enzyme support for pooling --- Project.toml | 4 ++-- docs/Project.toml | 2 +- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/test/Project.toml | 2 +- lib/LuxTestUtils/Project.toml | 4 ++-- lib/LuxTestUtils/src/utils.jl | 10 +--------- src/Lux.jl | 1 + src/layers/pooling.jl | 15 +++++++++++---- test/Project.toml | 2 +- test/enzyme_tests.jl | 14 +++++++++++--- test/layers/pooling_tests.jl | 9 +-------- 11 files changed, 33 insertions(+), 32 deletions(-) diff --git a/Project.toml b/Project.toml index 1b8ac950b0..c9d61e12da 100644 --- a/Project.toml +++ b/Project.toml @@ -69,14 +69,14 @@ LuxZygoteExt = "Zygote" ADTypes = "1.10" Adapt = "4.1" ArgCheck = "2.3" -ArrayInterface = "7.10" +ArrayInterface = "7.17.1" CUDA = "5.3.2" ChainRulesCore = "1.24" Compat = "4.16" ComponentArrays = "0.15.18" ConcreteStructs = "0.2.3" DispatchDoctor = "0.4.12" -Enzyme = "0.13.14" +Enzyme = "0.13.15" EnzymeCore = "0.8.5" FastClosures = "0.3.2" Flux = "0.14.25" diff --git a/docs/Project.toml b/docs/Project.toml index 52c3844a17..01eb7c2014 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -37,7 +37,7 @@ ChainRulesCore = "1.24" ComponentArrays = "0.15.18" Documenter = "1.4" DocumenterVitepress = "0.1.3" -Enzyme = "0.13.14" +Enzyme = "0.13.15" FiniteDiff = "2.23.1" ForwardDiff = "0.10.36" Functors = "0.5" diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 5cb751d7db..30def2d5da 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -65,7 +65,7 @@ ChainRulesCore = "1.24" Compat = "4.16" CpuId = "0.3" DispatchDoctor = "0.4.12" -Enzyme = "0.13.14" +Enzyme = "0.13.15" EnzymeCore = "0.8.5" FastClosures = "0.3.2" ForwardDiff = "0.10.36" diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 6386cf83e2..3a0d145754 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -38,7 +38,7 @@ BLISBLAS = "0.1" BenchmarkTools = "1.5" ChainRulesCore = "1.24" ComponentArrays = "0.15.18" -Enzyme = "0.13.14" +Enzyme = "0.13.15" EnzymeCore = "0.8.5" ExplicitImports = "1.9.0" ForwardDiff = "0.10.36" diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml index 690929e31e..76b0bfeb2b 100644 --- a/lib/LuxTestUtils/Project.toml +++ b/lib/LuxTestUtils/Project.toml @@ -26,11 +26,11 @@ MLDataDevices = {path = "../MLDataDevices"} [compat] ADTypes = "1.10" -ArrayInterface = "7.9" +ArrayInterface = "7.17.1" ChainRulesCore = "1.24.0" ComponentArrays = "0.15.18" DispatchDoctor = "0.4.12" -Enzyme = "0.13.14" +Enzyme = "0.13.15" FiniteDiff = "2.23.1" ForwardDiff = "0.10.36" Functors = "0.5" diff --git a/lib/LuxTestUtils/src/utils.jl b/lib/LuxTestUtils/src/utils.jl index 1da442eb3f..e9587e9858 100644 --- a/lib/LuxTestUtils/src/utils.jl +++ b/lib/LuxTestUtils/src/utils.jl @@ -53,20 +53,12 @@ end function flatten_gradient_computable(f, nt) if needs_gradient(nt) x_flat, re = Optimisers.destructure(nt) - _f = x -> f(Functors.fmap(aos_to_soa, re(x))) + _f = x -> f(Functors.fmap(ArrayInterface.aos_to_soa, re(x))) return _f, x_flat, re end return nothing, nothing, nothing end -# XXX: We can use ArrayInterface after https://github.com/JuliaArrays/ArrayInterface.jl/pull/457 -aos_to_soa(x) = x -function aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal, N}) where {N} - y = length(x) > 1 ? reduce(vcat, x) : reduce(vcat, [x[1], x[1]])[1:1] - return reshape(y, size(x)) -end -aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal, N}) where {N} = Tracker.collect(x) - function needs_gradient(y) leaves = Functors.fleaves(y) isempty(leaves) && return false diff --git a/src/Lux.jl b/src/Lux.jl index 525e331fa6..c31e2aeadb 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -8,6 +8,7 @@ using ArrayInterface: ArrayInterface using ChainRulesCore: ChainRulesCore, NoTangent, @thunk using Compat: @compat using ConcreteStructs: @concrete +using EnzymeCore: EnzymeRules using FastClosures: @closure using Functors: Functors, fmap using GPUArraysCore: @allowscalar diff --git a/src/layers/pooling.jl b/src/layers/pooling.jl index f29bc8db41..819aaaeebf 100644 --- a/src/layers/pooling.jl +++ b/src/layers/pooling.jl @@ -1,6 +1,11 @@ abstract type AbstractPoolMode end -CRC.@non_differentiable (::AbstractPoolMode)(::Any...) +(m::AbstractPoolMode)(x) = calculate_pool_dims(m, x) + +function calculate_pool_dims end + +CRC.@non_differentiable calculate_pool_dims(::Any...) +EnzymeRules.inactive(::typeof(calculate_pool_dims), ::Any...) = true @concrete struct GenericPoolMode <: AbstractPoolMode kernel_size <: Tuple{Vararg{IntegerType}} @@ -9,17 +14,19 @@ CRC.@non_differentiable (::AbstractPoolMode)(::Any...) dilation <: Tuple{Vararg{IntegerType}} end -(m::GenericPoolMode)(x) = PoolDims(x, m.kernel_size; padding=m.pad, m.stride, m.dilation) +function calculate_pool_dims(m::GenericPoolMode, x) + return PoolDims(x, m.kernel_size; padding=m.pad, m.stride, m.dilation) +end struct GlobalPoolMode <: AbstractPoolMode end -(::GlobalPoolMode)(x) = PoolDims(x, size(x)[1:(end - 2)]) +calculate_pool_dims(::GlobalPoolMode, x) = PoolDims(x, size(x)[1:(end - 2)]) @concrete struct AdaptivePoolMode <: AbstractPoolMode out_size <: Tuple{Vararg{IntegerType}} end -function (m::AdaptivePoolMode)(x) +function calculate_pool_dims(m::AdaptivePoolMode, x) in_size = size(x)[1:(end - 2)] stride = in_size .÷ m.out_size kernel_size = in_size .- (m.out_size .- 1) .* stride diff --git a/test/Project.toml b/test/Project.toml index 7440902d72..d308dbfb4f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -48,7 +48,7 @@ ChainRulesCore = "1.24" ComponentArrays = "0.15.18" DispatchDoctor = "0.4.12" Documenter = "1.4" -Enzyme = "0.13.14" +Enzyme = "0.13.15" ExplicitImports = "1.9.0" ForwardDiff = "0.10.36" Functors = "0.5" diff --git a/test/enzyme_tests.jl b/test/enzyme_tests.jl index 0a1fd0e903..4895acee67 100644 --- a/test/enzyme_tests.jl +++ b/test/enzyme_tests.jl @@ -10,8 +10,11 @@ generic_loss_function(model, x, ps, st) = sum(abs2, first(model(x, ps, st))) function compute_enzyme_gradient(model, x, ps, st) dx = Enzyme.make_zero(x) dps = Enzyme.make_zero(ps) - Enzyme.autodiff(Reverse, generic_loss_function, Active, Const(model), - Duplicated(x, dx), Duplicated(ps, dps), Const(st)) + Enzyme.autodiff( + Enzyme.set_runtime_activity(Reverse), + generic_loss_function, Active, Const(model), + Duplicated(x, dx), Duplicated(ps, dps), Const(st) + ) return dx, dps end @@ -40,7 +43,8 @@ const MODELS_LIST = [ (Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), rand(Float32, 5, 5, 2, 2)), (Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)), (Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)), - (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), + # XXX: https://github.com/EnzymeAD/Enzyme.jl/issues/2105 + # (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), # XXX: https://github.com/LuxDL/Lux.jl/issues/1024 # (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), @@ -83,6 +87,8 @@ end ongpu && continue @testset "[$(i)] $(nameof(typeof(model)))" for (i, (model, x)) in enumerate(MODELS_LIST) + display(model) + ps, st = Lux.setup(rng, model) |> dev x = x |> aType @@ -107,6 +113,8 @@ end ongpu && continue @testset "[$(i)] $(nameof(typeof(model)))" for (i, (model, x)) in enumerate(MODELS_LIST) + display(model) + ps, st = Lux.setup(rng, model) ps = ComponentArray(ps) st = st |> dev diff --git a/test/layers/pooling_tests.jl b/test/layers/pooling_tests.jl index ed26419c2e..522dbe8d0e 100644 --- a/test/layers/pooling_tests.jl +++ b/test/layers/pooling_tests.jl @@ -39,13 +39,6 @@ @test_gradients(sumabs2first, layer, y, ps, st; atol=1.0f-3, rtol=1.0f-3, broken_backends) - broken_backends2 = broken_backends - if VERSION ≥ v"1.11-" - push!(broken_backends2, AutoEnzyme()) - elseif ltype == :LPPool - push!(broken_backends2, AutoEnzyme()) - end - layer = getfield(Lux, global_ltype)() display(layer) ps, st = Lux.setup(rng, layer) |> dev @@ -54,7 +47,7 @@ @test layer(x, ps, st)[1] == nnlib_op[ltype](x, PoolDims(x, size(x)[1:2])) @jet layer(x, ps, st) @test_gradients(sumabs2first, layer, x, ps, st; atol=1.0f-3, - rtol=1.0f-3, broken_backends=broken_backends2) + rtol=1.0f-3, broken_backends) layer = getfield(Lux, ltype)((2, 2)) display(layer)