Skip to content

Commit

Permalink
bitonic median selection
Browse files Browse the repository at this point in the history
performance
  • Loading branch information
xaellison committed Apr 11, 2021
1 parent 59180e7 commit 9766333
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 3 deletions.
51 changes: 51 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ version = "0.1.0"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[BenchmarkTools]]
deps = ["JSON", "Logging", "Printf", "Statistics", "UUIDs"]
git-tree-sha1 = "068fda9b756e41e6c75da7b771e6f89fa8a43d15"
uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
version = "0.7.0"

[[CEnum]]
git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand All @@ -38,6 +44,12 @@ git-tree-sha1 = "44e9f638aa9ed1ad58885defc568c133010140aa"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.37"

[[CodeTracking]]
deps = ["InteractiveUtils", "UUIDs"]
git-tree-sha1 = "8ad457cfeb0bca98732c97958ef81000a543e73e"
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
version = "1.0.5"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "4fecfd5485d3c5de4003e19f00c6898cccd40667"
Expand Down Expand Up @@ -75,6 +87,15 @@ git-tree-sha1 = "10407a39b87f29d47ebaca8edbc75d7c302ff93e"
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
version = "0.1.3"

[[FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "31939159aeb8ffad1d4d8ee44d07f8558273120a"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.11.7"

[[GPUArrays]]
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
git-tree-sha1 = "9c95b2fd5c16bc7f97371e9f92f0fef77e0f5957"
Expand All @@ -96,6 +117,18 @@ git-tree-sha1 = "a431f5f2ca3f4feef3bd7a5e94b8b8d4f2f647a0"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.2.0"

[[JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.1"

[[JuliaInterpreter]]
deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"]
git-tree-sha1 = "ccc489088d6bc4b5265e043e3fbb1baad5025cf7"
uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
version = "0.8.11"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194"
Expand Down Expand Up @@ -132,6 +165,12 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[LoweredCodeUtils]]
deps = ["JuliaInterpreter"]
git-tree-sha1 = "8c96709706ce27471655247ad9a931447d16dd62"
uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b"
version = "1.2.9"

[[MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0"
Expand Down Expand Up @@ -172,6 +211,12 @@ git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.4.0"

[[Parsers]]
deps = ["Dates"]
git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "1.1.0"

[[Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down Expand Up @@ -205,6 +250,12 @@ git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.1.3"

[[Revise]]
deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "Pkg", "REPL", "Requires", "UUIDs", "Unicode"]
git-tree-sha1 = "b72fa706920b1421d581525de9f4e442b95ba254"
uuid = "295af30f-e4ad-537b-8983-00126c2a3abe"
version = "3.1.14"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

Expand Down
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ version = "3.0.0"
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
CompilerSupportLibraries_jll = "e66e0078-7015-5450-92f7-15fbd957f2ae"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
Expand All @@ -24,6 +26,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
55 changes: 52 additions & 3 deletions src/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,55 @@ end


# Sorting
"""
Finds the median of `vals` starting after `lo` and going for `blockDim().x`
elements spaced by `stride`. Performs bitonic sort in shmem, returns middle value.
Faster than bubble sort, but not as flexible. Does not modify `vals`
"""
function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride, lt::F1, by::F2) where {T,F1,F2}
sync_threads()
bitonic_lt(i1, i2) = @inbounds flex_lt(swap[i1 + 1], swap[i2 + 1], false, lt, by)

@inbounds swap[threadIdx().x] = vals[lo + threadIdx().x * stride]
sync_threads()
old_val = zero(eltype(swap))

log_blockDim = begin
out = 0
k = blockDim().x
while k > 1
k = k >> 1
out += 1
end
out
end

log_k = 1
while log_k <= log_blockDim
k = 1 << log_k
j = k ÷ 2

while j > 0
i = threadIdx().x - 1
l = xor(i, j)
to_swap = (i & k) == 0 && bitonic_lt(l, i) || (i & k) != 0 && bitonic_lt(i, l)
to_swap = to_swap == (i < l)

if to_swap
@inbounds old_val = swap[l + 1]
end
sync_threads()
if to_swap
@inbounds swap[i + 1] = old_val
end
sync_threads()
j = j ÷ 2
end
log_k += 1
end
sync_threads()
return @inbounds swap[blockDim().x ÷ 2]
end

"""
Performs bubble sort on `vals` starting after `lo` and going for min(`L`, `blockDim().x`)
Expand Down Expand Up @@ -311,15 +360,15 @@ function qsort_kernel(vals::AbstractArray{T,N}, lo, hi, parity, sync::Val{S}, sy
view(vals, idxs...)
end

# step 1: single block bubble sort. It'll either finish sorting a subproblem or
# step 1: single block sort. It'll either finish sorting a subproblem or
# help select a pivot value
bubble_sort(slice, swap, lo, L, L <= blockDim().x ? 1 : L ÷ blockDim().x, lt, by)

if L <= blockDim().x
bubble_sort(slice, swap, lo, L, 1, lt, by)
return
end

pivot = @inbounds slice[lo + (blockDim().x ÷ 2) * (L ÷ blockDim().x)]
pivot = bitonic_median(slice, swap, lo, L, L ÷ blockDim().x, lt, by)

# step 2: use pivot to partition into batches
call_batch_partition(slice, pivot, swap, b_sums, lo, hi, parity, sync, lt, by)
Expand Down

0 comments on commit 9766333

Please sign in to comment.