diff --git a/icicle/appUtils/ntt/ntt.cuh b/icicle/appUtils/ntt/ntt.cuh index ad090fbe8..ed4bbb736 100644 --- a/icicle/appUtils/ntt/ntt.cuh +++ b/icicle/appUtils/ntt/ntt.cuh @@ -306,7 +306,7 @@ void ntt_inplace_batch_template( const int logn = int(log(n) / log(2)); bool is_shared_mem_enabled = sizeof(E) <= MAX_SHARED_MEM_ELEMENT_SIZE; const int log2_shmem_elems = is_shared_mem_enabled ? int(log(int(MAX_SHARED_MEM / sizeof(E))) / log(2)) : logn; - int num_threads = min(min(n / 2, MAX_THREADS_BATCH), 1 << (log2_shmem_elems - 1)); + int num_threads = max(min(min(n / 2, MAX_THREADS_BATCH), 1 << (log2_shmem_elems - 1)), 1); const int chunks = max(int((n / 2) / num_threads), 1); const int total_tasks = batch_size * chunks; int num_blocks = total_tasks; @@ -328,7 +328,7 @@ void ntt_inplace_batch_template( if (is_coset) batch_vector_mult(coset, d_inout, n, batch_size, stream); - num_threads = min(n / 2, MAX_NUM_THREADS); + num_threads = max(min(n / 2, MAX_NUM_THREADS), 1); num_blocks = (n * batch_size + num_threads - 1) / num_threads; template_normalize_kernel <<>>(d_inout, n * batch_size, S::inv_log_size(logn));