diff --git a/.gitmodules b/.gitmodules index 282746ed0b53e2..2a098164aac49b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -105,7 +105,7 @@ [submodule "third_party/fbgemm"] ignore = dirty path = third_party/fbgemm - url = https://github.com/pytorch/fbgemm + url = https://github.com/DamianSzwichtenberg/fbgemm [submodule "third_party/foxi"] ignore = dirty path = third_party/foxi diff --git a/aten/src/ATen/native/cpu/SortingKernel.cpp b/aten/src/ATen/native/cpu/SortingKernel.cpp index 2d98c31d331553..14346ad5339ebc 100644 --- a/aten/src/ATen/native/cpu/SortingKernel.cpp +++ b/aten/src/ATen/native/cpu/SortingKernel.cpp @@ -1,15 +1,21 @@ #define TORCH_ASSERT_NO_OPERATORS + +#include + #include #include #include #include #include #include +#include +#include #include #include #include #include #include +#include namespace at::native { @@ -83,6 +89,39 @@ struct KeyValueCompDesc { } }; +static void parallel_sort1d_kernel( + const TensorBase& values, + const TensorBase& indices) { + // this kernel does not care about `stable` parameter as radix sort + // used here is a stable sorting algorithm + AT_DISPATCH_INTEGRAL_TYPES(values.scalar_type(), "parallel_sort1d_kernel", [&] { + int64_t elements = values.numel(); + scalar_t* keys = values.data_ptr(); + int64_t* vals = indices.data_ptr(); + std::vector tmp_keys(elements); + std::vector tmp_vals(elements); + const auto [sorted_keys, sorted_vals] = fbgemm::radix_sort_parallel( + keys, + vals, + tmp_keys.data(), + tmp_vals.data(), + elements, + std::numeric_limits::max(), + values.scalar_type() != ScalarType::Byte); + + const bool sorted_in_place = keys == sorted_keys; + if (!sorted_in_place) { + const auto common_size = values.numel(); + const int num_threads = at::get_num_threads(); + at::parallel_for(0, common_size, at::internal::GRAIN_SIZE / num_threads, [&](int64_t begin, int64_t end) { + const auto job_size = end - begin; + vec::map([](vec::Vectorized x) -> vec::Vectorized { return x; }, keys + begin, sorted_keys + begin, job_size); + vec::map([](vec::Vectorized x) -> vec::Vectorized { return x; }, vals + begin, sorted_vals + begin, job_size); + }); + } + }); +} + static void sort_kernel( const TensorBase& self, const TensorBase& values, @@ -97,6 +136,15 @@ static void sort_kernel( // https://github.com/pytorch/pytorch/issues/91420 return; } + // TODO(dszwicht): Should we add here check for `stable` param? + // Radix sort is a stable sorting algorithm. + if (fbgemm::is_radix_sort_accelerated_with_openmp() && + values.dim() == 1 && values.numel() >= at::internal::GRAIN_SIZE && + at::isIntegralType(values.scalar_type(), /*includeBool=*/false) && + !descending) { + parallel_sort1d_kernel(values, indices); + return; + } _dim_apply( values, indices, dim, "sort_cpu", [&]( diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index d8d7e7aaed10af..cb86bc8340b574 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -8,7 +8,7 @@ from itertools import permutations, product from torch.testing import make_tensor -from torch.testing._internal.common_dtype import all_types, all_types_and, floating_types_and +from torch.testing._internal.common_dtype import all_types, all_types_and, floating_types_and, integral_types from torch.testing._internal.common_utils import \ (TestCase, run_tests, slowTest) from torch.testing._internal.common_device_type import \ @@ -250,6 +250,15 @@ def test_sort_1d_output_discontiguous(self, device, dtype): self.assertEqual(indices, indices_cont) self.assertEqual(values, values_cont) + @slowTest + @onlyCPU + @dtypes(*integral_types()) + def test_sort_1d_parallel(self, device, dtype): + low = 0 if dtype == torch.uint8 else -128 + tensor = torch.randint(low=low, high=127, size=(100000, ), device=device, dtype=dtype) + vals, _ = torch.sort(tensor, stable=True) + self.assertEqual(True, torch.all(vals[:-1] <= vals[1:])) + @dtypes(torch.float32) def test_topk_1d_output_discontiguous(self, device, dtype): tensor = torch.randn(12, device=device, dtype=dtype) diff --git a/third_party/fbgemm b/third_party/fbgemm index 03b2046676707d..9aea7db4fbaa65 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 03b2046676707da64504e898490ab46104d4682a +Subproject commit 9aea7db4fbaa65bc8518e1d6a54d053203cacbe4