Skip to content

Commit

Permalink
Merge 8cbd17e into e89a2b8
Browse files Browse the repository at this point in the history
  • Loading branch information
rfourquet authored Dec 17, 2017
2 parents e89a2b8 + 8cbd17e commit 308f35a
Show file tree
Hide file tree
Showing 7 changed files with 373 additions and 150 deletions.
14 changes: 6 additions & 8 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1296,25 +1296,23 @@ function empty!(a::Vector)
return a
end

_memcmp(a, b, len) = ccall(:memcmp, Int32, (Ptr{Void}, Ptr{Void}, Csize_t), a, b, len) % Int

# use memcmp for lexcmp on byte arrays
function lexcmp(a::Array{UInt8,1}, b::Array{UInt8,1})
c = ccall(:memcmp, Int32, (Ptr{UInt8}, Ptr{UInt8}, UInt),
a, b, min(length(a),length(b)))
c = _memcmp(a, b, min(length(a),length(b)))
return c < 0 ? -1 : c > 0 ? +1 : cmp(length(a),length(b))
end

const BitIntegerArray{N} = Union{map(T->Array{T,N}, BitInteger_types)...} where N
# use memcmp for == on bit integer types
function ==(a::Arr, b::Arr) where Arr <: BitIntegerArray
size(a) == size(b) && 0 == ccall(
:memcmp, Int32, (Ptr{Void}, Ptr{Void}, UInt), a, b, sizeof(eltype(Arr)) * length(a))
end
==(a::Arr, b::Arr) where {Arr <: BitIntegerArray} =
size(a) == size(b) && 0 == _memcmp(a, b, sizeof(eltype(Arr)) * length(a))

# this is ~20% faster than the generic implementation above for very small arrays
function ==(a::Arr, b::Arr) where Arr <: BitIntegerArray{1}
len = length(a)
len == length(b) && 0 == ccall(
:memcmp, Int32, (Ptr{Void}, Ptr{Void}, UInt), a, b, sizeof(eltype(Arr)) * len)
len == length(b) && 0 == _memcmp(a, b, sizeof(eltype(Arr)) * len)
end

"""
Expand Down
40 changes: 24 additions & 16 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ IndexStyle(::Type{<:BitArray}) = IndexLinear()
## aux functions ##

const _msk64 = ~UInt64(0)
@inline _div64(l) = l >>> 6
@inline _div64(l) = l >> 6
@inline _mod64(l) = l & 63
@inline _msk_end(l::Integer) = _msk64 >>> _mod64(-l)
@inline _msk_end(B::BitArray) = _msk_end(length(B))
Expand Down Expand Up @@ -636,6 +636,10 @@ end

@inline function unsafe_bitsetindex!(Bc::Array{UInt64}, x::Bool, i::Int)
i1, i2 = get_chunks_id(i)
_unsafe_bitsetindex!(Bc, x, i1, i2)
end

@inline function _unsafe_bitsetindex!(Bc::Array{UInt64}, x::Bool, i1::Int, i2::Int)
u = UInt64(1) << i2
@inbounds begin
c = Bc[i1]
Expand Down Expand Up @@ -1438,22 +1442,17 @@ circshift!(B::BitVector, i::Integer) = circshift!(B, B, i)

## count & find ##

function count(B::BitArray)
function bitcount(Bc::Vector{UInt64})
n = 0
Bc = B.chunks
@inbounds for i = 1:length(Bc)
n += count_ones(Bc[i])
end
return n
end

# returns the index of the next non-zero element, or 0 if all zeros
function findnext(B::BitArray, start::Integer)
start > 0 || throw(BoundsError(B, start))
start > length(B) && return 0

Bc = B.chunks
count(B::BitArray) = bitcount(B.chunks)

function unsafe_bitfindnext(Bc::Vector{UInt64}, start::Integer)
chunk_start = _div64(start-1)+1
within_chunk_start = _mod64(start-1)
mask = _msk64 << within_chunk_start
Expand All @@ -1471,6 +1470,14 @@ function findnext(B::BitArray, start::Integer)
end
return 0
end

# returns the index of the next non-zero element, or 0 if all zeros
function findnext(B::BitArray, start::Integer)
start > 0 || throw(BoundsError(B, start))
start > length(B) && return 0
unsafe_bitfindnext(B.chunks, start)
end

#findfirst(B::BitArray) = findnext(B, 1) ## defined in array.jl

# aux function: same as findnext(~B, start), but performed without temporaries
Expand Down Expand Up @@ -1527,13 +1534,7 @@ function findnext(testf::Function, B::BitArray, start::Integer)
end
#findfirst(testf::Function, B::BitArray) = findnext(testf, B, 1) ## defined in array.jl

# returns the index of the previous non-zero element, or 0 if all zeros
function findprev(B::BitArray, start::Integer)
start > 0 || return 0
start > length(B) && throw(BoundsError(B, start))

Bc = B.chunks

function unsafe_bitfindprev(Bc::Vector{UInt64}, start::Integer)
chunk_start = _div64(start-1)+1
mask = _msk_end(start)

Expand All @@ -1551,6 +1552,13 @@ function findprev(B::BitArray, start::Integer)
return 0
end

# returns the index of the previous non-zero element, or 0 if all zeros
function findprev(B::BitArray, start::Integer)
start > 0 || return 0
start > length(B) && throw(BoundsError(B, start))
unsafe_bitfindprev(B.chunks, start)
end

function findprevnot(B::BitArray, start::Integer)
start > 0 || return 0
start > length(B) && throw(BoundsError(B, start))
Expand Down
Loading

0 comments on commit 308f35a

Please sign in to comment.