Skip to content

Commit

Permalink
Update autodiff::dual to give the same results as amrex::Real (#1612
Browse files Browse the repository at this point in the history
)

In my testing with converting the aprox* rates to use autodiff, I found a lot of instances where code templated with autodiff::dual would give slightly different values than the same code templated with amrex::Real. It turns out to be due to some optimizations autodiff makes to speed up evaluation, which this PR disables to avoid any weird inconsistencies in the future.
  • Loading branch information
yut23 authored Jul 17, 2024
1 parent 80dffdc commit a5d2e29
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 22 deletions.
24 changes: 12 additions & 12 deletions unit_test/burn_cell/ci-benchmarks/chamulak_VODE_unit_test.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Initializing AMReX (24.06-22-g731014ff3eed)...
AMReX (24.06-22-g731014ff3eed) initialized
Initializing AMReX (24.07-16-gdcb9cc0383dc)...
AMReX (24.07-16-gdcb9cc0383dc) initialized
starting the single zone burn...
Maximum Time (s): 0.01585
State Density (g/cm^3): 1000000000
Expand All @@ -13,21 +13,21 @@ RHS at t = 0
ash 0.01230280576
------------------------------------
successful? 1
- Hnuc = 5.277400893e+17
- added e = 8.364680415e+15
- final T = 1433712612
- Hnuc = 5.277406331e+17
- added e = 8.364689034e+15
- final T = 1433713030
------------------------------------
e initial = 1.253426044e+18
e final = 1.261790725e+18
e final = 1.261790733e+18
------------------------------------
new mass fractions:
C12 0.9657895158
C12 0.9657894806
O16 1e-30
ash 0.03421048417
ash 0.03421051942
------------------------------------
species creation rates:
omegadot(C12): -2.158390168
omegadot(O16): 9.946124747e-44
omegadot(ash): 2.158390168
omegadot(C12): -2.158392392
omegadot(O16): 8.840999775e-44
omegadot(ash): 2.158392392
number of steps taken: 381
AMReX (24.06-22-g731014ff3eed) finalized
AMReX (24.07-16-gdcb9cc0383dc) finalized
68 changes: 58 additions & 10 deletions util/autodiff/autodiff/forward/dual/dual.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ AUTODIFF_DEVICE_FUNC constexpr auto negative(U&& expr)
template<typename U>
AUTODIFF_DEVICE_FUNC constexpr auto inverse(U&& expr)
{
static_assert(isExpr<U>);
static_assert(isExpr<U> || isArithmetic<U>);
if constexpr (isInvExpr<U>)
return inner(expr);
else return InvExpr<PreventExprRef<U>>{ expr };
Expand Down Expand Up @@ -780,9 +780,11 @@ AUTODIFF_DEVICE_FUNC constexpr auto operator+(L&& l, R&& r)
// ADDITION EXPRESSION CASE: (-x) + (-y) => -(x + y)
if constexpr (isNegExpr<L> && isNegExpr<R>)
return -( inner(l) + inner(r) );
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ADDITION EXPRESSION CASE: expr + number => number + expr (number always on the left)
else if constexpr (isExpr<L> && isArithmetic<R>)
return std::forward<R>(r) + std::forward<L>(l);
#endif
// DEFAULT ADDITION EXPRESSION
else return AddExpr<PreventExprRef<L>, PreventExprRef<R>>{ l, r };
}
Expand All @@ -799,18 +801,22 @@ AUTODIFF_DEVICE_FUNC constexpr auto operator*(L&& l, R&& r)
// MULTIPLICATION EXPRESSION CASE: (-expr) * (-expr) => expr * expr
if constexpr (isNegExpr<L> && isNegExpr<R>)
return inner(l) * inner(r);
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// // MULTIPLICATION EXPRESSION CASE: (1 / expr) * (1 / expr) => 1 / (expr * expr)
else if constexpr (isInvExpr<L> && isInvExpr<R>)
return inverse(inner(l) * inner(r));
// // MULTIPLICATION EXPRESSION CASE: expr * number => number * expr
else if constexpr (isExpr<L> && isArithmetic<R>)
return std::forward<R>(r) * std::forward<L>(l);
#endif
// // MULTIPLICATION EXPRESSION CASE: number * (-expr) => (-number) * expr
else if constexpr (isArithmetic<L> && isNegExpr<R>)
return (-l) * inner(r);
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// // MULTIPLICATION EXPRESSION CASE: number * (number * expr) => (number * number) * expr
else if constexpr (isArithmetic<L> && isNumberDualMulExpr<R>)
return (l * left(r)) * right(r);
#endif
// MULTIPLICATION EXPRESSION CASE: number * dual => NumberDualMulExpr
else if constexpr (isArithmetic<L> && isDual<R>)
return NumberDualMulExpr<PreventExprRef<L>, PreventExprRef<R>>{ l, r };
Expand Down Expand Up @@ -845,9 +851,7 @@ AUTODIFF_DEVICE_FUNC constexpr auto operator-(L&& l, R&& r)
template<typename L, typename R, Requires<isOperable<L, R>> = true>
AUTODIFF_DEVICE_FUNC constexpr auto operator/(L&& l, R&& r)
{
if constexpr (isArithmetic<R>)
return std::forward<L>(l) * (One<L>() / std::forward<R>(r));
else return std::forward<L>(l) * inverse(std::forward<R>(r));
return std::forward<L>(l) * inverse(std::forward<R>(r));
}

