Skip to content

Commit

Permalink
Merge pull request #3956 from randombit/jack/comba-fn
Browse files Browse the repository at this point in the history
Add a generic Comba multiply/square function
  • Loading branch information
randombit authored Apr 2, 2024
2 parents 1e92459 + 67157e1 commit 9da15b9
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 215 deletions.
171 changes: 1 addition & 170 deletions src/lib/math/mp/mp_comba.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* Comba Multiplication and Squaring
*
* This file was automatically generated by ./src/scripts/dev_tools/gen_mp_comba.py on 2024-02-21
* This file was automatically generated by ./src/scripts/dev_tools/gen_mp_comba.py on 2024-03-29
*
* Botan is released under the Simplified BSD License (see license.txt)
*/
Expand Down Expand Up @@ -92,175 +92,6 @@ void bigint_comba_mul4(word z[8], const word x[4], const word y[4]) {
z[7] = w1;
}

/*
* Comba 7x7 Squaring
*/
void bigint_comba_sqr7(word z[14], const word x[7]) {
word w2 = 0, w1 = 0, w0 = 0;

word3_muladd(&w2, &w1, &w0, x[0], x[0]);
z[0] = w0;
w0 = 0;

word3_muladd_2(&w0, &w2, &w1, x[0], x[1]);
z[1] = w1;
w1 = 0;

word3_muladd_2(&w1, &w0, &w2, x[0], x[2]);
word3_muladd(&w1, &w0, &w2, x[1], x[1]);
z[2] = w2;
w2 = 0;

word3_muladd_2(&w2, &w1, &w0, x[0], x[3]);
word3_muladd_2(&w2, &w1, &w0, x[1], x[2]);
z[3] = w0;
w0 = 0;

word3_muladd_2(&w0, &w2, &w1, x[0], x[4]);
word3_muladd_2(&w0, &w2, &w1, x[1], x[3]);
word3_muladd(&w0, &w2, &w1, x[2], x[2]);
z[4] = w1;
w1 = 0;

word3_muladd_2(&w1, &w0, &w2, x[0], x[5]);
word3_muladd_2(&w1, &w0, &w2, x[1], x[4]);
word3_muladd_2(&w1, &w0, &w2, x[2], x[3]);
z[5] = w2;
w2 = 0;

word3_muladd_2(&w2, &w1, &w0, x[0], x[6]);
word3_muladd_2(&w2, &w1, &w0, x[1], x[5]);
word3_muladd_2(&w2, &w1, &w0, x[2], x[4]);
word3_muladd(&w2, &w1, &w0, x[3], x[3]);
z[6] = w0;
w0 = 0;

word3_muladd_2(&w0, &w2, &w1, x[1], x[6]);
word3_muladd_2(&w0, &w2, &w1, x[2], x[5]);
word3_muladd_2(&w0, &w2, &w1, x[3], x[4]);
z[7] = w1;
w1 = 0;

word3_muladd_2(&w1, &w0, &w2, x[2], x[6]);
word3_muladd_2(&w1, &w0, &w2, x[3], x[5]);
word3_muladd(&w1, &w0, &w2, x[4], x[4]);
z[8] = w2;
w2 = 0;

word3_muladd_2(&w2, &w1, &w0, x[3], x[6]);
word3_muladd_2(&w2, &w1, &w0, x[4], x[5]);
z[9] = w0;
w0 = 0;

word3_muladd_2(&w0, &w2, &w1, x[4], x[6]);
word3_muladd(&w0, &w2, &w1, x[5], x[5]);
z[10] = w1;
w1 = 0;

word3_muladd_2(&w1, &w0, &w2, x[5], x[6]);
z[11] = w2;
w2 = 0;

word3_muladd(&w2, &w1, &w0, x[6], x[6]);
z[12] = w0;
z[13] = w1;
}

