Skip to content

Commit

Permalink
Fix algorithms that assume linear indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Sep 18, 2015
1 parent 66eb856 commit 7bcc947
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 24 deletions.
21 changes: 18 additions & 3 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,10 @@ end
## copy between abstract arrays - generally more efficient
## since a single index variable can be used.

function copy!(dest::AbstractArray, src::AbstractArray)
copy!(dest::AbstractArray, src::AbstractArray) =
copy!(linearindexing(dest), dest, linearindexing(src), src)

function copy!(::LinearIndexing, dest::AbstractArray, ::LinearIndexing, src::AbstractArray)
n = length(src)
n > length(dest) && throw(BoundsError(dest, n))
@inbounds for i = 1:n
Expand All @@ -309,6 +312,16 @@ function copy!(dest::AbstractArray, src::AbstractArray)
return dest
end

function copy!(::LinearIndexing, dest::AbstractArray, ::LinearSlow, src::AbstractArray)
n = length(src)
n > length(dest) && throw(BoundsError(dest, n))
i = 0
@inbounds for a in src
dest[i+=1] = a
end
return dest
end

function copy!(dest::AbstractArray, doffs::Integer, src::AbstractArray)
copy!(dest, doffs, src, 1, length(src))
end
Expand Down Expand Up @@ -395,6 +408,8 @@ start(A::AbstractArray) = (@_inline_meta(); itr = eachindex(A); (itr, start(itr)
next(A::AbstractArray,i) = (@_inline_meta(); (idx, s) = next(i[1], i[2]); (A[idx], (i[1], s)))
done(A::AbstractArray,i) = done(i[1], i[2])

iterstate(i) = i

# eachindex iterates over all indices. LinearSlow definitions are later.
eachindex(A::AbstractArray) = (@_inline_meta(); eachindex(linearindexing(A), A))
eachindex(::LinearFast, A::AbstractArray) = 1:length(A)
Expand Down Expand Up @@ -1013,7 +1028,7 @@ function isequal(A::AbstractArray, B::AbstractArray)
if isa(A,Range) != isa(B,Range)
return false
end
for i in eachindex(A)
for i in eachindex(A,B)
if !isequal(A[i], B[i])
return false
end
Expand All @@ -1037,7 +1052,7 @@ function (==)(A::AbstractArray, B::AbstractArray)
if isa(A,Range) != isa(B,Range)
return false
end
for i in eachindex(A)
for i in eachindex(A,B)
if !(A[i]==B[i])
return false
end
Expand Down
28 changes: 16 additions & 12 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -827,32 +827,36 @@ function findmax(a)
if isempty(a)
throw(ArgumentError("collection must be non-empty"))
end
m = a[1]
mi = 1
for i=2:length(a)
ai = a[i]
i = start(a)
mi = i
m, i = next(a, i)
while !done(a, i)
iold = i
ai, i = next(a, i)
if ai > m || m!=m
m = ai
mi = i
mi = iold
end
end
return (m, mi)
return (m, iterstate(mi))
end

function findmin(a)
if isempty(a)
throw(ArgumentError("collection must be non-empty"))
end
m = a[1]
mi = 1
for i=2:length(a)
ai = a[i]
i = start(a)
mi = i
m, i = next(a, i)
while !done(a, i)
iold = i
ai, i = next(a, i)
if ai < m || m!=m
m = ai
mi = i
mi = iold
end
end
return (m, mi)
return (m, iterstate(mi))
end

indmax(a) = findmax(a)[2]
Expand Down
4 changes: 3 additions & 1 deletion base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
### Multidimensional iterators
module IteratorsMD

import Base: eltype, length, start, done, next, last, getindex, setindex!, linearindexing, min, max, eachindex, ndims
import Base: eltype, length, start, done, next, last, getindex, setindex!, linearindexing, min, max, eachindex, ndims, iterstate
importall ..Base.Operators
import Base: simd_outer_range, simd_inner_length, simd_index, @generated
import Base: @nref, @ncall, @nif, @nexprs, LinearFast, LinearSlow, to_index
Expand Down Expand Up @@ -59,6 +59,8 @@ immutable CartesianRange{I<:CartesianIndex}
stop::I
end

iterstate{CR<:CartesianRange,CI<:CartesianIndex}(i::Tuple{CR,CI}) = i[2]

@generated function CartesianRange{N}(I::CartesianIndex{N})
startargs = fill(1, N)
:(CartesianRange($I($(startargs...)), I))
Expand Down
16 changes: 9 additions & 7 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ mr_empty(::Abs2Fun, op::MaxFun, T) = abs2(zero(T)::T)
mr_empty(f, op::AndFun, T) = true
mr_empty(f, op::OrFun, T) = false

function _mapreduce{T}(f, op, A::AbstractArray{T})
_mapreduce(f, op, A::AbstractArray) = _mapreduce(f, op, linearindexing(A), A)

function _mapreduce{T}(f, op, ::LinearFast, A::AbstractArray{T})
n = Int(length(A))
if n == 0
return mr_empty(f, op, T)
Expand All @@ -152,7 +154,9 @@ function _mapreduce{T}(f, op, A::AbstractArray{T})
end
end

mapreduce(f, op, A::AbstractArray) = _mapreduce(f, op, A)
_mapreduce{T}(f, op, ::LinearSlow, A::AbstractArray{T}) = mapfoldl(f, op, A)

mapreduce(f, op, A::AbstractArray) = _mapreduce(f, op, linearindexing(A), A)
mapreduce(f, op, a::Number) = f(a)

mapreduce(f, op::Function, A::AbstractArray) = mapreduce(f, specialized_binary(op), A)
Expand Down Expand Up @@ -395,12 +399,10 @@ function count(pred, itr)
return n
end

function count(pred, a::AbstractArray)
function count(pred, A::AbstractArray)
n = 0
for i = 1:length(a)
@inbounds if pred(a[i])
n += 1
end
@inbounds for a in A
pred(a) && (n += 1)
end
return n
end
Expand Down
2 changes: 1 addition & 1 deletion base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,7 @@ function _mapreducezeros(f, op, T::Type, nzeros::Int, v0)
v
end

function Base._mapreduce{T}(f, op, A::SparseMatrixCSC{T})
function Base._mapreduce{T}(f, op, ::Base.LinearSlow, A::SparseMatrixCSC{T})
z = nnz(A)
n = length(A)
if z == 0
Expand Down

0 comments on commit 7bcc947

Please sign in to comment.