diff --git a/src/lib/utils/mul128.h b/src/lib/utils/mul128.h index fb91f548fff..5b2fc418b4b 100644 --- a/src/lib/utils/mul128.h +++ b/src/lib/utils/mul128.h @@ -51,22 +51,17 @@ constexpr inline void mul64x64_128(uint64_t a, uint64_t b, uint64_t* lo, uint64_ const uint32_t b_hi = (b >> HWORD_BITS); const uint32_t b_lo = (b & HWORD_MASK); - uint64_t x0 = static_cast(a_hi) * b_hi; - uint64_t x1 = static_cast(a_lo) * b_hi; - uint64_t x2 = static_cast(a_hi) * b_lo; - uint64_t x3 = static_cast(a_lo) * b_lo; + const uint64_t x0 = static_cast(a_hi) * b_hi; + const uint64_t x1 = static_cast(a_lo) * b_hi; + const uint64_t x2 = static_cast(a_hi) * b_lo; + const uint64_t x3 = static_cast(a_lo) * b_lo; - // this cannot overflow as (2^32-1)^2 + 2^32-1 < 2^64-1 - x2 += x3 >> HWORD_BITS; + // this cannot overflow as (2^32-1)^2 + 2^32-1 + 2^32-1 = 2^64-1 + const uint64_t middle = x2 + (x3 >> HWORD_BITS) + (x1 & HWORD_MASK); - // this one can overflow - x2 += x1; - - // propagate the carry if any - x0 += static_cast(static_cast(x2 < x1)) << HWORD_BITS; - - *hi = x0 + (x2 >> HWORD_BITS); - *lo = ((x2 & HWORD_MASK) << HWORD_BITS) + (x3 & HWORD_MASK); + // likewise these cannot overflow + *hi = x0 + (middle >> HWORD_BITS) + (x1 >> HWORD_BITS); + *lo = (middle << HWORD_BITS) + (x3 & HWORD_MASK); #endif }