/*
* Comba 7x7 Multiplication
*/
void bigint_comba_mul7(word z[14], const word x[7], const word y[7]) {
word w2 = 0, w1 = 0, w0 = 0;

word3_muladd(&w2, &w1, &w0, x[0], y[0]);
z[0] = w0;
w0 = 0;

word3_muladd(&w0, &w2, &w1, x[0], y[1]);
word3_muladd(&w0, &w2, &w1, x[1], y[0]);
z[1] = w1;
w1 = 0;

word3_muladd(&w1, &w0, &w2, x[0], y[2]);
word3_muladd(&w1, &w0, &w2, x[1], y[1]);
word3_muladd(&w1, &w0, &w2, x[2], y[0]);
z[2] = w2;
w2 = 0;

word3_muladd(&w2, &w1, &w0, x[0], y[3]);
word3_muladd(&w2, &w1, &w0, x[1], y[2]);
word3_muladd(&w2, &w1, &w0, x[2], y[1]);
word3_muladd(&w2, &w1, &w0, x[3], y[0]);
z[3] = w0;
w0 = 0;

word3_muladd(&w0, &w2, &w1, x[0], y[4]);
word3_muladd(&w0, &w2, &w1, x[1], y[3]);
word3_muladd(&w0, &w2, &w1, x[2], y[2]);
word3_muladd(&w0, &w2, &w1, x[3], y[1]);
word3_muladd(&w0, &w2, &w1, x[4], y[0]);
z[4] = w1;
w1 = 0;

word3_muladd(&w1, &w0, &w2, x[0], y[5]);
word3_muladd(&w1, &w0, &w2, x[1], y[4]);
word3_muladd(&w1, &w0, &w2, x[2], y[3]);
word3_muladd(&w1, &w0, &w2, x[3], y[2]);
word3_muladd(&w1, &w0, &w2, x[4], y[1]);
word3_muladd(&w1, &w0, &w2, x[5], y[0]);
z[5] = w2;
w2 = 0;

word3_muladd(&w2, &w1, &w0, x[0], y[6]);
word3_muladd(&w2, &w1, &w0, x[1], y[5]);
word3_muladd(&w2, &w1, &w0, x[2], y[4]);
word3_muladd(&w2, &w1, &w0, x[3], y[3]);
word3_muladd(&w2, &w1, &w0, x[4], y[2]);
word3_muladd(&w2, &w1, &w0, x[5], y[1]);
word3_muladd(&w2, &w1, &w0, x[6], y[0]);
z[6] = w0;
w0 = 0;

word3_muladd(&w0, &w2, &w1, x[1], y[6]);
word3_muladd(&w0, &w2, &w1, x[2], y[5]);
word3_muladd(&w0, &w2, &w1, x[3], y[4]);
word3_muladd(&w0, &w2, &w1, x[4], y[3]);
word3_muladd(&w0, &w2, &w1, x[5], y[2]);
word3_muladd(&w0, &w2, &w1, x[6], y[1]);
z[7] = w1;
w1 = 0;

word3_muladd(&w1, &w0, &w2, x[2], y[6]);
word3_muladd(&w1, &w0, &w2, x[3], y[5]);
word3_muladd(&w1, &w0, &w2, x[4], y[4]);
word3_muladd(&w1, &w0, &w2, x[5], y[3]);
word3_muladd(&w1, &w0, &w2, x[6], y[2]);
z[8] = w2;
w2 = 0;

word3_muladd(&w2, &w1, &w0, x[3], y[6]);
word3_muladd(&w2, &w1, &w0, x[4], y[5]);
word3_muladd(&w2, &w1, &w0, x[5], y[4]);
word3_muladd(&w2, &w1, &w0, x[6], y[3]);
z[9] = w0;
w0 = 0;

word3_muladd(&w0, &w2, &w1, x[4], y[6]);
word3_muladd(&w0, &w2, &w1, x[5], y[5]);
word3_muladd(&w0, &w2, &w1, x[6], y[4]);
z[10] = w1;
w1 = 0;

word3_muladd(&w1, &w0, &w2, x[5], y[6]);
word3_muladd(&w1, &w0, &w2, x[6], y[5]);
z[11] = w2;
w2 = 0;

word3_muladd(&w2, &w1, &w0, x[6], y[6]);
z[12] = w0;
z[13] = w1;
}

