Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make mapslices accept arbitrary number of inputs an allow broadcasting. #10928

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 50 additions & 20 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1217,41 +1217,63 @@ end
## transform any set of dimensions
## dims specifies which dimensions will be transformed. for example
## dims==1:2 will call f on all slices A[:,:,...]
mapslices(f::Function, A::AbstractArray, dims) = mapslices(f, A, [dims...])
function mapslices(f::Function, A::AbstractArray, dims::AbstractVector)
mapslices(f, A...; dims = ()) = _mapslices(f, A, [dims...])
function _mapslices(f, A, dims)
if isempty(dims)
return map(f,A)
return map(f, A...)
end

dimsA = [size(A)...]
ndimsA = ndims(A)
alldims = [1:ndimsA;]
ndimsA = [ndims(a)::Int for a in A]
ndmaxind = indmax(ndimsA)
ndmax = ndimsA[ndmaxind]
for i = 1:length(A)
if length(dims) < ndimsA[i] < ndmax
throw(DimensionMismatch("argument $i had wrong number of dimensions. Dimension of broadcast argument cannot be larger than $(length(dims))"))
end
end

dimsA = [[size(a)...] for a in A]
alldims = [1:ndimsA[ndmaxind];]

otherdims = setdiff(alldims, dims)

idx = cell(ndimsA)
fill!(idx, 1)
Asliceshape = tuple(dimsA[dims]...)
itershape = tuple(dimsA[otherdims]...)
for d in dims
idx[d] = 1:size(A,d)
idx = [cell(n) for n in ndimsA]
Asliceshape = cell(length(A))

for i = 1:length(A)
fill!(idx[i], 1)
if ndimsA[i] == ndmax
for d in dims
idx[i][d] = 1:size(A[i], d)
end
Asliceshape[i] = tuple(dimsA[i][dims]...)
end
end
itershape = tuple(dimsA[ndmaxind][otherdims]...)

r1 = f(reshape(A[idx...], Asliceshape))
args = []
for i = 1:length(A)
if ndimsA[i] < ndmax
push!(args, A[i])
else
push!(args, reshape(A[i][idx[i]...], Asliceshape[i]))
end
end
r1 = f(args...)

# determine result size and allocate
Rsize = copy(dimsA)
Rsize = copy(dimsA[ndmaxind])
# TODO: maybe support removing dimensions
if !isa(r1, AbstractArray) || ndims(r1) == 0
r1 = [r1]
end
Rsize[dims] = [size(r1)...; ones(Int,max(0,length(dims)-ndims(r1)))]
Rsize[dims] = [size(r1)...; ones(Int, max(0, length(dims) - ndims(r1)))]
R = similar(r1, tuple(Rsize...))

ridx = cell(ndims(R))
fill!(ridx, 1)
for d in dims
ridx[d] = 1:size(R,d)
ridx[d] = 1:size(R, d)
end

R[ridx...] = r1
Expand All @@ -1261,10 +1283,18 @@ function mapslices(f::Function, A::AbstractArray, dims::AbstractVector)
if first
first = false
else
ia = [idxs...]
idx[otherdims] = ia
ridx[otherdims] = ia
R[ridx...] = f(reshape(A[idx...], Asliceshape))
args = Any[]
for i = 1:length(A)
ia = [idxs...]
ridx[otherdims] = ia
if ndimsA[i] < ndmax
push!(args, A[i])
else
idx[i][otherdims] = ia
push!(args, reshape(A[i][idx[i]...], Asliceshape[i]))
end
end
R[ridx...] = f(args...)
end
end

Expand Down
2 changes: 1 addition & 1 deletion base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ end

## sorting multi-dimensional arrays ##

sort(A::AbstractArray, dim::Integer; kws...) = mapslices(a->sort(a; kws...), A, [dim])
sort(A::AbstractArray, dim::Integer; kws...) = mapslices(a->sort(a; kws...), A, dims = [dim])

function sortrows(A::AbstractMatrix; kws...)
c = 1:size(A,2)
Expand Down
2 changes: 1 addition & 1 deletion base/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ function median!{T}(v::AbstractVector{T})
end

median{T}(v::AbstractArray{T}) = median!(vec(copy(v)))
median{T}(v::AbstractArray{T}, region) = mapslices(median, v, region)
median{T}(v::AbstractArray{T}, region) = mapslices(median, v, dims = region)

