Skip to content

Commit

Permalink
Merge pull request #449 from FluxML/cl/gather2
Browse files Browse the repository at this point in the history
add `gather(src,  IJK...)`
  • Loading branch information
CarloLucibello authored Dec 26, 2022
2 parents 0b64dc1 + 40b2848 commit 61bda17
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
36 changes: 35 additions & 1 deletion src/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ or multiple `dst` columns.
See [`gather!`](@ref) for an in-place version.
# Examples
```jldoctest
julia> NNlib.gather([1,20,300,4000], [2,4,2])
3-element Vector{Int64}:
Expand Down Expand Up @@ -83,5 +84,38 @@ function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::A
y = gather!(dst, src, idx)
src_size = size(src)
gather!_pullback(Δ) = (NoTangent(), NoTangent(), ∇gather_src(unthunk(Δ), src_size, idx), NoTangent())
y, gather!_pullback
return y, gather!_pullback
end

"""
gather(src, IJK...)
Convert the tuple of integer vectors `IJK` to a tuple of `CartesianIndex` and
call `gather` on it: `gather(src, CartesianIndex.(IJK...))`.
# Examples
```jldoctest
julia> src = reshape([1:15;], 3, 5)
3×5 Matrix{Int64}:
1 4 7 10 13
2 5 8 11 14
3 6 9 12 15
julia> NNlib.gather(src, [1, 2], [2, 4])
2-element Vector{Int64}:
4
11
```
"""
function gather(src::AbstractArray{Tsrc, Nsrc},
I::AbstractVector{<:Integer},
J::AbstractVector{<:Integer},
Ks::AbstractVector{<:Integer}...) where {Nsrc, Tsrc}

return gather(src, to_cartesian_index(I, J, Ks...))
end

to_cartesian_index(IJK...) = CartesianIndex.(IJK...)

@non_differentiable to_cartesian_index(::Any...)
11 changes: 11 additions & 0 deletions test/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,14 @@ end
gradtest(xs -> gather!(dst, xs, index), src)
gradtest(xs -> gather(xs, index), src)
end

@testset "gather(src, IJK...)" begin
x = reshape([1:15;], 3, 5)

y = gather(x, [1,2], [2,4])
@test y == [4, 11]

@test gather(x, [1, 2]) == [1 4
2 5
3 6]
end

0 comments on commit 61bda17

Please sign in to comment.