-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add autodiff machinery and docs (#1593)
This adds a new header, microphysics_autodiff.H, which includes the autodiff library and sets things up for use with AMReX. It also provides the admath namespace, which is a drop-in replacement for std for functions from <cmath>, and works on both autodiff types and normal numeric types.
- Loading branch information
Showing
5 changed files
with
185 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
************************* | ||
Automatic Differentiation | ||
************************* | ||
|
||
Support for automatic differentiation is provided by the ``autodiff`` | ||
library :cite:p:`autodiff`, included under | ||
``Microphysics/util/autodiff``. We use the forward mode ``dual`` | ||
implementation, which produces the derivative of each computation along | ||
with its output value. This results in largely the same arithmetic | ||
operations as manually calculating the analytical derivative of each | ||
intermediate step, but with much less code and fewer typos. All the | ||
machinery needed for use in Microphysics is located in | ||
``Microphysics/util/microphysics_autodiff.H``. | ||
|
||
To take the derivative of some computation ``f(x)``, ``x`` | ||
must be an ``autodiff::dual``, and has to be seeded with | ||
``autodiff::seed()`` before the function is called: | ||
|
||
.. code-block:: c++ | ||
|
||
autodiff::dual x = 3.14_rt; | ||
autodiff::seed(x); | ||
autodiff::dual result = f(x); | ||
|
||
We can then use ``autodiff::val(result)`` or | ||
``static_cast<amrex::Real>(result)`` to extract the function value, | ||
and ``autodiff::derivative(result)`` to get the derivative with respect | ||
to x. Which has the advantage of working on both normal and dual | ||
numbers. | ||
|
||
Most functions can be updated to support autodiff by adding a template | ||
parameter for the numeric type (the current code calls it ``dual_t``). | ||
This should be used for any values that depend on the variables we're | ||
differentiating with respect to. Calls to functions from ``<cmath>`` as | ||
well as ``amrex::min`` and ``amrex::max`` can be replaced with ones in | ||
the ``admath`` namespace. This namespace also exports the original | ||
functions, so they work fine on normal numeric types too. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
#ifndef MICROPHYSICS_AUTODIFF_H | ||
#define MICROPHYSICS_AUTODIFF_H | ||
|
||
#include <AMReX.H> | ||
#include <AMReX_Algorithm.H> | ||
#include <AMReX_REAL.H> | ||
|
||
#include <approx_math.H> | ||
|
||
// required for AMREX_GPU_HOST_DEVICE, which is used via AUTODIFF_DEVICE_FUNC | ||
#include <AMReX_GpuQualifiers.H> | ||
#include <autodiff/forward/dual.hpp> | ||
#include <autodiff/forward/utils/derivative.hpp> | ||
|
||
// open the autodiff namespace so we can make our own changes | ||
namespace autodiff { | ||
namespace detail { | ||
|
||
// add a couple of missing math functions | ||
|
||
// natural logarithm of 1+x (std::log1p) | ||
using std::log1p; | ||
|
||
struct Log1pOp {}; | ||
|
||
template<typename R> | ||
using Log1pExpr = UnaryExpr<Log1pOp, R>; | ||
|
||
template<typename R, Requires<isExpr<R>> = true> | ||
AUTODIFF_DEVICE_FUNC constexpr auto log1p(R&& r) -> Log1pExpr<R> { return { std::forward<R>(r) }; } | ||
|
||
template<typename T, typename G> | ||
AUTODIFF_DEVICE_FUNC constexpr void apply(Dual<T, G>& self, Log1pOp) | ||
{ | ||
const T aux = One<T>() / (1.0 + self.val); | ||
self.val = log1p(self.val); | ||
self.grad *= aux; | ||
} | ||
|
||
|
||
// cube root (std::cbrt) | ||
using std::cbrt; | ||
|
||
struct CbrtOp {}; | ||
|
||
template<typename R> | ||
using CbrtExpr = UnaryExpr<CbrtOp, R>; | ||
|
||
template <typename R, Requires<isExpr<R>> = true> | ||
AUTODIFF_DEVICE_FUNC constexpr auto cbrt(R&& r) -> CbrtExpr<R> { return { std::forward<R>(r) }; } | ||
|
||
template<typename T, typename G> | ||
AUTODIFF_DEVICE_FUNC constexpr void apply(Dual<T, G>& self, CbrtOp) | ||
{ | ||
self.val = cbrt(self.val); | ||
self.grad *= 1.0 / (3.0 * self.val * self.val); | ||
} | ||
|
||
// custom functions from approx_math.H | ||
|
||
// fast_atan | ||
struct FastAtanOp {}; | ||
|
||
template<typename R> | ||
using FastAtanExpr = UnaryExpr<FastAtanOp, R>; | ||
|
||
template<typename R, Requires<isExpr<R>> = true> | ||
AUTODIFF_DEVICE_FUNC constexpr auto fast_atan(R&& r) -> FastAtanExpr<R> { return { std::forward<R>(r) }; } | ||
|
||
template<typename T, typename G> | ||
AUTODIFF_DEVICE_FUNC constexpr void apply(Dual<T, G>& self, FastAtanOp) | ||
{ | ||
const T aux = One<T>() / (1.0 + self.val * self.val); | ||
self.val = ::fast_atan(self.val); | ||
self.grad *= aux; | ||
} | ||
|
||
// fast_exp | ||
struct FastExpOp {}; | ||
|
||
template <typename R> | ||
using FastExpExpr = UnaryExpr<FastExpOp, R>; | ||
|
||
template<typename R, Requires<isExpr<R>> = true> | ||
AUTODIFF_DEVICE_FUNC constexpr auto fast_exp(R&& r) -> FastExpExpr<R> { return { std::forward<R>(r) }; } | ||
|
||
template<typename T, typename G> | ||
AUTODIFF_DEVICE_FUNC constexpr void apply(Dual<T, G>& self, FastExpOp) | ||
{ | ||
self.val = ::fast_exp(self.val); | ||
self.grad *= self.val; | ||
} | ||
|
||
} // namespace detail | ||
|
||
// Redefine dual to use amrex::Real instead of double | ||
using dual = HigherOrderDual<1, amrex::Real>; | ||
|
||
// A new namespace that has both the STL math functions and the overloads for | ||
// dual numbers, so we can write the same function name whether we're operating | ||
// on autodiff::dual or amrex::Real. | ||
namespace math_functions { | ||
|
||
using std::abs, autodiff::detail::abs; | ||
using std::acos, autodiff::detail::acos; | ||
using std::asin, autodiff::detail::asin; | ||
using std::atan, autodiff::detail::atan; | ||
using std::atan2, autodiff::detail::atan2; | ||
using std::cos, autodiff::detail::cos; | ||
using std::exp, autodiff::detail::exp; | ||
using std::log10, autodiff::detail::log10; | ||
using std::log, autodiff::detail::log; | ||
using std::pow, autodiff::detail::pow; | ||
using std::sin, autodiff::detail::sin; | ||
using std::sqrt, autodiff::detail::sqrt; | ||
using std::tan, autodiff::detail::tan; | ||
using std::cosh, autodiff::detail::cosh; | ||
using std::sinh, autodiff::detail::sinh; | ||
using std::tanh, autodiff::detail::tanh; | ||
using std::erf, autodiff::detail::erf; | ||
using std::hypot, autodiff::detail::hypot; | ||
|
||
using std::log1p, autodiff::detail::log1p; | ||
using std::cbrt, autodiff::detail::cbrt; | ||
|
||
using amrex::min, autodiff::detail::min; | ||
using amrex::max, autodiff::detail::max; | ||
|
||
using ::fast_atan, autodiff::detail::fast_atan; | ||
using ::fast_exp, autodiff::detail::fast_exp; | ||
|
||
} // namespace math_functions | ||
|
||
} // namespace autodiff | ||
|
||
namespace admath = autodiff::math_functions; | ||
|
||
#endif |