Skip to content

Commit

Permalink
feat: add vjp and jvp for Enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 25, 2024
1 parent 8f098a4 commit ff8e926
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 20 deletions.
10 changes: 10 additions & 0 deletions ext/LuxEnzymeExt/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ annotate_function(::AutoEnzyme{<:Any, A}, f::F) where {F, A} = A(f)

include("training.jl")

include("autodiff.jl")
include("batched_autodiff.jl")

@concrete struct OOPFunctionWrapper
f
end

function (f::OOPFunctionWrapper)(y, args...)
copyto!(y, f.f(args...))
return
end

end
37 changes: 37 additions & 0 deletions ext/LuxEnzymeExt/autodiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
function Lux.AutoDiffInternalImpl.jacobian_vector_product_impl(
f::F, ad::AutoEnzyme, x, u, p) where {F}
ad = normalize_backend(True(), ad)
@assert ADTypes.mode(ad) isa ForwardMode "JVPs are only supported in forward mode."
return only(
Enzyme.autodiff(ad.mode, annotate_function(ad, f), Duplicated(x, u), Const(p))
)
end

function Lux.AutoDiffInternalImpl.jacobian_vector_product_impl(
f::F, ad::AutoEnzyme, x, u) where {F}
ad = normalize_backend(True(), ad)
@assert ADTypes.mode(ad) isa ForwardMode "JVPs are only supported in forward mode."
return only(Enzyme.autodiff(ad.mode, annotate_function(ad, f), Duplicated(x, u)))
end

function Lux.AutoDiffInternalImpl.vector_jacobian_product_impl(
f::F, ad::AutoEnzyme, x, v, p) where {F}
ad = normalize_backend(False(), ad)
@assert ADTypes.mode(ad) isa ReverseMode "VJPs are only supported in reverse mode."
dx = zero(x)
# XXX: without the copy it overwrites the `v` with zeros
Enzyme.autodiff(ad.mode, annotate_function(ad, OOPFunctionWrapper(f)),
Duplicated(similar(v), copy(v)), Duplicated(x, dx), Const(p))
return dx
end

function Lux.AutoDiffInternalImpl.vector_jacobian_product_impl(
f::F, ad::AutoEnzyme, x, v) where {F}
ad = normalize_backend(False(), ad)
@assert ADTypes.mode(ad) isa ReverseMode "VJPs are only supported in reverse mode."
dx = zero(x)
# XXX: without the copy it overwrites the `v` with zeros
Enzyme.autodiff(ad.mode, annotate_function(ad, OOPFunctionWrapper(f)),
Duplicated(similar(v), copy(v)), Duplicated(x, dx))
return dx
end
9 changes: 0 additions & 9 deletions ext/LuxEnzymeExt/batched_autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,3 @@ function make_zero!(partials, idxs)
end
return partials[1:length(idxs)]
end

@concrete struct OOPFunctionWrapper
f
end

function (f::OOPFunctionWrapper)(y, x)
copyto!(y, f.f(x))
return
end
30 changes: 19 additions & 11 deletions src/autodiff/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ products efficiently using mixed-mode AD.
## Backends & AD Packages
| Supported Backends | Packages Needed |
| :----------------- | :-------------- |
| `AutoZygote` | `Zygote.jl` |
| Supported Backends | Packages Needed | Notes |
| :----------------- | :-------------- | :--------------------------------------------- |
| `AutoZygote` | `Zygote.jl` | |
| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD |
!!! warning
Expand All @@ -32,9 +33,12 @@ function vector_jacobian_product(::F, backend::AbstractADType, _, __) where {F}
throw(ArgumentError("`vector_jacobian_product` is not implemented for `$(backend)`."))
end

function vector_jacobian_product(f::F, backend::AutoZygote, x, u) where {F}
assert_backend_loaded(:vector_jacobian_product, backend)
return AutoDiffInternalImpl.vector_jacobian_product(f, backend, x, u)
for implemented_backend in (:AutoZygote, :AutoEnzyme)
@eval function vector_jacobian_product(
f::F, backend::$implemented_backend, x, u) where {F}
assert_backend_loaded(:vector_jacobian_product, backend)
return AutoDiffInternalImpl.vector_jacobian_product(f, backend, x, u)
end
end

@doc doc"""
Expand All @@ -46,9 +50,10 @@ products efficiently using mixed-mode AD.
## Backends & AD Packages
| Supported Backends | Packages Needed |
| :----------------- | :--------------- |
| `AutoForwardDiff` | |
| Supported Backends | Packages Needed | Notes |
| :----------------- | :-------------- | :--------------------------------------------- |
| `AutoForwardDiff` | | |
| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD |
!!! warning
Expand All @@ -71,8 +76,11 @@ function jacobian_vector_product(::F, backend::AbstractADType, _, __) where {F}
throw(ArgumentError("`jacobian_vector_product` is not implemented for `$(backend)`."))
end

function jacobian_vector_product(f::F, backend::AutoForwardDiff, x, u) where {F}
return AutoDiffInternalImpl.jacobian_vector_product(f, backend, x, u)
for implemented_backend in (:AutoEnzyme, :AutoForwardDiff)
@eval function jacobian_vector_product(
f::F, backend::$(implemented_backend), x, u) where {F}
return AutoDiffInternalImpl.jacobian_vector_product(f, backend, x, u)
end
end

"""
Expand Down
11 changes: 11 additions & 0 deletions src/autodiff/jac_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@ function vector_jacobian_product(f::F, backend::AbstractADType, x, u) where {F}
return vector_jacobian_product_impl(f, backend, x, u)
end

for fType in AD_CONVERTIBLE_FUNCTIONS
@eval function vector_jacobian_product(f::$(fType), backend::AbstractADType, x, u)
f̂, y = rewrite_autodiff_call(f)
return vector_jacobian_product_impl(f̂, backend, x, u, y)
end
end

function vector_jacobian_product_impl(f::F, backend::AbstractADType, x, u, y) where {F}
return vector_jacobian_product_impl(Base.Fix2(f, y), backend, x, u)
end

# JVP Implementation
function jacobian_vector_product(f::F, backend::AbstractADType, x, u) where {F}
return jacobian_vector_product_impl(f, backend, x, u)
Expand Down

0 comments on commit ff8e926

Please sign in to comment.