diff --git a/Project.toml b/Project.toml index 5b0923ad7..2985a7e02 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" version = "0.7.14" [deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/gather.jl b/src/gather.jl index e72d7905a..2229b455f 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -6,9 +6,9 @@ export gather, gather! Reverse operation of scatter. Gathers data from source `src` and writes it in destination `dst` according to the index array `idx`. -For each position `k` in `idx`, assign values to `dst` according to +For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to - dst[:, ... , k...] .= src[:, ... , idx[k]...] + dst[:, ... , k] .= src[:, ... , idx[k]...] Notice that if `idx` is a vector containing integers, and both `dst` and `src` are matrices, previous @@ -18,7 +18,7 @@ expression simplifies to and `k` will range over `1:length(idx)`. -Notice that elements of `idx` may be repeated. A single `src` column +The elements of `idx` may be repeated. A single `src` column can end up being copied into zero, one, or multiple `dst` columns. # Arguments @@ -29,51 +29,34 @@ can end up being copied into zero, one, or multiple `dst` columns. """ function gather!(dst::AbstractArray{Tdst,Ndst}, src::AbstractArray{Tsrc,Nsrc}, - idx::AbstractArray{NTuple{M,Int}, Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx,M} - - # TODO: use M = typelength(eltype(idx)) to merge the integer method into this? - Ndst - Nidx == Nsrc - M || throw(ArgumentError("")) - size(dst)[1:Ndst-Nidx] == size(src)[1:Ndst-Nidx] || throw(ArgumentError("")) - size(dst)[Ndst-Nidx+1:end] == size(idx) || throw(ArgumentError("")) - - coldst = ntuple(i -> Colon(), Ndst - Nidx) - colsrc = ntuple(i -> Colon(), Nsrc - M) - for k in CartesianIndices(idx) - view(dst, coldst..., Tuple(k)...) .= view(src, colsrc..., idx[k]...) - end - return dst -end + idx::AbstractArray{Tidx, Nidx}) where + {Tdst, Tsrc, Ndst, Nsrc, Nidx, Tidx <: IntOrIntTuple} + + M = typelength(Tidx) + Ndst - Nidx == Nsrc - M || throw(ArgumentError("Incompatible input shapes.")) + size(dst)[1:Ndst-Nidx] == size(src)[1:Ndst-Nidx] || throw(ArgumentError("Incompatible input shapes.")) + size(dst)[Ndst-Nidx+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) -function gather!(dst::AbstractArray{Tdst,Ndst}, - src::AbstractArray{Tsrc,Nsrc}, - idx::AbstractArray{<:Integer, Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx} - - Ndst - Nidx == Nsrc - 1 || throw(ArgumentError("")) - size(dst)[1:Ndst-Nidx] == size(src)[1:Ndst-Nidx] || throw(ArgumentError("")) - size(dst)[Ndst-Nidx+1:end] == size(idx) || throw(ArgumentError("")) coldst = ntuple(i -> Colon(), Ndst - Nidx) - colsrc = ntuple(i -> Colon(), Nsrc - 1) + colsrc = ntuple(i -> Colon(), Nsrc - M) for k in CartesianIndices(idx) - view(dst, coldst..., k) .= view(src, colsrc..., idx[k]) + view(dst, coldst..., k) .= view(src, colsrc..., idx[k]...) end return dst end - """ gather(src, idx) Non-mutating version of [`gather!`](@ref). """ function gather(src::AbstractArray{Tsrc, Nsrc}, - idx::AbstractArray{<:IntOrIntTuple, Nidx}) where {Tsrc, Nsrc, Nidx} - # Ndst - Nidx == Nsrc - M || throw(ArgumentError("")) - # size(dst)[1:Ndst-Nidx] == size(src)[1:Ndst-Nidx] || throw(ArgumentError("")) - # size(dst)[Ndst-Nidx+1:end] == size(idx) || throw(ArgumentError("")) - - M = typelength(eltype(idx)) + idx::AbstractArray{Tidx, Nidx}) where + {Tsrc, Nsrc, Nidx, Tidx<:IntOrIntTuple} + + M = typelength(Tidx) dstsize = (size(src)[1:Nsrc-M]..., size(idx)...) - dst = similar(src, eltype(src), dstsize) + dst = similar(src, Tsrc, dstsize) return gather!(dst, src, idx) end diff --git a/test/gather.jl b/test/gather.jl index 5635078e4..81d5cdeed 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -80,4 +80,15 @@ end @test y isa Array{T,1} @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) @test y == output + + ## 3d src, 2d index of 2-tuples -> 3d output + n1, nsrc, nidx = 2, 3, 6 + src = rand(Float32, n1, nsrc, nsrc) + index = [(rand(1:nsrc), rand(1:nsrc)) for i=1:nidx, j=1:nidx] + + y = gather(src, index) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa Array{T,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) end