Skip to content

Commit

Permalink
Fast exp algorithm implementation. (#1586)
Browse files Browse the repository at this point in the history
A fast exp algorithm implementation following:
Ref:

1. A Fast, Compact Approximation of the Exponential Function by Schraudolph 1999 (https://doi.org/10.1162/089976699300016467).

2. On a Fast Compact Approximation of the Exponential Function by Cawley 2000 (https://doi.org/10.1162/089976600300015033).

3. https://gist.github.com/jrade/293a73f89dfef51da6522428c857802d

4. https://github.com/ekmett/approximate/blob/master/cbits/fast.c

There are two versions:

jrade_exp: Its roughly 6-times faster than std::exp. It has reasonable accuracy across all range (roughly ~2% relative error), we follow Ref (3)

ekmett_exp: It is about twice as slow as jrade_exp, but still ~ 3 times faster than std::exp, but it gives much better accuracy, ~0.1% relative error for most cases.

The main driver is fast_exp, which also included a simple taylor approximation, 1+x, when x < 0.1, currently it is defaulted to use jrade_exp but one can just switch to ekmett_exp for better accuracy.

We also use memcpy approach to avoid undefined behavior from type-punning when using union (pointed out by Eric), and this approach is demonstrated in Ref 3, which we adapt for ekmett_exp.
  • Loading branch information
zhichen3 authored Jun 25, 2024
1 parent e27c00f commit 90adb0f
Show file tree
Hide file tree
Showing 7 changed files with 389 additions and 2 deletions.
3 changes: 2 additions & 1 deletion nse_solver/make_table/burn_cell.H
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ void burn_cell_c()
// find the nse state

const bool assume_ye_is_valid = true;
amrex::Real eps = 1.e-10;
amrex::Real eps = 1.e-10_rt;

use_hybrid_solver = 1;

auto nse_state = get_actual_nse_state(state, eps, assume_ye_is_valid);
Expand Down
3 changes: 2 additions & 1 deletion unit_test/test_ase/make_table/burn_cell.H
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ void burn_cell_c()
// find the nse state

const bool assume_ye_is_valid = true;
amrex::Real eps = 1.e-10;
amrex::Real eps = 1.e-10_rt;

use_hybrid_solver = 1;

auto nse_state = get_actual_nse_state(state, eps, assume_ye_is_valid);
Expand Down
189 changes: 189 additions & 0 deletions util/approx_math/approx_math.H
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#include <AMReX_REAL.H>
#include <microphysics_math.H>
#include <cstdint>
#include <cstring>

using namespace amrex::literals;

Expand Down Expand Up @@ -81,4 +83,191 @@ amrex::Real fast_atan(const amrex::Real x) {
return fast_atan_1(x);
}


///
/// A fast implementation of exp with single/double precision input
/// This gives reasonable accuracy across all range, ~2% error
///
/// This version uses memcpy to avoid potential undefined behavior
/// from type punning through a union in Ref (1) and (2)
///
/// Code is obtained from Ref (3):
/// Ref:
/// 1) A Fast, Compact Approximation of the Exponential Function
/// by Schraudolph 1999
/// 2) On a Fast Compact Approximation of the Exponential Function
/// by Cawley 2000
/// 3) https://gist.github.com/jrade/293a73f89dfef51da6522428c857802d
///

AMREX_GPU_HOST_DEVICE AMREX_INLINE
float jrade_exp(const float x) {
/// For single precision input.
/// a = 2^23 / ln2
///
/// b = 2^23 * (x0 - C), where x0 = 127 is the exponent bias
/// C = (ln(ln2 + 2/e) - ln2 - ln(ln2)) / ln2 = 0.04367744890362246
/// This is a constant shift term chosen to minimize maximum relative error
///
/// Let C = ln(3/(8*ln2) + 0.5)/ln2 = 0.0579848147254;
/// in order to minimize RMS relative error.
///

constexpr float a = gcem::pow(2.0F, 23) / 0.6931471805599453F;
constexpr float b = gcem::pow(2.0F, 23) * (127.0F - 0.04367744890362246F);
float y = a * x + b;

//
// Return 0 for large negative number
// Return Inf for large positive number
//

constexpr float c = gcem::pow(2.0F, 23);
constexpr float d = gcem::pow(2.0F, 23) * 255.0F;
if (y < c || y > d) {
y = (y < c) ? 0.0F : d;
}

auto n = static_cast<std::int32_t>(y);
memcpy(&y, &n, 4);
return y;
}


AMREX_GPU_HOST_DEVICE AMREX_INLINE
double jrade_exp(const double x) {
/// For double precision input.
/// a = 2^52 / ln2
///
/// b = 2^52 * (x0 - C), where x0 = 1023 is the exponent bias
/// C = (ln(ln2 + 2/e) - ln2 - ln(ln2)) / ln2 = 0.04367744890362246
/// This is a constant shift term chosen to minimize maximum relative error
///
/// Let C = ln(3/(8*ln2) + 0.5)/ln2 = 0.0579848147254;
/// in order to minimize RMS relative error.
///

constexpr double a = gcem::pow(2.0, 52) / 0.6931471805599453;
constexpr double b = gcem::pow(2.0, 52) * (1023.0 - 0.04367744890362246);
double y = a * x + b;

//
// Return 0 for large negative number
// Return Inf for large positive number
//

constexpr double c = gcem::pow(2.0, 52);
constexpr double d = gcem::pow(2.0, 52) * 2047.0;
if (y < c || y > d) {
y = (y < c) ? 0.0 : d;
}

auto n = static_cast<std::int64_t>(y);
memcpy(&y, &n, 8);
return y;
}


///
/// This is a more accurate than jrade_exp
/// but it is roughly twice as slow.
///
/// This uses the identity exp(x) = exp(x/2) / exp(-x/2),
/// so there is extra factor of 0.5 in a
///
/// This comes from:
/// https://github.com/ekmett/approximate/blob/master/cbits/fast.c
///


AMREX_GPU_HOST_DEVICE AMREX_INLINE
float ekmett_exp(const float x) {
///
/// For single precision input
///

constexpr float a = gcem::pow(2.0F, 23) * 0.5F / 0.6931471805599453F;

// For minimizing max relative error
constexpr float b = gcem::pow(2.0F, 23) * (127.0F - 0.04367744890362246F);

float u = a * x + b;
float v = b - a * x;

//
// Return 0 for large negative number
// Return Inf for large positive number
//

constexpr float c = gcem::pow(2.0F, 23);
constexpr float d = gcem::pow(2.0F, 23) * 255.0F;
if (u < c || u > d) {
u = (u < c) ? 0.0F : d;
}

auto n = static_cast<std::int32_t>(u);
auto m = static_cast<std::int32_t>(v);

memcpy(&u, &n, 4);
memcpy(&v, &m, 4);

return u / v;
}


AMREX_GPU_HOST_DEVICE AMREX_INLINE
double ekmett_exp(const double x) {
///
/// For double precision input
///

constexpr double a = gcem::pow(2.0, 52) * 0.5 / 0.6931471805599453;

// For minimizing max relative error
constexpr double b = gcem::pow(2.0, 52) * (1023.0 - 0.04367744890362246);

double u = a * x + b;
double v = b - a * x;

//
// Return 0 for large negative number
// Return Inf for large positive number
//

constexpr double c = gcem::pow(2.0, 52);
constexpr double d = gcem::pow(2.0, 52) * 2047.0;
if (u < c || u > d) {
u = (u < c) ? 0.0 : d;
}

auto n = static_cast<std::int64_t>(u);
auto m = static_cast<std::int64_t>(v);

memcpy(&u, &n, 8);
memcpy(&v, &m, 8);

return u / v;
}


AMREX_GPU_HOST_DEVICE AMREX_INLINE
amrex::Real fast_exp(const amrex::Real x) {
///
/// Implementation of fast exponential.
/// This combines Taylor series when x < 0.1
/// and the fast exponential function algorithm from various sources:
/// jrade: https://gist.github.com/jrade/293a73f89dfef51da6522428c857802d
/// ekmett: https://github.com/ekmett/approximate/blob/master/cbits/fast.c
///

// Use Taylor if number is smaller than 0.1
// Minor performance hit, but much better accuracy when x < 0.1

if (std::abs(x) < 0.1_rt) {
return 1.0_rt + x;
}

return jrade_exp(x);
// return ekmett_exp(x);
}
#endif
39 changes: 39 additions & 0 deletions util/approx_math/test_fast_exp/GNUmakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
PRECISION = DOUBLE
PROFILE = FALSE

DEBUG = FALSE

DIM = 3

COMP = gnu

USE_MPI = FALSE
USE_OMP = FALSE

USE_REACT = TRUE

EBASE = main

# define the location of the Microphysics top directory
MICROPHYSICS_HOME ?= ../../..

# This sets the EOS directory
EOS_DIR := helmholtz

# This sets the network directory
NETWORK_DIR := aprox21

CONDUCTIVITY_DIR := stellar

INTEGRATOR_DIR = VODE

ifeq ($(USE_CUDA), TRUE)
INTEGRATOR_DIR := VODE
endif

EXTERN_SEARCH += .

Bpack := ../Make.package ./Make.package
Blocs := ../ .

include $(MICROPHYSICS_HOME)/unit_test/Make.unit_test
1 change: 1 addition & 0 deletions util/approx_math/test_fast_exp/Make.package
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CEXE_sources += main.cpp
60 changes: 60 additions & 0 deletions util/approx_math/test_fast_exp/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include <test_fast_exp.H>

int main() {

// Accuracy tests

// Test values including edge cases and some typical values

test_fast_exp_accuracy(0.0_rt);
test_fast_exp_accuracy(0.001_rt);
test_fast_exp_accuracy(0.01_rt);
test_fast_exp_accuracy(0.1_rt);
test_fast_exp_accuracy(0.5_rt);
test_fast_exp_accuracy(1.0_rt);
test_fast_exp_accuracy(5.0_rt);
test_fast_exp_accuracy(10.0_rt);
test_fast_exp_accuracy(15.0_rt);
test_fast_exp_accuracy(20.0_rt);
test_fast_exp_accuracy(30.0_rt);
test_fast_exp_accuracy(50.0_rt);
test_fast_exp_accuracy(100.0_rt);
test_fast_exp_accuracy(500.0_rt);

test_fast_exp_accuracy(-0.001_rt);
test_fast_exp_accuracy(-0.01_rt);
test_fast_exp_accuracy(-0.1_rt);
test_fast_exp_accuracy(-0.5_rt);
test_fast_exp_accuracy(-1.0_rt);
test_fast_exp_accuracy(-5.0_rt);
test_fast_exp_accuracy(-10.0_rt);
test_fast_exp_accuracy(-15.0_rt);
test_fast_exp_accuracy(-20.0_rt);
test_fast_exp_accuracy(-30.0_rt);
test_fast_exp_accuracy(-50.0_rt);
test_fast_exp_accuracy(-100.0_rt);
test_fast_exp_accuracy(-500.0_rt);

std::cout << "Accuracy tests passed!" << std::endl;

// Now performance test

int iters = 5;
amrex::Real test_value = 160.0_rt;
test_fast_exp_speed(100, iters, test_value);

iters = 10;
test_fast_exp_speed(100, iters, test_value);

iters = 20;
test_fast_exp_speed(100, iters, test_value);

iters = 30;
test_fast_exp_speed(100, iters, test_value);

iters = 50;
test_fast_exp_speed(100, iters, test_value);

// iters = 70;
// test_fast_exp_speed(100, iters, test_value);
}
Loading

0 comments on commit 90adb0f

Please sign in to comment.