From 1febf7b0d706be7f6bd7d888b20dcc6963e291ab Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 24 Feb 2021 16:15:45 +0100 Subject: [PATCH] fix test --- src/gather.jl | 10 +++++----- test/gather.jl | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/gather.jl b/src/gather.jl index deaa7e023..6ed31b6e5 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -16,7 +16,7 @@ For each index `k` in `idx`, assign values to `dst` according to function gather!(dst::AbstractArray{Tdst,Ndst}, src::AbstractArray{Tsrc,Nsrc}, idx::AbstractArray{NTuple{M,Int}, Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx,M} - # TODO: use M = _length(eltype(idx)) to merge the integer method into this? + # TODO: use M = typelength(eltype(idx)) to merge the integer method into this? @boundscheck _gather_checkbounds(src, idx) Ndst - Nidx == Nsrc - M || throw(ArgumentError("")) size(dst)[1:Ndst-Nidx] == size(src)[1:Ndst-Nidx] || throw(ArgumentError("")) @@ -52,7 +52,7 @@ function _gather_checkbounds(src, idx::AbstractArray{<:Integer}) end function _gather_checkbounds(src, idx::AbstractArray{<:NTuple{M,Int}}) where M - # TODO: use M = _length(eltype(idx)) to merge the integer method into this? + # TODO: use M = typelength(eltype(idx)) to merge the integer method into this? minimaxi = ntuple(M) do d mini = minimum(i -> i[d], idx) maxi = maximum(i -> i[d], idx) @@ -72,11 +72,11 @@ function gather(src::AbstractArray{Tsrc, Nsrc}, # size(dst)[1:Ndst-Nidx] == size(src)[1:Ndst-Nidx] || throw(ArgumentError("")) # size(dst)[Ndst-Nidx+1:end] == size(idx) || throw(ArgumentError("")) - M = _length(eltype(idx)) + M = typelength(eltype(idx)) dstsize = (size(src)[1:Nsrc-M]..., size(idx)...) dst = similar(src, eltype(src), dstsize) return gather!(dst, src, idx) end -_length(::Type{<:Integer}) = 1 -_length(::Type{<:NTuple{M}}) where M = M +typelength(::Type{<:Number}) = 1 +typelength(::Type{<:NTuple{M}}) where M = M diff --git a/test/gather.jl b/test/gather.jl index 442c8f95b..4569f222e 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -14,7 +14,6 @@ @test y isa Array{T,2} @test size(y) == size(index) @test y == output - @test gather!(T.(zero(index)), src, index) == output @test_throws ArgumentError gather!(zeros(T, 3, 5), src, index) @@ -57,7 +56,7 @@ 8 4 8] y = gather(src, index) - M = NNlib._length(eltype(index)) + M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa Array{T,3} @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) @@ -65,17 +64,18 @@ end @testset "gather tuple index" begin + T = Float32 + ## 2d src, 1d index of 2-tuples -> 1d output src = T[3 5 7 4 6 8] - index = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)] - output = zeros(T, 2, 3, 3) + index = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)] output = T[3, 5, 7, 4, 6, 8] y = gather(src, index) - M = NNlib._length(eltype(index)) + M = NNlib.typelength(eltype(index)) Nsrc = ndims(src) @test y isa Array{T,1} @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)