diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index bc3e3505185ad0..cceb8baf9b0631 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -234,7 +234,7 @@ index_select_add( offsets_data = offsets_include_last.data(); } #if defined(USE_FBGEMM) - constexpr bool isbf16 = std::is_same::value ? false : true; + constexpr bool isbf16 = std::is_same_v ? false : true; auto kernel_16bit_index_t = fbgemm_kernel_cache ? fbgemm_kernel_cache ->getCallback(ddim) @@ -245,7 +245,8 @@ index_select_add( /* prefetch */ 16, /* is_weight_positional */ false, /* use_offsets */ true, - /* isbf16*/ isbf16); + /* is_bf16_out */ isbf16, + /* is_bf16_in */ isbf16); at::parallel_for( 0, output_size, 1, [&](index_t start_idx, index_t end_idx) { bool success = kernel_16bit_index_t( @@ -607,7 +608,7 @@ index_select_scale_add( auto* scale_data_fp32 = scale_fp32.mutable_data_ptr(); #if defined(USE_FBGEMM) - constexpr bool isbf16 = std::is_same::value ? false : true; + constexpr bool isbf16 = std::is_same_v ? false : true; if constexpr (isbf16) { fbgemm::Bfloat16ToFloat_simd( reinterpret_cast(scale_data), @@ -629,7 +630,8 @@ index_select_scale_add( /* prefetch */ 16, /* is_weight_positional */ false, /* use_offsets */ true, - /* isbf16*/ isbf16); + /* is_bf16_out */ isbf16, + /* is_bf16_in */ isbf16); at::parallel_for( 0, output_size, 1, [&](index_t start_idx, index_t end_idx) { bool success = kernel_16bit_index_t( diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 08b62cc1476ba4..d3b04617d2c1bb 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -1122,6 +1122,20 @@ def test_isin_different_devices(self, device, dtype): with self.assertRaises(RuntimeError): torch.isin(c, d) + @dtypes(*integral_types()) + def test_sort_overflow(self, device, dtype): + " Regression test for https://github.com/pytorch/pytorch/issues/111189 " + prev_num_threads = torch.get_num_threads() + try: + low = 0 if dtype == torch.uint8 else -1 + x = torch.full((32768,), low, dtype=dtype, device=device) + x[:100] = torch.iinfo(x.dtype).max + torch.set_num_threads(1) + uv = x.sort().values.unique() + self.assertEqual(uv.size(0), 2) + finally: + torch.set_num_threads(prev_num_threads) + instantiate_device_type_tests(TestSortAndSelect, globals()) diff --git a/third_party/fbgemm b/third_party/fbgemm index d0ee798b1f198c..70c6e83c29f672 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit d0ee798b1f198cc51b6ddae20cf6063f6380ba3f +Subproject commit 70c6e83c29f67278751abd0e28433c50743ccbe9