Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Feb 24, 2021
1 parent 4e07f73 commit 1febf7b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
10 changes: 5 additions & 5 deletions src/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ For each index `k` in `idx`, assign values to `dst` according to
function gather!(dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{NTuple{M,Int}, Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx,M}
# TODO: use M = _length(eltype(idx)) to merge the integer method into this?
# TODO: use M = typelength(eltype(idx)) to merge the integer method into this?
@boundscheck _gather_checkbounds(src, idx)
Ndst - Nidx == Nsrc - M || throw(ArgumentError(""))
size(dst)[1:Ndst-Nidx] == size(src)[1:Ndst-Nidx] || throw(ArgumentError(""))
Expand Down Expand Up @@ -52,7 +52,7 @@ function _gather_checkbounds(src, idx::AbstractArray{<:Integer})
end

function _gather_checkbounds(src, idx::AbstractArray{<:NTuple{M,Int}}) where M
# TODO: use M = _length(eltype(idx)) to merge the integer method into this?
# TODO: use M = typelength(eltype(idx)) to merge the integer method into this?
minimaxi = ntuple(M) do d
mini = minimum(i -> i[d], idx)
maxi = maximum(i -> i[d], idx)
Expand All @@ -72,11 +72,11 @@ function gather(src::AbstractArray{Tsrc, Nsrc},
# size(dst)[1:Ndst-Nidx] == size(src)[1:Ndst-Nidx] || throw(ArgumentError(""))
# size(dst)[Ndst-Nidx+1:end] == size(idx) || throw(ArgumentError(""))

M = _length(eltype(idx))
M = typelength(eltype(idx))
dstsize = (size(src)[1:Nsrc-M]..., size(idx)...)
dst = similar(src, eltype(src), dstsize)
return gather!(dst, src, idx)
end

_length(::Type{<:Integer}) = 1
_length(::Type{<:NTuple{M}}) where M = M
typelength(::Type{<:Number}) = 1
typelength(::Type{<:NTuple{M}}) where M = M
10 changes: 5 additions & 5 deletions test/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
@test y isa Array{T,2}
@test size(y) == size(index)
@test y == output

@test gather!(T.(zero(index)), src, index) == output
@test_throws ArgumentError gather!(zeros(T, 3, 5), src, index)

Expand Down Expand Up @@ -57,25 +56,26 @@
8 4 8]

y = gather(src, index)
M = NNlib._length(eltype(index))
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa Array{T,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
@test y == output
end

@testset "gather tuple index" begin
T = Float32

## 2d src, 1d index of 2-tuples -> 1d output
src = T[3 5 7
4 6 8]
index = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]

output = zeros(T, 2, 3, 3)
index = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]

output = T[3, 5, 7, 4, 6, 8]

y = gather(src, index)
M = NNlib._length(eltype(index))
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa Array{T,1}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
Expand Down

0 comments on commit 1febf7b

Please sign in to comment.