Skip to content

Commit

Permalink
Merge pull request #44 from oxinabox/patch-2
Browse files Browse the repository at this point in the history
Improve ChainRules references in the readme.
  • Loading branch information
mohamed82008 authored Apr 21, 2021
2 parents e241526 + 6e09bef commit 016f701
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,27 @@ r.minimizer

## Custom gradient / adjoint

To specify a custom gradient or adjoint rule for the function `f` above, the following can be used:
A custom gradient rule for a function should be defined using ChainRulesCore's `rrule`.
For example the following can be used for the function `f` defined above.

```julia
using ChainRulesCore

function ChainRulesCore.rrule(::typeof(f), x::AbstractVector)
val = f(x)
grad = [0.0, 1 / (2 * sqrt(x[2]))]
val, Δ -> (nothing, Δ * grad)
val, Δ -> (NO_FIELDS, Δ * grad)
end
```

You can check it is correct in your tests using [ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl/).
```julia
using ChainRulesTestUtils
test_rrule(f, [1.2, 3.6])
```

For full details on `rrules` etc see the [ChainRules documentation](https://juliadiff.org/ChainRulesCore.jl/stable/).

## Hack to use other automatic differentiation backends

For specific functions, if you want to use `ForwardDiff` instead of `Zygote`, one way to do this is to define an `rrule` using `ForwardDiff` to compute the gradient or jacobian, e.g:
Expand All @@ -126,6 +135,6 @@ using ChainRulesCore, ForwardDiff
function ChainRulesCore.rrule(::typeof(f), x::AbstractVector)
val = f(x)
grad = ForwardDiff.gradient(f, x)
val, Δ -> (nothing, Δ * grad)
val, Δ -> (NO_FIELDS, Δ * grad)
end
```

0 comments on commit 016f701

Please sign in to comment.