diff --git a/cpp/include/cudf/fixed_point/fixed_point.hpp b/cpp/include/cudf/fixed_point/fixed_point.hpp index 542e2b3c5c8..4445af6c5a8 100644 --- a/cpp/include/cudf/fixed_point/fixed_point.hpp +++ b/cpp/include/cudf/fixed_point/fixed_point.hpp @@ -85,41 +85,7 @@ constexpr inline auto is_supported_construction_value_type() namespace detail { /** - * @brief Recursively computes integer exponentiation - * - * @note This is intended to be run at compile time - * - * @tparam Rep Representation type for return type - * @tparam Base The base to be exponentiated - * @param exp The exponent to be used for exponentiation - * @return Result of `Base` to the power of `exponent` of type `Rep` - */ -template -CUDF_HOST_DEVICE inline constexpr Rep get_power(int32_t exp) -{ - // Compute power recursively - return (exp > 0) ? Rep(Base) * get_power(exp - 1) : 1; -} - -/** - * @brief Implementation of integer exponentiation by array lookup - * - * @tparam Rep Representation type for return type - * @tparam Base The base to be exponentiated - * @tparam Exponents The exponents for the array entries - * @param exponent The exponent to be used for exponentiation - * @return Result of `Base` to the power of `exponent` of type `Rep` - */ -template -CUDF_HOST_DEVICE inline Rep ipow_impl(int32_t exponent, cuda::std::index_sequence) -{ - // Compute powers at compile time, storing into array - static constexpr Rep powers[] = {get_power(Exponents)...}; - return powers[exponent]; -} - -/** - * @brief A function for integer exponentiation by array lookup + * @brief A function for integer exponentiation by squaring. * * @tparam Rep Representation type for return type * @tparam Base The base to be exponentiated @@ -134,16 +100,22 @@ template = 0 && "integer exponentiation with negative exponent is not possible."); - if constexpr (Base == numeric::Radix::BASE_2) { - return static_cast(1) << exponent; - } else { // BASE_10 - // Build index sequence for building power array at compile time - static constexpr auto max_exp = cuda::std::numeric_limits::digits10; - static constexpr auto exponents = cuda::std::make_index_sequence{}; - - // Get compile-time result - return ipow_impl(Base)>(exponent, exponents); + + if constexpr (Base == numeric::Radix::BASE_2) { return static_cast(1) << exponent; } + + // Note: Including an array here introduces too much register pressure + // https://simple.wikipedia.org/wiki/Exponentiation_by_squaring + // This is the iterative equivalent of the recursive definition (faster) + // Quick-bench for squaring: http://quick-bench.com/Wg7o7HYQC9FW5M0CO0wQAjSwP_Y + if (exponent == 0) { return static_cast(1); } + auto extra = static_cast(1); + auto square = static_cast(Base); + while (exponent > 1) { + if (exponent & 1) { extra *= square; } + exponent >>= 1; + square *= square; } + return square * extra; } /** @brief Function that performs a `right shift` scale "times" on the `val`