Skip to content

Commit

Permalink
Fix heap-buffer-overflow in radix_sort_parallel (#2075)
Browse files Browse the repository at this point in the history
Summary:
Setting `histogram_ps[RDX_HIST_SIZE * (nthreads - 1) + 127] = offset;` in `combine_prefix_sum_for_msb` is guaranteed to result in `heap-buffer-overflow` if bucket is not empty during the scatter stage (as all values of `histogram_ps` should be strictly less than `element_count`

Factor out common code from `RadixSortTest.cc` into `test_tempalte` and add regression test for buffer overflow, which before the test will fail as follows:
```
[ RUN      ] cpuKernelTest.raidx_sort_heap_overflow
/home/nshulga/git/pytorch/FBGEMM/test/RadixSortTest.cc:36: Failure
Expected equality of these values:
  expected_keys
    Which is: { 2, 3, 5, -1, -1, 2147483647, 2147483647, 2147483647 }
  keys
    Which is: { -1, -1, -1, -1, -1, -1, -1, -1 }
/home/nshulga/git/pytorch/FBGEMM/test/RadixSortTest.cc:37: Failure
Expected equality of these values:
  expected_values
    Which is: { 1, 4, 6, 7, 8, 2, 3, 5 }
  values
    Which is: { 2147483647, 4, 6, 7, 8, 6, 7, 8 }
[  FAILED  ] cpuKernelTest.raidx_sort_heap_overflow (0 ms)
```

Will fix pytorch/pytorch#111189 once FBGEMM is updated to the correct version


Reviewed By: kit1980, jianyuh

Differential Revision: D50256504

Pulled By: malfet
  • Loading branch information
malfet authored and facebook-github-bot committed Oct 16, 2023
1 parent 924f310 commit 621d854
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 42 deletions.
8 changes: 3 additions & 5 deletions src/Utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,6 @@ void combine_prefix_sum(
int64_t offset = 0;
update_prefsum_and_offset_in_range(
offset, 0, RDX_HIST_SIZE, nthreads, histogram, histogram_ps);
histogram_ps[RDX_HIST_SIZE * nthreads] = offset;
// TODO(DamianSzwichtenberg): Is assert sufficient? In most cases, it will
// work only in debug build.
assert(offset == elements_count);
Expand All @@ -641,7 +640,6 @@ void combine_prefix_sum_for_msb(
offset, 128, RDX_HIST_SIZE, nthreads, histogram, histogram_ps);
update_prefsum_and_offset_in_range(
offset, 0, 128, nthreads, histogram, histogram_ps);
histogram_ps[RDX_HIST_SIZE * (nthreads - 1) + 127] = offset;
// TODO(DamianSzwichtenberg): Is assert sufficient? In most cases, it will
// work only in debug build.
assert(offset == elements_count);
Expand Down Expand Up @@ -760,13 +758,13 @@ std::pair<K*, V*> radix_sort_parallel(
const size_t array_size = (size_t)RDX_HIST_SIZE * maxthreads;
// fixes MSVC error C2131
auto* const histogram = static_cast<int64_t*>(
fbgemm::fbgemmAlignedAlloc(64, (array_size) * sizeof(int64_t)));
fbgemm::fbgemmAlignedAlloc(64, array_size * sizeof(int64_t)));
auto* const histogram_ps = static_cast<int64_t*>(
fbgemm::fbgemmAlignedAlloc(64, (array_size + 1) * sizeof(int64_t)));
fbgemm::fbgemmAlignedAlloc(64, array_size * sizeof(int64_t)));

#else
alignas(64) int64_t histogram[RDX_HIST_SIZE * maxthreads];
alignas(64) int64_t histogram_ps[RDX_HIST_SIZE * maxthreads + 1];
alignas(64) int64_t histogram_ps[RDX_HIST_SIZE * maxthreads];
#endif
// If negative values are present, we want to perform all passes
// up to a sign bit
Expand Down
91 changes: 54 additions & 37 deletions test/RadixSortTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,53 +11,70 @@
#include <limits>

#include "fbgemm/Utils.h"
#ifdef _OPENMP
#include <omp.h>
#endif

TEST(cpuKernelTest, radix_sort_parallel_test) {
std::array<int, 8> keys = {1, 2, 4, 5, 4, 3, 2, 9};
std::array<int, 8> values = {0, 0, 0, 0, 1, 1, 1, 1};

std::array<int, 8> keys_tmp;
std::array<int, 8> values_tmp;

namespace {
template <typename T, unsigned N>
void test_template(
std::array<T, N> keys,
std::array<T, N> values,
std::array<T, N> expected_keys,
std::array<T, N> expected_values,
T max_val = std::numeric_limits<T>::max(),
bool may_be_neg = std::is_signed_v<T>) {
std::array<T, N> keys_tmp;
std::array<T, N> values_tmp;
const auto [sorted_keys, sorted_values] = fbgemm::radix_sort_parallel(
keys.data(),
values.data(),
keys_tmp.data(),
values_tmp.data(),
keys.size(),
10);

std::array<int, 8> expect_keys_tmp = {1, 2, 2, 3, 4, 4, 5, 9};
std::array<int, 8> expect_values_tmp = {0, 0, 1, 1, 0, 1, 0, 1};
EXPECT_EQ(sorted_keys, keys_tmp.data());
EXPECT_EQ(sorted_values, values_tmp.data());
EXPECT_EQ(keys_tmp, expect_keys_tmp);
EXPECT_EQ(values_tmp, expect_values_tmp);
max_val,
may_be_neg);
if (sorted_keys == keys.data()) { // even number of passes
EXPECT_EQ(expected_keys, keys);
EXPECT_EQ(expected_values, values);
} else { // odd number of passes
EXPECT_EQ(expected_keys, keys_tmp);
EXPECT_EQ(expected_values, values_tmp);
}
}

TEST(cpuKernelTest, radix_sort_parallel_test_neg_vals) {
std::array<int64_t, 8> keys = {-4, -3, 0, 1, -2, -1, 3, 2};
std::array<int64_t, 8> values = {0, 0, 0, 0, 1, 1, 1, 1};
} // anonymous namespace

std::array<int64_t, 8> keys_tmp;
std::array<int64_t, 8> values_tmp;
TEST(cpuKernelTest, radix_sort_parallel_test) {
test_template<int, 8>(
{1, 2, 4, 5, 4, 3, 2, 9},
{0, 0, 0, 0, 1, 1, 1, 1},
{1, 2, 2, 3, 4, 4, 5, 9},
{0, 0, 1, 1, 0, 1, 0, 1},
10,
false);
}

const auto [sorted_keys, sorted_values] = fbgemm::radix_sort_parallel(
keys.data(),
values.data(),
keys_tmp.data(),
values_tmp.data(),
keys.size(),
std::numeric_limits<int64_t>::max(),
/*maybe_with_neg_vals=*/true);
TEST(cpuKernelTest, radix_sort_parallel_test_neg_vals) {
test_template<int64_t, 8>(
{-4, -3, 0, 1, -2, -1, 3, 2},
{0, 0, 0, 0, 1, 1, 1, 1},
{-4, -3, -2, -1, 0, 1, 2, 3},
{0, 0, 1, 1, 0, 0, 1, 1});
}

std::array<int64_t, 8> expect_keys_tmp = {-4, -3, -2, -1, 0, 1, 2, 3};
std::array<int64_t, 8> expect_values_tmp = {0, 0, 1, 1, 0, 0, 1, 1};
if (sorted_keys == keys.data()) { // even number of passes
EXPECT_EQ(expect_keys_tmp, keys);
EXPECT_EQ(expect_values_tmp, values);
} else { // odd number of passes
EXPECT_EQ(expect_keys_tmp, keys_tmp);
EXPECT_EQ(expect_values_tmp, values_tmp);
}
TEST(cpuKernelTest, raidx_sort_heap_overflow) {
#ifdef _OPENMP
const auto orig_threads = omp_get_num_threads();
omp_set_num_threads(1);
#endif
constexpr auto max = std::numeric_limits<int>::max();
test_template<int, 8>(
{-1, max, max, -1, max, -1, -1, -1},
{1, 2, 3, 4, 5, 6, 7, 8},
{-1, -1, -1, -1, -1, max, max, max},
{1, 4, 6, 7, 8, 2, 3, 5});
#ifdef _OPENMP
omp_set_num_threads(orig_threads);
#endif
}

0 comments on commit 621d854

Please sign in to comment.