Skip to content

Commit

Permalink
Merge pull request #255 from yuehhua/scatter
Browse files Browse the repository at this point in the history
Add scatter operations
  • Loading branch information
CarloLucibello authored Mar 5, 2021
2 parents a01a8da + 6542ea9 commit b59cf53
Show file tree
Hide file tree
Showing 6 changed files with 363 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ using Requires
using ChainRulesCore
import ChainRulesCore: rrule
using Base.Broadcast: broadcasted
using Statistics: mean

const IntOrTuple = Union{Integer,Tuple}
const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}

# Include APIs
Expand Down Expand Up @@ -33,6 +35,8 @@ include("conv_bias_act.jl")
include("pooling.jl")
include("padding.jl")
include("upsample.jl")
include("utils.jl")
include("scatter.jl")

## Include implementations
include("impl/padding_edges.jl")
Expand Down
149 changes: 149 additions & 0 deletions src/scatter.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
export scatter!, scatter

## Scatter API
# - Scatter:
# - scatter(op, src, idx)
# - scatter!(op, dst, src, idx)
# - Scatter destination backpropagation
# - ∇scatter_dst!
# - Scatter source backpropagation
# - ∇scatter_src
# - ∇scatter_src!
#

function _check_dims(Ndst, Nsrc, N, Nidx)
@assert Ndst - N == Nsrc - Nidx "Incompatible input shapes of (dst, src, idx) = ($Ndst, $Nsrc, $Nidx)."
dims = Ndst - N
if dims < 0
throw(ArgumentError("dims must be non-negative but got dims=$dims."))
end
return dims
end

typelength(::Type{<:Number}) = 1
typelength(::Type{<:NTuple{M}}) where M = M

"""
scatter!(op, dst, src, idx)
Scatter operation, which scatters data in `src` and assigns to `dst` according to `idx`.
With the data going to the same place, specified aggregate operation is applied on to reduce
data. For each index `k` in `idx`, accumulate values in `dst` according to
dst[:, ..., idx[k]...] = (op).(dst[:, ..., idx[k]...], src[:, ..., k...])
# Arguments
- `op`: operations to be applied on `dst` and `src`, e.g. `+`, `-`, `*`, `/`, `max`, `min`
and `mean`.
- `dst`: the destination for `src` to aggregate to. This argument will be mutated.
- `src`: the source data for aggregating.
- `idx`: the mapping for aggregation from source (index) to destination (value).
The index of `idx` is corresponding to the index of `src` and the dimensions of `idx` must
aligned with the last few dimensions of `src`. The value of `idx` is corresponding to the
index of `dst` and the value of `idx` must indicate the last few dimensions of `dst`.
Once the dimensions match, arrays are aligned automatically. The value of `idx` can be
`Int` or `Tuple` type.
"""
function scatter!(op,
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tdst,Tsrc,Tidx<:IntOrTuple,Ndst,Nsrc,Nidx}
M = typelength(Tidx)
dims = _check_dims(Ndst, Nsrc, M, Nidx)
scatter!(op, dst, src, idx, Val(dims))
end

function scatter!(op, dst::AbstractArray{Tdst}, src::AbstractArray{Tsrc}, idx::AbstractArray{<:IntOrTuple},
dims::Val{N}) where {Tdst,Tsrc,N}
colons = Base.ntuple(_->Colon(), dims)
for k in CartesianIndices(idx)
dst_v = view(dst, colons..., idx[k]...)
src_v = view(src, colons..., k)
dst_v .= (op).(dst_v, src_v)
end
dst
end

function scatter!(op::typeof(mean),
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx}
Ns = scatter!(+, zero(dst), one.(src), idx)
dst_ = scatter!(+, zero(dst), src, idx)
dst .+= safe_div.(dst_, Ns)
return dst
end


"""
scatter(op, src, idx)
Scatter operation, which applies specified operation on `src` according to `idx`
and gives an new array `dst`.
For each index `k` in `idx`, accumulate values in `dst` according to
dst[:, ..., idx[k]...] = (op).(src[:, ..., k...])
# Arguments
- `op`: operations to be applied on `dst` and `src`, e.g. `+`, `-`, `*`, `/`, `max` and `min`.
- `src`: the source data for aggregating.
- `idx`: the mapping for aggregation from source (index) to destination (value).
The index of `idx` is corresponding to the index of `src` and the value of `idx` is
corresponding to the index of `dst`. The value of `idx` can be `Int` or `Tuple` type.
"""
function scatter end

