Skip to content

Commit

Permalink
Improve Enzyme batch size and gradient (#557)
Browse files Browse the repository at this point in the history
* Adaptive Enzyme batch size

* Use Enzyme.gradient whenever possible
  • Loading branch information
gdalle authored Oct 9, 2024
1 parent 3470e14 commit f663225
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 33 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.8"
version = "0.6.9"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,8 @@ function DI.gradient(
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
grad = make_zero(x)
autodiff(
reverse_noprimal(backend),
f_and_df,
Active,
Duplicated(x, grad),
map(translate, contexts)...,
)
ders = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...)
grad = first(ders)
return grad
end

Expand All @@ -237,14 +231,10 @@ function DI.value_and_gradient(
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
grad = make_zero(x)
_, y = autodiff(
reverse_withprimal(backend),
f_and_df,
Active,
Duplicated(x, grad),
map(translate, contexts)...,
ders, y = gradient(
reverse_withprimal(backend), f_and_df, x, map(translate, contexts)...
)
grad = first(ders)
return y, grad
end

Expand Down Expand Up @@ -272,13 +262,8 @@ function DI.gradient(
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
grad = make_zero(x)
autodiff(
reverse_noprimal(backend),
f_and_df,
Duplicated(x, grad),
map(translate, contexts)...,
)
ders = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...)
grad = first(ders)
return grad
end

Expand All @@ -300,7 +285,7 @@ function DI.gradient!(
Duplicated(x, grad_righttype),
map(translate, contexts)...,
)
grad isa typeof(x) || copyto!(grad, grad_righttype)
grad === grad_righttype || copyto!(grad, grad_righttype)
return grad
end

Expand All @@ -312,14 +297,10 @@ function DI.value_and_gradient(
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
grad = make_zero(x)
_, y = autodiff(
reverse_withprimal(backend),
f_and_df,
Active,
Duplicated(x, grad),
map(translate, contexts)...,
ders, y = gradient(
reverse_withprimal(backend), f_and_df, x, map(translate, contexts)...
)
grad = first(ders)
return y, grad
end

Expand All @@ -341,7 +322,7 @@ function DI.value_and_gradient!(
Duplicated(x, grad_righttype),
map(translate, contexts)...,
)
grad isa typeof(x) || copyto!(grad, grad_righttype)
grad === grad_righttype || copyto!(grad, grad_righttype)
return y, grad
end

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged
DI.pick_batchsize(::AutoEnzyme, dimension::Integer) = Val(16)
DI.pick_batchsize(::AutoEnzyme, dimension::Integer) = Val(min(dimension, 16))

## Annotations

Expand Down

0 comments on commit f663225

Please sign in to comment.