diff --git a/.github/workflows/AD.yml b/.github/workflows/AD.yml index 541ccdd9..070ede1a 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AD.yml @@ -25,6 +25,7 @@ jobs: - Tracker - ReverseDiff - Zygote + - Enzyme steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 diff --git a/src/interface.jl b/src/interface.jl index 5487d01a..8ed1dabd 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -27,6 +27,7 @@ struct ForwardDiffAD <: ADBackend end struct ReverseDiffAD <: ADBackend end struct TrackerAD <: ADBackend end struct ZygoteAD <: ADBackend end +struct EnzymeAD <: ADBackend end const ADBACKEND = Ref(:forwarddiff) setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym)) @@ -34,6 +35,7 @@ setadbackend(::Val{:forwarddiff}) = ADBACKEND[] = :forwarddiff setadbackend(::Val{:reversediff}) = ADBACKEND[] = :reversediff setadbackend(::Val{:tracker}) = ADBACKEND[] = :tracker setadbackend(::Val{:zygote}) = ADBACKEND[] = :zygote +setadbackend(::Val{:enzyme}) = ADBACKEND[] = :enzyme ADBackend() = ADBackend(ADBACKEND[]) ADBackend(T::Symbol) = ADBackend(Val(T)) @@ -41,6 +43,7 @@ ADBackend(::Val{:forwarddiff}) = ForwardDiffAD ADBackend(::Val{:reversediff}) = ReverseDiffAD ADBackend(::Val{:tracker}) = TrackerAD ADBackend(::Val{:zygote}) = ZygoteAD +setadbackend(::Val{:enzyme}) = EnzymeAD ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.") ###################### diff --git a/test/Project.toml b/test/Project.toml index 9840609d..fe081eed 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -20,6 +21,7 @@ ChainRulesTestUtils = "0.7, 1" ChangesOfVariables = "0.1" Combinatorics = "1.0.2" DistributionsAD = "0.6.3" +Enzyme = "0.10.14" FiniteDifferences = "0.11, 0.12" ForwardDiff = "0.10.12" Functors = "0.1, 0.2, 0.3" diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 6bf8365f..0a9b7734 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -39,5 +39,20 @@ function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) end end - return + if AD == "All" || AD == "Enzyme" + # `broken` keyword to `@test` requires Julia >= 1.7 + if :EnzymeReverse in broken + @test collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol=rtol atol=atol + @test_broken Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol=rtol atol=atol + elseif :EnzymeForward in broken + @test_broken collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol=rtol atol=atol + @test Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol=rtol atol=atol + elseif :Enzyme in broken + @test_broken collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol=rtol atol=atol + @test_broken Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol=rtol atol=atol + else + @test collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol=rtol atol=atol + @test Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol=rtol atol=atol + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 7fcd6dfc..d99f77b2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using Bijectors using ChainRulesTestUtils using Combinatorics using DistributionsAD +using Enzyme using FiniteDifferences using ForwardDiff using Functors