Skip to content

Commit

Permalink
Fix buffer overflow in torch.sort (pytorch#111672)
Browse files Browse the repository at this point in the history
By updating fbgemm submodule
Add regression test for it (though it can probably be limited to just CPU as reproducer only works if num_threads is 1)

Also, update call-sites  to `fbgemm:: GenerateEmbeddingSpMDM` to pass `isbf16` twice, to match API changes introduced in pytorch/FBGEMM#1851

Fixes pytorch#111189 and pytorch#111710

Pull Request resolved: pytorch#111672
Approved by: https://github.com/Skylion007
  • Loading branch information
malfet authored and Skylion007 committed Nov 14, 2023
1 parent 20247dd commit d481f1f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
10 changes: 6 additions & 4 deletions aten/src/ATen/native/EmbeddingBag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ index_select_add(
offsets_data = offsets_include_last.data();
}
#if defined(USE_FBGEMM)
constexpr bool isbf16 = std::is_same<data_t, at::Half>::value ? false : true;
constexpr bool isbf16 = std::is_same_v<data_t, at::Half> ? false : true;
auto kernel_16bit_index_t = fbgemm_kernel_cache
? fbgemm_kernel_cache
->getCallback</* has_weight */ false, index_t, uint16_t>(ddim)
Expand All @@ -245,7 +245,8 @@ index_select_add(
/* prefetch */ 16,
/* is_weight_positional */ false,
/* use_offsets */ true,
/* isbf16*/ isbf16);
/* is_bf16_out */ isbf16,
/* is_bf16_in */ isbf16);
at::parallel_for(
0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
bool success = kernel_16bit_index_t(
Expand Down Expand Up @@ -607,7 +608,7 @@ index_select_scale_add(
auto* scale_data_fp32 = scale_fp32.mutable_data_ptr<float>();

#if defined(USE_FBGEMM)
constexpr bool isbf16 = std::is_same<data_t, at::Half>::value ? false : true;
constexpr bool isbf16 = std::is_same_v<data_t, at::Half> ? false : true;
if constexpr (isbf16) {
fbgemm::Bfloat16ToFloat_simd(
reinterpret_cast<const fbgemm::bfloat16*>(scale_data),
Expand All @@ -629,7 +630,8 @@ index_select_scale_add(
/* prefetch */ 16,
/* is_weight_positional */ false,
/* use_offsets */ true,
/* isbf16*/ isbf16);
/* is_bf16_out */ isbf16,
/* is_bf16_in */ isbf16);
at::parallel_for(
0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
bool success = kernel_16bit_index_t(
Expand Down
14 changes: 14 additions & 0 deletions test/test_sort_and_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,20 @@ def test_isin_different_devices(self, device, dtype):
with self.assertRaises(RuntimeError):
torch.isin(c, d)

@dtypes(*integral_types())
def test_sort_overflow(self, device, dtype):
" Regression test for https://github.com/pytorch/pytorch/issues/111189 "
prev_num_threads = torch.get_num_threads()
try:
low = 0 if dtype == torch.uint8 else -1
x = torch.full((32768,), low, dtype=dtype, device=device)
x[:100] = torch.iinfo(x.dtype).max
torch.set_num_threads(1)
uv = x.sort().values.unique()
self.assertEqual(uv.size(0), 2)
finally:
torch.set_num_threads(prev_num_threads)


instantiate_device_type_tests(TestSortAndSelect, globals())

Expand Down
2 changes: 1 addition & 1 deletion third_party/fbgemm
Submodule fbgemm updated 453 files

0 comments on commit d481f1f

Please sign in to comment.