Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update autodiff::dual to give the same results as amrex::Real #1612

Merged
merged 4 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions unit_test/burn_cell/ci-benchmarks/chamulak_VODE_unit_test.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Initializing AMReX (24.06-22-g731014ff3eed)...
AMReX (24.06-22-g731014ff3eed) initialized
Initializing AMReX (24.07-16-gdcb9cc0383dc)...
AMReX (24.07-16-gdcb9cc0383dc) initialized
starting the single zone burn...
Maximum Time (s): 0.01585
State Density (g/cm^3): 1000000000
Expand All @@ -13,21 +13,21 @@ RHS at t = 0
ash 0.01230280576
------------------------------------
successful? 1
- Hnuc = 5.277400893e+17
- added e = 8.364680415e+15
- final T = 1433712612
- Hnuc = 5.277406331e+17
- added e = 8.364689034e+15
- final T = 1433713030
------------------------------------
e initial = 1.253426044e+18
e final = 1.261790725e+18
e final = 1.261790733e+18
------------------------------------
new mass fractions:
C12 0.9657895158
C12 0.9657894806
O16 1e-30
ash 0.03421048417
ash 0.03421051942
------------------------------------
species creation rates:
omegadot(C12): -2.158390168
omegadot(O16): 9.946124747e-44
omegadot(ash): 2.158390168
omegadot(C12): -2.158392392
omegadot(O16): 8.840999775e-44
omegadot(ash): 2.158392392
number of steps taken: 381
AMReX (24.06-22-g731014ff3eed) finalized
AMReX (24.07-16-gdcb9cc0383dc) finalized
68 changes: 58 additions & 10 deletions util/autodiff/autodiff/forward/dual/dual.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ AUTODIFF_DEVICE_FUNC constexpr auto negative(U&& expr)
template<typename U>
AUTODIFF_DEVICE_FUNC constexpr auto inverse(U&& expr)
{
static_assert(isExpr<U>);
static_assert(isExpr<U> || isArithmetic<U>);
if constexpr (isInvExpr<U>)
return inner(expr);
else return InvExpr<PreventExprRef<U>>{ expr };
Expand Down Expand Up @@ -780,9 +780,11 @@ AUTODIFF_DEVICE_FUNC constexpr auto operator+(L&& l, R&& r)
// ADDITION EXPRESSION CASE: (-x) + (-y) => -(x + y)
if constexpr (isNegExpr<L> && isNegExpr<R>)
return -( inner(l) + inner(r) );
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ADDITION EXPRESSION CASE: expr + number => number + expr (number always on the left)
else if constexpr (isExpr<L> && isArithmetic<R>)
return std::forward<R>(r) + std::forward<L>(l);
#endif
// DEFAULT ADDITION EXPRESSION
else return AddExpr<PreventExprRef<L>, PreventExprRef<R>>{ l, r };
}
Expand All @@ -799,18 +801,22 @@ AUTODIFF_DEVICE_FUNC constexpr auto operator*(L&& l, R&& r)
// MULTIPLICATION EXPRESSION CASE: (-expr) * (-expr) => expr * expr
if constexpr (isNegExpr<L> && isNegExpr<R>)
return inner(l) * inner(r);
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// // MULTIPLICATION EXPRESSION CASE: (1 / expr) * (1 / expr) => 1 / (expr * expr)
else if constexpr (isInvExpr<L> && isInvExpr<R>)
return inverse(inner(l) * inner(r));
// // MULTIPLICATION EXPRESSION CASE: expr * number => number * expr
else if constexpr (isExpr<L> && isArithmetic<R>)
return std::forward<R>(r) * std::forward<L>(l);
#endif
// // MULTIPLICATION EXPRESSION CASE: number * (-expr) => (-number) * expr
else if constexpr (isArithmetic<L> && isNegExpr<R>)
return (-l) * inner(r);
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// // MULTIPLICATION EXPRESSION CASE: number * (number * expr) => (number * number) * expr
else if constexpr (isArithmetic<L> && isNumberDualMulExpr<R>)
return (l * left(r)) * right(r);
#endif
// MULTIPLICATION EXPRESSION CASE: number * dual => NumberDualMulExpr
else if constexpr (isArithmetic<L> && isDual<R>)
return NumberDualMulExpr<PreventExprRef<L>, PreventExprRef<R>>{ l, r };
Expand Down Expand Up @@ -845,9 +851,7 @@ AUTODIFF_DEVICE_FUNC constexpr auto operator-(L&& l, R&& r)
template<typename L, typename R, Requires<isOperable<L, R>> = true>
AUTODIFF_DEVICE_FUNC constexpr auto operator/(L&& l, R&& r)
{
if constexpr (isArithmetic<R>)
return std::forward<L>(l) * (One<L>() / std::forward<R>(r));
else return std::forward<L>(l) * inverse(std::forward<R>(r));
return std::forward<L>(l) * inverse(std::forward<R>(r));
}

