Skip to content

Commit

Permalink
Use Enzyme.gradient whenever possible
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Oct 9, 2024
1 parent 8a16dd8 commit 5944472
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 32 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.8"
version = "0.6.9"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,8 @@ function DI.gradient(
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
grad = make_zero(x)
autodiff(
reverse_noprimal(backend),
f_and_df,
Active,
Duplicated(x, grad),
map(translate, contexts)...,
)
ders = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...)
grad = first(ders)
return grad
end

Expand All @@ -237,14 +231,10 @@ function DI.value_and_gradient(
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
grad = make_zero(x)
_, y = autodiff(
reverse_withprimal(backend),
f_and_df,
Active,
Duplicated(x, grad),
map(translate, contexts)...,
ders, y = gradient(
reverse_withprimal(backend), f_and_df, x, map(translate, contexts)...
)
grad = first(ders)
return y, grad
end

Expand Down Expand Up @@ -272,13 +262,8 @@ function DI.gradient(
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
grad = make_zero(x)
autodiff(
reverse_noprimal(backend),
f_and_df,
Duplicated(x, grad),
map(translate, contexts)...,
)
ders = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...)
grad = first(ders)
return grad
end

Expand All @@ -300,7 +285,7 @@ function DI.gradient!(
Duplicated(x, grad_righttype),
map(translate, contexts)...,
)
grad isa typeof(x) || copyto!(grad, grad_righttype)
grad === grad_righttype || copyto!(grad, grad_righttype)
return grad
end

Expand All @@ -312,14 +297,10 @@ function DI.value_and_gradient(
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
grad = make_zero(x)
_, y = autodiff(
reverse_withprimal(backend),
f_and_df,
Active,
Duplicated(x, grad),
map(translate, contexts)...,
ders, y = gradient(
reverse_withprimal(backend), f_and_df, x, map(translate, contexts)...
)
grad = first(ders)
return y, grad
end

Expand All @@ -341,7 +322,7 @@ function DI.value_and_gradient!(
Duplicated(x, grad_righttype),
map(translate, contexts)...,
)
grad isa typeof(x) || copyto!(grad, grad_righttype)
grad === grad_righttype || copyto!(grad, grad_righttype)
return y, grad
end

Expand Down

0 comments on commit 5944472

Please sign in to comment.