Skip to content

Commit

Permalink
Support passing a workspace vector throughout sorting methods and use…
Browse files Browse the repository at this point in the history
… this feature in sort(A; dims) (#45330)

Co-authored-by: Lilith Hafner <[email protected]>
  • Loading branch information
LilithHafner and Lilith Hafner authored Jun 2, 2022
1 parent b46c14e commit bd8dbc3
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 36 deletions.
95 changes: 59 additions & 36 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module Sort
import ..@__MODULE__, ..parentmodule
const Base = parentmodule(@__MODULE__)
using .Base.Order
using .Base: copymutable, LinearIndices, length, (:), iterate,
using .Base: copymutable, LinearIndices, length, (:), iterate, elsize,
eachindex, axes, first, last, similar, zip, OrdinalRange, firstindex, lastindex,
AbstractVector, @inbounds, AbstractRange, @eval, @inline, Vector, @noinline,
AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !,
Expand Down Expand Up @@ -599,12 +599,13 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::QuickSortAlg, o::
return v
end

function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::MergeSortAlg, o::Ordering, t=similar(v,0))
function sort!(v::AbstractVector{T}, lo::Integer, hi::Integer, a::MergeSortAlg, o::Ordering,
t0::Union{AbstractVector{T}, Nothing}=nothing) where T
@inbounds if lo < hi
hi-lo <= SMALL_THRESHOLD && return sort!(v, lo, hi, SMALL_ALGORITHM, o)

m = midpoint(lo, hi)
(length(t) < m-lo+1) && resize!(t, m-lo+1)
t = workspace(v, t0, m-lo+1)

sort!(v, lo, m, a, o, t)
sort!(v, m+1, hi, a, o, t)
Expand Down Expand Up @@ -731,7 +732,8 @@ end

# For AbstractVector{Bool}, counting sort is always best.
# This is an implementation of counting sort specialized for Bools.
function sort!(v::AbstractVector{<:Bool}, lo::Integer, hi::Integer, a::AdaptiveSort, o::Ordering)
function sort!(v::AbstractVector{B}, lo::Integer, hi::Integer, a::AdaptiveSort, o::Ordering,
t::Union{AbstractVector{B}, Nothing}=nothing) where {B <: Bool}
first = lt(o, false, true) ? false : lt(o, true, false) ? true : return v
count = 0
@inbounds for i in lo:hi
Expand All @@ -744,6 +746,10 @@ function sort!(v::AbstractVector{<:Bool}, lo::Integer, hi::Integer, a::AdaptiveS
v
end

workspace(v::AbstractVector, ::Nothing, len::Integer) = similar(v, len)
function workspace(v::AbstractVector{T}, t::AbstractVector{T}, len::Integer) where T
length(t) < len ? resize!(t, len) : t
end
maybe_unsigned(x::Integer) = x # this is necessary to avoid calling unsigned on BigInt
maybe_unsigned(x::BitSigned) = unsigned(x)
function _extrema(v::AbstractVector, lo::Integer, hi::Integer, o::Ordering)
Expand All @@ -755,10 +761,11 @@ function _extrema(v::AbstractVector, lo::Integer, hi::Integer, o::Ordering)
end
mn, mx
end
function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::AdaptiveSort, o::Ordering)
function sort!(v::AbstractVector{T}, lo::Integer, hi::Integer, a::AdaptiveSort, o::Ordering,
t::Union{AbstractVector{T}, Nothing}=nothing) where T
# if the sorting task is not UIntMappable, then we can't radix sort or sort_int_range!
# so we skip straight to the fallback algorithm which is comparison based.
U = UIntMappable(eltype(v), o)
U = UIntMappable(T, o)
U === nothing && return sort!(v, lo, hi, a.fallback, o)

# to avoid introducing excessive detection costs for the trivial sorting problem
Expand All @@ -783,7 +790,7 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::AdaptiveSort, o::
# UInt128 does not support fast bit shifting so we never
# dispatch to radix sort but we may still perform count sort
if sizeof(U) > 8
if eltype(v) <: Integer && o isa DirectOrdering
if T <: Integer && o isa DirectOrdering
v_min, v_max = _extrema(v, lo, hi, Forward)
v_range = maybe_unsigned(v_max-v_min)
v_range == 0 && return v # all same
Expand All @@ -799,7 +806,7 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::AdaptiveSort, o::

