From c11f5c730b47706f7cf6f99551536cac97e78013 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Thu, 7 Dec 2023 14:02:27 -0800 Subject: [PATCH] Enable sequence TBE CPU via AVX (#2195) Summary: Instead of using the ref implementation for sequence embedding on CPU, this diff directs TBE to invoke the AVX implementation of pooled TBE by forcing pooling factors of 1 (i.e., passing `at::arange(index_size + 1)` as offfsets). The performance gained from using the AVX implementation offsets the overhead incurs in creating the new offsets. Reviewed By: jspark1105 Differential Revision: D51918878 --- ...bedding_forward_quantized_cpu_template.cpp | 116 ++++++------------ fbgemm_gpu/test/failures_dict_fast.json | 8 -- .../split_table_batched_embeddings_test.py | 102 ++------------- 3 files changed, 45 insertions(+), 181 deletions(-) diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp index 652fc894cf..a61ecec7d3 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp @@ -200,6 +200,13 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ const float* indice_weights_acc = indice_weights.data_ptr(); {% endif %} + using float16 = uint16_t; + using bfloat16 = uint16_t; + using fbgemm_out_t = typename std::conditional< + std::is_same::value, + float16, + std::conditional::value, bfloat16, float>::type >::type; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_", [&] { const auto* indices_acc = indices.data_ptr(); const auto* offsets_acc = offsets.data_ptr(); @@ -208,10 +215,9 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ auto* output_acc = output.data_ptr(); int32_t num_indices_m_1 = indices.numel() - 1; - int32_t D_start_ = 0; -for (const auto t : c10::irange(T)) { + for (const auto t : c10::irange(T)) { {% if not nobag %} const auto* D_offsets_acc = D_offsets.data_ptr(); const int32_t D_start = D_offsets_acc[t]; @@ -226,27 +232,29 @@ for (const auto t : c10::irange(T)) { const auto& weight_tensor = (placement == PlacementType::HOST) ? dev_weights : uvm_weights; weights_acc = weight_tensor.data_ptr(); const uint8_t* weights = &weights_acc[weights_offsets_acc[t]]; - auto weight_ty = static_cast(weights_tys_acc[t]); + const auto weight_ty = static_cast(weights_tys_acc[t]); // default to 1 byte alignment for CPU TBE const int32_t D_bytes = nbit::padded_row_size_in_bytes(D, weight_ty, row_alignment); int tt; for (tt = t + 1; tt < T && weights_offsets_acc[tt] == weights_offsets_acc[t]; ++tt); - size_t num_rows = ((tt == T ? weight_tensor.numel() : weights_offsets_acc[tt]) - weights_offsets_acc[t]) / D_bytes; + const size_t num_rows = ((tt == T ? weight_tensor.numel() : weights_offsets_acc[tt]) - weights_offsets_acc[t]) / D_bytes; const index_t* offsets_begin_ptr = offsets_acc + t * B; - using float16 = uint16_t; - using bfloat16 = uint16_t; - using fbgemm_out_t = typename std::conditional< - std::is_same::value, - float16, - std::conditional::value, bfloat16, float>::type >::type; - bool success = true; - bool has_weight = {{ "true" if weighted else "false" }}; - bool normalize_by_lengths = static_cast(pooling_mode) == PoolingMode::MEAN; + const bool has_weight = {{ "true" if weighted else "false" }}; + const bool normalize_by_lengths = static_cast(pooling_mode) == PoolingMode::MEAN; + + const index_t index_size = offsets_acc[(t + 1) * B] - *offsets_begin_ptr; + + {% if nobag %} + // Create virtual offsets for the nobag case. Lengths are all ones. + const auto offsets_nobag = at::arange(*offsets_begin_ptr, offsets_acc[(t + 1) * B] + 1, offsets.options()); + const index_t* offsets_nobag_ptr = offsets_nobag.data_ptr(); + TORCH_CHECK(offsets_nobag.numel() == index_size + 1); + TORCH_CHECK(offsets_nobag_ptr[index_size] - offsets_nobag_ptr[0] == index_size); + {% endif %} - index_t index_size = offsets_acc[(t + 1) * B] - *offsets_begin_ptr; const float* indice_weights_ptr = nullptr; {% if weighted %} indice_weights_ptr = indice_weights_acc + *offsets_begin_ptr; @@ -259,32 +267,18 @@ for (const auto t : c10::irange(T)) { /*prefetch=*/16, /*is_weight_positional=*/false, /*use_offsets=*/true, - {% if not nobag %} - /*output_stride=*/total_D, - {% else %} - /*output_stride=*/D, - {% endif %} + /*output_stride=*/{{ "total_D" if not nobag else "D" }}, /*input_stride=*/D_bytes / sizeof(float), - {% if not nobag %} /*scale_bias_last=*/false, /*no_bag=*/false, /*is_bf16_out=*/output_is_bf16); - {% else %} - /*scale_bias_last=*/false, - /*no_bag=*/true, - /*is_bf16_out=*/output_is_bf16); - {% endif %} success = kernel( - {% if not nobag %} - B, - {% else %} - index_size, - {% endif %} + {{ "B" if not nobag else "index_size"}}, index_size, num_rows, reinterpret_cast(weights), indices_acc + *offsets_begin_ptr, - offsets_begin_ptr, + {{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }}, indice_weights_ptr, reinterpret_cast(output_acc + D_start)); } else if (weight_ty == SparseType::FP16) { @@ -295,32 +289,18 @@ for (const auto t : c10::irange(T)) { /*prefetch=*/16, /*is_weight_positional=*/false, /*use_offsets=*/true, - {% if not nobag %} - /*output_stride=*/total_D, - {% else %} - /*output_stride=*/D, - {% endif %} + /*output_stride=*/{{ "total_D" if not nobag else "D" }}, /*input_stride=*/D_bytes / sizeof(float16), - {% if not nobag %} /*scale_bias_last=*/false, /*no_bag=*/false, /*is_bf16_out=*/output_is_bf16); - {% else %} - /*scale_bias_last=*/false, - /*no_bag=*/true, - /*is_bf16_out=*/output_is_bf16); - {% endif %} success = kernel( - {% if not nobag %} - B, - {% else %} - index_size, - {% endif %} + {{ "B" if not nobag else "index_size"}}, index_size, num_rows, reinterpret_cast(weights), indices_acc + *offsets_begin_ptr, - offsets_begin_ptr, + {{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }}, indice_weights_ptr, reinterpret_cast(output_acc + D_start)); } else if (weight_ty == SparseType::FP8) { @@ -330,22 +310,18 @@ for (const auto t : c10::irange(T)) { normalize_by_lengths, /*is_weight_positional=*/false, /*use_offsets=*/true, - {% if not nobag %} - /*output_stride=*/total_D, - {% else %} - /*output_stride=*/D, - {% endif %} + /*output_stride=*/{{ "total_D" if not nobag else "D" }}, /*input_stride=*/D_bytes / sizeof(uint8_t), /*exponent_bits=*/fp8_exponent_bits, /*exponent_bias=*/fp8_exponent_bias, /*is_bf16_out=*/output_is_bf16); success = kernel( - B, + {{ "B" if not nobag else "index_size"}}, index_size, num_rows, weights, indices_acc + *offsets_begin_ptr, - offsets_begin_ptr, + {{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }}, indice_weights_ptr, reinterpret_cast(output_acc + D_start)); } else if (weight_ty == SparseType::INT8) { @@ -356,32 +332,18 @@ for (const auto t : c10::irange(T)) { /*prefetch=*/16, /*is_weight_positional=*/false, /*use_offsets=*/true, - {% if not nobag %} - /*output_stride=*/total_D, - {% else %} - /*output_stride=*/D, - {% endif %} + /*output_stride=*/{{ "total_D" if not nobag else "D" }}, /*input_stride=*/D_bytes / sizeof(uint8_t), - {% if not nobag %} /*scale_bias_last=*/false, /*no_bag=*/false, /*is_bf16_out=*/output_is_bf16); - {% else %} - /*scale_bias_last=*/false, - /*no_bag=*/true, - /*is_bf16_out=*/output_is_bf16); - {% endif %} success = kernel( - {% if not nobag %} - B, - {% else %} - index_size, - {% endif %} + {{ "B" if not nobag else "index_size"}}, index_size, num_rows, weights, indices_acc + *offsets_begin_ptr, - offsets_begin_ptr, + {{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }}, indice_weights_ptr, reinterpret_cast(output_acc + D_start)); } else if (weight_ty == SparseType::INT4 || weight_ty == SparseType::INT2) { @@ -404,21 +366,17 @@ for (const auto t : c10::irange(T)) { /*prefetch=*/16, /*is_weight_positional=*/false, /*use_offsets=*/true, - {% if not nobag %} - /*output_stride=*/total_D, - {% else %} - /*output_stride=*/D, - {% endif %} + /*output_stride=*/{{ "total_D" if not nobag else "D" }}, /*input_stride=*/D_bytes / sizeof(uint8_t), /*scale_bias_last=*/false, /*is_bf16_out=*/output_is_bf16); success = kernel( - B, + {{ "B" if not nobag else "index_size"}}, index_size, num_rows, weights, indices_acc + *offsets_begin_ptr, - offsets_begin_ptr, + {{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }}, indice_weights_ptr, reinterpret_cast(output_acc + D_start)); } else { diff --git a/fbgemm_gpu/test/failures_dict_fast.json b/fbgemm_gpu/test/failures_dict_fast.json index 5711a75052..982ed03dcc 100644 --- a/fbgemm_gpu/test/failures_dict_fast.json +++ b/fbgemm_gpu/test/failures_dict_fast.json @@ -35,10 +35,6 @@ "comment": "", "status": "xfail" }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_cpu_bf16_out": { - "comment": "", - "status": "xfail" - }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_gpu_no_cache": { "comment": "", "status": "xfail" @@ -75,10 +71,6 @@ "comment": "", "status": "xfail" }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_cpu_bf16_out": { - "comment": "", - "status": "xfail" - }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_gpu_no_cache": { "comment": "", "status": "xfail" diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index ce8e41e630..037de61815 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -4526,6 +4526,12 @@ def execute_nbit_forward_( # noqa C901 nbit_weights_ty=get_nbit_weights_ty(), use_array_for_index_remapping=st.booleans(), do_pruning=st.booleans(), + pooling_mode=st.sampled_from( + [PoolingMode.SUM, PoolingMode.NONE, PoolingMode.MEAN] + ), + output_dtype=st.sampled_from( + [SparseType.FP32, SparseType.FP16, SparseType.BF16] + ), ) @settings( verbosity=VERBOSITY, @@ -4537,6 +4543,8 @@ def test_nbit_forward_cpu( nbit_weights_ty: Optional[SparseType], use_array_for_index_remapping: bool, do_pruning: bool, + pooling_mode: PoolingMode, + output_dtype: SparseType, ) -> None: use_cpu = True T = random.randint(1, 50) @@ -4549,27 +4557,7 @@ def test_nbit_forward_cpu( # cache_algorithm is don't care as we don't use cache. cache_algorithm = CacheAlgorithm.LRU - pooling_mode = random.choice( - [ - PoolingMode.SUM, - PoolingMode.MEAN, - PoolingMode.NONE, - ] - ) mixed = random.choice([True, False]) - if pooling_mode == PoolingMode.NONE: - nbit_weights_ty = random.choice( - [ - SparseType.FP32, - SparseType.FP16, - # CPU sequence embedding does not support FP8/INT4/INT2 yet - # SparseType.FP8, - SparseType.INT8, - # SparseType.INT4, - # SparseType.INT2, - ] - ) - if pooling_mode == PoolingMode.SUM: weighted = random.choice([True, False]) else: @@ -4582,81 +4570,7 @@ def test_nbit_forward_cpu( else: weights_ty: SparseType = nbit_weights_ty mixed_weights_ty = False - output_dtype = random.choice( - ( - [SparseType.BF16] - if weights_ty in [SparseType.INT4, SparseType.INT2] - else [] - ) - + [SparseType.FP32, SparseType.FP16] - ) - self.execute_nbit_forward_( - T, - D, - B, - log_E, - L, - weighted, - mixed, - pooling_mode, - weights_ty, - use_cache, - cache_algorithm, - use_cpu, - use_array_for_index_remapping, - do_pruning, - mixed_weights_ty, - output_dtype, - ) - - @given( - nbit_weights_ty=get_nbit_weights_ty(), - use_array_for_index_remapping=st.booleans(), - do_pruning=st.booleans(), - ) - @settings( - verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES_LONG_RUNNING, - deadline=None, - ) - def test_nbit_forward_cpu_bf16_out( - self, - nbit_weights_ty: Optional[SparseType], - use_array_for_index_remapping: bool, - do_pruning: bool, - ) -> None: - use_cpu = True - T = random.randint(1, 50) - B = random.randint(0, 128) - L = random.randint(0, 32) - D = random.randint(2, 2048) - log_E = random.randint(2, 4) - - use_cache = False - # cache_algorithm is don't care as we don't use cache. - cache_algorithm = CacheAlgorithm.LRU - - pooling_mode = random.choice( - [ - PoolingMode.SUM, - PoolingMode.MEAN, - ] - ) - mixed = random.choice([True, False]) - - if pooling_mode == PoolingMode.SUM: - weighted = random.choice([True, False]) - else: - weighted = False - if nbit_weights_ty is None: - # don't care when mixed type is used. - weights_ty: SparseType = SparseType.INT8 - mixed_weights_ty = True - else: - weights_ty: SparseType = nbit_weights_ty - mixed_weights_ty = False - output_dtype = SparseType.BF16 self.execute_nbit_forward_( T, D,