//=====================================================================================================================
Expand Down Expand Up @@ -1000,13 +1004,27 @@ AUTODIFF_DEVICE_FUNC constexpr void assign(Dual<T, G>& self, U&& other)
}
// ASSIGN AN ADDITION EXPRESSION: self = expr + expr
else if constexpr (isAddExpr<U>) {
#if defined(AUTODIFF_STRICT_ASSOCIATIVITY)
assign(self, other.l);
assignAdd(self, other.r);
#else
// this reordering saves a few FLOPs when other.l is arithmetic,
// but otherwise breaks the left-to-right associativity of a+b+c+...
assign(self, other.r);
assignAdd(self, other.l);
#endif
}
// ASSIGN A MULTIPLICATION EXPRESSION: self = expr * expr
else if constexpr (isMulExpr<U>) {
#if defined(AUTODIFF_STRICT_ASSOCIATIVITY)
assign(self, other.l);
assignMul(self, other.r);
#else
// this reordering saves a few FLOPs when other.l is arithmetic,
// but otherwise breaks the left-to-right associativity of a*b*c*...
assign(self, other.r);
assignMul(self, other.l);
#endif
}
// ASSIGN A POWER EXPRESSION: self = pow(expr)
else if constexpr (isPowExpr<U>) {
Expand Down Expand Up @@ -1042,13 +1060,23 @@ AUTODIFF_DEVICE_FUNC constexpr void assign(Dual<T, G>& self, U&& other, Dual<T,
}
// ASSIGN AN ADDITION EXPRESSION: self = expr + expr
else if constexpr (isAddExpr<U>) {
#if defined(AUTODIFF_STRICT_ASSOCIATIVITY)
assign(self, other.l, tmp);
assignAdd(self, other.r, tmp);
#else
assign(self, other.r, tmp);
assignAdd(self, other.l, tmp);
#endif
}
// ASSIGN A MULTIPLICATION EXPRESSION: self = expr * expr
else if constexpr (isMulExpr<U>) {
#if defined(AUTODIFF_STRICT_ASSOCIATIVITY)
assign(self, other.l, tmp);
assignMul(self, other.r, tmp);
#else
assign(self, other.r, tmp);
assignMul(self, other.l, tmp);
#endif
}
// ASSIGN A POWER EXPRESSION: self = pow(expr, expr)
else if constexpr (isPowExpr<U>) {
Expand Down Expand Up @@ -1091,11 +1119,13 @@ AUTODIFF_DEVICE_FUNC constexpr void assignAdd(Dual<T, G>& self, U&& other)
self.val += other.l * other.r.val;
self.grad += other.l * other.r.grad;
}
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ASSIGN-ADD AN ADDITION EXPRESSION: self += expr + expr
else if constexpr (isAddExpr<U>) {
assignAdd(self, other.l);
assignAdd(self, other.r);
}
#endif
// ASSIGN-ADD ALL OTHER EXPRESSIONS
else {
Dual<T, G> tmp;
Expand All @@ -1112,11 +1142,13 @@ AUTODIFF_DEVICE_FUNC constexpr void assignAdd(Dual<T, G>& self, U&& other, Dual<
if constexpr (isNegExpr<U>) {
assignSub(self, other.r, tmp);
}
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ASSIGN-ADD AN ADDITION EXPRESSION: self += expr + expr
else if constexpr (isAddExpr<U>) {
assignAdd(self, other.l, tmp);
assignAdd(self, other.r, tmp);
}
#endif
// ASSIGN-ADD ALL OTHER EXPRESSIONS
else {
assign(tmp, other);
Expand Down Expand Up @@ -1153,11 +1185,13 @@ AUTODIFF_DEVICE_FUNC constexpr void assignSub(Dual<T, G>& self, U&& other)
self.val -= other.l * other.r.val;
self.grad -= other.l * other.r.grad;
}
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ASSIGN-SUBTRACT AN ADDITION EXPRESSION: self -= expr + expr
else if constexpr (isAddExpr<U>) {
assignSub(self, other.l);
assignSub(self, other.r);
}
#endif
// ASSIGN-SUBTRACT ALL OTHER EXPRESSIONS
else {
Dual<T, G> tmp;
Expand All @@ -1174,11 +1208,13 @@ AUTODIFF_DEVICE_FUNC constexpr void assignSub(Dual<T, G>& self, U&& other, Dual<
if constexpr (isNegExpr<U>) {
assignAdd(self, other.r, tmp);
}
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ASSIGN-SUBTRACT AN ADDITION EXPRESSION: self -= expr + expr
else if constexpr (isAddExpr<U>) {
assignSub(self, other.l, tmp);
assignSub(self, other.r, tmp);
}
#endif
// ASSIGN-SUBTRACT ALL OTHER EXPRESSIONS
else {
assign(tmp, other);
Expand Down Expand Up @@ -1214,6 +1250,11 @@ AUTODIFF_DEVICE_FUNC constexpr void assignMul(Dual<T, G>& self, U&& other)
assignMul(self, other.r);
negate(self);
}
// ASSIGN-MULTIPLY AN INVERSE EXPRESSION: self *= 1/expr
else if constexpr (isInvExpr<U>) {
assignDiv(self, other.r);
}
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ASSIGN-MULTIPLY A NUMBER-DUAL MULTIPLICATION EXPRESSION: self *= number * dual
else if constexpr (isNumberDualMulExpr<U>) {
assignMul(self, other.r);
Expand All @@ -1224,6 +1265,7 @@ AUTODIFF_DEVICE_FUNC constexpr void assignMul(Dual<T, G>& self, U&& other)
assignMul(self, other.l);
assignMul(self, other.r);
}
#endif
// ASSIGN-MULTIPLY ALL OTHER EXPRESSIONS
else {
Dual<T, G> tmp;
Expand All @@ -1241,11 +1283,13 @@ AUTODIFF_DEVICE_FUNC constexpr void assignMul(Dual<T, G>& self, U&& other, Dual<
assignMul(self, other.r, tmp);
negate(self);
}
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ASSIGN-MULTIPLY A MULTIPLICATION EXPRESSION: self *= expr * expr
else if constexpr (isMulExpr<U>) {
assignMul(self, other.l, tmp);
assignMul(self, other.r, tmp);
}
#endif
// ASSIGN-MULTIPLY ALL OTHER EXPRESSIONS
else {
assign(tmp, other);
Expand All @@ -1271,10 +1315,10 @@ AUTODIFF_DEVICE_FUNC constexpr void assignDiv(Dual<T, G>& self, U&& other)
}
// ASSIGN-DIVIDE A DUAL NUMBER: self /= dual
else if constexpr (isDual<U>) {
const T aux = One<T>() / other.val; // to avoid aliasing when self === other
self.val *= aux;
const T aux = other.val; // to avoid aliasing when self === other
self.val /= aux;
self.grad -= self.val * other.grad;
self.grad *= aux;
self.grad /= aux;
}
// ASSIGN-DIVIDE A NEGATIVE EXPRESSION: self /= (-expr)
else if constexpr (isNegExpr<U>) {
Expand All @@ -1285,6 +1329,7 @@ AUTODIFF_DEVICE_FUNC constexpr void assignDiv(Dual<T, G>& self, U&& other)
else if constexpr (isInvExpr<U>) {
assignMul(self, other.r);
}
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ASSIGN-DIVIDE A NUMBER-DUAL MULTIPLICATION EXPRESSION: self /= number * dual
else if constexpr (isNumberDualMulExpr<U>) {
assignDiv(self, other.r);
Expand All @@ -1295,6 +1340,7 @@ AUTODIFF_DEVICE_FUNC constexpr void assignDiv(Dual<T, G>& self, U&& other)
assignDiv(self, other.l);
assignDiv(self, other.r);
}
#endif
// ASSIGN-DIVIDE ALL OTHER EXPRESSIONS
else {
Dual<T, G> tmp;
Expand All @@ -1316,11 +1362,13 @@ AUTODIFF_DEVICE_FUNC constexpr void assignDiv(Dual<T, G>& self, U&& other, Dual<
else if constexpr (isInvExpr<U>) {
assignMul(self, other.r, tmp);
}
#if !defined(AUTODIFF_STRICT_ASSOCIATIVITY)
// ASSIGN-DIVIDE A MULTIPLICATION EXPRESSION: self /= expr * expr
else if constexpr (isMulExpr<U>) {
assignDiv(self, other.l, tmp);
assignDiv(self, other.r, tmp);
}
#endif
// ASSIGN-DIVIDE ALL OTHER EXPRESSIONS
else {
assign(tmp, other);
Expand All @@ -1339,9 +1387,9 @@ AUTODIFF_DEVICE_FUNC constexpr void assignPow(Dual<T, G>& self, U&& other)
{
// ASSIGN-POW A NUMBER: self = pow(self, number)
if constexpr (isArithmetic<U>) {
const T aux = pow(self.val, other - 1);
self.grad *= other * aux;
self.val = aux * self.val;
const T aux = pow(self.val, other);
self.grad *= other * aux / self.val;
self.val = aux;
}
// ASSIGN-POW A DUAL NUMBER: self = pow(self, dual)
else if constexpr (isDual<U>) {
Expand Down
3 changes: 3 additions & 0 deletions util/microphysics_autodiff.H
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

// required for AMREX_GPU_HOST_DEVICE, which is used via AUTODIFF_DEVICE_FUNC
#include <AMReX_GpuQualifiers.H>
// disable some optimizations that break standard left-to-right operator
// associativity, giving slightly different results with Dual vs. double
#define AUTODIFF_STRICT_ASSOCIATIVITY
#include <autodiff/forward/dual.hpp>
#include <autodiff/forward/utils/derivative.hpp>

Expand Down

0 comments on commit a5d2e29

Please sign in to comment.