-
-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #255 from yuehhua/scatter
Add scatter operations
- Loading branch information
Showing
6 changed files
with
363 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
export scatter!, scatter | ||
|
||
## Scatter API | ||
# - Scatter: | ||
# - scatter(op, src, idx) | ||
# - scatter!(op, dst, src, idx) | ||
# - Scatter destination backpropagation | ||
# - ∇scatter_dst! | ||
# - Scatter source backpropagation | ||
# - ∇scatter_src | ||
# - ∇scatter_src! | ||
# | ||
|
||
function _check_dims(Ndst, Nsrc, N, Nidx) | ||
@assert Ndst - N == Nsrc - Nidx "Incompatible input shapes of (dst, src, idx) = ($Ndst, $Nsrc, $Nidx)." | ||
dims = Ndst - N | ||
if dims < 0 | ||
throw(ArgumentError("dims must be non-negative but got dims=$dims.")) | ||
end | ||
return dims | ||
end | ||
|
||
typelength(::Type{<:Number}) = 1 | ||
typelength(::Type{<:NTuple{M}}) where M = M | ||
|
||
""" | ||
scatter!(op, dst, src, idx) | ||
Scatter operation, which scatters data in `src` and assigns to `dst` according to `idx`. | ||
With the data going to the same place, specified aggregate operation is applied on to reduce | ||
data. For each index `k` in `idx`, accumulate values in `dst` according to | ||
dst[:, ..., idx[k]...] = (op).(dst[:, ..., idx[k]...], src[:, ..., k...]) | ||
# Arguments | ||
- `op`: operations to be applied on `dst` and `src`, e.g. `+`, `-`, `*`, `/`, `max`, `min` | ||
and `mean`. | ||
- `dst`: the destination for `src` to aggregate to. This argument will be mutated. | ||
- `src`: the source data for aggregating. | ||
- `idx`: the mapping for aggregation from source (index) to destination (value). | ||
The index of `idx` is corresponding to the index of `src` and the dimensions of `idx` must | ||
aligned with the last few dimensions of `src`. The value of `idx` is corresponding to the | ||
index of `dst` and the value of `idx` must indicate the last few dimensions of `dst`. | ||
Once the dimensions match, arrays are aligned automatically. The value of `idx` can be | ||
`Int` or `Tuple` type. | ||
""" | ||
function scatter!(op, | ||
dst::AbstractArray{Tdst,Ndst}, | ||
src::AbstractArray{Tsrc,Nsrc}, | ||
idx::AbstractArray{Tidx,Nidx}) where {Tdst,Tsrc,Tidx<:IntOrTuple,Ndst,Nsrc,Nidx} | ||
M = typelength(Tidx) | ||
dims = _check_dims(Ndst, Nsrc, M, Nidx) | ||
scatter!(op, dst, src, idx, Val(dims)) | ||
end | ||
|
||
function scatter!(op, dst::AbstractArray{Tdst}, src::AbstractArray{Tsrc}, idx::AbstractArray{<:IntOrTuple}, | ||
dims::Val{N}) where {Tdst,Tsrc,N} | ||
colons = Base.ntuple(_->Colon(), dims) | ||
for k in CartesianIndices(idx) | ||
dst_v = view(dst, colons..., idx[k]...) | ||
src_v = view(src, colons..., k) | ||
dst_v .= (op).(dst_v, src_v) | ||
end | ||
dst | ||
end | ||
|
||
function scatter!(op::typeof(mean), | ||
dst::AbstractArray{Tdst,Ndst}, | ||
src::AbstractArray{Tsrc,Nsrc}, | ||
idx::AbstractArray{<:IntOrTuple,Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx} | ||
Ns = scatter!(+, zero(dst), one.(src), idx) | ||
dst_ = scatter!(+, zero(dst), src, idx) | ||
dst .+= safe_div.(dst_, Ns) | ||
return dst | ||
end | ||
|
||
|
||
""" | ||
scatter(op, src, idx) | ||
Scatter operation, which applies specified operation on `src` according to `idx` | ||
and gives an new array `dst`. | ||
For each index `k` in `idx`, accumulate values in `dst` according to | ||
dst[:, ..., idx[k]...] = (op).(src[:, ..., k...]) | ||
# Arguments | ||
- `op`: operations to be applied on `dst` and `src`, e.g. `+`, `-`, `*`, `/`, `max` and `min`. | ||
- `src`: the source data for aggregating. | ||
- `idx`: the mapping for aggregation from source (index) to destination (value). | ||
The index of `idx` is corresponding to the index of `src` and the value of `idx` is | ||
corresponding to the index of `dst`. The value of `idx` can be `Int` or `Tuple` type. | ||
""" | ||
function scatter end | ||
|
||
for op in [+, -] | ||
@eval function scatter(op::typeof($op), | ||
src::AbstractArray{T,Nsrc}, | ||
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx} | ||
dims = Nsrc - Nidx | ||
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...) | ||
dst = similar(src, T, dstsize) | ||
fill!(dst, Base.reduce_empty(+, T)) | ||
scatter!(op, dst, src, idx) | ||
end | ||
end | ||
|
||
for op in [*, /] | ||
@eval function scatter(op::typeof($op), | ||
src::AbstractArray{T,Nsrc}, | ||
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx} | ||
dims = Nsrc - Nidx | ||
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...) | ||
dst = similar(src, T, dstsize) | ||
fill!(dst, Base.reduce_empty(*, T)) | ||
scatter!(op, dst, src, idx) | ||
end | ||
end | ||
|
||
function scatter(op::typeof(max), | ||
src::AbstractArray{T,Nsrc}, | ||
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx} | ||
dims = Nsrc - Nidx | ||
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...) | ||
dst = similar(src, T, dstsize) | ||
fill!(dst, typemin(T)) | ||
scatter!(op, dst, src, idx) | ||
end | ||
|
||
function scatter(op::typeof(min), | ||
src::AbstractArray{T,Nsrc}, | ||
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx} | ||
dims = Nsrc - Nidx | ||
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...) | ||
dst = similar(src, T, dstsize) | ||
fill!(dst, typemax(T)) | ||
scatter!(op, dst, src, idx) | ||
end | ||
|
||
function scatter(op::typeof(mean), | ||
src::AbstractArray{T,Nsrc}, | ||
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx} | ||
FT = float(T) | ||
dims = Nsrc - Nidx | ||
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...) | ||
dst = similar(src, T, dstsize) | ||
fill!(dst, Base.reduce_empty(+, FT)) | ||
scatter!(op, dst, src, idx) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
""" | ||
safe_div(x, y) | ||
Safely divide `x` by `y`. If `y` is zero, return `x` directly. | ||
""" | ||
safe_div(x, y) = ifelse(iszero(y), x, x/y) | ||
|
||
""" | ||
maximum_dims(dims) | ||
Return the maximum value for each dimension. An array of dimensions `dims` is accepted. | ||
The maximum of each dimension in the element is computed. | ||
""" | ||
maximum_dims(dims::AbstractArray{<:Integer}) = (maximum(dims), ) | ||
|
||
function maximum_dims(dims::AbstractArray{<:Tuple}) | ||
Tuple(maximum(xs) for xs in zip(dims...)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
dsts = Dict( | ||
0 => [3, 4, 5, 6, 7], | ||
1 => [3 3 4 4 5; | ||
5 5 6 6 7], | ||
) | ||
srcs = Dict( | ||
(0, true) => ones(Int, 3, 4), | ||
(0, false) => ones(Int, 3) * collect(1:4)', | ||
(1, true) => ones(Int, 2, 3, 4), | ||
(1, false) => [1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4), | ||
) | ||
idxs = Dict( | ||
:int => [1 2 3 4; | ||
4 2 1 3; | ||
3 5 5 3], | ||
:tup => [(1,) (2,) (3,) (4,); | ||
(4,) (2,) (1,) (3,); | ||
(3,) (5,) (5,) (3,)], | ||
) | ||
res = Dict( | ||
(+, 0, true) => [5, 6, 9, 8, 9], | ||
(+, 1, true) => [5 5 8 6 7; | ||
7 7 10 8 9], | ||
(+, 0, false) => [4, 4, 12, 5, 5], | ||
(+, 1, false) => [4 4 12 5 5; | ||
8 8 24 10 10], | ||
(-, 0, true) => [1, 2, 1, 4, 5], | ||
(-, 1, true) => [1 1 0 2 3; | ||
3 3 2 4 5], | ||
(-, 0, false) => [-4, -4, -12, -5, -5], | ||
(-, 1, false) => [-4 -4 -12 -5 -5; | ||
-8 -8 -24 -10 -10], | ||
(max, 0, true) => [3, 4, 5, 6, 7], | ||
(max, 1, true) => [3 3 4 4 5; | ||
5 5 6 6 7], | ||
(max, 0, false) => [3, 2, 4, 4, 3], | ||
(max, 1, false) => [3 2 4 4 3; | ||
6 4 8 8 6], | ||
(min, 0, true) => [1, 1, 1, 1, 1], | ||
(min, 1, true) => [1 1 1 1 1; | ||
1 1 1 1 1], | ||
(min, 0, false) => [1, 2, 1, 1, 2], | ||
(min, 1, false) => [1 2 1 1 2; | ||
2 4 2 2 4], | ||
(*, 0, true) => [3, 4, 5, 6, 7], | ||
(*, 1, true) => [3 3 4 4 5; | ||
5 5 6 6 7], | ||
(*, 0, false) => [3, 4, 48, 4, 6], | ||
(*, 1, false) => [3 4 48 4 6; | ||
12 16 768 16 24], | ||
(/, 0, true) => [0.75, 1., 0.3125, 1.5, 1.75], | ||
(/, 1, true) => [0.75 0.75 0.25 1. 1.25; | ||
1.25 1.25 0.375 1.5 1.75], | ||
(/, 0, false) => [1//3, 1//4, 1//48, 1//4, 1//6], | ||
(/, 1, false) => [1//3 1//4 1//48 1//4 1//6; | ||
1//12 1//16 1//768 1//16 1//24], | ||
(mean, 0, true) => [4., 5., 6., 7., 8.], | ||
(mean, 1, true) => [4. 4. 5. 5. 6.; | ||
6. 6. 7. 7. 8.], | ||
(mean, 0, false) => [2, 2, 3, 2.5, 2.5], | ||
(mean, 1, false) => [2. 2. 3. 2.5 2.5; | ||
4. 4. 6. 5. 5.], | ||
) | ||
|
||
types = [UInt8, UInt16, UInt32, UInt64, UInt128, | ||
Int8, Int16, Int32, Int64, Int128, BigInt, | ||
Float16, Float32, Float64, BigFloat, Rational] | ||
|
||
@testset "scatter" begin | ||
for T = types | ||
@testset "$T" begin | ||
PT = promote_type(T, Int) | ||
@testset "+" begin | ||
for idx = values(idxs), dims = [0, 1] | ||
mutated = true | ||
@test scatter!(+, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(+, dims, mutated)]) | ||
@test scatter!(+, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(+, dims, mutated)]) | ||
@test scatter!(+, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(+, dims, mutated)]) | ||
|
||
mutated = false | ||
@test scatter(+, T.(srcs[(dims, mutated)]), idx) == T.(res[(+, dims, mutated)]) | ||
end | ||
end | ||
|
||
@testset "-" begin | ||
for idx = values(idxs), dims = [0, 1] | ||
mutated = true | ||
@test scatter!(-, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(-, dims, mutated)]) | ||
@test scatter!(-, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(-, dims, mutated)]) | ||
@test scatter!(-, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(-, dims, mutated)]) | ||
|
||
mutated = false | ||
if !(T in [UInt8, UInt16, UInt32, UInt64, UInt128]) | ||
@test scatter(-, T.(srcs[(dims, mutated)]), idx) == T.(res[(-, dims, mutated)]) | ||
end | ||
end | ||
end | ||
|
||
@testset "max" begin | ||
for idx = values(idxs), dims = [0, 1] | ||
mutated = true | ||
@test scatter!(max, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(max, dims, mutated)]) | ||
@test scatter!(max, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(max, dims, mutated)]) | ||
@test scatter!(max, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(max, dims, mutated)]) | ||
|
||
mutated = false | ||
if !(T in [BigInt]) | ||
@test scatter(max, T.(srcs[(dims, mutated)]), idx) == T.(res[(max, dims, mutated)]) | ||
end | ||
end | ||
end | ||
|
||
@testset "min" begin | ||
for idx = values(idxs), dims = [0, 1] | ||
mutated = true | ||
@test scatter!(min, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(min, dims, mutated)]) | ||
@test scatter!(min, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(min, dims, mutated)]) | ||
@test scatter!(min, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(min, dims, mutated)]) | ||
|
||
mutated = false | ||
if !(T in [BigInt]) | ||
@test scatter(min, T.(srcs[(dims, mutated)]), idx) == T.(res[(min, dims, mutated)]) | ||
end | ||
end | ||
end | ||
|
||
@testset "*" begin | ||
for idx = values(idxs), dims = [0, 1] | ||
mutated = true | ||
@test scatter!(*, T.(copy(dsts[dims])), T.(srcs[(dims, mutated)]), idx) == T.(res[(*, dims, mutated)]) | ||
@test scatter!(*, T.(copy(dsts[dims])), srcs[(dims, mutated)], idx) == PT.(res[(*, dims, mutated)]) | ||
@test scatter!(*, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(*, dims, mutated)]) | ||
|
||
mutated = false | ||
if !(T in [UInt8, Int8]) | ||
@test scatter(*, T.(srcs[(dims, mutated)]), idx) == T.(res[(*, dims, mutated)]) | ||
end | ||
end | ||
end | ||
end | ||
end | ||
|
||
for T = [Float16, Float32, Float64, BigFloat, Rational] | ||
@testset "$T" begin | ||
PT = promote_type(T, Float64) | ||
@testset "/" begin | ||
for idx = values(idxs), dims = [0, 1] | ||
mutated = true | ||
@test scatter!(/, T.(dsts[dims]), T.(srcs[(dims, mutated)].*2), idx) == T.(res[(/, dims, mutated)]) | ||
@test scatter!(/, T.(dsts[dims]), srcs[(dims, mutated)].*2, idx) == PT.(res[(/, dims, mutated)]) | ||
@test scatter!(/, T.(dsts[dims]), T.(srcs[(dims, mutated)].*2), idx) == PT.(res[(/, dims, mutated)]) | ||
|
||
mutated = false | ||
@test scatter(/, T.(srcs[(dims, mutated)]), idx) == T.(res[(/, dims, mutated)]) | ||
end | ||
end | ||
|
||
@testset "mean" begin | ||
for idx = values(idxs), dims = [0, 1] | ||
mutated = true | ||
@test scatter!(mean, T.(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == T.(res[(mean, dims, mutated)]) | ||
@test scatter!(mean, T.(dsts[dims]), srcs[(dims, mutated)], idx) == PT.(res[(mean, dims, mutated)]) | ||
@test scatter!(mean, copy(dsts[dims]), T.(srcs[(dims, mutated)]), idx) == PT.(res[(mean, dims, mutated)]) | ||
|
||
mutated = false | ||
@test scatter(mean, T.(srcs[(dims, mutated)]), idx) == T.(res[(mean, dims, mutated)]) | ||
end | ||
end | ||
end | ||
end | ||
|
||
@test_throws AssertionError scatter!(+, dsts[0], srcs[(1, true)], idxs[:int]) | ||
idx = [1 2 3 4; 4 2 1 3; 6 7 8 9] | ||
@test_throws BoundsError scatter!(+, dsts[1], srcs[(1, true)], idx) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
@testset "maximum_dims" begin | ||
ind1 = [1,2,3,4,5,6] | ||
@test NNlib.maximum_dims(ind1) == (6,) | ||
ind2 = [(3,4,5), (1,2,3), (2,3,9)] | ||
@test NNlib.maximum_dims(ind2) == (3,4,9) | ||
ind3 = [(3,4,5) (1,2,3) (2,3,9); | ||
(4,6,2) (5,3,2) (4,4,4)] | ||
@test NNlib.maximum_dims(ind3) == (5,6,9) | ||
end |