Skip to content

Commit

Permalink
simplify gather
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Mar 1, 2021
1 parent 28b7b83 commit 6db7c8f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 34 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.14"

[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
51 changes: 17 additions & 34 deletions src/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ export gather, gather!
Reverse operation of scatter. Gathers data from source `src`
and writes it in destination `dst` according to the index
array `idx`.
For each position `k` in `idx`, assign values to `dst` according to
For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to
dst[:, ... , k...] .= src[:, ... , idx[k]...]
dst[:, ... , k] .= src[:, ... , idx[k]...]
Notice that if `idx` is a vector containing integers,
and both `dst` and `src` are matrices, previous
Expand All @@ -18,7 +18,7 @@ expression simplifies to
and `k` will range over `1:length(idx)`.
Notice that elements of `idx` may be repeated. A single `src` column
The elements of `idx` may be repeated. A single `src` column
can end up being copied into zero, one, or multiple `dst` columns.
# Arguments
Expand All @@ -29,51 +29,34 @@ can end up being copied into zero, one, or multiple `dst` columns.
"""
function gather!(dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{NTuple{M,Int}, Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx,M}

# TODO: use M = typelength(eltype(idx)) to merge the integer method into this?
Ndst - Nidx == Nsrc - M || throw(ArgumentError(""))
size(dst)[1:Ndst-Nidx] == size(src)[1:Ndst-Nidx] || throw(ArgumentError(""))
size(dst)[Ndst-Nidx+1:end] == size(idx) || throw(ArgumentError(""))

coldst = ntuple(i -> Colon(), Ndst - Nidx)
colsrc = ntuple(i -> Colon(), Nsrc - M)
for k in CartesianIndices(idx)
view(dst, coldst..., Tuple(k)...) .= view(src, colsrc..., idx[k]...)
end
return dst
end
idx::AbstractArray{Tidx, Nidx}) where
{Tdst, Tsrc, Ndst, Nsrc, Nidx, Tidx <: IntOrIntTuple}

M = typelength(Tidx)
Ndst - Nidx == Nsrc - M || throw(ArgumentError("Incompatible input shapes."))
size(dst)[1:Ndst-Nidx] == size(src)[1:Ndst-Nidx] || throw(ArgumentError("Incompatible input shapes."))
size(dst)[Ndst-Nidx+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))

function gather!(dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{<:Integer, Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx}

Ndst - Nidx == Nsrc - 1 || throw(ArgumentError(""))
size(dst)[1:Ndst-Nidx] == size(src)[1:Ndst-Nidx] || throw(ArgumentError(""))
size(dst)[Ndst-Nidx+1:end] == size(idx) || throw(ArgumentError(""))
coldst = ntuple(i -> Colon(), Ndst - Nidx)
colsrc = ntuple(i -> Colon(), Nsrc - 1)
colsrc = ntuple(i -> Colon(), Nsrc - M)
for k in CartesianIndices(idx)
view(dst, coldst..., k) .= view(src, colsrc..., idx[k])
view(dst, coldst..., k) .= view(src, colsrc..., idx[k]...)
end
return dst
end


"""
gather(src, idx)
Non-mutating version of [`gather!`](@ref).
"""
function gather(src::AbstractArray{Tsrc, Nsrc},
idx::AbstractArray{<:IntOrIntTuple, Nidx}) where {Tsrc, Nsrc, Nidx}
# Ndst - Nidx == Nsrc - M || throw(ArgumentError(""))
# size(dst)[1:Ndst-Nidx] == size(src)[1:Ndst-Nidx] || throw(ArgumentError(""))
# size(dst)[Ndst-Nidx+1:end] == size(idx) || throw(ArgumentError(""))

M = typelength(eltype(idx))
idx::AbstractArray{Tidx, Nidx}) where
{Tsrc, Nsrc, Nidx, Tidx<:IntOrIntTuple}

M = typelength(Tidx)
dstsize = (size(src)[1:Nsrc-M]..., size(idx)...)
dst = similar(src, eltype(src), dstsize)
dst = similar(src, Tsrc, dstsize)
return gather!(dst, src, idx)
end

Expand Down
11 changes: 11 additions & 0 deletions test/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,15 @@ end
@test y isa Array{T,1}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
@test y == output

## 3d src, 2d index of 2-tuples -> 3d output
n1, nsrc, nidx = 2, 3, 6
src = rand(Float32, n1, nsrc, nsrc)
index = [(rand(1:nsrc), rand(1:nsrc)) for i=1:nidx, j=1:nidx]

y = gather(src, index)
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa Array{T,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
end

0 comments on commit 6db7c8f

Please sign in to comment.