-
Notifications
You must be signed in to change notification settings - Fork 508
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix heap-buffer-overflow in
radix_sort_parallel
(#2075)
Summary: Setting `histogram_ps[RDX_HIST_SIZE * (nthreads - 1) + 127] = offset;` in `combine_prefix_sum_for_msb` is guaranteed to result in `heap-buffer-overflow` if bucket is not empty during the scatter stage (as all values of `histogram_ps` should be strictly less than `element_count` Factor out common code from `RadixSortTest.cc` into `test_tempalte` and add regression test for buffer overflow, which before the test will fail as follows: ``` [ RUN ] cpuKernelTest.raidx_sort_heap_overflow /home/nshulga/git/pytorch/FBGEMM/test/RadixSortTest.cc:36: Failure Expected equality of these values: expected_keys Which is: { 2, 3, 5, -1, -1, 2147483647, 2147483647, 2147483647 } keys Which is: { -1, -1, -1, -1, -1, -1, -1, -1 } /home/nshulga/git/pytorch/FBGEMM/test/RadixSortTest.cc:37: Failure Expected equality of these values: expected_values Which is: { 1, 4, 6, 7, 8, 2, 3, 5 } values Which is: { 2147483647, 4, 6, 7, 8, 6, 7, 8 } [ FAILED ] cpuKernelTest.raidx_sort_heap_overflow (0 ms) ``` Will fix pytorch/pytorch#111189 once FBGEMM is updated to the correct version Reviewed By: kit1980, jianyuh Differential Revision: D50256504 Pulled By: malfet
- Loading branch information
1 parent
924f310
commit 621d854
Showing
2 changed files
with
57 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters