Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make 1D integer sorting parallel #2

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
[submodule "third_party/fbgemm"]
ignore = dirty
path = third_party/fbgemm
url = https://github.com/pytorch/fbgemm
url = https://github.com/DamianSzwichtenberg/fbgemm
[submodule "third_party/foxi"]
ignore = dirty
path = third_party/foxi
Expand Down
48 changes: 48 additions & 0 deletions aten/src/ATen/native/cpu/SortingKernel.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
#define TORCH_ASSERT_NO_OPERATORS

#include <limits>

#include <ATen/native/Sorting.h>
#include <ATen/core/TensorBase.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/NumericUtils.h>
#include <ATen/TensorIterator.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/StridedRandomAccessor.h>
#include <ATen/native/CompositeRandomAccessor.h>
#include <ATen/native/TopKImpl.h>
#include <c10/core/WrapDimMinimal.h>
#include <c10/util/irange.h>
#include <fbgemm/Utils.h>

namespace at::native {

Expand Down Expand Up @@ -83,6 +89,39 @@ struct KeyValueCompDesc {
}
};

static void parallel_sort1d_kernel(
const TensorBase& values,
const TensorBase& indices) {
// this kernel does not care about `stable` parameter as radix sort
// used here is a stable sorting algorithm
AT_DISPATCH_INTEGRAL_TYPES(values.scalar_type(), "parallel_sort1d_kernel", [&] {
int64_t elements = values.numel();
scalar_t* keys = values.data_ptr<scalar_t>();
int64_t* vals = indices.data_ptr<int64_t>();
std::vector<scalar_t> tmp_keys(elements);
std::vector<int64_t> tmp_vals(elements);
const auto [sorted_keys, sorted_vals] = fbgemm::radix_sort_parallel(
keys,
vals,
tmp_keys.data(),
tmp_vals.data(),
elements,
std::numeric_limits<scalar_t>::max(),
values.scalar_type() != ScalarType::Byte);

const bool sorted_in_place = keys == sorted_keys;
if (!sorted_in_place) {
const auto common_size = values.numel();
const int num_threads = at::get_num_threads();
at::parallel_for(0, common_size, at::internal::GRAIN_SIZE / num_threads, [&](int64_t begin, int64_t end) {
const auto job_size = end - begin;
vec::map([](vec::Vectorized<scalar_t> x) -> vec::Vectorized<scalar_t> { return x; }, keys + begin, sorted_keys + begin, job_size);
vec::map([](vec::Vectorized<int64_t> x) -> vec::Vectorized<int64_t> { return x; }, vals + begin, sorted_vals + begin, job_size);
});
}
});
}

static void sort_kernel(
const TensorBase& self,
const TensorBase& values,
Expand All @@ -97,6 +136,15 @@ static void sort_kernel(
// https://github.com/pytorch/pytorch/issues/91420
return;
}
// TODO(dszwicht): Should we add here check for `stable` param?
// Radix sort is a stable sorting algorithm.
if (fbgemm::is_radix_sort_accelerated_with_openmp() &&
values.dim() == 1 && values.numel() >= at::internal::GRAIN_SIZE &&
at::isIntegralType(values.scalar_type(), /*includeBool=*/false) &&
!descending) {
parallel_sort1d_kernel(values, indices);
return;
}
_dim_apply(
values, indices, dim,
"sort_cpu", [&](
Expand Down
11 changes: 10 additions & 1 deletion test/test_sort_and_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from itertools import permutations, product

from torch.testing import make_tensor
from torch.testing._internal.common_dtype import all_types, all_types_and, floating_types_and
from torch.testing._internal.common_dtype import all_types, all_types_and, floating_types_and, integral_types
from torch.testing._internal.common_utils import \
(TestCase, run_tests, slowTest)
from torch.testing._internal.common_device_type import \
Expand Down Expand Up @@ -250,6 +250,15 @@ def test_sort_1d_output_discontiguous(self, device, dtype):
self.assertEqual(indices, indices_cont)
self.assertEqual(values, values_cont)

@slowTest
@onlyCPU
@dtypes(*integral_types())
def test_sort_1d_parallel(self, device, dtype):
low = 0 if dtype == torch.uint8 else -128
tensor = torch.randint(low=low, high=127, size=(100000, ), device=device, dtype=dtype)
vals, _ = torch.sort(tensor, stable=True)
self.assertEqual(True, torch.all(vals[:-1] <= vals[1:]))

@dtypes(torch.float32)
def test_topk_1d_output_discontiguous(self, device, dtype):
tensor = torch.randn(12, device=device, dtype=dtype)
Expand Down
2 changes: 1 addition & 1 deletion third_party/fbgemm
Submodule fbgemm updated 111 files