Skip to content

Commit

Permalink
Simplify dword logic in mp layer
Browse files Browse the repository at this point in the history
  • Loading branch information
randombit committed Mar 28, 2024
1 parent 5515966 commit 2145ac4
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 91 deletions.
28 changes: 18 additions & 10 deletions src/lib/math/mp/mp_asmi.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,31 @@
#define BOTAN_MP_ASM_INTERNAL_H_

#include <botan/types.h>
#include <botan/internal/mul128.h>

#if BOTAN_MP_WORD_BITS == 64
#include <botan/internal/donna128.h>
#include <botan/internal/mul128.h>
#endif

namespace Botan {

#if(BOTAN_MP_WORD_BITS == 32)
#define BOTAN_MP_DWORD uint64_t
// clang-format off
#if BOTAN_MP_WORD_BITS == 32
typedef uint64_t dword;
#define BOTAN_HAS_NATIVE_DWORD

#elif(BOTAN_MP_WORD_BITS == 64)
#elif BOTAN_MP_WORD_BITS == 64
#if defined(BOTAN_TARGET_HAS_NATIVE_UINT128)
#define BOTAN_MP_DWORD uint128_t
typedef uint128_t dword;
#define BOTAN_HAS_NATIVE_DWORD
#else
// No native 128 bit integer type; use mul64x64_128 instead
typedef donna128 dword;
#endif

#else
#error BOTAN_MP_WORD_BITS must be 32 or 64
#endif
// clang-format on

#if defined(BOTAN_USE_GCC_INLINE_ASM)

Expand Down Expand Up @@ -66,8 +74,8 @@ inline word word_madd2(word a, word b, word* c) {

return a;

#elif defined(BOTAN_MP_DWORD)
const BOTAN_MP_DWORD s = static_cast<BOTAN_MP_DWORD>(a) * b + *c;
#elif defined(BOTAN_HAS_NATIVE_DWORD)
const dword s = static_cast<dword>(a) * b + *c;
*c = static_cast<word>(s >> BOTAN_MP_WORD_BITS);
return static_cast<word>(s);
#else
Expand Down Expand Up @@ -121,8 +129,8 @@ inline word word_madd3(word a, word b, word c, word* d) {

return a;

#elif defined(BOTAN_MP_DWORD)
const BOTAN_MP_DWORD s = static_cast<BOTAN_MP_DWORD>(a) * b + c + *d;
#elif defined(BOTAN_HAS_NATIVE_DWORD)
const dword s = static_cast<dword>(a) * b + c + *d;
*d = static_cast<word>(s >> BOTAN_MP_WORD_BITS);
return static_cast<word>(s);
#else
Expand Down
8 changes: 4 additions & 4 deletions src/lib/math/mp/mp_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -667,8 +667,8 @@ inline word bigint_divop_vartime(word n1, word n0, word d) {
throw Invalid_Argument("bigint_divop_vartime divide by zero");
}

#if defined(BOTAN_MP_DWORD)
return static_cast<word>(((static_cast<BOTAN_MP_DWORD>(n1) << BOTAN_MP_WORD_BITS) | n0) / d);
#if defined(BOTAN_HAS_NATIVE_DWORD)
return static_cast<word>(((static_cast<dword>(n1) << BOTAN_MP_WORD_BITS) | n0) / d);
#else

word high = n1 % d;
Expand Down Expand Up @@ -699,8 +699,8 @@ inline word bigint_modop_vartime(word n1, word n0, word d) {
throw Invalid_Argument("bigint_modop_vartime divide by zero");
}

#if defined(BOTAN_MP_DWORD)
return ((static_cast<BOTAN_MP_DWORD>(n1) << BOTAN_MP_WORD_BITS) | n0) % d;
#if defined(BOTAN_HAS_NATIVE_DWORD)
return ((static_cast<dword>(n1) << BOTAN_MP_WORD_BITS) | n0) % d;
#else
word z = bigint_divop_vartime(n1, n0, d);
word dummy = 0;
Expand Down
80 changes: 41 additions & 39 deletions src/lib/pubkey/curve25519/donna.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,24 @@ inline void fadd_sub(uint64_t x[5], uint64_t y[5]) {
fdifference_backwards(x, tmp); // does x - z
}

const uint64_t MASK_63 = 0x7ffffffffffff;

/* Multiply a number by a scalar: out = in * scalar */
inline void fscalar_product(uint64_t out[5], const uint64_t in[5], const uint64_t scalar) {
uint128_t a = uint128_t(in[0]) * scalar;
out[0] = a & 0x7ffffffffffff;
out[0] = a & MASK_63;

a = uint128_t(in[1]) * scalar + carry_shift(a, 51);
out[1] = a & 0x7ffffffffffff;
out[1] = a & MASK_63;

a = uint128_t(in[2]) * scalar + carry_shift(a, 51);
out[2] = a & 0x7ffffffffffff;
out[2] = a & MASK_63;

a = uint128_t(in[3]) * scalar + carry_shift(a, 51);
out[3] = a & 0x7ffffffffffff;
out[3] = a & MASK_63;

a = uint128_t(in[4]) * scalar + carry_shift(a, 51);
out[4] = a & 0x7ffffffffffff;
out[4] = a & MASK_63;

out[0] += carry_shift(a, 51) * 19;
}
Expand Down Expand Up @@ -139,23 +141,23 @@ inline void fmul(uint64_t out[5], const uint64_t in[5], const uint64_t in2[5]) {
t2 += r4 * s3 + r3 * s4;
t3 += r4 * s4;

r0 = t0 & 0x7ffffffffffff;
r0 = t0 & MASK_63;
t1 += carry_shift(t0, 51);
r1 = t1 & 0x7ffffffffffff;
r1 = t1 & MASK_63;
t2 += carry_shift(t1, 51);
r2 = t2 & 0x7ffffffffffff;
r2 = t2 & MASK_63;
t3 += carry_shift(t2, 51);
r3 = t3 & 0x7ffffffffffff;
r3 = t3 & MASK_63;
t4 += carry_shift(t3, 51);
r4 = t4 & 0x7ffffffffffff;
r4 = t4 & MASK_63;
uint64_t c = carry_shift(t4, 51);

r0 += c * 19;
c = r0 >> 51;
r0 = r0 & 0x7ffffffffffff;
r0 = r0 & MASK_63;
r1 += c;
c = r1 >> 51;
r1 = r1 & 0x7ffffffffffff;
r1 = r1 & MASK_63;
r2 += c;

out[0] = r0;
Expand Down Expand Up @@ -185,23 +187,23 @@ inline void fsquare(uint64_t out[5], const uint64_t in[5], size_t count = 1) {
uint128_t t3 = uint128_t(d0) * r3 + uint128_t(d1) * r2 + uint128_t(r4) * (d419);
uint128_t t4 = uint128_t(d0) * r4 + uint128_t(d1) * r3 + uint128_t(r2) * (r2);

r0 = t0 & 0x7ffffffffffff;
r0 = t0 & MASK_63;
t1 += carry_shift(t0, 51);
r1 = t1 & 0x7ffffffffffff;
r1 = t1 & MASK_63;
t2 += carry_shift(t1, 51);
r2 = t2 & 0x7ffffffffffff;
r2 = t2 & MASK_63;
t3 += carry_shift(t2, 51);
r3 = t3 & 0x7ffffffffffff;
r3 = t3 & MASK_63;
t4 += carry_shift(t3, 51);
r4 = t4 & 0x7ffffffffffff;
r4 = t4 & MASK_63;
uint64_t c = carry_shift(t4, 51);

r0 += c * 19;
c = r0 >> 51;
r0 = r0 & 0x7ffffffffffff;
r0 = r0 & MASK_63;
r1 += c;
c = r1 >> 51;
r1 = r1 & 0x7ffffffffffff;
r1 = r1 & MASK_63;
r2 += c;
}

Expand All @@ -214,11 +216,11 @@ inline void fsquare(uint64_t out[5], const uint64_t in[5], size_t count = 1) {

/* Take a little-endian, 32-byte number and expand it into polynomial form */
inline void fexpand(uint64_t* out, const uint8_t* in) {
out[0] = load_le<uint64_t>(in, 0) & 0x7ffffffffffff;
out[1] = (load_le<uint64_t>(in + 6, 0) >> 3) & 0x7ffffffffffff;
out[2] = (load_le<uint64_t>(in + 12, 0) >> 6) & 0x7ffffffffffff;
out[3] = (load_le<uint64_t>(in + 19, 0) >> 1) & 0x7ffffffffffff;
out[4] = (load_le<uint64_t>(in + 24, 0) >> 12) & 0x7ffffffffffff;
out[0] = load_le<uint64_t>(in, 0) & MASK_63;
out[1] = (load_le<uint64_t>(in + 6, 0) >> 3) & MASK_63;
out[2] = (load_le<uint64_t>(in + 12, 0) >> 6) & MASK_63;
out[3] = (load_le<uint64_t>(in + 19, 0) >> 1) & MASK_63;
out[4] = (load_le<uint64_t>(in + 24, 0) >> 12) & MASK_63;
}

/* Take a fully reduced polynomial form number and contract it into a
Expand All @@ -233,15 +235,15 @@ inline void fcontract(uint8_t* out, const uint64_t input[5]) {

for(size_t i = 0; i != 2; ++i) {
t1 += t0 >> 51;
t0 &= 0x7ffffffffffff;
t0 &= MASK_63;
t2 += t1 >> 51;
t1 &= 0x7ffffffffffff;
t1 &= MASK_63;
t3 += t2 >> 51;
t2 &= 0x7ffffffffffff;
t2 &= MASK_63;
t4 += t3 >> 51;
t3 &= 0x7ffffffffffff;
t3 &= MASK_63;
t0 += (t4 >> 51) * 19;
t4 &= 0x7ffffffffffff;
t4 &= MASK_63;
}

/* now t is between 0 and 2^255-1, properly carried. */
Expand All @@ -250,15 +252,15 @@ inline void fcontract(uint8_t* out, const uint64_t input[5]) {
t0 += 19;

t1 += t0 >> 51;
t0 &= 0x7ffffffffffff;
t0 &= MASK_63;
t2 += t1 >> 51;
t1 &= 0x7ffffffffffff;
t1 &= MASK_63;
t3 += t2 >> 51;
t2 &= 0x7ffffffffffff;
t2 &= MASK_63;
t4 += t3 >> 51;
t3 &= 0x7ffffffffffff;
t3 &= MASK_63;
t0 += (t4 >> 51) * 19;
t4 &= 0x7ffffffffffff;
t4 &= MASK_63;

/* now between 19 and 2^255-1 in both cases, and offset by 19. */

Expand All @@ -271,14 +273,14 @@ inline void fcontract(uint8_t* out, const uint64_t input[5]) {
/* now between 2^255 and 2^256-20, and offset by 2^255. */

t1 += t0 >> 51;
t0 &= 0x7ffffffffffff;
t0 &= MASK_63;
t2 += t1 >> 51;
t1 &= 0x7ffffffffffff;
t1 &= MASK_63;
t3 += t2 >> 51;
t2 &= 0x7ffffffffffff;
t2 &= MASK_63;
t4 += t3 >> 51;
t3 &= 0x7ffffffffffff;
t4 &= 0x7ffffffffffff;
t3 &= MASK_63;
t4 &= MASK_63;

store_le(out,
combine_lower(t0, 0, t1, 51),
Expand Down
49 changes: 30 additions & 19 deletions src/lib/utils/donna128.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,48 +9,51 @@
#define BOTAN_CURVE25519_DONNA128_H_

#include <botan/internal/mul128.h>
#include <type_traits>

namespace Botan {

class donna128 final {
public:
donna128(uint64_t ll = 0, uint64_t hh = 0) {
constexpr donna128(uint64_t ll = 0, uint64_t hh = 0) {
l = ll;
h = hh;
}

donna128(const donna128&) = default;
donna128& operator=(const donna128&) = default;

friend donna128 operator>>(const donna128& x, size_t shift) {
template <typename T>
constexpr friend donna128 operator>>(const donna128& x, T shift) {
donna128 z = x;
if(shift > 0) {
const uint64_t carry = z.h << (64 - shift);
const uint64_t carry = z.h << static_cast<size_t>(64 - shift);
z.h = (z.h >> shift);
z.l = (z.l >> shift) | carry;
}
return z;
}

friend donna128 operator<<(const donna128& x, size_t shift) {
template <typename T>
constexpr friend donna128 operator<<(const donna128& x, T shift) {
donna128 z = x;
if(shift > 0) {
const uint64_t carry = z.l >> (64 - shift);
const uint64_t carry = z.l >> static_cast<size_t>(64 - shift);
z.l = (z.l << shift);
z.h = (z.h << shift) | carry;
}
return z;
}

friend uint64_t operator&(const donna128& x, uint64_t mask) { return x.l & mask; }
constexpr friend uint64_t operator&(const donna128& x, uint64_t mask) { return x.l & mask; }

uint64_t operator&=(uint64_t mask) {
constexpr uint64_t operator&=(uint64_t mask) {
h = 0;
l &= mask;
return l;
}

donna128& operator+=(const donna128& x) {
constexpr donna128& operator+=(const donna128& x) {
l += x.l;
h += x.h;

Expand All @@ -59,54 +62,62 @@ class donna128 final {
return *this;
}

donna128& operator+=(uint64_t x) {
constexpr donna128& operator+=(uint64_t x) {
l += x;
const uint64_t carry = (l < x);
h += carry;
return *this;
}

uint64_t lo() const { return l; }
constexpr uint64_t lo() const { return l; }

uint64_t hi() const { return h; }
constexpr uint64_t hi() const { return h; }

constexpr operator uint64_t() const { return l; }

private:
uint64_t h = 0, l = 0;
};

inline donna128 operator*(const donna128& x, uint64_t y) {
template <std::unsigned_integral T>
constexpr inline donna128 operator*(const donna128& x, T y) {
BOTAN_ARG_CHECK(x.hi() == 0, "High 64 bits of donna128 set to zero during multiply");

uint64_t lo = 0, hi = 0;
mul64x64_128(x.lo(), y, &lo, &hi);
mul64x64_128(x.lo(), static_cast<uint64_t>(y), &lo, &hi);
return donna128(lo, hi);
}

inline donna128 operator*(uint64_t y, const donna128& x) {
template <std::unsigned_integral T>
constexpr inline donna128 operator*(T y, const donna128& x) {
return x * y;
}

inline donna128 operator+(const donna128& x, const donna128& y) {
constexpr inline donna128 operator+(const donna128& x, const donna128& y) {
donna128 z = x;
z += y;
return z;
}

inline donna128 operator+(const donna128& x, uint64_t y) {
constexpr inline donna128 operator+(const donna128& x, uint64_t y) {
donna128 z = x;
z += y;
return z;
}

inline donna128 operator|(const donna128& x, const donna128& y) {
constexpr inline donna128 operator|(const donna128& x, const donna128& y) {
return donna128(x.lo() | y.lo(), x.hi() | y.hi());
}

inline uint64_t carry_shift(const donna128& a, size_t shift) {
constexpr inline donna128 operator|(const donna128& x, uint64_t y) {
return donna128(x.lo() | y, x.hi());
}

constexpr inline uint64_t carry_shift(const donna128& a, size_t shift) {
return (a >> shift).lo();
}

inline uint64_t combine_lower(const donna128& a, size_t s1, const donna128& b, size_t s2) {
constexpr inline uint64_t combine_lower(const donna128& a, size_t s1, const donna128& b, size_t s2) {
donna128 z = (a >> s1) | (b << s2);
return z.lo();
}
Expand Down
Loading

0 comments on commit 2145ac4

Please sign in to comment.