//=====================================================================================================================
Expand Down Expand Up @@ -1000,13 +1004,27 @@ AUTODIFF_DEVICE_FUNC constexpr void assign(Dual<T, G>& self, U&& other)
}
// ASSIGN AN ADDITION EXPRESSION: self = expr + expr
else if constexpr (isAddExpr<U>) {
#if defined(AUTODIFF_STRICT_ASSOCIATIVITY)
assign(self, other.l);
assignAdd(self, other.r);
#else
// this reordering saves a few FLOPs when other.l is arithmetic,
// but otherwise breaks the left-to-right associativity of a+b+c+...
assign(self, other.r);
assignAdd(self, other.l);
#endif
}
// ASSIGN A MULTIPLICATION EXPRESSION: self = expr * expr
else if constexpr (isMulExpr<U>) {
#if defined(AUTODIFF_STRICT_ASSOCIATIVITY)
assign(self, other.l);
assignMul(self, other.r);
#else
// this reordering saves a few FLOPs when other.l is arithmetic,
// but otherwise breaks the left-to-right associativity of a*b*c*...
assign(self, other.r);
assignMul(self, other.l);
#endif
}
// ASSIGN A POWER EXPRESSION: self = pow(expr)
else if constexpr (isPowExpr<U>) {
Expand Down Expand Up @@ -1042,13 +1060,23 @@ AUTODIFF_DEVICE_FUNC constexpr void assign(Dual<T, G>& self, U&& other, Dual<T,
}
// ASSIGN AN ADDITION EXPRESSION: self = expr + expr
else if constexpr (isAddExpr<U>) {
#if defined(AUTODIFF_STRICT_ASSOCIATIVITY)
assign(self, other.l, tmp);
assignAdd(self, other.r, tmp);
#else
assign(self, other.r, tmp);
assignAdd(self, other.l, tmp);
#endif
}
// ASSIGN A MULTIPLICATION EXPRESSION: self = expr * expr
else if constexpr (isMulExpr<U>) {
#if defined(AUTODIFF_STRICT_ASSOCIATIVITY)
assign(self, other.l, tmp);
assignMul(self, other.r, tmp);
#else
assign(self, other.r, tmp);
assignMul(self, other.l, tmp);
#endif
}
// ASSIGN A POWER EXPRESSION: self = pow(expr, expr)
else if constexpr (isPowExpr<U>) {
Expand Down Expand Up @@ -1091,11 +1119,13 @@ AUTODIFF_DEVICE_FUNC constexpr void assignAdd(Dual<T, G>& self, U&& other)
self.val += other.l * other.r.val;
self.grad += other.l * other.r.grad;
}
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ASSIGN-ADD AN ADDITION EXPRESSION: self += expr + expr
else if constexpr (isAddExpr<U>) {
assignAdd(self, other.l);
assignAdd(self, other.r);
}
#endif
// ASSIGN-ADD ALL OTHER EXPRESSIONS
else {
Dual<T, G> tmp;
Expand All @@ -1112,11 +1142,13 @@ AUTODIFF_DEVICE_FUNC constexpr void assignAdd(Dual<T, G>& self, U&& other, Dual<
if constexpr (isNegExpr<U>) {
assignSub(self, other.r, tmp);
}
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ASSIGN-ADD AN ADDITION EXPRESSION: self += expr + expr
else if constexpr (isAddExpr<U>) {
assignAdd(self, other.l, tmp);
assignAdd(self, other.r, tmp);
}
#endif
// ASSIGN-ADD ALL OTHER EXPRESSIONS
else {
assign(tmp, other);
Expand Down Expand Up @@ -1153,11 +1185,13 @@ AUTODIFF_DEVICE_FUNC constexpr void assignSub(Dual<T, G>& self, U&& other)
self.val -= other.l * other.r.val;
self.grad -= other.l * other.r.grad;
}
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ASSIGN-SUBTRACT AN ADDITION EXPRESSION: self -= expr + expr
else if constexpr (isAddExpr<U>) {
assignSub(self, other.l);
assignSub(self, other.r);
}
#endif
// ASSIGN-SUBTRACT ALL OTHER EXPRESSIONS
else {
Dual<T, G> tmp;
Expand All @@ -1174,11 +1208,13 @@ AUTODIFF_DEVICE_FUNC constexpr void assignSub(Dual<T, G>& self, U&& other, Dual<
if constexpr (isNegExpr<U>) {
assignAdd(self, other.r, tmp);
}
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ASSIGN-SUBTRACT AN ADDITION EXPRESSION: self -= expr + expr
else if constexpr (isAddExpr<U>) {
assignSub(self, other.l, tmp);
assignSub(self, other.r, tmp);
}
#endif
// ASSIGN-SUBTRACT ALL OTHER EXPRESSIONS
else {
assign(tmp, other);
Expand Down Expand Up @@ -1214,6 +1250,11 @@ AUTODIFF_DEVICE_FUNC constexpr void assignMul(Dual<T, G>& self, U&& other)
assignMul(self, other.r);
negate(self);
}
// ASSIGN-MULTIPLY AN INVERSE EXPRESSION: self *= 1/expr
else if constexpr (isInvExpr<U>) {
assignDiv(self, other.r);
}
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ASSIGN-MULTIPLY A NUMBER-DUAL MULTIPLICATION EXPRESSION: self *= number * dual
else if constexpr (isNumberDualMulExpr<U>) {
assignMul(self, other.r);
Expand All @@ -1224,6 +1265,7 @@ AUTODIFF_DEVICE_FUNC constexpr void assignMul(Dual<T, G>& self, U&& other)
assignMul(self, other.l);
assignMul(self, other.r);
}
#endif
// ASSIGN-MULTIPLY ALL OTHER EXPRESSIONS
else {
Dual<T, G> tmp;
Expand All @@ -1241,11 +1283,13 @@ AUTODIFF_DEVICE_FUNC constexpr void assignMul(Dual<T, G>& self, U&& other, Dual<
assignMul(self, other.r, tmp);
negate(self);
}
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ASSIGN-MULTIPLY A MULTIPLICATION EXPRESSION: self *= expr * expr
else if constexpr (isMulExpr<U>) {
assignMul(self, other.l, tmp);
assignMul(self, other.r, tmp);
}
#endif
// ASSIGN-MULTIPLY ALL OTHER EXPRESSIONS
else {
assign(tmp, other);
Expand All @@ -1271,10 +1315,10 @@ AUTODIFF_DEVICE_FUNC constexpr void assignDiv(Dual<T, G>& self, U&& other)
}
// ASSIGN-DIVIDE A DUAL NUMBER: self /= dual
else if constexpr (isDual<U>) {
const T aux = One<T>() / other.val; // to avoid aliasing when self === other
self.val *= aux;
const T aux = other.val; // to avoid aliasing when self === other
self.val /= aux;
self.grad -= self.val * other.grad;
self.grad *= aux;
self.grad /= aux;
}
// ASSIGN-DIVIDE A NEGATIVE EXPRESSION: self /= (-expr)
else if constexpr (isNegExpr<U>) {
Expand All @@ -1285,6 +1329,7 @@ AUTODIFF_DEVICE_FUNC constexpr void assignDiv(Dual<T, G>& self, U&& other)
else if constexpr (isInvExpr<U>) {
assignMul(self, other.r);
}
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ASSIGN-DIVIDE A NUMBER-DUAL MULTIPLICATION EXPRESSION: self /= number * dual
else if constexpr (isNumberDualMulExpr<U>) {
assignDiv(self, other.r);
Expand All @@ -1295,6 +1340,7 @@ AUTODIFF_DEVICE_FUNC constexpr void assignDiv(Dual<T, G>& self, U&& other)
assignDiv(self, other.l);
assignDiv(self, other.r);
}
#endif
// ASSIGN-DIVIDE ALL OTHER EXPRESSIONS
else {
Dual<T, G> tmp;
Expand All @@ -1316,11 +1362,13 @@ AUTODIFF_DEVICE_FUNC constexpr void assignDiv(Dual<T, G>& self, U&& other, Dual<
else if constexpr (isInvExpr<U>) {
assignMul(self, other.r, tmp);
}
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ASSIGN-DIVIDE A MULTIPLICATION EXPRESSION: self /= expr * expr
else if constexpr (isMulExpr<U>) {
assignDiv(self, other.l, tmp);
assignDiv(self, other.r, tmp);
}
#endif
// ASSIGN-DIVIDE ALL OTHER EXPRESSIONS
else {
assign(tmp, other);
Expand All @@ -1339,9 +1387,9 @@ AUTODIFF_DEVICE_FUNC constexpr void assignPow(Dual<T, G>& self, U&& other)
{
// ASSIGN-POW A NUMBER: self = pow(self, number)
if constexpr (isArithmetic<U>) {
const T aux = pow(self.val, other - 1);
self.grad *= other * aux;
self.val = aux * self.val;
const T aux = pow(self.val, other);
self.grad *= other * aux / self.val;
self.val = aux;
}
// ASSIGN-POW A DUAL NUMBER: self = pow(self, dual)
else if constexpr (isDual<U>) {
Expand Down
3 changes: 3 additions & 0 deletions util/microphysics_autodiff.H
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

// required for AMREX_GPU_HOST_DEVICE, which is used via AUTODIFF_DEVICE_FUNC
#include <AMReX_GpuQualifiers.H>
// disable some optimizations that break standard left-to-right operator
// associativity, giving slightly different results with Dual vs. double
#define AUTODIFF_STRICT_ASSOCIATIVITY
#include <autodiff/forward/dual.hpp>
#include <autodiff/forward/utils/derivative.hpp>

Expand Down
Loading