Skip to content

Commit

Permalink
Enable sequence TBE CPU via AVX (pytorch#2195)
Browse files Browse the repository at this point in the history
Summary:

Instead of using the ref implementation for sequence embedding on CPU,
this diff directs TBE to invoke the AVX implementation by forcing
pooling factors of 1 (i.e., passing `at::arange(index_size)` as
offfsets).  The performance gained from using the AVX implementation
offsets the overhead incurs in creating the new offsets.

Differential Revision: D51918878
  • Loading branch information
Sarunya Pumma authored and facebook-github-bot committed Dec 7, 2023
1 parent a75b43f commit c2f64ec
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 83 deletions.
114 changes: 35 additions & 79 deletions fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,13 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
const float* indice_weights_acc = indice_weights.data_ptr<float>();
{% endif %}

using float16 = uint16_t;
using bfloat16 = uint16_t;
using fbgemm_out_t = typename std::conditional<
std::is_same<output_t, at::Half>::value,
float16,
std::conditional<std::is_same<output_t, at::BFloat16>::value, bfloat16, float>::type >::type;

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_", [&] {
const auto* indices_acc = indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();
Expand All @@ -208,10 +215,9 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{

auto* output_acc = output.data_ptr<output_t>();
int32_t num_indices_m_1 = indices.numel() - 1;

int32_t D_start_ = 0;
for (const auto t : c10::irange(T)) {

for (const auto t : c10::irange(T)) {
{% if not nobag %}
const auto* D_offsets_acc = D_offsets.data_ptr<int32_t>();
const int32_t D_start = D_offsets_acc[t];
Expand All @@ -226,27 +232,27 @@ for (const auto t : c10::irange(T)) {
const auto& weight_tensor = (placement == PlacementType::HOST) ? dev_weights : uvm_weights;
weights_acc = weight_tensor.data_ptr<uint8_t>();
const uint8_t* weights = &weights_acc[weights_offsets_acc[t]];
auto weight_ty = static_cast<SparseType>(weights_tys_acc[t]);
const auto weight_ty = static_cast<SparseType>(weights_tys_acc[t]);
// default to 1 byte alignment for CPU TBE
const int32_t D_bytes = nbit::padded_row_size_in_bytes(D, weight_ty, row_alignment);

int tt;
for (tt = t + 1; tt < T && weights_offsets_acc[tt] == weights_offsets_acc[t]; ++tt);
size_t num_rows = ((tt == T ? weight_tensor.numel() : weights_offsets_acc[tt]) - weights_offsets_acc[t]) / D_bytes;
const size_t num_rows = ((tt == T ? weight_tensor.numel() : weights_offsets_acc[tt]) - weights_offsets_acc[t]) / D_bytes;
const index_t* offsets_begin_ptr = offsets_acc + t * B;

using float16 = uint16_t;
using bfloat16 = uint16_t;
using fbgemm_out_t = typename std::conditional<
std::is_same<output_t, at::Half>::value,
float16,
std::conditional<std::is_same<output_t, at::BFloat16>::value, bfloat16, float>::type >::type;

bool success = true;
bool has_weight = {{ "true" if weighted else "false" }};
bool normalize_by_lengths = static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN;
const bool has_weight = {{ "true" if weighted else "false" }};
const bool normalize_by_lengths = static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN;

const index_t index_size = offsets_acc[(t + 1) * B] - *offsets_begin_ptr;

{% if nobag %}
// Create virtual offsets for the nobag case. Lengths are all ones.
const auto offsets_nobag = at::arange(0, index_size, offsets.options());
const index_t* offsets_nobag_ptr = offsets_nobag.data_ptr<index_t>();
{% endif %}

index_t index_size = offsets_acc[(t + 1) * B] - *offsets_begin_ptr;
const float* indice_weights_ptr = nullptr;
{% if weighted %}
indice_weights_ptr = indice_weights_acc + *offsets_begin_ptr;
Expand All @@ -259,32 +265,18 @@ for (const auto t : c10::irange(T)) {
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
{% if not nobag %}
/*output_stride=*/total_D,
{% else %}
/*output_stride=*/D,
{% endif %}
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(float),
{% if not nobag %}
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
{% else %}
/*scale_bias_last=*/false,
/*no_bag=*/true,
/*is_bf16_out=*/output_is_bf16);
{% endif %}
success = kernel(
{% if not nobag %}
B,
{% else %}
index_size,
{% endif %}
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
reinterpret_cast<const float*>(weights),
indices_acc + *offsets_begin_ptr,
offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
} else if (weight_ty == SparseType::FP16) {
Expand All @@ -295,32 +287,18 @@ for (const auto t : c10::irange(T)) {
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
{% if not nobag %}
/*output_stride=*/total_D,
{% else %}
/*output_stride=*/D,
{% endif %}
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(float16),
{% if not nobag %}
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
{% else %}
/*scale_bias_last=*/false,
/*no_bag=*/true,
/*is_bf16_out=*/output_is_bf16);
{% endif %}
success = kernel(
{% if not nobag %}
B,
{% else %}
index_size,
{% endif %}
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
reinterpret_cast<const float16*>(weights),
indices_acc + *offsets_begin_ptr,
offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
} else if (weight_ty == SparseType::FP8) {
Expand All @@ -330,22 +308,18 @@ for (const auto t : c10::irange(T)) {
normalize_by_lengths,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
{% if not nobag %}
/*output_stride=*/total_D,
{% else %}
/*output_stride=*/D,
{% endif %}
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(uint8_t),
/*exponent_bits=*/fp8_exponent_bits,
/*exponent_bias=*/fp8_exponent_bias,
/*is_bf16_out=*/output_is_bf16);
success = kernel(
B,
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
weights,
indices_acc + *offsets_begin_ptr,
offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
} else if (weight_ty == SparseType::INT8) {
Expand All @@ -356,32 +330,18 @@ for (const auto t : c10::irange(T)) {
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
{% if not nobag %}
/*output_stride=*/total_D,
{% else %}
/*output_stride=*/D,
{% endif %}
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(uint8_t),
{% if not nobag %}
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
{% else %}
/*scale_bias_last=*/false,
/*no_bag=*/true,
/*is_bf16_out=*/output_is_bf16);
{% endif %}
success = kernel(
{% if not nobag %}
B,
{% else %}
index_size,
{% endif %}
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
weights,
indices_acc + *offsets_begin_ptr,
offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
} else if (weight_ty == SparseType::INT4 || weight_ty == SparseType::INT2) {
Expand All @@ -404,21 +364,17 @@ for (const auto t : c10::irange(T)) {
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
{% if not nobag %}
/*output_stride=*/total_D,
{% else %}
/*output_stride=*/D,
{% endif %}
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(uint8_t),
/*scale_bias_last=*/false,
/*is_bf16_out=*/output_is_bf16);
success = kernel(
B,
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
weights,
indices_acc + *offsets_begin_ptr,
offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
} else {
Expand Down
7 changes: 3 additions & 4 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4562,11 +4562,10 @@ def test_nbit_forward_cpu(
[
SparseType.FP32,
SparseType.FP16,
# CPU sequence embedding does not support FP8/INT4/INT2 yet
# SparseType.FP8,
SparseType.FP8,
SparseType.INT8,
# SparseType.INT4,
# SparseType.INT2,
SparseType.INT4,
SparseType.INT2,
]
)

Expand Down

0 comments on commit c2f64ec

Please sign in to comment.