Skip to content

Commit

Permalink
ngenerate/nsplat: getindex methods for BitArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Nov 21, 2014
1 parent b054272 commit a030d08
Showing 1 changed file with 76 additions and 49 deletions.
125 changes: 76 additions & 49 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -439,26 +439,33 @@ end
# bounds check and is defined in bitarray.jl)
# (code is duplicated for safe and unsafe versions for performance reasons)

@ngenerate N Bool function unsafe_getindex(B::BitArray, I_0::Int, I::NTuple{N,Int}...)
stride = 1
index = I_0
@nexprs N d->begin
stride *= size(B,d)
index += (I_d - 1) * stride
stagedfunction unsafe_getindex(B::BitArray, I_0::Int, I::Int...)
N = length(I)
quote
stride = 1
index = I_0
@nexprs $N d->begin
stride *= size(B,d)
index += (I[d] - 1) * stride
end
return unsafe_getindex(B, index)
end
return unsafe_getindex(B, index)
end

@ngenerate N Bool function getindex(B::BitArray, I_0::Int, I::NTuple{N,Int}...)
stride = 1
index = I_0
@nexprs N d->begin
l = size(B,d)
stride *= l
1 <= I_{d-1} <= l || throw(BoundsError())
index += (I_d - 1) * stride
stagedfunction getindex(B::BitArray, I_0::Int, I::Int...)
N = length(I)
quote
stride = 1
index = I_0
@nexprs $N d->(I_d = I[d])
@nexprs $N d->begin
l = size(B,d)
stride *= l
1 <= I_{d-1} <= l || throw(BoundsError())
index += (I_d - 1) * stride
end
return B[index]
end
return B[index]
end

# contiguous multidimensional indexing: if the first dimension is a range,
Expand All @@ -477,53 +484,73 @@ end

getindex{T<:Real}(B::BitArray, I0::UnitRange{T}) = getindex(B, to_index(I0))

@ngenerate N BitArray{length(index_shape(I0, I...))} function unsafe_getindex(B::BitArray, I0::UnitRange{Int}, I::NTuple{N,Union(Int,UnitRange{Int})}...)
X = BitArray(index_shape(I0, I...))

f0 = first(I0)
l0 = length(I0)

gap_lst_1 = 0
@nexprs N d->(gap_lst_{d+1} = length(I_d))
stride = 1
ind = f0
@nexprs N d->begin
stride *= size(B, d)
stride_lst_d = stride
ind += stride * (first(I_d) - 1)
gap_lst_{d+1} *= stride
end
stagedfunction unsafe_getindex(B::BitArray, I0::UnitRange{Int}, I::Union(Int,UnitRange{Int})...)
N = length(I)
Isplat = Expr[:(I[$d]) for d = 1:N]
quote
@nexprs $N d->(I_d = I[d])
X = BitArray(index_shape(I0, $(Isplat...)))

f0 = first(I0)
l0 = length(I0)

gap_lst_1 = 0
@nexprs $N d->(gap_lst_{d+1} = length(I_d))
stride = 1
ind = f0
@nexprs $N d->begin
stride *= size(B, d)
stride_lst_d = stride
ind += stride * (first(I_d) - 1)
gap_lst_{d+1} *= stride
end

storeind = 1
@nloops(N, i, d->I_d,
d->nothing, # PRE
d->(ind += stride_lst_d - gap_lst_d), # POST
storeind = 1
@nloops($N, i, d->I_d,
d->nothing, # PRE
d->(ind += stride_lst_d - gap_lst_d), # POST
begin # BODY
copy_chunks!(X.chunks, storeind, B.chunks, ind, l0)
storeind += l0
end)
return X
end)
return X
end
end

# general multidimensional non-scalar indexing

@ngenerate N BitArray{length(index_shape(I...))} function unsafe_getindex(B::BitArray, I::NTuple{N,Union(Int,AbstractVector{Int})}...)
X = BitArray(index_shape(I...))
Xc = X.chunks
stagedfunction unsafe_getindex(B::BitArray, I::Union(Int,AbstractVector{Int})...)
N = length(I)
Isplat = Expr[:(I[$d]) for d = 1:N]
quote
@nexprs $N d->(I_d = I[d])
X = BitArray(index_shape($(Isplat...)))
Xc = X.chunks

ind = 1
@nloops N i d->I_d begin
unsafe_bitsetindex!(Xc, (@ncall N unsafe_getindex B i), ind)
ind += 1
stride_1 = 1
@nexprs $N d->(stride_{d+1} = stride_d * size(B, d))
@nexprs 1 d->(offset_{$N} = 1)
ind = 1
@nloops($N, i, d->I_d,
d->(offset_{d-1} = offset_d + (i_d-1)*stride_d), # PRE
begin
unsafe_bitsetindex!(Xc, B[offset_0], ind)
ind += 1
end)
return X
end
return X
end

# general version with Real (or logical) indexing which dispatches on the appropriate method

@ngenerate N BitArray{length(index_shape(I...))} function getindex(B::BitArray, I::NTuple{N,Union(Real,AbstractVector)}...)
checkbounds(B, I...)
return unsafe_getindex(B, to_index(I...)...)
stagedfunction getindex(B::BitArray, I::Union(Real,AbstractVector)...)
N = length(I)
Isplat = Expr[:(I[$d]) for d = 1:N]
Jsplat = Expr[:(to_index(I[$d])) for d = 1:N]
quote
checkbounds(B, $(Isplat...))
return unsafe_getindex(B, $(Jsplat...))
end
end

## setindex!
Expand Down

0 comments on commit a030d08

Please sign in to comment.