diff --git a/ext/LuxEnzymeExt/LuxEnzymeExt.jl b/ext/LuxEnzymeExt/LuxEnzymeExt.jl index 01ffe068c..71969be97 100644 --- a/ext/LuxEnzymeExt/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt/LuxEnzymeExt.jl @@ -24,6 +24,16 @@ annotate_function(::AutoEnzyme{<:Any, A}, f::F) where {F, A} = A(f) include("training.jl") +include("autodiff.jl") include("batched_autodiff.jl") +@concrete struct OOPFunctionWrapper + f +end + +function (f::OOPFunctionWrapper)(y, args...) + copyto!(y, f.f(args...)) + return +end + end diff --git a/ext/LuxEnzymeExt/autodiff.jl b/ext/LuxEnzymeExt/autodiff.jl new file mode 100644 index 000000000..c3ff2413c --- /dev/null +++ b/ext/LuxEnzymeExt/autodiff.jl @@ -0,0 +1,37 @@ +function Lux.AutoDiffInternalImpl.jacobian_vector_product_impl( + f::F, ad::AutoEnzyme, x, u, p) where {F} + ad = normalize_backend(True(), ad) + @assert ADTypes.mode(ad) isa ForwardMode "JVPs are only supported in forward mode." + return only( + Enzyme.autodiff(ad.mode, annotate_function(ad, f), Duplicated(x, u), Const(p)) + ) +end + +function Lux.AutoDiffInternalImpl.jacobian_vector_product_impl( + f::F, ad::AutoEnzyme, x, u) where {F} + ad = normalize_backend(True(), ad) + @assert ADTypes.mode(ad) isa ForwardMode "JVPs are only supported in forward mode." + return only(Enzyme.autodiff(ad.mode, annotate_function(ad, f), Duplicated(x, u))) +end + +function Lux.AutoDiffInternalImpl.vector_jacobian_product_impl( + f::F, ad::AutoEnzyme, x, v, p) where {F} + ad = normalize_backend(False(), ad) + @assert ADTypes.mode(ad) isa ReverseMode "VJPs are only supported in reverse mode." + dx = zero(x) + # XXX: without the copy it overwrites the `v` with zeros + Enzyme.autodiff(ad.mode, annotate_function(ad, OOPFunctionWrapper(f)), + Duplicated(similar(v), copy(v)), Duplicated(x, dx), Const(p)) + return dx +end + +function Lux.AutoDiffInternalImpl.vector_jacobian_product_impl( + f::F, ad::AutoEnzyme, x, v) where {F} + ad = normalize_backend(False(), ad) + @assert ADTypes.mode(ad) isa ReverseMode "VJPs are only supported in reverse mode." + dx = zero(x) + # XXX: without the copy it overwrites the `v` with zeros + Enzyme.autodiff(ad.mode, annotate_function(ad, OOPFunctionWrapper(f)), + Duplicated(similar(v), copy(v)), Duplicated(x, dx)) + return dx +end diff --git a/ext/LuxEnzymeExt/batched_autodiff.jl b/ext/LuxEnzymeExt/batched_autodiff.jl index b116b396b..32cf8c2c3 100644 --- a/ext/LuxEnzymeExt/batched_autodiff.jl +++ b/ext/LuxEnzymeExt/batched_autodiff.jl @@ -85,12 +85,3 @@ function make_zero!(partials, idxs) end return partials[1:length(idxs)] end - -@concrete struct OOPFunctionWrapper - f -end - -function (f::OOPFunctionWrapper)(y, x) - copyto!(y, f.f(x)) - return -end diff --git a/src/autodiff/api.jl b/src/autodiff/api.jl index 95625f853..3bc1a907f 100644 --- a/src/autodiff/api.jl +++ b/src/autodiff/api.jl @@ -7,9 +7,10 @@ products efficiently using mixed-mode AD. ## Backends & AD Packages -| Supported Backends | Packages Needed | -| :----------------- | :-------------- | -| `AutoZygote` | `Zygote.jl` | +| Supported Backends | Packages Needed | Notes | +| :----------------- | :-------------- | :--------------------------------------------- | +| `AutoZygote` | `Zygote.jl` | | +| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD | !!! warning @@ -32,9 +33,12 @@ function vector_jacobian_product(::F, backend::AbstractADType, _, __) where {F} throw(ArgumentError("`vector_jacobian_product` is not implemented for `$(backend)`.")) end -function vector_jacobian_product(f::F, backend::AutoZygote, x, u) where {F} - assert_backend_loaded(:vector_jacobian_product, backend) - return AutoDiffInternalImpl.vector_jacobian_product(f, backend, x, u) +for implemented_backend in (:AutoZygote, :AutoEnzyme) + @eval function vector_jacobian_product( + f::F, backend::$implemented_backend, x, u) where {F} + assert_backend_loaded(:vector_jacobian_product, backend) + return AutoDiffInternalImpl.vector_jacobian_product(f, backend, x, u) + end end @doc doc""" @@ -46,9 +50,10 @@ products efficiently using mixed-mode AD. ## Backends & AD Packages -| Supported Backends | Packages Needed | -| :----------------- | :--------------- | -| `AutoForwardDiff` | | +| Supported Backends | Packages Needed | Notes | +| :----------------- | :-------------- | :--------------------------------------------- | +| `AutoForwardDiff` | | | +| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD | !!! warning @@ -71,8 +76,11 @@ function jacobian_vector_product(::F, backend::AbstractADType, _, __) where {F} throw(ArgumentError("`jacobian_vector_product` is not implemented for `$(backend)`.")) end -function jacobian_vector_product(f::F, backend::AutoForwardDiff, x, u) where {F} - return AutoDiffInternalImpl.jacobian_vector_product(f, backend, x, u) +for implemented_backend in (:AutoEnzyme, :AutoForwardDiff) + @eval function jacobian_vector_product( + f::F, backend::$(implemented_backend), x, u) where {F} + return AutoDiffInternalImpl.jacobian_vector_product(f, backend, x, u) + end end """ diff --git a/src/autodiff/jac_products.jl b/src/autodiff/jac_products.jl index 97785f144..8c5d0544b 100644 --- a/src/autodiff/jac_products.jl +++ b/src/autodiff/jac_products.jl @@ -3,6 +3,17 @@ function vector_jacobian_product(f::F, backend::AbstractADType, x, u) where {F} return vector_jacobian_product_impl(f, backend, x, u) end +for fType in AD_CONVERTIBLE_FUNCTIONS + @eval function vector_jacobian_product(f::$(fType), backend::AbstractADType, x, u) + f̂, y = rewrite_autodiff_call(f) + return vector_jacobian_product_impl(f̂, backend, x, u, y) + end +end + +function vector_jacobian_product_impl(f::F, backend::AbstractADType, x, u, y) where {F} + return vector_jacobian_product_impl(Base.Fix2(f, y), backend, x, u) +end + # JVP Implementation function jacobian_vector_product(f::F, backend::AbstractADType, x, u) where {F} return jacobian_vector_product_impl(f, backend, x, u)