Skip to content

Commit

Permalink
Enable sequence TBE CPU via AVX (#2195)
Browse files Browse the repository at this point in the history
Summary:

Instead of using the ref implementation for sequence embedding on CPU,
this diff directs TBE to invoke the AVX implementation of pooled TBE
by forcing pooling factors of 1 (i.e., passing `at::arange(index_size +
1)` as offfsets).  The performance gained from using the AVX
implementation offsets the overhead incurs in creating the new
offsets.

Reviewed By: jspark1105

Differential Revision: D51918878
  • Loading branch information
Sarunya Pumma authored and facebook-github-bot committed Dec 7, 2023
1 parent f8def44 commit c11f5c7
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 181 deletions.
116 changes: 37 additions & 79 deletions fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,13 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
const float* indice_weights_acc = indice_weights.data_ptr<float>();
{% endif %}

using float16 = uint16_t;
using bfloat16 = uint16_t;
using fbgemm_out_t = typename std::conditional<
std::is_same<output_t, at::Half>::value,
float16,
std::conditional<std::is_same<output_t, at::BFloat16>::value, bfloat16, float>::type >::type;

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_", [&] {
const auto* indices_acc = indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();
Expand All @@ -208,10 +215,9 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{

auto* output_acc = output.data_ptr<output_t>();
int32_t num_indices_m_1 = indices.numel() - 1;

int32_t D_start_ = 0;
for (const auto t : c10::irange(T)) {

for (const auto t : c10::irange(T)) {
{% if not nobag %}
const auto* D_offsets_acc = D_offsets.data_ptr<int32_t>();
const int32_t D_start = D_offsets_acc[t];
Expand All @@ -226,27 +232,29 @@ for (const auto t : c10::irange(T)) {
const auto& weight_tensor = (placement == PlacementType::HOST) ? dev_weights : uvm_weights;
weights_acc = weight_tensor.data_ptr<uint8_t>();
const uint8_t* weights = &weights_acc[weights_offsets_acc[t]];
auto weight_ty = static_cast<SparseType>(weights_tys_acc[t]);
const auto weight_ty = static_cast<SparseType>(weights_tys_acc[t]);
// default to 1 byte alignment for CPU TBE
const int32_t D_bytes = nbit::padded_row_size_in_bytes(D, weight_ty, row_alignment);

int tt;
for (tt = t + 1; tt < T && weights_offsets_acc[tt] == weights_offsets_acc[t]; ++tt);
size_t num_rows = ((tt == T ? weight_tensor.numel() : weights_offsets_acc[tt]) - weights_offsets_acc[t]) / D_bytes;
const size_t num_rows = ((tt == T ? weight_tensor.numel() : weights_offsets_acc[tt]) - weights_offsets_acc[t]) / D_bytes;
const index_t* offsets_begin_ptr = offsets_acc + t * B;

using float16 = uint16_t;
using bfloat16 = uint16_t;
using fbgemm_out_t = typename std::conditional<
std::is_same<output_t, at::Half>::value,
float16,
std::conditional<std::is_same<output_t, at::BFloat16>::value, bfloat16, float>::type >::type;

bool success = true;
bool has_weight = {{ "true" if weighted else "false" }};
bool normalize_by_lengths = static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN;
const bool has_weight = {{ "true" if weighted else "false" }};
const bool normalize_by_lengths = static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN;

const index_t index_size = offsets_acc[(t + 1) * B] - *offsets_begin_ptr;

{% if nobag %}
// Create virtual offsets for the nobag case. Lengths are all ones.
const auto offsets_nobag = at::arange(*offsets_begin_ptr, offsets_acc[(t + 1) * B] + 1, offsets.options());
const index_t* offsets_nobag_ptr = offsets_nobag.data_ptr<index_t>();
TORCH_CHECK(offsets_nobag.numel() == index_size + 1);
TORCH_CHECK(offsets_nobag_ptr[index_size] - offsets_nobag_ptr[0] == index_size);
{% endif %}

index_t index_size = offsets_acc[(t + 1) * B] - *offsets_begin_ptr;
const float* indice_weights_ptr = nullptr;
{% if weighted %}
indice_weights_ptr = indice_weights_acc + *offsets_begin_ptr;
Expand All @@ -259,32 +267,18 @@ for (const auto t : c10::irange(T)) {
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
{% if not nobag %}
/*output_stride=*/total_D,
{% else %}
/*output_stride=*/D,
{% endif %}
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(float),
{% if not nobag %}
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
{% else %}
/*scale_bias_last=*/false,
/*no_bag=*/true,
/*is_bf16_out=*/output_is_bf16);
{% endif %}
success = kernel(
{% if not nobag %}
B,
{% else %}
index_size,
{% endif %}
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
reinterpret_cast<const float*>(weights),
indices_acc + *offsets_begin_ptr,
offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
} else if (weight_ty == SparseType::FP16) {
Expand All @@ -295,32 +289,18 @@ for (const auto t : c10::irange(T)) {
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
{% if not nobag %}
/*output_stride=*/total_D,
{% else %}
/*output_stride=*/D,
{% endif %}
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(float16),
{% if not nobag %}
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
{% else %}
/*scale_bias_last=*/false,
/*no_bag=*/true,
/*is_bf16_out=*/output_is_bf16);
{% endif %}
success = kernel(
{% if not nobag %}
B,
{% else %}
index_size,
{% endif %}
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
reinterpret_cast<const float16*>(weights),
indices_acc + *offsets_begin_ptr,
offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
} else if (weight_ty == SparseType::FP8) {
Expand All @@ -330,22 +310,18 @@ for (const auto t : c10::irange(T)) {
normalize_by_lengths,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
{% if not nobag %}
/*output_stride=*/total_D,
{% else %}
/*output_stride=*/D,
{% endif %}
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(uint8_t),
/*exponent_bits=*/fp8_exponent_bits,
/*exponent_bias=*/fp8_exponent_bias,
/*is_bf16_out=*/output_is_bf16);
success = kernel(
B,
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
weights,
indices_acc + *offsets_begin_ptr,
offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
} else if (weight_ty == SparseType::INT8) {
Expand All @@ -356,32 +332,18 @@ for (const auto t : c10::irange(T)) {
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
{% if not nobag %}
/*output_stride=*/total_D,
{% else %}
/*output_stride=*/D,
{% endif %}
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(uint8_t),
{% if not nobag %}
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
{% else %}
/*scale_bias_last=*/false,
/*no_bag=*/true,
/*is_bf16_out=*/output_is_bf16);
{% endif %}
success = kernel(
{% if not nobag %}
B,
{% else %}
index_size,
{% endif %}
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
weights,
indices_acc + *offsets_begin_ptr,
offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
} else if (weight_ty == SparseType::INT4 || weight_ty == SparseType::INT2) {
Expand All @@ -404,21 +366,17 @@ for (const auto t : c10::irange(T)) {
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
{% if not nobag %}
/*output_stride=*/total_D,
{% else %}
/*output_stride=*/D,
{% endif %}
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(uint8_t),
/*scale_bias_last=*/false,
/*is_bf16_out=*/output_is_bf16);
success = kernel(
B,
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
weights,
indices_acc + *offsets_begin_ptr,
offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
} else {
Expand Down
8 changes: 0 additions & 8 deletions fbgemm_gpu/test/failures_dict_fast.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_cpu_bf16_out": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_gpu_no_cache": {
"comment": "",
"status": "xfail"
Expand Down Expand Up @@ -75,10 +71,6 @@
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_cpu_bf16_out": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_gpu_no_cache": {
"comment": "",
"status": "xfail"
Expand Down
102 changes: 8 additions & 94 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4526,6 +4526,12 @@ def execute_nbit_forward_( # noqa C901
nbit_weights_ty=get_nbit_weights_ty(),
use_array_for_index_remapping=st.booleans(),
do_pruning=st.booleans(),
pooling_mode=st.sampled_from(
[PoolingMode.SUM, PoolingMode.NONE, PoolingMode.MEAN]
),
output_dtype=st.sampled_from(
[SparseType.FP32, SparseType.FP16, SparseType.BF16]
),
)
@settings(
verbosity=VERBOSITY,
Expand All @@ -4537,6 +4543,8 @@ def test_nbit_forward_cpu(
nbit_weights_ty: Optional[SparseType],
use_array_for_index_remapping: bool,
do_pruning: bool,
pooling_mode: PoolingMode,
output_dtype: SparseType,
) -> None:
use_cpu = True
T = random.randint(1, 50)
Expand All @@ -4549,27 +4557,7 @@ def test_nbit_forward_cpu(
# cache_algorithm is don't care as we don't use cache.
cache_algorithm = CacheAlgorithm.LRU

pooling_mode = random.choice(
[
PoolingMode.SUM,
PoolingMode.MEAN,
PoolingMode.NONE,
]
)
mixed = random.choice([True, False])
if pooling_mode == PoolingMode.NONE:
nbit_weights_ty = random.choice(
[
SparseType.FP32,
SparseType.FP16,
# CPU sequence embedding does not support FP8/INT4/INT2 yet
# SparseType.FP8,
SparseType.INT8,
# SparseType.INT4,
# SparseType.INT2,
]
)

if pooling_mode == PoolingMode.SUM:
weighted = random.choice([True, False])
else:
Expand All @@ -4582,81 +4570,7 @@ def test_nbit_forward_cpu(
else:
weights_ty: SparseType = nbit_weights_ty
mixed_weights_ty = False
output_dtype = random.choice(
(
[SparseType.BF16]
if weights_ty in [SparseType.INT4, SparseType.INT2]
else []
)
+ [SparseType.FP32, SparseType.FP16]
)
self.execute_nbit_forward_(
T,
D,
B,
log_E,
L,
weighted,
mixed,
pooling_mode,
weights_ty,
use_cache,
cache_algorithm,
use_cpu,
use_array_for_index_remapping,
do_pruning,
mixed_weights_ty,
output_dtype,
)

@given(
nbit_weights_ty=get_nbit_weights_ty(),
use_array_for_index_remapping=st.booleans(),
do_pruning=st.booleans(),
)
@settings(
verbosity=VERBOSITY,
max_examples=MAX_EXAMPLES_LONG_RUNNING,
deadline=None,
)
def test_nbit_forward_cpu_bf16_out(
self,
nbit_weights_ty: Optional[SparseType],
use_array_for_index_remapping: bool,
do_pruning: bool,
) -> None:
use_cpu = True
T = random.randint(1, 50)
B = random.randint(0, 128)
L = random.randint(0, 32)
D = random.randint(2, 2048)
log_E = random.randint(2, 4)

use_cache = False
# cache_algorithm is don't care as we don't use cache.
cache_algorithm = CacheAlgorithm.LRU

pooling_mode = random.choice(
[
PoolingMode.SUM,
PoolingMode.MEAN,
]
)
mixed = random.choice([True, False])

if pooling_mode == PoolingMode.SUM:
weighted = random.choice([True, False])
else:
weighted = False

if nbit_weights_ty is None:
# don't care when mixed type is used.
weights_ty: SparseType = SparseType.INT8
mixed_weights_ty = True
else:
weights_ty: SparseType = nbit_weights_ty
mixed_weights_ty = False
output_dtype = SparseType.BF16
self.execute_nbit_forward_(
T,
D,
Expand Down

0 comments on commit c11f5c7

Please sign in to comment.