Skip to content

Commit

Permalink
Merge pull request #12303 from JuliaLang/kms/selectperm
Browse files Browse the repository at this point in the history
RFC: quicksort refactor, addition of selectperm
  • Loading branch information
kmsquire committed Jul 24, 2015
2 parents a7b8159 + 3b6221f commit 4c7a9d0
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 82 deletions.
3 changes: 3 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ export
ObjectIdDict,
OrdinalRange,
Pair,
PartialQuickSort,
PollingFileWatcher,
ProcessGroup,
QuickSort,
Expand Down Expand Up @@ -594,6 +595,8 @@ export
sort!,
sort,
sortcols,
selectperm,
selectperm!,
sortperm,
sortperm!,
sortrows,
Expand Down
210 changes: 136 additions & 74 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ export # also exported by Base
# order & algorithm:
sort,
sort!,
selectperm,
selectperm!,
sortperm,
sortperm!,
sortrows,
sortcols,
# algorithms:
InsertionSort,
QuickSort,
MergeSort
MergeSort,
PartialQuickSort

export # not exported by Base
Algorithm,
Expand All @@ -55,73 +58,17 @@ issorted(itr;
lt=isless, by=identity, rev::Bool=false, order::Ordering=Forward) =
issorted(itr, ord(lt,by,rev,order))

function select!(v::AbstractVector, k::Int, lo::Int, hi::Int, o::Ordering)
lo <= k <= hi || throw(ArgumentError("select index $k is out of range $lo:$hi"))
@inbounds while lo < hi
if hi-lo == 1
if lt(o, v[hi], v[lo])
v[lo], v[hi] = v[hi], v[lo]
end
return v[k]
end
pivot = v[(lo+hi)>>>1]
i, j = lo, hi
while true
while lt(o, v[i], pivot); i += 1; end
while lt(o, pivot, v[j]); j -= 1; end
i <= j || break
v[i], v[j] = v[j], v[i]
i += 1; j -= 1
end
if k <= j
hi = j
elseif i <= k
lo = i
else
return pivot
end
end
return v[lo]
end

function select!(v::AbstractVector, r::OrdinalRange, lo::Int, hi::Int, o::Ordering)
isempty(r) && (return v[r])
a, b = extrema(r)
lo <= a <= b <= hi || throw(ArgumentError("selection $r is out of range $lo:$hi"))
@inbounds while true
if lo == a && hi == b
sort!(v, lo, hi, DEFAULT_UNSTABLE, o)
return v[r]
end
pivot = v[(lo+hi)>>>1]
i, j = lo, hi
while true
while lt(o, v[i], pivot); i += 1; end
while lt(o, pivot, v[j]); j -= 1; end
i <= j || break
v[i], v[j] = v[j], v[i]
i += 1; j -= 1
end
if b <= j
hi = j
elseif i <= a
lo = i
else
a <= j && select!(v, a, lo, j, o)
b >= i && select!(v, b, i, hi, o)
sort!(v, a, b, DEFAULT_UNSTABLE, o)
return v[r]
end
end
function select!(v::AbstractVector, k::Union{Int,OrdinalRange}, o::Ordering)
sort!(v, 1, length(v), PartialQuickSort(k), o)
v[k]
end

select!(v::AbstractVector, k::Union{Int,OrdinalRange}, o::Ordering) = select!(v,k,1,length(v),o)
select!(v::AbstractVector, k::Union{Int,OrdinalRange};
lt=isless, by=identity, rev::Bool=false, order::Ordering=Forward) =
select!(v, k, ord(lt,by,rev,order))

select(v::AbstractVector, k::Union{Int,OrdinalRange}; kws...) = select!(copy(v), k; kws...)


# reference on sorted binary search:
# http://www.tbray.org/ongoing/When/200x/2003/03/22/Binary

Expand Down Expand Up @@ -248,6 +195,10 @@ immutable InsertionSortAlg <: Algorithm end
immutable QuickSortAlg <: Algorithm end
immutable MergeSortAlg <: Algorithm end

immutable PartialQuickSort{T <: Union(Int,OrdinalRange)} <: Algorithm
k::T
end

