diff --git a/doc/sf/ccmath.qbk b/doc/sf/ccmath.qbk index f39e388ee5..4a2b0b2fbf 100644 --- a/doc/sf/ccmath.qbk +++ b/doc/sf/ccmath.qbk @@ -184,6 +184,7 @@ All of the following functions require C++17 or greater. template inline constexpr Real fma(Real x, Real y, Real z) noexcept + Requires compiling with fma flag template inline constepxr Promoted fma(Arithmetic1 x, Arithmetic2 y, Arithmetic3 z) noexcept diff --git a/include/boost/math/ccmath/fma.hpp b/include/boost/math/ccmath/fma.hpp index 12390953e7..0ed3cc5668 100644 --- a/include/boost/math/ccmath/fma.hpp +++ b/include/boost/math/ccmath/fma.hpp @@ -13,6 +13,11 @@ #include #include +#if __has_include("immintrin.h") && defined(__X86_64__) || defined(__amd64__) +# include "immintrin.h" +# define BOOST_MATH_HAS_IMMINTRIN_H +#endif + namespace boost::math::ccmath { namespace detail { @@ -20,9 +25,7 @@ namespace detail { template inline constexpr T fma_imp(const T x, const T y, const T z) noexcept { - #if __GNUC__ < 10 - return (x * y) + z; - #else + #if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER) && !defined(__INTEL_LLVM_COMPILER) if constexpr (std::is_same_v) { return __builtin_fmaf(x, y, z); @@ -35,11 +38,23 @@ inline constexpr T fma_imp(const T x, const T y, const T z) noexcept { return __builtin_fmal(x, y, z); } - else // e.g. Boost.Multiprecision types where no built-in exists + #elif defined(BOOST_MATH_HAS_IMMINTRIN_H) + if constexpr (std::is_same_v) + { + return static_cast(_mm_fmadd_ps(x, y, z)); + } + else if constexpr (std::is_same_v) + { + return static_cast(_mm_fmadd_pd(x, y, z)); + } + else if constexpr (std::is_same_v) { - return (x * y) + z; + return static_cast(_mm256_fmadd_pd(x, y, z)); } #endif + + // If we can't use compiler intrinsics hope that -fma flag optimizes this call to fma instruction + return (x * y) + z; } } // Namespace detail diff --git a/test/ccmath_fma_test.cpp b/test/ccmath_fma_test.cpp index 3b7354df2e..a5aa74914d 100644 --- a/test/ccmath_fma_test.cpp +++ b/test/ccmath_fma_test.cpp @@ -17,6 +17,7 @@ #include #endif +#if !defined(BOOST_MATH_NO_CONSTEXPR_DETECTION) && !defined(BOOST_MATH_USING_BUILTIN_CONSTANT_P) template constexpr void test() { @@ -49,7 +50,6 @@ constexpr void test() } } -#if !defined(BOOST_MATH_NO_CONSTEXPR_DETECTION) && !defined(BOOST_MATH_USING_BUILTIN_CONSTANT_P) int main() { test();