Skip to content

Commit

Permalink
hotfix for large ecntt (#448)
Browse files Browse the repository at this point in the history
 hotfix for large ECNTTs
  • Loading branch information
vhnatyk authored Mar 27, 2024
1 parent 2c1431d commit ef757e8
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions icicle/appUtils/ntt/ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@

#include <mutex>

#define IS_ECNTT std::is_same_v<E, curve_config::projective_t>

namespace ntt {

namespace {

const uint32_t MAX_NUM_THREADS = 512; // TODO: hotfix - should be 1024, currently limits shared memory size
const uint32_t MAX_THREADS_BATCH = 512; // TODO: allows 100% occupancy for scalar NTT for sm_86..sm_89
// TODO: Set MAX THREADS based on GPU arch
const uint32_t MAX_NUM_THREADS = 512; // TODO: hotfix - should be 1024, currently limits shared memory size
const uint32_t MAX_THREADS_BATCH = 512;
const uint32_t MAX_THREADS_BATCH_ECNTT =
256; // TODO: hardcodded - allows (2^18 x 64) ECNTT for sm86, decrease this to allow larger batch or ecntt length
const uint32_t MAX_SHARED_MEM_ELEMENT_SIZE = 32; // TODO: occupancy calculator, hardcoded for sm_86..sm_89
const uint32_t MAX_SHARED_MEM = MAX_SHARED_MEM_ELEMENT_SIZE * MAX_NUM_THREADS;

Expand Down Expand Up @@ -291,7 +295,8 @@ namespace ntt {

bool is_shared_mem_enabled = sizeof(E) <= MAX_SHARED_MEM_ELEMENT_SIZE;
const int log2_shmem_elems = is_shared_mem_enabled ? int(log(int(MAX_SHARED_MEM / sizeof(E))) / log(2)) : logn;
int num_threads = max(min(min(n / 2, MAX_THREADS_BATCH), 1 << (log2_shmem_elems - 1)), 1);
int max_threads_batch = IS_ECNTT ? MAX_THREADS_BATCH_ECNTT : MAX_THREADS_BATCH;
int num_threads = max(min(min(n / 2, max_threads_batch), 1 << (log2_shmem_elems - 1)), 1);
const int chunks = max(int((n / 2) / num_threads), 1);
const int total_tasks = batch_size * chunks;
int num_blocks = total_tasks;
Expand Down Expand Up @@ -651,7 +656,7 @@ namespace ntt {

const bool is_inverse = dir == NTTDir::kInverse;

if constexpr (std::is_same_v<E, curve_config::projective_t>) {
if constexpr (IS_ECNTT) {
CHK_IF_RETURN(ntt::radix2_ntt(
d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset,
coset_index, stream));
Expand Down

0 comments on commit ef757e8

Please sign in to comment.