Skip to content

Commit

Permalink
ngenerate/nsplat: multidimensional algorithms on AbstractArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Nov 21, 2014
1 parent 23f54d1 commit 2726e3c
Showing 1 changed file with 126 additions and 103 deletions.
229 changes: 126 additions & 103 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,12 @@ using .IteratorsMD

### From array.jl

@ngenerate N Void function checksize(A::AbstractArray, I::NTuple{N, Any}...)
@nexprs N d->(size(A, d) == length(I_d) || throw(DimensionMismatch("index $d has length $(length(I_d)), but size(A, $d) = $(size(A,d))")))
nothing
stagedfunction checksize(A::AbstractArray, I...)
N = length(I)
quote
@nexprs $N d->(size(A, d) == length(I[d]) || throw(DimensionMismatch("index $d has length $(length(I[d])), but size(A, $d) = $(size(A,d))")))
nothing
end
end

unsafe_getindex(v::BitArray, ind::Int) = Base.unsafe_bitgetindex(v.chunks, ind)
Expand Down Expand Up @@ -207,17 +210,19 @@ end
end


@ngenerate N NTuple{N,Vector{Int}} function findn{T,N}(A::AbstractArray{T,N})
nnzA = countnz(A)
@nexprs N d->(I_d = Array(Int, nnzA))
k = 1
@nloops N i A begin
@inbounds if (@nref N A i) != zero(T)
@nexprs N d->(I_d[k] = i_d)
k += 1
stagedfunction findn{T,N}(A::AbstractArray{T,N})
quote
nnzA = countnz(A)
@nexprs $N d->(I_d = Array(Int, nnzA))
k = 1
@nloops $N i A begin
@inbounds if (@nref $N A i) != zero(T)
@nexprs $N d->(I_d[k] = i_d)
k += 1
end
end
@ntuple $N I
end
@ntuple N I
end


Expand Down Expand Up @@ -334,56 +339,70 @@ end


cumsum(A::AbstractArray, axis::Integer=1) = cumsum!(similar(A, Base._cumsum_type(A)), A, axis)
cumsum!(B, A::AbstractArray) = cumsum!(B, A, 1)
cumprod(A::AbstractArray, axis::Integer=1) = cumprod!(similar(A), A, axis)
cumprod!(B, A) = cumprod!(B, A, 1)

for (f, op) in ((:cumsum!, :+),
(:cumprod!, :*))
@eval begin
@ngenerate N typeof(B) function ($f){T,N}(B, A::AbstractArray{T,N}, axis::Integer=1)
if size(B, axis) < 1
return B
end
size(B) == size(A) || throw(DimensionMismatch("Size of B must match A"))
if axis == 1
# We can accumulate to a temporary variable, which allows register usage and will be slightly faster
@inbounds @nloops N i d->(d > 1 ? (1:size(A,d)) : (1:1)) begin
tmp = convert(eltype(B), @nref(N, A, i))
@nref(N, B, i) = tmp
for i_1 = 2:size(A,1)
tmp = ($op)(tmp, @nref(N, A, i))
@nref(N, B, i) = tmp
end
stagedfunction ($f){T,N}(B, A::AbstractArray{T,N}, axis::Integer)
quote
if size(B, axis) < 1
return B
end
else
@nexprs N d->(isaxis_d = axis == d)
# Copy the initial element in each 1d vector along dimension `axis`
@inbounds @nloops N i d->(d == axis ? (1:1) : (1:size(A,d))) @nref(N, B, i) = @nref(N, A, i)
# Accumulate
@inbounds @nloops N i d->((1+isaxis_d):size(A, d)) d->(j_d = i_d - isaxis_d) begin
@nref(N, B, i) = ($op)(@nref(N, B, j), @nref(N, A, i))
size(B) == size(A) || throw(DimensionMismatch("Size of B must match A"))
if axis == 1
# We can accumulate to a temporary variable, which allows register usage and will be slightly faster
@inbounds @nloops $N i d->(d > 1 ? (1:size(A,d)) : (1:1)) begin
tmp = convert(eltype(B), @nref($N, A, i))
@nref($N, B, i) = tmp
for i_1 = 2:size(A,1)
tmp = ($($op))(tmp, @nref($N, A, i))
@nref($N, B, i) = tmp
end
end
else
@nexprs $N d->(isaxis_d = axis == d)
# Copy the initial element in each 1d vector along dimension `axis`
@inbounds @nloops $N i d->(d == axis ? (1:1) : (1:size(A,d))) @nref($N, B, i) = @nref($N, A, i)
# Accumulate
@inbounds @nloops $N i d->((1+isaxis_d):size(A, d)) d->(j_d = i_d - isaxis_d) begin
@nref($N, B, i) = ($($op))(@nref($N, B, j), @nref($N, A, i))
end
end
B
end
B
end
end
end

### from abstractarray.jl

@ngenerate N typeof(A) function fill!{T,N}(A::AbstractArray{T,N}, x)
@nloops N i A begin
@inbounds (@nref N A i) = x
function fill!(A::AbstractArray, x)
for I in eachindex(A)
@inbounds A[I] = x
end
A
end