for op in [+, -]
@eval function scatter(op::typeof($op),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, Base.reduce_empty(+, T))
scatter!(op, dst, src, idx)
end
end

for op in [*, /]
@eval function scatter(op::typeof($op),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, Base.reduce_empty(*, T))
scatter!(op, dst, src, idx)
end
end

function scatter(op::typeof(max),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, typemin(T))
scatter!(op, dst, src, idx)
end

function scatter(op::typeof(min),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, typemax(T))
scatter!(op, dst, src, idx)
end

function scatter(op::typeof(mean),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
FT = float(T)
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, Base.reduce_empty(+, FT))
scatter!(op, dst, src, idx)
end
18 changes: 18 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""
safe_div(x, y)
Safely divide `x` by `y`. If `y` is zero, return `x` directly.
"""
safe_div(x, y) = ifelse(iszero(y), x, x/y)

"""
maximum_dims(dims)
Return the maximum value for each dimension. An array of dimensions `dims` is accepted.
The maximum of each dimension in the element is computed.
"""
maximum_dims(dims::AbstractArray{<:Integer}) = (maximum(dims), )

function maximum_dims(dims::AbstractArray{<:Tuple})
Tuple(maximum(xs) for xs in zip(dims...))
end
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,11 @@ end
@testset "Upsampling" begin
include("upsample.jl")
end

@testset "Scatter" begin
include("scatter.jl")
end

@testset "Utilities" begin
include("utils.jl")
end
175 changes: 175 additions & 0 deletions test/scatter.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
dsts = Dict(
0 => [3, 4, 5, 6, 7],
1 => [3 3 4 4 5;
5 5 6 6 7],
)
srcs = Dict(
(0, true) => ones(Int, 3, 4),
(0, false) => ones(Int, 3) * collect(1:4)',
(1, true) => ones(Int, 2, 3, 4),
(1, false) => [1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4),
)
idxs = Dict(
:int => [1 2 3 4;
4 2 1 3;
3 5 5 3],
:tup => [(1,) (2,) (3,) (4,);
(4,) (2,) (1,) (3,);
(3,) (5,) (5,) (3,)],
)
res = Dict(
(+, 0, true) => [5, 6, 9, 8, 9],
(+, 1, true) => [5 5 8 6 7;
7 7 10 8 9],
(+, 0, false) => [4, 4, 12, 5, 5],
(+, 1, false) => [4 4 12 5 5;
8 8 24 10 10],
(-, 0, true) => [1, 2, 1, 4, 5],
(-, 1, true) => [1 1 0 2 3;
3 3 2 4 5],
(-, 0, false) => [-4, -4, -12, -5, -5],
(-, 1, false) => [-4 -4 -12 -5 -5;
-8 -8 -24 -10 -10],
(max, 0, true) => [3, 4, 5, 6, 7],
(max, 1, true) => [3 3 4 4 5;
5 5 6 6 7],
(max, 0, false) => [3, 2, 4, 4, 3],
(max, 1, false) => [3 2 4 4 3;
6 4 8 8 6],
(min, 0, true) => [1, 1, 1, 1, 1],
(min, 1, true) => [1 1 1 1 1;
1 1 1 1 1],
(min, 0, false) => [1, 2, 1, 1, 2],
(min, 1, false) => [1 2 1 1 2;
2 4 2 2 4],
(*, 0, true) => [3, 4, 5, 6, 7],
(*, 1, true) => [3 3 4 4 5;
5 5 6 6 7],
(*, 0, false) => [3, 4, 48, 4, 6],
(*, 1, false) => [3 4 48 4 6;
12 16 768 16 24],
(/, 0, true) => [0.75, 1., 0.3125, 1.5, 1.75],
(/, 1, true) => [0.75 0.75 0.25 1. 1.25;
1.25 1.25 0.375 1.5 1.75],
(/, 0, false) => [1//3, 1//4, 1//48, 1//4, 1//6],
(/, 1, false) => [1//3 1//4 1//48 1//4 1//6;
1//12 1//16 1//768 1//16 1//24],
(mean, 0, true) => [4., 5., 6., 7., 8.],
(mean, 1, true) => [4. 4. 5. 5. 6.;
6. 6. 7. 7. 8.],
(mean, 0, false) => [2, 2, 3, 2.5, 2.5],
(mean, 1, false) => [2. 2. 3. 2.5 2.5;
4. 4. 6. 5. 5.],
)

types = [UInt8, UInt16, UInt32, UInt64, UInt128,
Int8, Int16, Int32, Int64, Int128, BigInt,
Float16, Float32, Float64, BigFloat, Rational]

@testset "scatter" begin
for T = types
@testset "$T" begin
PT = promote_type(T, Int)
@testset "+" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(+, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(+, dims, mutated)])
@test scatter!(+, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(+, dims, mutated)])
@test scatter!(+, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(+, dims, mutated)])

mutated = false
@test scatter(+, T.(srcs[(dims, mutated)]), idx) == T.(res[(+, dims, mutated)])
end
end

@testset "-" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(-, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(-, dims, mutated)])
@test scatter!(-, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(-, dims, mutated)])
@test scatter!(-, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(-, dims, mutated)])

mutated = false
if !(T in [UInt8, UInt16, UInt32, UInt64, UInt128])
@test scatter(-, T.(srcs[(dims, mutated)]), idx) == T.(res[(-, dims, mutated)])
end
end
end

@testset "max" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(max, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(max, dims, mutated)])
@test scatter!(max, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(max, dims, mutated)])
@test scatter!(max, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(max, dims, mutated)])

mutated = false
if !(T in [BigInt])
@test scatter(max, T.(srcs[(dims, mutated)]), idx) == T.(res[(max, dims, mutated)])
end
end
end

@testset "min" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(min, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(min, dims, mutated)])
@test scatter!(min, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(min, dims, mutated)])
@test scatter!(min, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(min, dims, mutated)])

mutated = false
if !(T in [BigInt])
@test scatter(min, T.(srcs[(dims, mutated)]), idx) == T.(res[(min, dims, mutated)])
end
end
end

@testset "*" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(*, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(*, dims, mutated)])
@test scatter!(*, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(*, dims, mutated)])
@test scatter!(*, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(*, dims, mutated)])

mutated = false
if !(T in [UInt8, Int8])
@test scatter(*, T.(srcs[(dims, mutated)]), idx) == T.(res[(*, dims, mutated)])
end
end
end
end
end

for T = [Float16, Float32, Float64, BigFloat, Rational]
@testset "$T" begin
PT = promote_type(T, Float64)
@testset "/" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(/, T.(dsts[dims]), T.(srcs[(dims, mutated)].*2), idx) == T.(res[(/, dims, mutated)])
@test scatter!(/, T.(dsts[dims]), srcs[(dims, mutated)].*2, idx) == PT.(res[(/, dims, mutated)])
@test scatter!(/, T.(dsts[dims]), T.(srcs[(dims, mutated)].*2), idx) == PT.(res[(/, dims, mutated)])

mutated = false
@test scatter(/, T.(srcs[(dims, mutated)]), idx) == T.(res[(/, dims, mutated)])
end
end

@testset "mean" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(mean, T.(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == T.(res[(mean, dims, mutated)])
@test scatter!(mean, T.(dsts[dims]), srcs[(dims, mutated)], idx) == PT.(res[(mean, dims, mutated)])
@test scatter!(mean, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(mean, dims, mutated)])

mutated = false
@test scatter(mean, T.(srcs[(dims, mutated)]), idx) == T.(res[(mean, dims, mutated)])
end
end
end
end

@test_throws AssertionError scatter!(+, dsts[0], srcs[(1, true)], idxs[:int])
idx = [1 2 3 4; 4 2 1 3; 6 7 8 9]
@test_throws BoundsError scatter!(+, dsts[1], srcs[(1, true)], idx)
end
9 changes: 9 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@testset "maximum_dims" begin
ind1 = [1,2,3,4,5,6]
@test NNlib.maximum_dims(ind1) == (6,)
ind2 = [(3,4,5), (1,2,3), (2,3,9)]
@test NNlib.maximum_dims(ind2) == (3,4,9)
ind3 = [(3,4,5) (1,2,3) (2,3,9);
(4,6,2) (5,3,2) (4,4,4)]
@test NNlib.maximum_dims(ind3) == (5,6,9)
end

0 comments on commit b59cf53

Please sign in to comment.