From 90adb0f90b28c39c5171ac2446c033afeb2f6f27 Mon Sep 17 00:00:00 2001 From: Zhi Chen <62574124+zhichen3@users.noreply.github.com> Date: Tue, 25 Jun 2024 13:17:52 -0400 Subject: [PATCH] Fast exp algorithm implementation. (#1586) A fast exp algorithm implementation following: Ref: 1. A Fast, Compact Approximation of the Exponential Function by Schraudolph 1999 (https://doi.org/10.1162/089976699300016467). 2. On a Fast Compact Approximation of the Exponential Function by Cawley 2000 (https://doi.org/10.1162/089976600300015033). 3. https://gist.github.com/jrade/293a73f89dfef51da6522428c857802d 4. https://github.com/ekmett/approximate/blob/master/cbits/fast.c There are two versions: jrade_exp: Its roughly 6-times faster than std::exp. It has reasonable accuracy across all range (roughly ~2% relative error), we follow Ref (3) ekmett_exp: It is about twice as slow as jrade_exp, but still ~ 3 times faster than std::exp, but it gives much better accuracy, ~0.1% relative error for most cases. The main driver is fast_exp, which also included a simple taylor approximation, 1+x, when x < 0.1, currently it is defaulted to use jrade_exp but one can just switch to ekmett_exp for better accuracy. We also use memcpy approach to avoid undefined behavior from type-punning when using union (pointed out by Eric), and this approach is demonstrated in Ref 3, which we adapt for ekmett_exp. --- nse_solver/make_table/burn_cell.H | 3 +- unit_test/test_ase/make_table/burn_cell.H | 3 +- util/approx_math/approx_math.H | 189 ++++++++++++++++++ util/approx_math/test_fast_exp/GNUmakefile | 39 ++++ util/approx_math/test_fast_exp/Make.package | 1 + util/approx_math/test_fast_exp/main.cpp | 60 ++++++ .../approx_math/test_fast_exp/test_fast_exp.H | 96 +++++++++ 7 files changed, 389 insertions(+), 2 deletions(-) create mode 100644 util/approx_math/test_fast_exp/GNUmakefile create mode 100644 util/approx_math/test_fast_exp/Make.package create mode 100644 util/approx_math/test_fast_exp/main.cpp create mode 100644 util/approx_math/test_fast_exp/test_fast_exp.H diff --git a/nse_solver/make_table/burn_cell.H b/nse_solver/make_table/burn_cell.H index 472f81a193..4ed3d89896 100644 --- a/nse_solver/make_table/burn_cell.H +++ b/nse_solver/make_table/burn_cell.H @@ -53,7 +53,8 @@ void burn_cell_c() // find the nse state const bool assume_ye_is_valid = true; - amrex::Real eps = 1.e-10; + amrex::Real eps = 1.e-10_rt; + use_hybrid_solver = 1; auto nse_state = get_actual_nse_state(state, eps, assume_ye_is_valid); diff --git a/unit_test/test_ase/make_table/burn_cell.H b/unit_test/test_ase/make_table/burn_cell.H index a542b2be6e..168ac8cdc7 100644 --- a/unit_test/test_ase/make_table/burn_cell.H +++ b/unit_test/test_ase/make_table/burn_cell.H @@ -58,7 +58,8 @@ void burn_cell_c() // find the nse state const bool assume_ye_is_valid = true; - amrex::Real eps = 1.e-10; + amrex::Real eps = 1.e-10_rt; + use_hybrid_solver = 1; auto nse_state = get_actual_nse_state(state, eps, assume_ye_is_valid); diff --git a/util/approx_math/approx_math.H b/util/approx_math/approx_math.H index 5dea510c17..773032f198 100644 --- a/util/approx_math/approx_math.H +++ b/util/approx_math/approx_math.H @@ -3,6 +3,8 @@ #include #include +#include +#include using namespace amrex::literals; @@ -81,4 +83,191 @@ amrex::Real fast_atan(const amrex::Real x) { return fast_atan_1(x); } + +/// +/// A fast implementation of exp with single/double precision input +/// This gives reasonable accuracy across all range, ~2% error +/// +/// This version uses memcpy to avoid potential undefined behavior +/// from type punning through a union in Ref (1) and (2) +/// +/// Code is obtained from Ref (3): +/// Ref: +/// 1) A Fast, Compact Approximation of the Exponential Function +/// by Schraudolph 1999 +/// 2) On a Fast Compact Approximation of the Exponential Function +/// by Cawley 2000 +/// 3) https://gist.github.com/jrade/293a73f89dfef51da6522428c857802d +/// + +AMREX_GPU_HOST_DEVICE AMREX_INLINE +float jrade_exp(const float x) { + /// For single precision input. + /// a = 2^23 / ln2 + /// + /// b = 2^23 * (x0 - C), where x0 = 127 is the exponent bias + /// C = (ln(ln2 + 2/e) - ln2 - ln(ln2)) / ln2 = 0.04367744890362246 + /// This is a constant shift term chosen to minimize maximum relative error + /// + /// Let C = ln(3/(8*ln2) + 0.5)/ln2 = 0.0579848147254; + /// in order to minimize RMS relative error. + /// + + constexpr float a = gcem::pow(2.0F, 23) / 0.6931471805599453F; + constexpr float b = gcem::pow(2.0F, 23) * (127.0F - 0.04367744890362246F); + float y = a * x + b; + + // + // Return 0 for large negative number + // Return Inf for large positive number + // + + constexpr float c = gcem::pow(2.0F, 23); + constexpr float d = gcem::pow(2.0F, 23) * 255.0F; + if (y < c || y > d) { + y = (y < c) ? 0.0F : d; + } + + auto n = static_cast(y); + memcpy(&y, &n, 4); + return y; +} + + +AMREX_GPU_HOST_DEVICE AMREX_INLINE +double jrade_exp(const double x) { + /// For double precision input. + /// a = 2^52 / ln2 + /// + /// b = 2^52 * (x0 - C), where x0 = 1023 is the exponent bias + /// C = (ln(ln2 + 2/e) - ln2 - ln(ln2)) / ln2 = 0.04367744890362246 + /// This is a constant shift term chosen to minimize maximum relative error + /// + /// Let C = ln(3/(8*ln2) + 0.5)/ln2 = 0.0579848147254; + /// in order to minimize RMS relative error. + /// + + constexpr double a = gcem::pow(2.0, 52) / 0.6931471805599453; + constexpr double b = gcem::pow(2.0, 52) * (1023.0 - 0.04367744890362246); + double y = a * x + b; + + // + // Return 0 for large negative number + // Return Inf for large positive number + // + + constexpr double c = gcem::pow(2.0, 52); + constexpr double d = gcem::pow(2.0, 52) * 2047.0; + if (y < c || y > d) { + y = (y < c) ? 0.0 : d; + } + + auto n = static_cast(y); + memcpy(&y, &n, 8); + return y; +} + + +/// +/// This is a more accurate than jrade_exp +/// but it is roughly twice as slow. +/// +/// This uses the identity exp(x) = exp(x/2) / exp(-x/2), +/// so there is extra factor of 0.5 in a +/// +/// This comes from: +/// https://github.com/ekmett/approximate/blob/master/cbits/fast.c +/// + + +AMREX_GPU_HOST_DEVICE AMREX_INLINE +float ekmett_exp(const float x) { + /// + /// For single precision input + /// + + constexpr float a = gcem::pow(2.0F, 23) * 0.5F / 0.6931471805599453F; + + // For minimizing max relative error + constexpr float b = gcem::pow(2.0F, 23) * (127.0F - 0.04367744890362246F); + + float u = a * x + b; + float v = b - a * x; + + // + // Return 0 for large negative number + // Return Inf for large positive number + // + + constexpr float c = gcem::pow(2.0F, 23); + constexpr float d = gcem::pow(2.0F, 23) * 255.0F; + if (u < c || u > d) { + u = (u < c) ? 0.0F : d; + } + + auto n = static_cast(u); + auto m = static_cast(v); + + memcpy(&u, &n, 4); + memcpy(&v, &m, 4); + + return u / v; +} + + +AMREX_GPU_HOST_DEVICE AMREX_INLINE +double ekmett_exp(const double x) { + /// + /// For double precision input + /// + + constexpr double a = gcem::pow(2.0, 52) * 0.5 / 0.6931471805599453; + + // For minimizing max relative error + constexpr double b = gcem::pow(2.0, 52) * (1023.0 - 0.04367744890362246); + + double u = a * x + b; + double v = b - a * x; + + // + // Return 0 for large negative number + // Return Inf for large positive number + // + + constexpr double c = gcem::pow(2.0, 52); + constexpr double d = gcem::pow(2.0, 52) * 2047.0; + if (u < c || u > d) { + u = (u < c) ? 0.0 : d; + } + + auto n = static_cast(u); + auto m = static_cast(v); + + memcpy(&u, &n, 8); + memcpy(&v, &m, 8); + + return u / v; +} + + +AMREX_GPU_HOST_DEVICE AMREX_INLINE +amrex::Real fast_exp(const amrex::Real x) { + /// + /// Implementation of fast exponential. + /// This combines Taylor series when x < 0.1 + /// and the fast exponential function algorithm from various sources: + /// jrade: https://gist.github.com/jrade/293a73f89dfef51da6522428c857802d + /// ekmett: https://github.com/ekmett/approximate/blob/master/cbits/fast.c + /// + + // Use Taylor if number is smaller than 0.1 + // Minor performance hit, but much better accuracy when x < 0.1 + + if (std::abs(x) < 0.1_rt) { + return 1.0_rt + x; + } + + return jrade_exp(x); + // return ekmett_exp(x); +} #endif diff --git a/util/approx_math/test_fast_exp/GNUmakefile b/util/approx_math/test_fast_exp/GNUmakefile new file mode 100644 index 0000000000..9a46591a74 --- /dev/null +++ b/util/approx_math/test_fast_exp/GNUmakefile @@ -0,0 +1,39 @@ +PRECISION = DOUBLE +PROFILE = FALSE + +DEBUG = FALSE + +DIM = 3 + +COMP = gnu + +USE_MPI = FALSE +USE_OMP = FALSE + +USE_REACT = TRUE + +EBASE = main + +# define the location of the Microphysics top directory +MICROPHYSICS_HOME ?= ../../.. + +# This sets the EOS directory +EOS_DIR := helmholtz + +# This sets the network directory +NETWORK_DIR := aprox21 + +CONDUCTIVITY_DIR := stellar + +INTEGRATOR_DIR = VODE + +ifeq ($(USE_CUDA), TRUE) + INTEGRATOR_DIR := VODE +endif + +EXTERN_SEARCH += . + +Bpack := ../Make.package ./Make.package +Blocs := ../ . + +include $(MICROPHYSICS_HOME)/unit_test/Make.unit_test diff --git a/util/approx_math/test_fast_exp/Make.package b/util/approx_math/test_fast_exp/Make.package new file mode 100644 index 0000000000..6b4b865e8f --- /dev/null +++ b/util/approx_math/test_fast_exp/Make.package @@ -0,0 +1 @@ +CEXE_sources += main.cpp diff --git a/util/approx_math/test_fast_exp/main.cpp b/util/approx_math/test_fast_exp/main.cpp new file mode 100644 index 0000000000..2d7139631c --- /dev/null +++ b/util/approx_math/test_fast_exp/main.cpp @@ -0,0 +1,60 @@ +#include + +int main() { + + // Accuracy tests + + // Test values including edge cases and some typical values + + test_fast_exp_accuracy(0.0_rt); + test_fast_exp_accuracy(0.001_rt); + test_fast_exp_accuracy(0.01_rt); + test_fast_exp_accuracy(0.1_rt); + test_fast_exp_accuracy(0.5_rt); + test_fast_exp_accuracy(1.0_rt); + test_fast_exp_accuracy(5.0_rt); + test_fast_exp_accuracy(10.0_rt); + test_fast_exp_accuracy(15.0_rt); + test_fast_exp_accuracy(20.0_rt); + test_fast_exp_accuracy(30.0_rt); + test_fast_exp_accuracy(50.0_rt); + test_fast_exp_accuracy(100.0_rt); + test_fast_exp_accuracy(500.0_rt); + + test_fast_exp_accuracy(-0.001_rt); + test_fast_exp_accuracy(-0.01_rt); + test_fast_exp_accuracy(-0.1_rt); + test_fast_exp_accuracy(-0.5_rt); + test_fast_exp_accuracy(-1.0_rt); + test_fast_exp_accuracy(-5.0_rt); + test_fast_exp_accuracy(-10.0_rt); + test_fast_exp_accuracy(-15.0_rt); + test_fast_exp_accuracy(-20.0_rt); + test_fast_exp_accuracy(-30.0_rt); + test_fast_exp_accuracy(-50.0_rt); + test_fast_exp_accuracy(-100.0_rt); + test_fast_exp_accuracy(-500.0_rt); + + std::cout << "Accuracy tests passed!" << std::endl; + + // Now performance test + + int iters = 5; + amrex::Real test_value = 160.0_rt; + test_fast_exp_speed(100, iters, test_value); + + iters = 10; + test_fast_exp_speed(100, iters, test_value); + + iters = 20; + test_fast_exp_speed(100, iters, test_value); + + iters = 30; + test_fast_exp_speed(100, iters, test_value); + + iters = 50; + test_fast_exp_speed(100, iters, test_value); + + // iters = 70; + // test_fast_exp_speed(100, iters, test_value); +} diff --git a/util/approx_math/test_fast_exp/test_fast_exp.H b/util/approx_math/test_fast_exp/test_fast_exp.H new file mode 100644 index 0000000000..b42b5582e6 --- /dev/null +++ b/util/approx_math/test_fast_exp/test_fast_exp.H @@ -0,0 +1,96 @@ +#ifndef TEST_FAST_EXP_H +#define TEST_FAST_EXP_H + +#include +#include +#include +#include +#include +#include + + +template +void test_fast_exp_accuracy(T x) { + //This tests fast_exp accuracy + + T fast_exp_result = fast_exp(x); + T std_exp_result = std::exp(x); + + // Print results + std::cout << "x: " << x + << " fast_exp: " << fast_exp_result + << " std::exp: " << std_exp_result << std::endl; + + T abs_err = std::abs(fast_exp_result - std_exp_result); + T rel_err = abs_err / std_exp_result; + auto rtol = static_cast(0.05); + + std::cout << "absolute error: " << abs_err << std::endl; + std::cout << "relative error: " << rel_err << std::endl; + + assert(rel_err < rtol); +} + + +template +void test_fast_exp_speed(int loops, int iter, T x) { + // This tests fast_exp performance + + std::cout << "Testing with loops: " << loops + << " with iter: " << iter + << " with initial x: " << x << std::endl; + + auto fac = static_cast(1.e-4); + + auto start = std::chrono::high_resolution_clock::now(); + { + T x_in; + T result; + for (int m = 0; m < loops; ++m) { + x_in = x; + result = 0.0; + for (int i = 0; i < iter; ++i) { + for (int j = 0; j < iter; ++j) { + for (int k = 0; k < iter; ++k) { + result += std::exp(x_in); + x_in -= fac * x_in; + } + } + } + } + // don't let the compiler elide this side-effect-free loop (at the cost of a memory write) + volatile T volatile_result; + volatile_result = result; + } + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration std_exp_duration = end - start; + + + start = std::chrono::high_resolution_clock::now(); + { + T x_in; + T result; + for (int m = 0; m < loops; ++m) { + x_in = x; + result = 0.0; + for (int i = 0; i < iter; ++i) { + for (int j = 0; j < iter; ++j) { + for (int k = 0; k < iter; ++k) { + result += fast_exp(x_in); + x_in -= fac * x_in; + } + } + } + } + // don't let the compiler elide this side-effect-free loop (at the cost of a memory write) + volatile T volatile_result; + volatile_result = result; + } + end = std::chrono::high_resolution_clock::now(); + std::chrono::duration fast_exp_duration = end - start; + + + std::cout << "fast_exp duration: " << fast_exp_duration.count() << " seconds\n"; + std::cout << "std::exp duration: " << std_exp_duration.count() << " seconds\n"; +} +#endif