@ngenerate N typeof(dest) function copy!{T,N}(dest::AbstractArray{T,N}, src::AbstractArray{T,N})
if @nall N d->(size(dest,d) == size(src,d))
@nloops N i dest begin
@inbounds (@nref N dest i) = (@nref N src i)
function copy!{T,N}(dest::AbstractArray{T,N}, src::AbstractArray{T,N})
samesize = true
for d = 1:N
if size(dest,d) != size(src,d)
samesize = false
break
end
end
if samesize
for I in eachindex(dest)
@inbounds dest[I] = src[I]
end
else
invoke(copy!, (typeof(dest), Any), dest, src)
length(dest) == length(src) || throw(DimensionMismatch("Inconsistent lengths"))
for (Idest, Isrc) in zip(eachindex(dest),eachindex(src))
@inbounds dest[Idest] = src[Isrc]
end
end
dest
end
Expand Down Expand Up @@ -643,19 +662,21 @@ end

## findn

@ngenerate N NTuple{N,Vector{Int}} function findn{N}(B::BitArray{N})
nnzB = countnz(B)
I = ntuple(N, x->Array(Int, nnzB))
if nnzB > 0
count = 1
@nloops N i B begin
if (@nref N B i) # TODO: should avoid bounds checking
@nexprs N d->(I[d][count] = i_d)
count += 1
stagedfunction findn{N}(B::BitArray{N})
quote
nnzB = countnz(B)
I = ntuple($N, x->Array(Int, nnzB))
if nnzB > 0
count = 1
@nloops $N i B begin
if (@nref $N B i) # TODO: should avoid bounds checking
@nexprs $N d->(I[d][count] = i_d)
count += 1
end
end
end
return I
end
return I
end

## isassigned
Expand Down Expand Up @@ -720,70 +741,72 @@ immutable Prehashed
end
hash(x::Prehashed) = x.hash

@ngenerate N typeof(A) function unique{T,N}(A::AbstractArray{T,N}, dim::Int)
1 <= dim <= N || return copy(A)
hashes = zeros(UInt, size(A, dim))
stagedfunction unique{T,N}(A::AbstractArray{T,N}, dim::Int)
quote
1 <= dim <= $N || return copy(A)
hashes = zeros(UInt, size(A, dim))

# Compute hash for each row
k = 0
@nloops N i A d->(if d == dim; k = i_d; end) begin
@inbounds hashes[k] = hash(hashes[k], hash((@nref N A i)))
end
# Compute hash for each row
k = 0
@nloops $N i A d->(if d == dim; k = i_d; end) begin
@inbounds hashes[k] = hash(hashes[k], hash((@nref $N A i)))
end

# Collect index of first row for each hash
uniquerow = Array(Int, size(A, dim))
firstrow = Dict{Prehashed,Int}()
for k = 1:size(A, dim)
uniquerow[k] = get!(firstrow, Prehashed(hashes[k]), k)
end
uniquerows = collect(values(firstrow))
# Collect index of first row for each hash
uniquerow = Array(Int, size(A, dim))
firstrow = Dict{Prehashed,Int}()
for k = 1:size(A, dim)
uniquerow[k] = get!(firstrow, Prehashed(hashes[k]), k)
end
uniquerows = collect(values(firstrow))

# Check for collisions
collided = falses(size(A, dim))
@inbounds begin
@nloops N i A d->(if d == dim
# Check for collisions
collided = falses(size(A, dim))
@inbounds begin
@nloops $N i A d->(if d == dim
k = i_d
j_d = uniquerow[k]
else
j_d = i_d
end) begin
if (@nref N A j) != (@nref N A i)
collided[k] = true
end
if (@nref $N A j) != (@nref $N A i)
collided[k] = true
end
end
end
end

if any(collided)
nowcollided = BitArray(size(A, dim))
while any(collided)
# Collect index of first row for each collided hash
empty!(firstrow)
for j = 1:size(A, dim)
collided[j] || continue
uniquerow[j] = get!(firstrow, Prehashed(hashes[j]), j)
end
for v in values(firstrow)
push!(uniquerows, v)
end
if any(collided)
nowcollided = BitArray(size(A, dim))
while any(collided)
# Collect index of first row for each collided hash
empty!(firstrow)
for j = 1:size(A, dim)
collided[j] || continue
uniquerow[j] = get!(firstrow, Prehashed(hashes[j]), j)
end
for v in values(firstrow)
push!(uniquerows, v)
end

# Check for collisions
fill!(nowcollided, false)
@nloops N i A d->begin
if d == dim
k = i_d
j_d = uniquerow[k]
(!collided[k] || j_d == k) && continue
else
j_d = i_d
end
end begin
if (@nref N A j) != (@nref N A i)
nowcollided[k] = true
# Check for collisions
fill!(nowcollided, false)
@nloops $N i A d->begin
if d == dim
k = i_d
j_d = uniquerow[k]
(!collided[k] || j_d == k) && continue
else
j_d = i_d
end
end begin
if (@nref $N A j) != (@nref $N A i)
nowcollided[k] = true
end
end
(collided, nowcollided) = (nowcollided, collided)
end
(collided, nowcollided) = (nowcollided, collided)
end
end

@nref N A d->d == dim ? sort!(uniquerows) : (1:size(A, d))
@nref $N A d->d == dim ? sort!(uniquerows) : (1:size(A, d))
end
end

0 comments on commit 2726e3c

Please sign in to comment.