v_min, v_max = _extrema(v, lo, hi, o)
lt(o, v_min, v_max) || return v # all same
if eltype(v) <: Integer && o isa DirectOrdering
if T <: Integer && o isa DirectOrdering
R = o === Reverse
v_range = maybe_unsigned(R ? v_min-v_max : v_max-v_min)
if v_range < div(lenm1, 2) # count sort will be superior if v's range is very small
Expand Down Expand Up @@ -849,7 +856,7 @@ function sort!(v::AbstractVector, lo::Integer, hi::Integer, a::AdaptiveSort, o::
u[i] -= u_min
end

u2 = radix_sort!(u, lo, hi, bits, similar(u))
u2 = radix_sort!(u, lo, hi, bits, reinterpret(U, workspace(v, t, hi)))
uint_unmap!(v, u2, lo, hi, o, u_min)
end

Expand All @@ -860,8 +867,14 @@ defalg(v::AbstractArray{<:Union{Number, Missing}}) = DEFAULT_UNSTABLE
defalg(v::AbstractArray{Missing}) = DEFAULT_UNSTABLE # for method disambiguation
defalg(v::AbstractArray{Union{}}) = DEFAULT_UNSTABLE # for method disambiguation

function sort!(v::AbstractVector, alg::Algorithm, order::Ordering)
sort!(v,firstindex(v),lastindex(v),alg,order)
function sort!(v::AbstractVector{T}, alg::Algorithm,
order::Ordering, t::Union{AbstractVector{T}, Nothing}=nothing) where T
sort!(v, firstindex(v), lastindex(v), alg, order, t)
end

function sort!(v::AbstractVector{T}, lo::Integer, hi::Integer, alg::Algorithm,
order::Ordering, t::Union{AbstractVector{T}, Nothing}=nothing) where T
sort!(v, lo, hi, alg, order)
end

"""
Expand Down Expand Up @@ -904,13 +917,14 @@ julia> v = [(1, "c"), (3, "a"), (2, "b")]; sort!(v, by = x -> x[2]); v
(1, "c")
```
"""
function sort!(v::AbstractVector;
function sort!(v::AbstractVector{T};
alg::Algorithm=defalg(v),
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward)
sort!(v, alg, ord(lt,by,rev,order))
order::Ordering=Forward,
workspace::Union{AbstractVector{T}, Nothing}=nothing) where T
sort!(v, alg, ord(lt,by,rev,order), workspace)
end

# sort! for vectors of few unique integers
Expand Down Expand Up @@ -1098,7 +1112,8 @@ function sortperm(v::AbstractVector;
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward)
order::Ordering=Forward,
workspace::Union{AbstractVector, Nothing}=nothing)
ordr = ord(lt,by,rev,order)
if ordr === Forward && isa(v,Vector) && eltype(v)<:Integer
n = length(v)
Expand All @@ -1112,7 +1127,7 @@ function sortperm(v::AbstractVector;
end
end
p = copymutable(eachindex(v))
sort!(p, alg, Perm(ordr,v))
sort!(p, alg, Perm(ordr,v), workspace)
end


Expand All @@ -1139,13 +1154,14 @@ julia> v[p]
3
```
"""
function sortperm!(x::AbstractVector{<:Integer}, v::AbstractVector;
function sortperm!(x::AbstractVector{T}, v::AbstractVector;
alg::Algorithm=DEFAULT_UNSTABLE,
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
initialized::Bool=false)
initialized::Bool=false,
workspace::Union{AbstractVector{T}, Nothing}=nothing) where T <: Integer
if axes(x,1) != axes(v,1)
throw(ArgumentError("index vector must have the same length/indices as the source vector, $(axes(x,1)) != $(axes(v,1))"))
end
Expand All @@ -1154,7 +1170,7 @@ function sortperm!(x::AbstractVector{<:Integer}, v::AbstractVector;
x[i] = i
end
end
sort!(x, alg, Perm(ord(lt,by,rev,order),v))
sort!(x, alg, Perm(ord(lt,by,rev,order),v), workspace)
end

# sortperm for vectors of few unique integers
Expand Down Expand Up @@ -1212,33 +1228,34 @@ julia> sort(A, dims = 2)
1 2
```
"""
function sort(A::AbstractArray;
function sort(A::AbstractArray{T};
dims::Integer,
alg::Algorithm=DEFAULT_UNSTABLE,
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward)
order::Ordering=Forward,
workspace::Union{AbstractVector{T}, Nothing}=similar(A, 0)) where T
dim = dims
order = ord(lt,by,rev,order)
n = length(axes(A, dim))
if dim != 1
pdims = (dim, setdiff(1:ndims(A), dim)...) # put the selected dimension first
Ap = permutedims(A, pdims)
Av = vec(Ap)
sort_chunks!(Av, n, alg, order)
sort_chunks!(Av, n, alg, order, workspace)
permutedims(Ap, invperm(pdims))
else
Av = A[:]
sort_chunks!(Av, n, alg, order)
sort_chunks!(Av, n, alg, order, workspace)
reshape(Av, axes(A))
end
end

@noinline function sort_chunks!(Av, n, alg, order)
@noinline function sort_chunks!(Av, n, alg, order, t)
inds = LinearIndices(Av)
for s = first(inds):n:last(inds)
sort!(Av, s, s+n-1, alg, order)
sort!(Av, s, s+n-1, alg, order, t)
end
Av
end
Expand Down Expand Up @@ -1272,13 +1289,14 @@ julia> sort!(A, dims = 2); A
3 4
```
"""
function sort!(A::AbstractArray;
function sort!(A::AbstractArray{T};
dims::Integer,
alg::Algorithm=defalg(A),
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward)
order::Ordering=Forward,
workspace::Union{AbstractVector{T}, Nothing}=nothing) where T
ordr = ord(lt, by, rev, order)
nd = ndims(A)
k = dims
Expand All @@ -1288,7 +1306,7 @@ function sort!(A::AbstractArray;
remdims = ntuple(i -> i == k ? 1 : axes(A, i), nd)
for idx in CartesianIndices(remdims)
Av = view(A, ntuple(i -> i == k ? Colon() : idx[i], nd)...)
sort!(Av, alg, ordr)
sort!(Av, alg, ordr, workspace)
end
A
end
Expand Down Expand Up @@ -1505,10 +1523,11 @@ issignleft(o::ForwardOrdering, x::Floats) = lt(o, x, zero(x))
issignleft(o::ReverseOrdering, x::Floats) = lt(o, x, -zero(x))
issignleft(o::Perm, i::Integer) = issignleft(o.order, o.data[i])

function fpsort!(v::AbstractVector, a::Algorithm, o::Ordering)
function fpsort!(v::AbstractVector, a::Algorithm, o::Ordering,
t::Union{AbstractVector, Nothing}=nothing)
# fpsort!'s optimizations speed up comparisons, of which there are O(nlogn).
# The overhead is O(n). For n < 10, it's not worth it.
length(v) < 10 && return sort!(v, firstindex(v), lastindex(v), SMALL_ALGORITHM, o)
length(v) < 10 && return sort!(v, firstindex(v), lastindex(v), SMALL_ALGORITHM, o, t)

i, j = lo, hi = specials2end!(v,a,o)
@inbounds while true
Expand All @@ -1518,19 +1537,23 @@ function fpsort!(v::AbstractVector, a::Algorithm, o::Ordering)
v[i], v[j] = v[j], v[i]
i += 1; j -= 1
end
sort!(v, lo, j, a, left(o))
sort!(v, i, hi, a, right(o))
sort!(v, lo, j, a, left(o), t)
sort!(v, i, hi, a, right(o), t)
return v
end


fpsort!(v::AbstractVector, a::Sort.PartialQuickSort, o::Ordering) =
sort!(v, firstindex(v), lastindex(v), a, o)

sort!(v::FPSortable, a::Algorithm, o::DirectOrdering) =
fpsort!(v, a, o)
sort!(v::AbstractVector{<:Union{Signed, Unsigned}}, a::Algorithm, o::Perm{<:DirectOrdering,<:FPSortable}) =
fpsort!(v, a, o)
function sort!(v::FPSortable, a::Algorithm, o::DirectOrdering,
t::Union{FPSortable, Nothing}=nothing)
fpsort!(v, a, o, t)
end
function sort!(v::AbstractVector{<:Union{Signed, Unsigned}}, a::Algorithm,
o::Perm{<:DirectOrdering,<:FPSortable}, t::Union{AbstractVector, Nothing}=nothing)
fpsort!(v, a, o, t)
end

end # module Sort.Float

Expand Down
28 changes: 28 additions & 0 deletions test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,34 @@ end
end
end

@testset "workspace()" begin
for v in [[1, 2, 3], [0.0]]
for t0 in vcat([nothing], [similar(v,i) for i in 1:5]), len in 0:5
t = Base.Sort.workspace(v, t0, len)
@test eltype(t) == eltype(v)
@test length(t) >= len
@test firstindex(t) == 1
end
end
end

@testset "sort(x; workspace=w) " begin
for n in [1,10,100,1000]
v = rand(n)
w = [0.0]
@test sort(v) == sort(v; workspace=w)
@test sort!(copy(v)) == sort!(copy(v); workspace=w)
@test sortperm(v) == sortperm(v; workspace=[4])
@test sortperm!(Vector{Int}(undef, n), v) == sortperm!(Vector{Int}(undef, n), v; workspace=[4])

n > 100 && continue
M = rand(n, n)
@test sort(M; dims=2) == sort(M; dims=2, workspace=w)
@test sort!(copy(M); dims=1) == sort!(copy(M); dims=1, workspace=w)
end
end


@testset "searchsorted" begin
numTypes = [ Int8, Int16, Int32, Int64, Int128,
UInt8, UInt16, UInt32, UInt64, UInt128,
Expand Down

8 comments on commit bd8dbc3

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily package evaluation, I will reply here when finished:

@nanosoldier runtests(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your package evaluation job has completed - possible new issues were detected. A full report can be found here.

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily package evaluation, I will reply here when finished:

@nanosoldier runtests(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your package evaluation job has completed - possible issues were detected. A full report can be found here.

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily package evaluation, I will reply here when finished:

@nanosoldier runtests(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your package evaluation job has completed - possible issues were detected. A full report can be found here.

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily package evaluation, I will reply here when finished:

@nanosoldier runtests(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your package evaluation job has completed - possible issues were detected. A full report can be found here.

Please sign in to comment.