From 33412b791582e822f0012d2293787190aa051fe3 Mon Sep 17 00:00:00 2001 From: Jack Lloyd Date: Mon, 16 Dec 2024 18:03:43 -0500 Subject: [PATCH] Use a better approach for modular subtraction For pcurves this improves overall ECDH/ECDSA performance from anywhere from 5% to up to 15% (!) depending on the curve. Remove the templatized length variant of bigint_mod_sub; this only existed as a performance hack for EC_Point and is much less relevant with pcurves in use. --- src/lib/math/bigint/big_ops2.cpp | 8 +---- src/lib/math/mp/mp_core.h | 35 +++---------------- .../math/pcurves/pcurves_impl/pcurves_impl.h | 7 +++- src/tests/test_pcurves.cpp | 18 ++++++++-- 4 files changed, 28 insertions(+), 40 deletions(-) diff --git a/src/lib/math/bigint/big_ops2.cpp b/src/lib/math/bigint/big_ops2.cpp index 1b4589f5f13..c768c5a829f 100644 --- a/src/lib/math/bigint/big_ops2.cpp +++ b/src/lib/math/bigint/big_ops2.cpp @@ -105,13 +105,7 @@ BigInt& BigInt::mod_sub(const BigInt& s, const BigInt& mod, secure_vector& ws.resize(mod_sw); } - if(mod_sw == 4) { - bigint_mod_sub_n<4>(mutable_data(), s._data(), mod._data(), ws.data()); - } else if(mod_sw == 6) { - bigint_mod_sub_n<6>(mutable_data(), s._data(), mod._data(), ws.data()); - } else { - bigint_mod_sub(mutable_data(), s._data(), mod._data(), mod_sw, ws.data()); - } + bigint_mod_sub(mutable_data(), s._data(), mod._data(), mod_sw, ws.data()); return (*this); } diff --git a/src/lib/math/mp/mp_core.h b/src/lib/math/mp/mp_core.h index 90d947447ab..6c9247e92b2 100644 --- a/src/lib/math/mp/mp_core.h +++ b/src/lib/math/mp/mp_core.h @@ -737,38 +737,13 @@ inline constexpr int32_t bigint_sub_abs(W z[], const W x[], size_t x_size, const */ template inline constexpr void bigint_mod_sub(W t[], const W s[], const W mod[], size_t mod_sw, W ws[]) { - // is t < s or not? - const auto is_lt = bigint_ct_is_lt(t, mod_sw, s, mod_sw); + // ws = t - s + const W borrow = bigint_sub3(ws, t, mod_sw, s, mod_sw); - // ws = p - s - const W borrow = bigint_sub3(ws, mod, mod_sw, s, mod_sw); + // Conditionally add back the modulus + bigint_cnd_add(borrow, ws, mod, mod_sw); - // Compute either (t - s) or (t + (p - s)) depending on mask - const W carry = bigint_cnd_addsub(is_lt, t, ws, s, mod_sw); - - if(!std::is_constant_evaluated()) { - BOTAN_DEBUG_ASSERT(borrow == 0 && carry == 0); - } - - BOTAN_UNUSED(carry, borrow); -} - -template -inline constexpr void bigint_mod_sub_n(W t[], const W s[], const W mod[], W ws[]) { - // is t < s or not? - const auto is_lt = bigint_ct_is_lt(t, N, s, N); - - // ws = p - s - const W borrow = bigint_sub3(ws, mod, N, s, N); - - // Compute either (t - s) or (t + (p - s)) depending on mask - const W carry = bigint_cnd_addsub(is_lt, t, ws, s, N); - - if(!std::is_constant_evaluated()) { - BOTAN_DEBUG_ASSERT(borrow == 0 && carry == 0); - } - - BOTAN_UNUSED(carry, borrow); + copy_mem(t, ws, mod_sw); } /** diff --git a/src/lib/math/pcurves/pcurves_impl/pcurves_impl.h b/src/lib/math/pcurves/pcurves_impl/pcurves_impl.h index ef857978255..3882e63f402 100644 --- a/src/lib/math/pcurves/pcurves_impl/pcurves_impl.h +++ b/src/lib/math/pcurves/pcurves_impl/pcurves_impl.h @@ -135,7 +135,12 @@ class IntMod final { return Self(r); } - friend constexpr Self operator-(const Self& a, const Self& b) { return a + b.negate(); } + friend constexpr Self operator-(const Self& a, const Self& b) { + std::array r; + word carry = bigint_sub3(r.data(), a.data(), N, b.data(), N); + bigint_cnd_add(carry, r.data(), N, P.data(), N); + return Self(r); + } /// Return (*this) divided by 2 Self div2() const { diff --git a/src/tests/test_pcurves.cpp b/src/tests/test_pcurves.cpp index 7a300dbfe8c..510e1a79171 100644 --- a/src/tests/test_pcurves.cpp +++ b/src/tests/test_pcurves.cpp @@ -156,8 +156,22 @@ class Pcurve_Arithmetic_Tests final : public Test { const auto g_plus_g = g_one + g_one; result.test_eq("2*g == g+g", g_two.to_affine().serialize(), g_plus_g.to_affine().serialize()); - result.confirm("Scalar::zero is zero", curve->scalar_zero().is_zero()); - result.confirm("Scalar::one is not zero", !curve->scalar_one().is_zero()); + result.confirm("Scalar::zero is zero", zero.is_zero()); + result.confirm("(zero+zero) is zero", (zero + zero).is_zero()); + result.confirm("(zero*zero) is zero", (zero * zero).is_zero()); + result.confirm("(zero-zero) is zero", (zero - zero).is_zero()); + + const auto neg_zero = zero.negate(); + result.confirm("zero.negate() is zero", neg_zero.is_zero()); + + result.confirm("(zero+nz) is zero", (zero + neg_zero).is_zero()); + result.confirm("(nz+nz) is zero", (neg_zero + neg_zero).is_zero()); + result.confirm("(nz+zero) is zero", (neg_zero + zero).is_zero()); + + result.confirm("Scalar::one is not zero", !one.is_zero()); + result.confirm("(one-one) is zero", (one - one).is_zero()); + result.confirm("(one+one.negate()) is zero", (one + one.negate()).is_zero()); + result.confirm("(one.negate()+one) is zero", (one.negate() + one).is_zero()); for(size_t i = 0; i != 16; ++i) { const auto pt = curve->mul_by_g(curve->random_scalar(rng), rng).to_affine();