/*
* Comba 6x6 Squaring
*/
Expand Down
43 changes: 40 additions & 3 deletions src/lib/math/mp/mp_core.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
* MPI Algorithms
* (C) 1999-2010,2018 Jack Lloyd
* (C) 1999-2010,2018,2024 Jack Lloyd
* 2006 Luca Piccarreta
* 2016 Matthias Gierlings
*
Expand Down Expand Up @@ -751,20 +751,57 @@ inline constexpr auto bigint_modop_vartime(W n1, W n0, W d) -> W {
return (n0 - z);
}

/*
* Comba Fixed Length Multiplication
*/
template <size_t N, WordType W>
constexpr inline void comba_mul(W z[2 * N], const W x[N], const W y[N]) {
W w2 = 0, w1 = 0, w0 = 0;

for(size_t i = 0; i != 2 * N; ++i) {
const size_t start = i + 1 < N ? 0 : i + 1 - N;
const size_t end = std::min(N, i + 1);

for(size_t j = start; j != end; ++j) {
word3_muladd(&w2, &w1, &w0, x[j], y[i - j]);
}
z[i] = w0;
w0 = w1;
w1 = w2;
w2 = 0;
}
}

template <size_t N, WordType W>
constexpr inline void comba_sqr(W z[2 * N], const W x[N]) {
W w2 = 0, w1 = 0, w0 = 0;

for(size_t i = 0; i != 2 * N; ++i) {
const size_t start = i + 1 < N ? 0 : i + 1 - N;
const size_t end = std::min(N, i + 1);

for(size_t j = start; j != end; ++j) {
word3_muladd(&w2, &w1, &w0, x[j], x[i - j]);
}
z[i] = w0;
w0 = w1;
w1 = w2;
w2 = 0;
}
}

/*
* Comba Multiplication / Squaring
*/
BOTAN_FUZZER_API void bigint_comba_mul4(word z[8], const word x[4], const word y[4]);
BOTAN_FUZZER_API void bigint_comba_mul6(word z[12], const word x[6], const word y[6]);
BOTAN_FUZZER_API void bigint_comba_mul7(word z[14], const word x[7], const word y[7]);
BOTAN_FUZZER_API void bigint_comba_mul8(word z[16], const word x[8], const word y[8]);
BOTAN_FUZZER_API void bigint_comba_mul9(word z[18], const word x[9], const word y[9]);
BOTAN_FUZZER_API void bigint_comba_mul16(word z[32], const word x[16], const word y[16]);
BOTAN_FUZZER_API void bigint_comba_mul24(word z[48], const word x[24], const word y[24]);

