-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Swinitz/improved ntt #629
Swinitz/improved ntt #629
Conversation
…fix inverse && coset
@@ -302,7 +302,7 @@ TYPED_TEST(FieldApiTest, ntt) | |||
int seed = time(0); | |||
srand(seed); | |||
const bool inplace = rand() % 2; | |||
const int logn = rand() % 16 + 3; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we support 2,4 and 8?
uint64_t bit_reverse(uint64_t i, uint32_t logn) | ||
{ | ||
uint32_t rev = 0; | ||
for (uint32_t j = 0; j < logn; ++j) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe the following code will be more efficient
rev_idx = ((rev_idx >> 1) & 0x5555555555555555) | ((rev_idx & 0x5555555555555555) << 1); // bit rev single bits
rev_idx = ((rev_idx >> 2) & 0x3333333333333333) | ((rev_idx & 0x3333333333333333) << 2); // bit rev 2 bits chunk
rev_idx = ((rev_idx >> 4) & 0x0F0F0F0F0F0F0F0F) | ((rev_idx & 0x0F0F0F0F0F0F0F0F) << 4); // bit rev 4 bits chunk
rev_idx = ((rev_idx >> 8) & 0x00FF00FF00FF00FF) | ((rev_idx & 0x00FF00FF00FF00FF) << 8); // bit rev 8 bits chunk
rev_idx =
((rev_idx >> 16) & 0x0000FFFF0000FFFF) | ((rev_idx & 0x0000FFFF0000FFFF) << 16); // bit rev 16 bits chunk
rev_idx = (rev_idx >> 32) | (rev_idx << 32); // bit rev 32 bits chunk
rev_idx = rev_idx >> (64 - m_bit_size); // Align rev_idx to the LSB
|
||
// Function to decrement the counter for a given task and check if it is ready to execute. if so, return true | ||
bool decrement_counter(NttTaskCoordinates ntt_task_coordinates); | ||
uint32_t get_dependent_subntt_count(uint32_t hierarchy_0_layer_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add inline
{ | ||
return dependent_subntt_count[hierarchy_0_layer_idx]; | ||
} | ||
uint32_t get_nof_hierarchy_0_layers() { return nof_hierarchy_0_layers; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inline
* @return `true` if the buffer is full, `false` otherwise. | ||
*/ | ||
template <typename S, typename E> | ||
bool NttTasksManager<S, E>::is_full() const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inline
* @return `true` if the buffer is empty and there are no pending tasks, `false` otherwise. | ||
*/ | ||
template <typename S, typename E> | ||
bool NttTasksManager<S, E>::is_empty() const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inline
* @param index Reference to the index to be incremented. | ||
*/ | ||
template <typename S, typename E> | ||
void NttTasksManager<S, E>::increment(size_t& index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inline
* | ||
* Moves the given index to the previous position in the circular buffer, wrapping around if necessary. | ||
* | ||
* @param index Reference to the index to be decremented. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inline
template <typename S, typename E> | ||
eIcicleError NttTasksManager<S, E>::push_task(const NttTaskCoordinates& ntt_task_coordinates) | ||
{ | ||
if (is_full()) { return eIcicleError::OUT_OF_MEMORY; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we suppose to reach here? if we do is it a bug in user code or your code?
template <typename S, typename E> | ||
NttTaskCoordinates* NttTasksManager<S, E>::get_slot_for_next_task_coordinates() | ||
{ | ||
if (is_full()) { return nullptr; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same question
* @return `true` if there are tasks to do, `false` otherwise. | ||
*/ | ||
template <typename S, typename E> | ||
bool NttTasksManager<S, E>::tasks_to_do() const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inline
* @return `true` if there are available tasks, `false` otherwise. | ||
*/ | ||
template <typename S, typename E> | ||
bool NttTasksManager<S, E>::available_tasks() const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inline
bool NttTasksManager<S, E>::handle_completed(NttTask<S, E>* completed_task, uint32_t nof_subntts_l1) | ||
{ | ||
bool task_dispatched = false; | ||
NttTaskCoordinates task_c = *completed_task->get_coordinates(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you avoid copying the coordinates multiple times?
can't you use the corrdinates directly from the task?
if (counters[task_c.hierarchy_1_layer_idx].decrement_counter(task_c)) { | ||
if (task_c.hierarchy_0_layer_idx < nof_hierarchy_0_layers - 1) { | ||
NttTaskCoordinates* next_task_c_ptr = nullptr; | ||
uint32_t nof_new_ready_tasks = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const
(task_c.hierarchy_0_layer_idx == nof_hierarchy_0_layers - 1) | ||
? 1 | ||
: counters[task_c.hierarchy_1_layer_idx].get_dependent_subntt_count(task_c.hierarchy_0_layer_idx + 1); | ||
uint32_t stride = nof_subntts_l1 / nof_new_ready_tasks; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const
} | ||
} else { | ||
// Reorder the output | ||
NttTaskCoordinates* next_task_c_ptr = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to init to null
// Update dependencies in counters | ||
if (counters[task_c.hierarchy_1_layer_idx].decrement_counter(task_c)) { | ||
if (task_c.hierarchy_0_layer_idx < nof_hierarchy_0_layers - 1) { | ||
NttTaskCoordinates* next_task_c_ptr = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to reset to null
@@ -97,21 +116,540 @@ namespace ntt_cpu { | |||
s_ntt_domain.coset_index[temp_twiddles[i]] = i; | |||
} | |||
s_ntt_domain.twiddles = std::move(temp_twiddles); // Assign twiddles using unique_ptr | |||
|
|||
// Winograd 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please write a function per init winograd X
* @method NttSubLogn(uint32_t logn) Initializes the struct based on the given `logn`. | ||
*/ | ||
struct NttSubLogn { | ||
uint32_t logn; // Original log_size of the problem |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const
*/ | ||
struct NttSubLogn { | ||
uint32_t logn; // Original log_size of the problem | ||
uint64_t size; // Original size of the problem |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const
NttTask() : ntt_data(nullptr) {} | ||
|
||
void execute(); | ||
NttTaskCoordinates* get_coordinates() const |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inline all short functions
hierarchy_0_cpu_ntt(); | ||
} else { | ||
// if all hierarchy_0_subntts are done, and at least 2 layers in hierarchy 0 - reorder the subntt's output | ||
reorder_and_refactor_if_needed(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't we do that on the same task?
i.e. reorder and then continue with the same task to execute the ntt.
} | ||
} | ||
} | ||
std::copy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't it inplace?
template <typename S, typename E> | ||
eIcicleError NttTask<S, E>::reorder_and_refactor_if_needed() | ||
{ | ||
uint32_t columns_batch_reps = ntt_data->config.columns_batch ? ntt_data->config.batch_size : 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use the algorithm from vec ops?
std::vector<uint32_t> index_in_mem(8); | ||
uint32_t stride = ntt_data->config.columns_batch ? ntt_data->config.batch_size : 1; | ||
for (uint32_t i = 0; i < 8; i++) { | ||
index_in_mem[i] = stride * idx_in_mem(ntt_task_coordinates, i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we really need to call the idx_in_mem every time?
isn't it fixed?
apply_coset_multiplication(current_elements, index_in_mem, CpuNttDomain<S>::s_ntt_domain.get_twiddles()); | ||
} | ||
|
||
T = current_elements[index_in_mem[3]] - current_elements[index_in_mem[7]]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if we tryed this but, can idx_in_mem be an array of pointers(E*) pointing to the element in current_elements?
? CpuNttDomain<S>::s_ntt_domain.get_winograd8_twiddles() | ||
: CpuNttDomain<S>::s_ntt_domain.get_winograd8_twiddles_inv(); | ||
|
||
E T; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by T, you mean tmp?
current_elements[index_in_mem[1]] = current_elements[index_in_mem[3]] + T; | ||
current_elements[index_in_mem[3]] = current_elements[index_in_mem[3]] - T; | ||
T = current_elements[index_in_mem[5]] + current_elements[index_in_mem[7]]; | ||
current_elements[index_in_mem[5]] = current_elements[index_in_mem[5]] - current_elements[index_in_mem[7]]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't we have a -= operator thus avoiding the move?
apply_coset_multiplication(current_elements, index_in_mem, CpuNttDomain<S>::s_ntt_domain.get_twiddles()); | ||
} | ||
|
||
/* Stage s00 */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not using a for loop?
|
||
/* Stage s16 */ | ||
|
||
current_elements[index_in_mem[0]] = temp_1[0] + temp_1[16]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function can be done inplace avoiding all the move operations
? hierarchy_1_subntt_elements + batch | ||
: hierarchy_1_subntt_elements + batch * original_size; | ||
for (uint32_t elem = 0; elem < hierarchy_0_subntt_size; elem++) { | ||
uint64_t elem_mem_idx = stride * idx_in_mem(ntt_task_coordinates, elem); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to call idx_in_mem every time
for (uint64_t i = 0; i < subntt_size; ++i) { | ||
// rev = NttUtils<S, E>::bit_reverse(i, subntt_log_size); | ||
rev = bit_reverse(i, subntt_log_size); | ||
i_mem_idx = idx_in_mem(ntt_task_coordinates, i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a more efficient way to run without idx_in_mem called every time
uint32_t step = (subntt_size / len) * (CpuNttDomain<S>::s_ntt_domain.get_max_size() >> subntt_size_log); | ||
for (uint32_t i = 0; i < subntt_size; i += len) { | ||
for (uint32_t j = 0; j < half_len; ++j) { | ||
uint64_t u_mem_idx = stride * idx_in_mem(ntt_task_coordinates, i + j); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idx_in_mem is very expensive - can we avoid it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please fix the performance comment on the next version
improved ntt