Skip to content

Commit

Permalink
scatter accepts element type of idx to be CartesianIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Apr 7, 2021
1 parent ee019d3 commit 62705ae
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
3 changes: 3 additions & 0 deletions src/gather.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
export gather, gather!

typelength(::Type{<:Number}) = 1
typelength(::Type{<:NTuple{M}}) where M = M

"""
gather!(dst, src, idx)
Expand Down
18 changes: 6 additions & 12 deletions src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ function _check_dims(Ndst, Nsrc, N, Nidx)
return dims
end

typelength(::Type{<:Number}) = 1
typelength(::Type{<:NTuple{M}}) where M = M

"""
scatter!(op, dst, src, idx)
Expand All @@ -44,20 +41,17 @@ index of `dst` and the value of `idx` must indicate the last few dimensions of `
Once the dimensions match, arrays are aligned automatically. The value of `idx` can be
`Int` or `Tuple` type.
"""
scatter!(op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray{<:IntOrIntTuple}) =
scatter!(op, dst, src, CartesianIndex.(idx))

function scatter!(op,
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tdst,Tsrc,Tidx<:IntOrIntTuple,Ndst,Nsrc,Nidx}
M = typelength(Tidx)
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{CartesianIndex{M},Nidx}) where {Tdst,Ndst,Tsrc,Nsrc,M,Nidx}
dims = _check_dims(Ndst, Nsrc, M, Nidx)
scatter!(op, dst, src, idx, Val(dims))
end

function scatter!(op, dst::AbstractArray{Tdst}, src::AbstractArray{Tsrc}, idx::AbstractArray{<:IntOrIntTuple},
dims::Val{N}) where {Tdst,Tsrc,N}
colons = Base.ntuple(_->Colon(), dims)
for k in CartesianIndices(idx)
dst_v = view(dst, colons..., idx[k]...)
dst_v = view(dst, colons..., idx[k])
src_v = view(src, colons..., k)
dst_v .= (op).(dst_v, src_v)
end
Expand Down

0 comments on commit 62705ae

Please sign in to comment.