BOTAN_FUZZER_API void bigint_comba_sqr4(word out[8], const word in[4]);
BOTAN_FUZZER_API void bigint_comba_sqr6(word out[12], const word in[6]);
BOTAN_FUZZER_API void bigint_comba_sqr7(word out[14], const word in[7]);
BOTAN_FUZZER_API void bigint_comba_sqr8(word out[16], const word in[8]);
BOTAN_FUZZER_API void bigint_comba_sqr9(word out[18], const word in[9]);
BOTAN_FUZZER_API void bigint_comba_sqr16(word out[32], const word in[16]);
Expand Down
4 changes: 0 additions & 4 deletions src/lib/math/mp/mp_karat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,6 @@ void bigint_mul(word z[],
bigint_comba_mul4(z, x, y);
} else if(sized_for_comba_mul<6>(x_sw, x_size, y_sw, y_size, z_size)) {
bigint_comba_mul6(z, x, y);
} else if(sized_for_comba_mul<7>(x_sw, x_size, y_sw, y_size, z_size)) {
bigint_comba_mul7(z, x, y);
} else if(sized_for_comba_mul<8>(x_sw, x_size, y_sw, y_size, z_size)) {
bigint_comba_mul8(z, x, y);
} else if(sized_for_comba_mul<9>(x_sw, x_size, y_sw, y_size, z_size)) {
Expand Down Expand Up @@ -336,8 +334,6 @@ void bigint_sqr(word z[], size_t z_size, const word x[], size_t x_size, size_t x
bigint_comba_sqr4(z, x);
} else if(sized_for_comba_sqr<6>(x_sw, x_size, z_size)) {
bigint_comba_sqr6(z, x);
} else if(sized_for_comba_sqr<7>(x_sw, x_size, z_size)) {
bigint_comba_sqr7(z, x);
} else if(sized_for_comba_sqr<8>(x_sw, x_size, z_size)) {
bigint_comba_sqr8(z, x);
} else if(sized_for_comba_sqr<9>(x_sw, x_size, z_size)) {
Expand Down
38 changes: 2 additions & 36 deletions src/lib/pubkey/curve448/curve448_gf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,51 +148,17 @@ void reduce_after_mul(std::span<uint64_t, WORDS_448> out, std::span<const uint64
reduce_after_add(out, h_1);
}

constexpr size_t words_per_uint64 = 8 / sizeof(word);
static_assert(8 % sizeof(word) == 0); // Greetings to the future

void gf_mul(std::span<uint64_t, WORDS_448> out,
std::span<const uint64_t, WORDS_448> a,
std::span<const uint64_t, WORDS_448> b) {
std::array<uint64_t, 14> ws;
if constexpr(std::same_as<uint64_t, word>) {
// Reinterpret cast to itself to prevent compiler errors on non 64-bit systems
bigint_comba_mul7(reinterpret_cast<word*>(ws.data()),
reinterpret_cast<const word*>(a.data()),
reinterpret_cast<const word*>(b.data()));
} else {
const auto a_arr = load_le<std::array<word, words_per_uint64 * WORDS_448>>(store_le(a));
const auto b_arr = load_le<std::array<word, words_per_uint64 * WORDS_448>>(store_le(b));
auto ws_arr = std::array<word, words_per_uint64 * 14>{};

bigint_mul(ws_arr.data(),
ws_arr.size(),
a_arr.data(),
a_arr.size(),
a_arr.size(),
b_arr.data(),
b_arr.size(),
b_arr.size(),
nullptr,
0);

load_le(ws, store_le(ws_arr));
}
comba_mul<7>(ws.data(), a.data(), b.data());
reduce_after_mul(out, ws);
}

void gf_square(std::span<uint64_t, WORDS_448> out, std::span<const uint64_t, WORDS_448> a) {
std::array<uint64_t, 14> ws;

if constexpr(std::same_as<uint64_t, word>) {
// Reinterpret cast to itself to prevent compiler errors on non 64-bit systems
bigint_comba_sqr7(reinterpret_cast<word*>(ws.data()), reinterpret_cast<const word*>(a.data()));
} else {
const auto a_arr = load_le<std::array<word, words_per_uint64 * WORDS_448>>(store_le(a));
auto ws_arr = std::array<word, words_per_uint64 * 14>{};
bigint_sqr(ws_arr.data(), ws_arr.size(), a_arr.data(), a_arr.size(), a_arr.size(), nullptr, 0);
load_le(ws, store_le(ws_arr));
}
comba_sqr<7>(ws.data(), a.data());
reduce_after_mul(out, ws);
}

Expand Down
1 change: 1 addition & 0 deletions src/lib/pubkey/curve448/curve448_scalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/
#include <botan/internal/curve448_scalar.h>

#include <botan/internal/ct_utils.h>
#include <botan/internal/mp_core.h>

namespace Botan {
Expand Down
1 change: 0 additions & 1 deletion src/lib/pubkey/curve448/ed448/ed448_internal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include <botan/types.h>
#include <botan/internal/ct_utils.h>
#include <botan/internal/loadstor.h>
#include <botan/internal/mp_core.h>
#include <botan/internal/shake_xof.h>
#include <botan/internal/stl_util.h>

Expand Down
2 changes: 1 addition & 1 deletion src/scripts/dev_tools/gen_mp_comba.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def main(args = None):
args = sys.argv

if len(args) <= 1:
sizes = [4, 7, 6, 8, 9, 16, 24]
sizes = [4, 6, 8, 9, 16, 24]
else:
sizes = map(int, args[1:])

Expand Down

0 comments on commit 9da15b9

Please sign in to comment.