Skip to content

Commit

Permalink
update vector ntt and the required twiddle factors
Browse files Browse the repository at this point in the history
  • Loading branch information
Lucas-Wye committed Jan 21, 2025
1 parent b5dccaa commit ec86041
Show file tree
Hide file tree
Showing 8 changed files with 341 additions and 56 deletions.
53 changes: 53 additions & 0 deletions tests/eval/_ntt/gen_vector_ntt_tw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
def gen_tw_for_vector_ntt(l, w_one, prime_p):
n = pow(2, l)
w_power_list = []
m = 2
while m <= n:
w_power = 0
w = 1
w_power_dict = {}
for j in range(m // 2):
k = 0
while k < n:
i_u = k + j
i_t = k + j + m //2
k += m
w_power_dict[i_u] = (i_t, w_power)
w_power += n//m
m = 2 * m
w_power_list.append(w_power_dict)

# print(w_power_list)
perm_each = { }
for i in range(n//2):
perm_each[i] = i
perm_each[i+n//2] = i + n//2
# print("(coe 0, 1), w_power, (permu 0, 1)\n")
print(f"\nfor ntt {n}")
layer_index = 0
for w_power_dict in w_power_list:
print(f"// layer #{layer_index}")
layer_index += 1

# sort_keys = sorted(w_power_dict.keys())
sort_keys = w_power_dict.keys()
index = 0
for w_key in sort_keys:
# print(f"({w_key}, {w_power_dict[w_key][0]}), {w_power_dict[w_key][1]}, ", end = "")
# print(f"({perm_each[w_key]}, {perm_each[w_power_dict[w_key][0]]})")
current_w = pow(w_one, w_power_dict[w_key][1], prime_p)
print(current_w, end = ", ")
perm_each[w_key] = index
perm_each[w_power_dict[w_key][0]] = index + n//2
index += 1

print("\n")

if __name__ == '__main__':
gen_tw_for_vector_ntt(6, 7311, 12289)
gen_tw_for_vector_ntt(7, 12149, 12289)
gen_tw_for_vector_ntt(8, 8340, 12289)
gen_tw_for_vector_ntt(9, 3400, 12289)
gen_tw_for_vector_ntt(10, 10302, 12289)
gen_tw_for_vector_ntt(12, 1331, 12289)

106 changes: 79 additions & 27 deletions tests/eval/_ntt/ntt.c
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#include <assert.h>
#include <stdio.h>

// #define USERN 32
// #define DEBUG

// array is of length n=2^l, p is a prime number
// roots is of length l, where g = roots[0] satisfies that
// g^(2^l) == 1 mod p and g^(2^(l-1)) == -1 mod p
// roots[i] = g^(2^i) (hence roots[l - 1] = -1)
//
// 32bit * n <= VLEN * 8 => n <= VLEN / 4
void ntt(const int *array, int l, const int *twindle, int p, int *dst) {
void ntt(const int *array, int l, const int *twiddle, int p, int *dst) {
// prepare an array of permutation indices
assert(l <= 16);

int n = 1 << l;
int g = twindle[0];

// registers:
// v8-15: array
Expand All @@ -29,7 +31,7 @@ void ntt(const int *array, int l, const int *twindle, int p, int *dst) {
:
: "r"(n));

// prepare the permutation list
// prepare the bit-reversal permutation list
for (int k = 0; 2 * k < l; k++) {
asm("vand.vx v8, v4, %0\n"
"vsub.vv v4, v4, v8\n"
Expand All @@ -46,38 +48,88 @@ void ntt(const int *array, int l, const int *twindle, int p, int *dst) {
: "r"(1 << k), "r"(l - 1 - 2 * k), "r"(1 << (l - k - 1)));
}

// perform bit-reversal for input coefficients
asm("vsetvli zero, %0, e32, m8, tu, mu\n"
"vle32.v v16, 0(%1)\n"
"vrgatherei16.vv v8, v16, v4\n"
"vse32.v v8, 0(%2)\n"

:
: "r"(n), "r"(array), "r"(dst));

// generate permutation list (0, 2, 4, ..., 1, 3, 5, ...)
asm("vsetvli zero, %0, e16, m4, tu, mu\n"
"vid.v v4\n"
"vsrl.vx v0, v4, %1\n" // (0, 0, 0, 0, ..., 1, 1, 1, 1, ...)
"vand.vx v4, v4, %2\n" // (0, 1, 2, 3, ..., 0, 1, 2, 3, ...)
"vsll.vi v4, v4, 1\n"
"vadd.vv v4, v4, v0\n"

// set v16 to all 1
"vxor.vv v16, v16, v16\n"
"vadd.vi v16, v16, 1\n"
:
: "r"(n), "r"(array));
: "r"(n), "r"(l-1), "r"((n / 2 - 1)), "r"(n / 2));

#ifdef DEBUG
int tmp1[USERN];// c
int tmp2[USERN];// c
int tmp3[USERN];// c
#endif

for (int k = 0; k < l; k++) {
asm( // prepare coefficients in v16-23
"vid.v v24\n" // v24-31[i] = i
"vand.vx v24, v24, %1\n" // v24-31[i] = i & (1 << k)
"vmsne.vi v0, v24, 0\n" // vm0[i] = i & (1 << k) != 0
"vmul.vx v16, v16, %2, v0.t\n" // v16-23[i] = w^(???)

// prepare shifted elements in v24-31
"vslideup.vx v24, v8, %3\n" // shift the first 2^(l-k) elements to tail
"vsetvli zero, %3, e32, m8, tu, mu\n" // last n - 2^(l-k) elements
"vslidedown.vx v24, v8, %4\n"

// mul and add
asm(
// "n" mode
"vsetvli zero, %0, e32, m8, tu, mu\n"
"vmul.vv v24, v24, v16\n"
"vrem.vx v24, v24, %5\n"
"vadd.vv v8, v8, v24\n" // TODO: will it overflow?
// load coefficients
"vle32.v v16, 0(%4)\n"
// perform permutation for coefficient
"vrgatherei16.vv v8, v16, v4\n"
// save coefficients
"vse32.v v8, 0(%4)\n"

// "n/2" mode
"vsetvli zero, %1, e32, m4, tu, mu\n"
// load twiddle factors
"vle32.v v16, 0(%2)\n"
// load half coefficients
"vle32.v v8, 0(%4)\n"
"vle32.v v12, 0(%5)\n"

#ifdef DEBUG
"vse32.v v8, 0(%6)\n"// c
"vse32.v v12, 0(%7)\n"// c
"vse32.v v16, 0(%8)\n"// c
#endif

// butterfly operation
"vmul.vv v12, v12, v16\n"
"vrem.vx v12, v12, %3\n"
"vadd.vv v16, v8, v12\n" // TODO: will it overflow?
"vsub.vv v20, v8, v12\n"
// save half coefficients
"vse32.v v16, 0(%4)\n"
"vse32.v v20, 0(%5)\n"
:
: "r"(n), /* %1 */ "r"(1 << k), /* %2 */ "r"(twindle[l - 1 - k]),
/* %3 */ "r"(n - (1 << (l - k))),
/* %4 */ "r"(1 << (l - k)), /* %5 */ "r"(p));
: /* %0 */ "r"(n),
/* %1 */ "r"(n / 2),
/* %2 */ "r"(twiddle + k * (n / 2)),
/* %3 */ "r"(p),
"r"(dst),
"r"(dst + (n / 2))
#ifdef DEBUG
, "r"(tmp1), "r"(tmp2), "r"(tmp3)
#endif
);
#ifdef DEBUG
for(int k = 0; k < USERN; k++) {
printf("(%x, %x, %x)\n", tmp1[k], tmp2[k], tmp3[k]);
}
#endif
}
asm("vse32.v v8, 0(%0)\n" : : "r"(dst));
}
// deal with modular
asm("vsetvli zero, %0, e32, m8, tu, mu\n"
"vle32.v v16, 0(%1)\n"
"vrem.vx v8, v16, %2\n"
"vse32.v v8, 0(%1)\n"

:
: "r"(n), "r"(dst), "r"(p));
}
39 changes: 35 additions & 4 deletions tests/eval/_ntt/ntt_1024_main.c

Large diffs are not rendered by default.

28 changes: 25 additions & 3 deletions tests/eval/_ntt/ntt_128_main.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include <stdio.h>

void ntt(const int *array, int l, const int *twindle, int p, int *dst);
void ntt(const int *array, int l, const int *twiddle, int p, int *dst);

void test() {
const int l = 7;
Expand All @@ -21,10 +21,32 @@ void test() {
9032, 9131, 11715, 6662, 3423, 10027, 5436, 4259, 999, 3316,
11164, 5597, 6578, 800, 8242, 6952, 2288, 1481, 6770, 11948,
8938, 10813, 11107, 1362, 4510, 9388, 8840, 10557};
const int twindle[7] = {12149, 7311, 5860, 4134, 8246, 1479, 12288};
// const int twiddle[7] = {12149, 7311, 5860, 4134, 8246, 1479, 12288};
const int twiddle[] = {
// layer #0
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

// layer #1
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479,

// layer #2
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146,

// layer #3
1, 1, 1, 1, 1, 1, 1, 1, 4134, 4134, 4134, 4134, 4134, 4134, 4134, 4134, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 11567, 11567, 11567, 11567, 11567, 11567, 11567, 11567, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 6553, 6553, 6553, 6553, 6553, 6553, 6553, 6553, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 1305, 1305, 1305, 1305, 1305, 1305, 1305, 1305,

// layer #4
1, 1, 1, 1, 5860, 5860, 5860, 5860, 4134, 4134, 4134, 4134, 3621, 3621, 3621, 3621, 8246, 8246, 8246, 8246, 1212, 1212, 1212, 1212, 11567, 11567, 11567, 11567, 8785, 8785, 8785, 8785, 1479, 1479, 1479, 1479, 3195, 3195, 3195, 3195, 6553, 6553, 6553, 6553, 9744, 9744, 9744, 9744, 5146, 5146, 5146, 5146, 10643, 10643, 10643, 10643, 1305, 1305, 1305, 1305, 3542, 3542, 3542, 3542,

// layer #5
1, 1, 7311, 7311, 5860, 5860, 3006, 3006, 4134, 4134, 5023, 5023, 3621, 3621, 2625, 2625, 8246, 8246, 8961, 8961, 1212, 1212, 563, 563, 11567, 11567, 5728, 5728, 8785, 8785, 4821, 4821, 1479, 1479, 10938, 10938, 3195, 3195, 9545, 9545, 6553, 6553, 6461, 6461, 9744, 9744, 11340, 11340, 5146, 5146, 5777, 5777, 10643, 10643, 9314, 9314, 1305, 1305, 4591, 4591, 3542, 3542, 2639, 2639,

// layer #6
1, 12149, 7311, 8736, 5860, 2963, 3006, 9275, 4134, 11112, 5023, 9542, 3621, 9198, 2625, 1170, 8246, 726, 8961, 11227, 1212, 2366, 563, 7203, 11567, 2768, 5728, 9154, 8785, 11289, 4821, 955, 1479, 1853, 10938, 4805, 3195, 7393, 9545, 3201, 6553, 4255, 6461, 4846, 9744, 12208, 11340, 9970, 5146, 4611, 5777, 2294, 10643, 9238, 9314, 10963, 1305, 1635, 4591, 8577, 3542, 7969, 2639, 11499,
};
const int p = 12289;
int dst[128];
ntt(arr, l, twindle, p, dst);
ntt(arr, l, twiddle, p, dst);

// for (int i = 0; i < n; i++) {
// printf("%d", dst[i]);
Expand Down
31 changes: 28 additions & 3 deletions tests/eval/_ntt/ntt_256_main.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include <stdio.h>

void ntt(const int *array, int l, const int *twindle, int p, int *dst);
void ntt(const int *array, int l, const int *twiddle, int p, int *dst);

void test() {
const int l = 8;
Expand Down Expand Up @@ -34,10 +34,35 @@ void test() {
6270, 4938, 6206, 1003, 596, 11173, 9858, 4825, 7940, 794,
7477, 10146, 7203, 4729, 5741, 4603, 1806, 7034, 8772, 10435,
10777, 1359, 630, 11059, 8005, 225};
const int twindle[8] = {8340, 12149, 7311, 5860, 4134, 8246, 1479, 12288};
// const int twiddle[8] = {8340, 12149, 7311, 5860, 4134, 8246, 1479, 12288};
const int twiddle[] = {
// layer #0
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

// layer #1
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479,

// layer #2
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146,

// layer #3
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4134, 4134, 4134, 4134, 4134, 4134, 4134, 4134, 4134, 4134, 4134, 4134, 4134, 4134, 4134, 4134, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 11567, 11567, 11567, 11567, 11567, 11567, 11567, 11567, 11567, 11567, 11567, 11567, 11567, 11567, 11567, 11567, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 6553, 6553, 6553, 6553, 6553, 6553, 6553, 6553, 6553, 6553, 6553, 6553, 6553, 6553, 6553, 6553, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 1305, 1305, 1305, 1305, 1305, 1305, 1305, 1305, 1305, 1305, 1305, 1305, 1305, 1305, 1305, 1305,

// layer #4
1, 1, 1, 1, 1, 1, 1, 1, 5860, 5860, 5860, 5860, 5860, 5860, 5860, 5860, 4134, 4134, 4134, 4134, 4134, 4134, 4134, 4134, 3621, 3621, 3621, 3621, 3621, 3621, 3621, 3621, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 8246, 1212, 1212, 1212, 1212, 1212, 1212, 1212, 1212, 11567, 11567, 11567, 11567, 11567, 11567, 11567, 11567, 8785, 8785, 8785, 8785, 8785, 8785, 8785, 8785, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 1479, 3195, 3195, 3195, 3195, 3195, 3195, 3195, 3195, 6553, 6553, 6553, 6553, 6553, 6553, 6553, 6553, 9744, 9744, 9744, 9744, 9744, 9744, 9744, 9744, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 5146, 10643, 10643, 10643, 10643, 10643, 10643, 10643, 10643, 1305, 1305, 1305, 1305, 1305, 1305, 1305, 1305, 3542, 3542, 3542, 3542, 3542, 3542, 3542, 3542,

// layer #5
1, 1, 1, 1, 7311, 7311, 7311, 7311, 5860, 5860, 5860, 5860, 3006, 3006, 3006, 3006, 4134, 4134, 4134, 4134, 5023, 5023, 5023, 5023, 3621, 3621, 3621, 3621, 2625, 2625, 2625, 2625, 8246, 8246, 8246, 8246, 8961, 8961, 8961, 8961, 1212, 1212, 1212, 1212, 563, 563, 563, 563, 11567, 11567, 11567, 11567, 5728, 5728, 5728, 5728, 8785, 8785, 8785, 8785, 4821, 4821, 4821, 4821, 1479, 1479, 1479, 1479, 10938, 10938, 10938, 10938, 3195, 3195, 3195, 3195, 9545, 9545, 9545, 9545, 6553, 6553, 6553, 6553, 6461, 6461, 6461, 6461, 9744, 9744, 9744, 9744, 11340, 11340, 11340, 11340, 5146, 5146, 5146, 5146, 5777, 5777, 5777, 5777, 10643, 10643, 10643, 10643, 9314, 9314, 9314, 9314, 1305, 1305, 1305, 1305, 4591, 4591, 4591, 4591, 3542, 3542, 3542, 3542, 2639, 2639, 2639, 2639,

// layer #6
1, 1, 12149, 12149, 7311, 7311, 8736, 8736, 5860, 5860, 2963, 2963, 3006, 3006, 9275, 9275, 4134, 4134, 11112, 11112, 5023, 5023, 9542, 9542, 3621, 3621, 9198, 9198, 2625, 2625, 1170, 1170, 8246, 8246, 726, 726, 8961, 8961, 11227, 11227, 1212, 1212, 2366, 2366, 563, 563, 7203, 7203, 11567, 11567, 2768, 2768, 5728, 5728, 9154, 9154, 8785, 8785, 11289, 11289, 4821, 4821, 955, 955, 1479, 1479, 1853, 1853, 10938, 10938, 4805, 4805, 3195, 3195, 7393, 7393, 9545, 9545, 3201, 3201, 6553, 6553, 4255, 4255, 6461, 6461, 4846, 4846, 9744, 9744, 12208, 12208, 11340, 11340, 9970, 9970, 5146, 5146, 4611, 4611, 5777, 5777, 2294, 2294, 10643, 10643, 9238, 9238, 9314, 9314, 10963, 10963, 1305, 1305, 1635, 1635, 4591, 4591, 8577, 8577, 3542, 3542, 7969, 7969, 2639, 2639, 11499, 11499,

// layer #7
1, 8340, 12149, 12144, 7311, 8011, 8736, 9048, 5860, 11336, 2963, 10530, 3006, 480, 9275, 6534, 4134, 6915, 11112, 2731, 5023, 10908, 9542, 9005, 3621, 5067, 9198, 3382, 2625, 5791, 1170, 334, 8246, 2396, 726, 8652, 8961, 5331, 11227, 3289, 1212, 6522, 2366, 8595, 563, 1022, 7203, 4388, 11567, 130, 2768, 6378, 5728, 4177, 9154, 5092, 8785, 12171, 11289, 4231, 4821, 9821, 955, 1428, 1479, 8993, 1853, 6747, 10938, 1673, 4805, 11560, 3195, 3748, 7393, 3707, 9545, 9447, 3201, 4632, 6553, 2837, 4255, 8357, 6461, 9764, 4846, 9408, 9744, 10092, 12208, 355, 11340, 11745, 9970, 2426, 5146, 4452, 4611, 3459, 5777, 7300, 2294, 10276, 10643, 11462, 9238, 5179, 9314, 12280, 10963, 1260, 1305, 7935, 1635, 7399, 4591, 8705, 8577, 10200, 3542, 9813, 7969, 2548, 2639, 11950, 11499, 10593,
};
const int p = 12289;
int dst[256];
ntt(arr, l, twindle, p, dst);
ntt(arr, l, twiddle, p, dst);

// for (int i = 0; i < n; i++) {
// printf("%d", dst[i]);
Expand Down
Loading

0 comments on commit ec86041

Please sign in to comment.