Skip to content

Commit

Permalink
Fix sort overwriting values in target array (#823)
Browse files Browse the repository at this point in the history
  • Loading branch information
xaellison authored Apr 12, 2021
1 parent 3399099 commit ef2a610
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,11 @@ elements spaced by `stride`. Good for sampling pivot values as well as short sor
end
sync_threads()
if 1 <= buddy <= L && threadIdx().x <= L
if (threadIdx().x < buddy) != flex_lt(swap[threadIdx().x], buddy_val, false, lt, by)
is_left = threadIdx().x < buddy
# flex_lt needs to handle equivalence in opposite ways for the
# two threads in each swap pair. Otherwise, if there are two
# different values with the same by, one will overwrite the other
if is_left != flex_lt(swap[threadIdx().x], buddy_val, is_left, lt, by)
swap[threadIdx().x] = buddy_val
end
end
Expand Down
10 changes: 9 additions & 1 deletion test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,15 @@ end

# using a `by` argument
@test check_sort(Float32, 100000; by=x->abs(x - 0.5))
@test check_sort(Float64, (4, 100000); by=x->8*x-round(8*x), dims=2)
@test check_sort!(Float32, (100000, 4); by=x->abs(x - 0.5), dims=1)
@test check_sort!(Float32, (4, 100000); by=x->abs(x - 0.5), dims=2)
@test check_sort!(Float64, 400000; by=x->8*x-round(8*x))
@test check_sort!(Float64, (100000, 4); by=x->8*x-round(8*x), dims=1)
@test check_sort!(Float64, (4, 100000); by=x->8*x-round(8*x), dims=2)
# target bubble sort by using sub-blocksize input:
@test check_sort!(Int, 200; by=x->x % 2)
@test check_sort!(Int, 200; by=x->x % 3)
@test check_sort!(Int, 200; by=x->x % 4)
end

end

0 comments on commit ef2a610

Please sign in to comment.