Skip to content

Commit

Permalink
support gather for CartesianIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Apr 15, 2021
1 parent 731c899 commit e499d3a
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ See [`gather!`](@ref) for an in-place version.
"""
function gather(src::AbstractArray{Tsrc, Nsrc},
idx::AbstractArray{Tidx, Nidx}) where
{Tsrc, Nsrc, Nidx, Tidx<:IntOrIntTuple}
{Tsrc, Nsrc, Nidx, Tidx}

M = typelength(Tidx)
dstsize = (size(src)[1:Nsrc-M]..., size(idx)...)
Expand Down
1 change: 1 addition & 0 deletions src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

typelength(::Type{<:Number}) = 1
typelength(::Type{<:NTuple{M}}) where M = M
typelength(::Type{CartesianIndex{M}}) where M = M

function _check_dims(X::AbstractArray{Tx,Nx},
Y::AbstractArray{Ty,Ny},
Expand Down
30 changes: 30 additions & 0 deletions test/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,33 @@ end
@test y isa Array{T,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
end

@testset "gather cartesian index" begin
T = Float32

## 2d src, 1d index of 2-tuples -> 1d output
src = T[3 5 7
4 6 8]

index = CartesianIndex.([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)])

output = T[3, 5, 7, 4, 6, 8]

y = gather(src, index)
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa Array{T,1}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
@test y == output

## 3d src, 2d index of 2-tuples -> 3d output
n1, nsrc, nidx = 2, 3, 6
src = rand(Float32, n1, nsrc, nsrc)
index = [CartesianIndex((rand(1:nsrc), rand(1:nsrc))) for i=1:nidx, j=1:nidx]

y = gather(src, index)
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa Array{T,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
end

0 comments on commit e499d3a

Please sign in to comment.