From ef2a610272728f2b0ada70d72aaac72938e0a773 Mon Sep 17 00:00:00 2001 From: xaellison Date: Mon, 12 Apr 2021 03:14:10 -0400 Subject: [PATCH] Fix sort overwriting values in target array (#823) --- src/sorting.jl | 6 +++++- test/sorting.jl | 10 +++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/sorting.jl b/src/sorting.jl index 5be5bc280a..b7f1453d75 100644 --- a/src/sorting.jl +++ b/src/sorting.jl @@ -225,7 +225,11 @@ elements spaced by `stride`. Good for sampling pivot values as well as short sor end sync_threads() if 1 <= buddy <= L && threadIdx().x <= L - if (threadIdx().x < buddy) != flex_lt(swap[threadIdx().x], buddy_val, false, lt, by) + is_left = threadIdx().x < buddy + # flex_lt needs to handle equivalence in opposite ways for the + # two threads in each swap pair. Otherwise, if there are two + # different values with the same by, one will overwrite the other + if is_left != flex_lt(swap[threadIdx().x], buddy_val, is_left, lt, by) swap[threadIdx().x] = buddy_val end end diff --git a/test/sorting.jl b/test/sorting.jl index be1663e063..1369d758d5 100644 --- a/test/sorting.jl +++ b/test/sorting.jl @@ -252,7 +252,15 @@ end # using a `by` argument @test check_sort(Float32, 100000; by=x->abs(x - 0.5)) - @test check_sort(Float64, (4, 100000); by=x->8*x-round(8*x), dims=2) + @test check_sort!(Float32, (100000, 4); by=x->abs(x - 0.5), dims=1) + @test check_sort!(Float32, (4, 100000); by=x->abs(x - 0.5), dims=2) + @test check_sort!(Float64, 400000; by=x->8*x-round(8*x)) + @test check_sort!(Float64, (100000, 4); by=x->8*x-round(8*x), dims=1) + @test check_sort!(Float64, (4, 100000); by=x->8*x-round(8*x), dims=2) + # target bubble sort by using sub-blocksize input: + @test check_sort!(Int, 200; by=x->x % 2) + @test check_sort!(Int, 200; by=x->x % 3) + @test check_sort!(Int, 200; by=x->x % 4) end end