const InsertionSort = InsertionSortAlg()
const QuickSort = QuickSortAlg()
const MergeSort = MergeSortAlg()
Expand All @@ -274,10 +225,20 @@ function sort!(v::AbstractVector, lo::Int, hi::Int, ::InsertionSortAlg, o::Order
return v
end

function sort!(v::AbstractVector, lo::Int, hi::Int, a::QuickSortAlg, o::Ordering)
@inbounds while lo < hi
hi-lo <= SMALL_THRESHOLD && return sort!(v, lo, hi, SMALL_ALGORITHM, o)
# selectpivot!
#
# Given 3 locations in an array (lo, mi, and hi), sort v[lo], v[mi], v[hi]) and
# choose the middle value as a pivot
#
# Upon return, the pivot is in v[lo], and v[hi] is guaranteed to be
# greater than the pivot

@inline function selectpivot!(v::AbstractVector, lo::Int, hi::Int, o::Ordering)
@inbounds begin
mi = (lo+hi)>>>1

# sort the values in v[lo], v[mi], v[hi]

if lt(o, v[mi], v[lo])
v[mi], v[lo] = v[lo], v[mi]
end
Expand All @@ -288,17 +249,43 @@ function sort!(v::AbstractVector, lo::Int, hi::Int, a::QuickSortAlg, o::Ordering
v[hi], v[mi] = v[mi], v[hi]
end
end
v[mi], v[lo] = v[lo], v[mi]
i, j = lo, hi

# move v[mi] to v[lo] and use it as the pivot
v[lo], v[mi] = v[mi], v[lo]
pivot = v[lo]
while true
i += 1; j -= 1;
while lt(o, v[i], pivot); i += 1; end;
while lt(o, pivot, v[j]); j -= 1; end;
i >= j && break
v[i], v[j] = v[j], v[i]
end
v[j], v[lo] = v[lo], v[j]
end

# return the pivot
return pivot
end

# partition!
#
# select a pivot, and partition v according to the pivot

function partition!(v::AbstractVector, lo::Int, hi::Int, o::Ordering)
pivot = selectpivot!(v, lo, hi, o)
# pivot == v[lo], v[hi] > pivot
i, j = lo, hi
@inbounds while true
i += 1; j -= 1
while lt(o, v[i], pivot); i += 1; end;
while lt(o, pivot, v[j]); j -= 1; end;
i >= j && break
v[i], v[j] = v[j], v[i]
end
v[j], v[lo] = pivot, v[j]

# v[j] == pivot
# v[k] >= pivot for k > j
# v[i] <= pivot for i < j
return j
end

function sort!(v::AbstractVector, lo::Int, hi::Int, a::QuickSortAlg, o::Ordering)
@inbounds while lo < hi
hi-lo <= SMALL_THRESHOLD && return sort!(v, lo, hi, SMALL_ALGORITHM, o)
j = partition!(v, lo, hi, o)
if j-lo < hi-j
# recurse on the smaller chunk
# this is necessary to preserve O(log(n))
Expand Down Expand Up @@ -351,6 +338,53 @@ function sort!(v::AbstractVector, lo::Int, hi::Int, a::MergeSortAlg, o::Ordering
return v
end

function sort!(v::AbstractVector, lo::Int, hi::Int, a::PartialQuickSort{Int},
o::Ordering)
@inbounds while lo < hi
hi-lo <= SMALL_THRESHOLD && return sort!(v, lo, hi, SMALL_ALGORITHM, o)
j = partition!(v, lo, hi, o)
if j >= a.k
# we don't need to sort anything bigger than j
hi = j-1
elseif j-lo < hi-j
# recurse on the smaller chunk
# this is necessary to preserve O(log(n))
# stack space in the worst case (rather than O(n))
lo < (j-1) && sort!(v, lo, j-1, a, o)
lo = j+1
else
(j+1) < hi && sort!(v, j+1, hi, a, o)
hi = j-1
end
end
return v
end


function sort!{T<:OrdinalRange}(v::AbstractVector, lo::Int, hi::Int, a::PartialQuickSort{T},
o::Ordering)
@inbounds while lo < hi
hi-lo <= SMALL_THRESHOLD && return sort!(v, lo, hi, SMALL_ALGORITHM, o)
j = partition!(v, lo, hi, o)

if j <= first(a.k)
lo = j+1
elseif j >= last(a.k)
hi = j-1
else
if j-lo < hi-j
lo < (j-1) && sort!(v, lo, j-1, a, o)
lo = j+1
else
hi > (j+1) && sort!(v, j+1, hi, a, o)
hi = j-1
end
end
end
return v
end


## generic sorting methods ##

defalg(v::AbstractArray) = DEFAULT_STABLE
Expand All @@ -369,6 +403,30 @@ end

sort(v::AbstractVector; kws...) = sort!(copy(v); kws...)


## selectperm: the permutation to sort the first k elements of an array ##

selectperm(v::AbstractVector, k::Union(Integer,OrdinalRange); kwargs...) =
selectperm!(Vector{eltype(k)}(length(v)), v, k; kwargs..., initialized=false)

function selectperm!{I<:Integer}(ix::AbstractVector{I}, v::AbstractVector,
k::Union(Int, OrdinalRange);
lt::Function=isless,
by::Function=identity,
rev::Bool=false,
order::Ordering=Forward,
initialized::Bool=false)
if !initialized
@inbounds for i = 1:length(ix)
ix[i] = i
end
end

# do partial quicksort
sort!(ix, PartialQuickSort(k), Perm(ord(lt, by, rev, order), v))
return ix[k]
end

## sortperm: the permutation to sort an array ##

function sortperm(v::AbstractVector;
Expand Down Expand Up @@ -499,6 +557,10 @@ function fpsort!(v::AbstractVector, a::Algorithm, o::Ordering)
return v
end


fpsort!(v::AbstractVector, a::Sort.PartialQuickSort, o::Ordering) =
sort!(v, 1, length(v), a, o)

sort!{T<:Floats}(v::AbstractVector{T}, a::Algorithm, o::DirectOrdering) = fpsort!(v,a,o)
sort!{O<:DirectOrdering,T<:Floats}(v::Vector{Int}, a::Algorithm, o::Perm{O,Vector{T}}) = fpsort!(v,a,o)

Expand Down
37 changes: 36 additions & 1 deletion doc/stdlib/sort.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,33 @@ Order-Related Functions
Variant of ``select!`` which copies ``v`` before partially sorting it, thereby
returning the same thing as ``select!`` but leaving ``v`` unmodified.

.. function:: selectperm(v, k, [alg=<algorithm>,] [by=<transform>,] [lt=<comparison>,] [rev=false])

Return a partial permutation of the the vector ``v``, according to the order
specified by ``by``, ``lt`` and ``rev``, so that ``v[output]`` returns the
first ``k`` (or range of adjacent values if ``k`` is a range) values of a
fully sorted version of ``v``. If ``k`` is a single index (Integer), an
array of the first ``k`` indices is returned; if ``k`` is a range, an array
of those indices is returned. Note that the handling of integer values for
``k`` is different from ``select`` in that it returns a vector of ``k``
elements instead of just the ``k`` th element. Also note that this is
equivalent to, but more efficient than, calling ``sortperm(...)[k]``

.. function:: selectperm!(ix, v, k, [alg=<algorithm>,] [by=<transform>,] [lt=<comparison>,] [rev=false,] [initialized=false])

Like ``selectperm``, but accepts a preallocated index vector ``ix``. If
``initialized`` is ``false`` (the default), ix is initialized to contain the
values ``1:length(ix)``.


Sorting Algorithms
------------------

There are currently three sorting algorithms available in base Julia:
There are currently four sorting algorithms available in base Julia:

- ``InsertionSort``
- ``QuickSort``
- ``PartialQuickSort(k)``
- ``MergeSort``

``InsertionSort`` is an O(n^2) stable sorting algorithm. It is efficient
Expand All @@ -225,6 +244,22 @@ equal will not remain in the same order in which they originally
appeared in the array to be sorted. ``QuickSort`` is the default
algorithm for numeric values, including integers and floats.

``PartialQuickSort(k)`` is similar to ``QuickSort``, but the output array
is only sorted up to index ``k`` if ``k`` is an integer, or in the range
of ``k`` if ``k`` is an ``OrdinalRange``. For example::

x = rand(1:500, 100)
k = 50
k2 = 50:100
s = sort(x; alg=QuickSort)
ps = sort(x; alg=PartialQuickSort(k))
qs = sort(x; alg=PartialQuickSort(k2))
map(issorted, (s, ps, qs)) # => (true, false, false)
map(x->issorted(x[1:k]), (s, ps, qs)) # => (true, true, false)
map(x->issorted(x[k2]), (s, ps, qs)) # => (true, false, true)
s[1:k] == ps[1:k] # => true
s[k2] == qs[k2] # => true

``MergeSort`` is an O(n log n) stable sorting algorithm but is not
in-place – it requires a temporary array of half the size of the
input array – and is typically not quite as fast as ``QuickSort``.
Expand Down
Loading

0 comments on commit 4c7a9d0

Please sign in to comment.