Skip to content

Commit

Permalink
parallelize pruned_array_lookup_cpu (#1904)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1904

as titled, parallelize pruned_array_lookup_cpu across multiple threads.

Differential Revision: D47923801

fbshipit-source-id: ae22deab752de2996476c1b46b077420e8808c15
  • Loading branch information
Feixiong Zhang authored and facebook-github-bot committed Aug 2, 2023
1 parent 410d264 commit 73eb2fb
Showing 1 changed file with 21 additions and 18 deletions.
39 changes: 21 additions & 18 deletions fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <ATen/ATen.h>
#include <ATen/Context.h>
#include <ATen/Parallel.h>

#include "fbgemm_gpu/cpu_utils.h"
#include "fbgemm_gpu/dispatch_macros.h"
Expand Down Expand Up @@ -524,24 +525,26 @@ Tensor pruned_array_lookup_cpu(

const auto index_remappings_acc = index_remappings.data_ptr<int32_t>();
const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr<int64_t>();
for (const auto t : c10::irange(T)) {
int64_t index_remappings_start = index_remappings_offsets_acc[t];
int64_t index_remappings_end = index_remappings_offsets_acc[t + 1];
int64_t capacity = index_remappings_end - index_remappings_start;
int32_t indices_start = offsets_acc[t * B];
int32_t indices_end = offsets_acc[(t + 1) * B];
if (capacity > 0) {
for (const auto i : c10::irange(indices_start,indices_end)) {
int32_t idx = indices_acc[i];
dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx];
}
} else {
std::memcpy(
dense_indices_acc + indices_start,
indices_acc + indices_start,
(indices_end - indices_start) * sizeof(int32_t));
}
}
at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) {
for (const auto t : c10::irange(begin, end)) {
int64_t index_remappings_start = index_remappings_offsets_acc[t];
int64_t index_remappings_end = index_remappings_offsets_acc[t + 1];
int64_t capacity = index_remappings_end - index_remappings_start;
int32_t indices_start = offsets_acc[t * B];
int32_t indices_end = offsets_acc[(t + 1) * B];
if (capacity > 0) {
for (const auto i : c10::irange(indices_start,indices_end)) {
int32_t idx = indices_acc[i];
dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx];
}
} else {
std::memcpy(
dense_indices_acc + indices_start,
indices_acc + indices_start,
(indices_end - indices_start) * sizeof(int32_t));
}
}
});
return dense_indices;
}

Expand Down

0 comments on commit 73eb2fb

Please sign in to comment.