Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swinitz/improved ntt #629

Merged
merged 11 commits into from
Oct 20, 2024
Merged

Swinitz/improved ntt #629

merged 11 commits into from
Oct 20, 2024

Conversation

ShanieWinitz
Copy link
Contributor

improved ntt

@@ -302,7 +302,7 @@ TYPED_TEST(FieldApiTest, ntt)
int seed = time(0);
srand(seed);
const bool inplace = rand() % 2;
const int logn = rand() % 16 + 3;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we support 2,4 and 8?

uint64_t bit_reverse(uint64_t i, uint32_t logn)
{
uint32_t rev = 0;
for (uint32_t j = 0; j < logn; ++j) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe the following code will be more efficient
rev_idx = ((rev_idx >> 1) & 0x5555555555555555) | ((rev_idx & 0x5555555555555555) << 1); // bit rev single bits
rev_idx = ((rev_idx >> 2) & 0x3333333333333333) | ((rev_idx & 0x3333333333333333) << 2); // bit rev 2 bits chunk
rev_idx = ((rev_idx >> 4) & 0x0F0F0F0F0F0F0F0F) | ((rev_idx & 0x0F0F0F0F0F0F0F0F) << 4); // bit rev 4 bits chunk
rev_idx = ((rev_idx >> 8) & 0x00FF00FF00FF00FF) | ((rev_idx & 0x00FF00FF00FF00FF) << 8); // bit rev 8 bits chunk
rev_idx =
((rev_idx >> 16) & 0x0000FFFF0000FFFF) | ((rev_idx & 0x0000FFFF0000FFFF) << 16); // bit rev 16 bits chunk
rev_idx = (rev_idx >> 32) | (rev_idx << 32); // bit rev 32 bits chunk
rev_idx = rev_idx >> (64 - m_bit_size); // Align rev_idx to the LSB


// Function to decrement the counter for a given task and check if it is ready to execute. if so, return true
bool decrement_counter(NttTaskCoordinates ntt_task_coordinates);
uint32_t get_dependent_subntt_count(uint32_t hierarchy_0_layer_idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add inline

{
return dependent_subntt_count[hierarchy_0_layer_idx];
}
uint32_t get_nof_hierarchy_0_layers() { return nof_hierarchy_0_layers; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inline

* @return `true` if the buffer is full, `false` otherwise.
*/
template <typename S, typename E>
bool NttTasksManager<S, E>::is_full() const
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inline

* @return `true` if the buffer is empty and there are no pending tasks, `false` otherwise.
*/
template <typename S, typename E>
bool NttTasksManager<S, E>::is_empty() const
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inline

* @param index Reference to the index to be incremented.
*/
template <typename S, typename E>
void NttTasksManager<S, E>::increment(size_t& index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inline

*
* Moves the given index to the previous position in the circular buffer, wrapping around if necessary.
*
* @param index Reference to the index to be decremented.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inline

template <typename S, typename E>
eIcicleError NttTasksManager<S, E>::push_task(const NttTaskCoordinates& ntt_task_coordinates)
{
if (is_full()) { return eIcicleError::OUT_OF_MEMORY; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we suppose to reach here? if we do is it a bug in user code or your code?

template <typename S, typename E>
NttTaskCoordinates* NttTasksManager<S, E>::get_slot_for_next_task_coordinates()
{
if (is_full()) { return nullptr; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question

* @return `true` if there are tasks to do, `false` otherwise.
*/
template <typename S, typename E>
bool NttTasksManager<S, E>::tasks_to_do() const
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inline

* @return `true` if there are available tasks, `false` otherwise.
*/
template <typename S, typename E>
bool NttTasksManager<S, E>::available_tasks() const
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inline

bool NttTasksManager<S, E>::handle_completed(NttTask<S, E>* completed_task, uint32_t nof_subntts_l1)
{
bool task_dispatched = false;
NttTaskCoordinates task_c = *completed_task->get_coordinates();
Copy link
Contributor

@mickeyasa mickeyasa Oct 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you avoid copying the coordinates multiple times?
can't you use the corrdinates directly from the task?

if (counters[task_c.hierarchy_1_layer_idx].decrement_counter(task_c)) {
if (task_c.hierarchy_0_layer_idx < nof_hierarchy_0_layers - 1) {
NttTaskCoordinates* next_task_c_ptr = nullptr;
uint32_t nof_new_ready_tasks =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const

(task_c.hierarchy_0_layer_idx == nof_hierarchy_0_layers - 1)
? 1
: counters[task_c.hierarchy_1_layer_idx].get_dependent_subntt_count(task_c.hierarchy_0_layer_idx + 1);
uint32_t stride = nof_subntts_l1 / nof_new_ready_tasks;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const

}
} else {
// Reorder the output
NttTaskCoordinates* next_task_c_ptr = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to init to null

// Update dependencies in counters
if (counters[task_c.hierarchy_1_layer_idx].decrement_counter(task_c)) {
if (task_c.hierarchy_0_layer_idx < nof_hierarchy_0_layers - 1) {
NttTaskCoordinates* next_task_c_ptr = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to reset to null

@@ -97,21 +116,540 @@ namespace ntt_cpu {
s_ntt_domain.coset_index[temp_twiddles[i]] = i;
}
s_ntt_domain.twiddles = std::move(temp_twiddles); // Assign twiddles using unique_ptr

// Winograd 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please write a function per init winograd X

* @method NttSubLogn(uint32_t logn) Initializes the struct based on the given `logn`.
*/
struct NttSubLogn {
uint32_t logn; // Original log_size of the problem
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const

*/
struct NttSubLogn {
uint32_t logn; // Original log_size of the problem
uint64_t size; // Original size of the problem
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const

NttTask() : ntt_data(nullptr) {}

void execute();
NttTaskCoordinates* get_coordinates() const
Copy link
Contributor

@mickeyasa mickeyasa Oct 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inline all short functions

hierarchy_0_cpu_ntt();
} else {
// if all hierarchy_0_subntts are done, and at least 2 layers in hierarchy 0 - reorder the subntt's output
reorder_and_refactor_if_needed();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we do that on the same task?
i.e. reorder and then continue with the same task to execute the ntt.

}
}
}
std::copy(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it inplace?

template <typename S, typename E>
eIcicleError NttTask<S, E>::reorder_and_refactor_if_needed()
{
uint32_t columns_batch_reps = ntt_data->config.columns_batch ? ntt_data->config.batch_size : 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use the algorithm from vec ops?

std::vector<uint32_t> index_in_mem(8);
uint32_t stride = ntt_data->config.columns_batch ? ntt_data->config.batch_size : 1;
for (uint32_t i = 0; i < 8; i++) {
index_in_mem[i] = stride * idx_in_mem(ntt_task_coordinates, i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need to call the idx_in_mem every time?
isn't it fixed?

apply_coset_multiplication(current_elements, index_in_mem, CpuNttDomain<S>::s_ntt_domain.get_twiddles());
}

T = current_elements[index_in_mem[3]] - current_elements[index_in_mem[7]];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we tryed this but, can idx_in_mem be an array of pointers(E*) pointing to the element in current_elements?

? CpuNttDomain<S>::s_ntt_domain.get_winograd8_twiddles()
: CpuNttDomain<S>::s_ntt_domain.get_winograd8_twiddles_inv();

E T;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by T, you mean tmp?

current_elements[index_in_mem[1]] = current_elements[index_in_mem[3]] + T;
current_elements[index_in_mem[3]] = current_elements[index_in_mem[3]] - T;
T = current_elements[index_in_mem[5]] + current_elements[index_in_mem[7]];
current_elements[index_in_mem[5]] = current_elements[index_in_mem[5]] - current_elements[index_in_mem[7]];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we have a -= operator thus avoiding the move?

apply_coset_multiplication(current_elements, index_in_mem, CpuNttDomain<S>::s_ntt_domain.get_twiddles());
}

/* Stage s00 */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not using a for loop?


/* Stage s16 */

current_elements[index_in_mem[0]] = temp_1[0] + temp_1[16];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function can be done inplace avoiding all the move operations

? hierarchy_1_subntt_elements + batch
: hierarchy_1_subntt_elements + batch * original_size;
for (uint32_t elem = 0; elem < hierarchy_0_subntt_size; elem++) {
uint64_t elem_mem_idx = stride * idx_in_mem(ntt_task_coordinates, elem);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to call idx_in_mem every time

for (uint64_t i = 0; i < subntt_size; ++i) {
// rev = NttUtils<S, E>::bit_reverse(i, subntt_log_size);
rev = bit_reverse(i, subntt_log_size);
i_mem_idx = idx_in_mem(ntt_task_coordinates, i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a more efficient way to run without idx_in_mem called every time

uint32_t step = (subntt_size / len) * (CpuNttDomain<S>::s_ntt_domain.get_max_size() >> subntt_size_log);
for (uint32_t i = 0; i < subntt_size; i += len) {
for (uint32_t j = 0; j < half_len; ++j) {
uint64_t u_mem_idx = stride * idx_in_mem(ntt_task_coordinates, i + j);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idx_in_mem is very expensive - can we avoid it?

Copy link
Contributor

@mickeyasa mickeyasa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please fix the performance comment on the next version

@mickeyasa mickeyasa merged commit f93b5eb into main Oct 20, 2024
28 checks passed
@mickeyasa mickeyasa deleted the swinitz/improved_ntt branch October 20, 2024 09:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants