From 8a16dd871d09a2ca88de381a2761b790060c747b Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 9 Oct 2024 17:35:30 +0200 Subject: [PATCH 1/2] Adaptive Enzyme batch size --- .../ext/DifferentiationInterfaceEnzymeExt/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 2ebfa52e8..82ec0e245 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -1,5 +1,5 @@ # until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged -DI.pick_batchsize(::AutoEnzyme, dimension::Integer) = Val(16) +DI.pick_batchsize(::AutoEnzyme, dimension::Integer) = Val(min(dimension, 16)) ## Annotations From 59444720319191a3dd4b1ec59a3e6ea3ee21fcec Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 9 Oct 2024 17:58:29 +0200 Subject: [PATCH 2/2] Use Enzyme.gradient whenever possible --- DifferentiationInterface/Project.toml | 2 +- .../reverse_onearg.jl | 43 ++++++------------- 2 files changed, 13 insertions(+), 32 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 33aad1405..ce7502c68 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.8" +version = "0.6.9" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 67152f97f..452822467 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -219,14 +219,8 @@ function DI.gradient( contexts::Vararg{Context,C}, ) where {F,C} f_and_df = get_f_and_df(f, backend) - grad = make_zero(x) - autodiff( - reverse_noprimal(backend), - f_and_df, - Active, - Duplicated(x, grad), - map(translate, contexts)..., - ) + ders = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...) + grad = first(ders) return grad end @@ -237,14 +231,10 @@ function DI.value_and_gradient( contexts::Vararg{Context,C}, ) where {F,C} f_and_df = get_f_and_df(f, backend) - grad = make_zero(x) - _, y = autodiff( - reverse_withprimal(backend), - f_and_df, - Active, - Duplicated(x, grad), - map(translate, contexts)..., + ders, y = gradient( + reverse_withprimal(backend), f_and_df, x, map(translate, contexts)... ) + grad = first(ders) return y, grad end @@ -272,13 +262,8 @@ function DI.gradient( contexts::Vararg{Context,C}, ) where {F,C} f_and_df = get_f_and_df(f, backend) - grad = make_zero(x) - autodiff( - reverse_noprimal(backend), - f_and_df, - Duplicated(x, grad), - map(translate, contexts)..., - ) + ders = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...) + grad = first(ders) return grad end @@ -300,7 +285,7 @@ function DI.gradient!( Duplicated(x, grad_righttype), map(translate, contexts)..., ) - grad isa typeof(x) || copyto!(grad, grad_righttype) + grad === grad_righttype || copyto!(grad, grad_righttype) return grad end @@ -312,14 +297,10 @@ function DI.value_and_gradient( contexts::Vararg{Context,C}, ) where {F,C} f_and_df = get_f_and_df(f, backend) - grad = make_zero(x) - _, y = autodiff( - reverse_withprimal(backend), - f_and_df, - Active, - Duplicated(x, grad), - map(translate, contexts)..., + ders, y = gradient( + reverse_withprimal(backend), f_and_df, x, map(translate, contexts)... ) + grad = first(ders) return y, grad end @@ -341,7 +322,7 @@ function DI.value_and_gradient!( Duplicated(x, grad_righttype), map(translate, contexts)..., ) - grad isa typeof(x) || copyto!(grad, grad_righttype) + grad === grad_righttype || copyto!(grad, grad_righttype) return y, grad end