# for now, use the R/S definition of quantile; may want variants later
# see ?quantile in R -- this is type 7
Expand Down
40 changes: 23 additions & 17 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,10 @@ B = cat(3, 1, 2, 3)
begin
local a,h,i
a = rand(5,5)
h = mapslices(v -> hist(v,0:0.1:1)[2], a, 1)
H = mapslices(v -> hist(v,0:0.1:1)[2], a, 2)
s = mapslices(sort, a, [1])
S = mapslices(sort, a, [2])
h = mapslices(v -> hist(v,0:0.1:1)[2], a, dims = 1)
H = mapslices(v -> hist(v,0:0.1:1)[2], a, dims = 2)
s = mapslices(sort, a, dims = [1])
S = mapslices(sort, a, dims = [2])
for i = 1:5
@test h[:,i] == hist(a[:,i],0:0.1:1)[2]
@test vec(H[i,:]) == hist(vec(a[i,:]),0:0.1:1)[2]
Expand All @@ -625,35 +625,41 @@ begin
end

# issue #3613
b = mapslices(sum, ones(2,3,4), [1,2])
b = mapslices(sum, ones(2,3,4), dims = [1,2])
@test size(b) === (1,1,4)
@test all(b.==6)

# issue #5141
## Update Removed the version that removes the dimensions when dims==1:ndims(A)
c1 = mapslices(x-> maximum(-x), a, [])
c1 = mapslices(x-> maximum(-x), a, dims = [])
@test c1 == -a

# other types than Number
@test mapslices(prod,["1" "2"; "3" "4"],1) == ["13" "24"]
@test mapslices(prod,["1"],1) == ["1"]
@test mapslices(prod,["1" "2"; "3" "4"], dims = 1) == ["13" "24"]
@test mapslices(prod,["1"], dims = 1) == ["1"]

# issue #5177

c = ones(2,3,4)
m1 = mapslices(x-> ones(2,3), c, [1,2])
m2 = mapslices(x-> ones(2,4), c, [1,3])
m3 = mapslices(x-> ones(3,4), c, [2,3])
m1 = mapslices(x-> ones(2,3), c, dims = [1,2])
m2 = mapslices(x-> ones(2,4), c, dims = [1,3])
m3 = mapslices(x-> ones(3,4), c, dims = [2,3])
@test size(m1) == size(m2) == size(m3) == size(c)

n1 = mapslices(x-> ones(6), c, [1,2])
n2 = mapslices(x-> ones(6), c, [1,3])
n3 = mapslices(x-> ones(6), c, [2,3])
n1a = mapslices(x-> ones(1,6), c, [1,2])
n2a = mapslices(x-> ones(1,6), c, [1,3])
n3a = mapslices(x-> ones(1,6), c, [2,3])
n1 = mapslices(x-> ones(6), c, dims = [1,2])
n2 = mapslices(x-> ones(6), c, dims = [1,3])
n3 = mapslices(x-> ones(6), c, dims = [2,3])
n1a = mapslices(x-> ones(1,6), c, dims = [1,2])
n2a = mapslices(x-> ones(1,6), c, dims = [1,3])
n3a = mapslices(x-> ones(1,6), c, dims = [2,3])
@test size(n1a) == (1,6,4) && size(n2a) == (1,3,6) && size(n3a) == (2,1,6)
@test size(n1) == (6,1,4) && size(n2) == (6,3,1) && size(n3) == (2,6,1)

a = randn(4,3,4)
b = randn(3,3,4)
@test mapslices(svdvals, a, b, dims = (1,2))[:,1,4] == svdvals(a[:,:,4], b[:,:,4])
@test mapslices(svdvals, a, b[:,:,1], dims = (1,2))[:,1,4] == svdvals(a[:,:,4], b[:,:,1])
@test mapslices(norm, a, 1, dims = 1)[1,3,4] == norm(a[:,3,4], 1)
end


Expand Down
2 changes: 1 addition & 1 deletion test/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

function safe_mapslices(op, A, region)
newregion = intersect(region, 1:ndims(A))
return isempty(newregion) ? A : mapslices(op, A, newregion)
return isempty(newregion) ? A : mapslices(op, A, dims = newregion)
end
safe_sum{T}(A::Array{T}, region) = safe_mapslices(sum, A, region)
safe_prod{T}(A::Array{T}, region) = safe_mapslices(prod, A, region)
Expand Down