Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable sequence TBE CPU via AVX #2195

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 37 additions & 79 deletions fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,13 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
const float* indice_weights_acc = indice_weights.data_ptr<float>();
{% endif %}

using float16 = uint16_t;
using bfloat16 = uint16_t;
using fbgemm_out_t = typename std::conditional<
std::is_same<output_t, at::Half>::value,
float16,
std::conditional<std::is_same<output_t, at::BFloat16>::value, bfloat16, float>::type >::type;

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_", [&] {
const auto* indices_acc = indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();
Expand All @@ -208,10 +215,9 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{

auto* output_acc = output.data_ptr<output_t>();
int32_t num_indices_m_1 = indices.numel() - 1;

int32_t D_start_ = 0;
for (const auto t : c10::irange(T)) {

for (const auto t : c10::irange(T)) {
{% if not nobag %}
const auto* D_offsets_acc = D_offsets.data_ptr<int32_t>();
const int32_t D_start = D_offsets_acc[t];
Expand All @@ -226,27 +232,29 @@ for (const auto t : c10::irange(T)) {
const auto& weight_tensor = (placement == PlacementType::HOST) ? dev_weights : uvm_weights;
weights_acc = weight_tensor.data_ptr<uint8_t>();
const uint8_t* weights = &weights_acc[weights_offsets_acc[t]];
auto weight_ty = static_cast<SparseType>(weights_tys_acc[t]);
const auto weight_ty = static_cast<SparseType>(weights_tys_acc[t]);
// default to 1 byte alignment for CPU TBE
const int32_t D_bytes = nbit::padded_row_size_in_bytes(D, weight_ty, row_alignment);

int tt;
for (tt = t + 1; tt < T && weights_offsets_acc[tt] == weights_offsets_acc[t]; ++tt);
size_t num_rows = ((tt == T ? weight_tensor.numel() : weights_offsets_acc[tt]) - weights_offsets_acc[t]) / D_bytes;
const size_t num_rows = ((tt == T ? weight_tensor.numel() : weights_offsets_acc[tt]) - weights_offsets_acc[t]) / D_bytes;
const index_t* offsets_begin_ptr = offsets_acc + t * B;

using float16 = uint16_t;
using bfloat16 = uint16_t;
using fbgemm_out_t = typename std::conditional<
std::is_same<output_t, at::Half>::value,
float16,
std::conditional<std::is_same<output_t, at::BFloat16>::value, bfloat16, float>::type >::type;

bool success = true;
bool has_weight = {{ "true" if weighted else "false" }};
bool normalize_by_lengths = static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN;
const bool has_weight = {{ "true" if weighted else "false" }};
const bool normalize_by_lengths = static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN;

const index_t index_size = offsets_acc[(t + 1) * B] - *offsets_begin_ptr;

{% if nobag %}
// Create virtual offsets for the nobag case. Lengths are all ones.
const auto offsets_nobag = at::arange(*offsets_begin_ptr, offsets_acc[(t + 1) * B] + 1, offsets.options());
const index_t* offsets_nobag_ptr = offsets_nobag.data_ptr<index_t>();
TORCH_CHECK(offsets_nobag.numel() == index_size + 1);
TORCH_CHECK(offsets_nobag_ptr[index_size] - offsets_nobag_ptr[0] == index_size);
{% endif %}

index_t index_size = offsets_acc[(t + 1) * B] - *offsets_begin_ptr;
const float* indice_weights_ptr = nullptr;
{% if weighted %}
indice_weights_ptr = indice_weights_acc + *offsets_begin_ptr;
Expand All @@ -259,32 +267,18 @@ for (const auto t : c10::irange(T)) {
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
{% if not nobag %}
/*output_stride=*/total_D,
{% else %}
/*output_stride=*/D,
{% endif %}
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(float),
{% if not nobag %}
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
{% else %}
/*scale_bias_last=*/false,
/*no_bag=*/true,
/*is_bf16_out=*/output_is_bf16);
{% endif %}
success = kernel(
{% if not nobag %}
B,
{% else %}
index_size,
{% endif %}
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
reinterpret_cast<const float*>(weights),
indices_acc + *offsets_begin_ptr,
offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
} else if (weight_ty == SparseType::FP16) {
Expand All @@ -295,32 +289,18 @@ for (const auto t : c10::irange(T)) {
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
{% if not nobag %}
/*output_stride=*/total_D,
{% else %}
/*output_stride=*/D,
{% endif %}
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(float16),
{% if not nobag %}
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
{% else %}
/*scale_bias_last=*/false,
/*no_bag=*/true,
/*is_bf16_out=*/output_is_bf16);
{% endif %}
success = kernel(
{% if not nobag %}
B,
{% else %}
index_size,
{% endif %}
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
reinterpret_cast<const float16*>(weights),
indices_acc + *offsets_begin_ptr,
offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
} else if (weight_ty == SparseType::FP8) {
Expand All @@ -330,22 +310,18 @@ for (const auto t : c10::irange(T)) {
normalize_by_lengths,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
{% if not nobag %}
/*output_stride=*/total_D,
{% else %}
/*output_stride=*/D,
{% endif %}
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(uint8_t),
/*exponent_bits=*/fp8_exponent_bits,
/*exponent_bias=*/fp8_exponent_bias,
/*is_bf16_out=*/output_is_bf16);
success = kernel(
B,
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
weights,
indices_acc + *offsets_begin_ptr,
offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
} else if (weight_ty == SparseType::INT8) {
Expand All @@ -356,32 +332,18 @@ for (const auto t : c10::irange(T)) {
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
{% if not nobag %}
/*output_stride=*/total_D,
{% else %}
/*output_stride=*/D,
{% endif %}
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(uint8_t),
{% if not nobag %}
/*scale_bias_last=*/false,
/*no_bag=*/false,
/*is_bf16_out=*/output_is_bf16);
{% else %}
/*scale_bias_last=*/false,
/*no_bag=*/true,
/*is_bf16_out=*/output_is_bf16);
{% endif %}
success = kernel(
{% if not nobag %}
B,
{% else %}
index_size,
{% endif %}
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
weights,
indices_acc + *offsets_begin_ptr,
offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
} else if (weight_ty == SparseType::INT4 || weight_ty == SparseType::INT2) {
Expand All @@ -404,21 +366,17 @@ for (const auto t : c10::irange(T)) {
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
{% if not nobag %}
/*output_stride=*/total_D,
{% else %}
/*output_stride=*/D,
{% endif %}
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
/*input_stride=*/D_bytes / sizeof(uint8_t),
/*scale_bias_last=*/false,
/*is_bf16_out=*/output_is_bf16);
success = kernel(
B,
{{ "B" if not nobag else "index_size"}},
index_size,
num_rows,
weights,
indices_acc + *offsets_begin_ptr,
offsets_begin_ptr,
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
indice_weights_ptr,
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
} else {
Expand Down
8 changes: 0 additions & 8 deletions fbgemm_gpu/test/failures_dict_fast.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_cpu_bf16_out": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_gpu_no_cache": {
"comment": "",
"status": "xfail"
Expand Down Expand Up @@ -75,10 +71,6 @@
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_cpu_bf16_out": {
"comment": "",
"status": "xfail"
},
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_gpu_no_cache": {
"comment": "",
"status": "xfail"
Expand Down
102 changes: 8 additions & 94 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4526,6 +4526,12 @@ def execute_nbit_forward_( # noqa C901
nbit_weights_ty=get_nbit_weights_ty(),
use_array_for_index_remapping=st.booleans(),
do_pruning=st.booleans(),
pooling_mode=st.sampled_from(
[PoolingMode.SUM, PoolingMode.NONE, PoolingMode.MEAN]
),
output_dtype=st.sampled_from(
[SparseType.FP32, SparseType.FP16, SparseType.BF16]
),
)
@settings(
verbosity=VERBOSITY,
Expand All @@ -4537,6 +4543,8 @@ def test_nbit_forward_cpu(
nbit_weights_ty: Optional[SparseType],
use_array_for_index_remapping: bool,
do_pruning: bool,
pooling_mode: PoolingMode,
output_dtype: SparseType,
) -> None:
use_cpu = True
T = random.randint(1, 50)
Expand All @@ -4549,27 +4557,7 @@ def test_nbit_forward_cpu(
# cache_algorithm is don't care as we don't use cache.
cache_algorithm = CacheAlgorithm.LRU

pooling_mode = random.choice(
[
PoolingMode.SUM,
PoolingMode.MEAN,
PoolingMode.NONE,
]
)
mixed = random.choice([True, False])
if pooling_mode == PoolingMode.NONE:
nbit_weights_ty = random.choice(
[
SparseType.FP32,
SparseType.FP16,
# CPU sequence embedding does not support FP8/INT4/INT2 yet
# SparseType.FP8,
SparseType.INT8,
# SparseType.INT4,
# SparseType.INT2,
]
)

if pooling_mode == PoolingMode.SUM:
weighted = random.choice([True, False])
else:
Expand All @@ -4582,81 +4570,7 @@ def test_nbit_forward_cpu(
else:
weights_ty: SparseType = nbit_weights_ty
mixed_weights_ty = False
output_dtype = random.choice(
(
[SparseType.BF16]
if weights_ty in [SparseType.INT4, SparseType.INT2]
else []
)
+ [SparseType.FP32, SparseType.FP16]
)
self.execute_nbit_forward_(
T,
D,
B,
log_E,
L,
weighted,
mixed,
pooling_mode,
weights_ty,
use_cache,
cache_algorithm,
use_cpu,
use_array_for_index_remapping,
do_pruning,
mixed_weights_ty,
output_dtype,
)

@given(
nbit_weights_ty=get_nbit_weights_ty(),
use_array_for_index_remapping=st.booleans(),
do_pruning=st.booleans(),
)
@settings(
verbosity=VERBOSITY,
max_examples=MAX_EXAMPLES_LONG_RUNNING,
deadline=None,
)
def test_nbit_forward_cpu_bf16_out(
self,
nbit_weights_ty: Optional[SparseType],
use_array_for_index_remapping: bool,
do_pruning: bool,
) -> None:
use_cpu = True
T = random.randint(1, 50)
B = random.randint(0, 128)
L = random.randint(0, 32)
D = random.randint(2, 2048)
log_E = random.randint(2, 4)

use_cache = False
# cache_algorithm is don't care as we don't use cache.
cache_algorithm = CacheAlgorithm.LRU

pooling_mode = random.choice(
[
PoolingMode.SUM,
PoolingMode.MEAN,
]
)
mixed = random.choice([True, False])

if pooling_mode == PoolingMode.SUM:
weighted = random.choice([True, False])
else:
weighted = False

if nbit_weights_ty is None:
# don't care when mixed type is used.
weights_ty: SparseType = SparseType.INT8
mixed_weights_ty = True
else:
weights_ty: SparseType = nbit_weights_ty
mixed_weights_ty = False
output_dtype = SparseType.BF16
self.execute_nbit_forward_(
T,
D,
Expand Down
Loading