From 4d84d00076752c904f81cf30b4f06fc2a6962e87 Mon Sep 17 00:00:00 2001 From: Nick Robinson Date: Sat, 31 Aug 2019 13:31:20 +0100 Subject: [PATCH] Allow `dropdims` with reduction to take mutliple args and kwargs --- base/abstractarraymath.jl | 26 ++++++++++++++++++++++--- test/arrayops.jl | 40 +++++++++++++++++++++++++++++---------- 2 files changed, 53 insertions(+), 13 deletions(-) diff --git a/base/abstractarraymath.jl b/base/abstractarraymath.jl index b7acf141bc4e32..6885fdf52a08ee 100644 --- a/base/abstractarraymath.jl +++ b/base/abstractarraymath.jl @@ -88,11 +88,31 @@ end _dropdims(A::AbstractArray, dim::Integer) = _dropdims(A, (Int(dim),)) """ - squeeze(f, A, dims) + dropdims(f, args...; dims, kwargs...) -Compute reduction `f` over dimensions `dims` in array `A` and drop those dimensions from the result. +Compute reduction `f` over dimensions `dims` and drop those dimensions from the result. + +# Examples +```jldoctest +julia> a = [3.0 2.0 6.0 8.0 + 6.0 1.0 4.0 2.0 + 3.0 0.0 7.0 6.0]; + +julia> dropdims(sum, a, dims=1) +4-element Array{Float64,1}: + 12.0 + 3.0 + 17.0 + 16.0 + +julia> dropdims(sum, abs2, a, dims=2) +3-element Array{Float64,1}: + 113.0 + 57.0 + 94.0 +``` """ -squeeze(f, A::AbstractArray, dims::Union{Dims, Integer}) = squeeze(f(A, dims), dims) +dropdims(f, args...; dims, kwargs...) = dropdims(f(args...; kwargs..., dims=dims), dims=dims) ## Unary operators ## diff --git a/test/arrayops.jl b/test/arrayops.jl index 65e31028b644ed..c4cebe2c316598 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -303,16 +303,36 @@ end @test_throws ArgumentError dropdims(a, dims=4) @test_throws ArgumentError dropdims(a, dims=6) - @test @inferred(squeeze(sum, a, 1)) == @inferred(squeeze(sum, a, (1,))) == reshape(sum(a, 1), (1, 8, 8, 1)) - @test @inferred(squeeze(sum, a, 3)) == @inferred(squeeze(sum, a, (3,))) == reshape(sum(a, 3), (1, 1, 8, 1)) - @test @inferred(squeeze(sum, a, 4)) == @inferred(squeeze(sum, a, (4,))) == reshape(sum(a, 4), (1, 1, 8, 1)) - @test @inferred(squeeze(sum, a, (1, 5))) == squeeze(sum, a, (5, 1)) == reshape(sum(a, (5, 1)), (1, 8, 8)) - @test @inferred(squeeze(sum, a, (1, 2, 5))) == squeeze(sum, a, (5, 2, 1)) == reshape(sum(a, (5, 2, 1)), (8, 8)) - @test_throws ArgumentError squeeze(sum, a, 0) - @test_throws ArgumentError squeeze(sum, a, (1, 1)) - @test_throws ArgumentError squeeze(sum, a, (1, 2, 1)) - @test_throws ArgumentError squeeze(sum, a, (1, 1, 2)) - @test_throws ArgumentError squeeze(sum, a, 6) + # dropdims with reductions. issue #16606 + @test (@inferred(dropdims(sum, a, dims=1)) == + @inferred(dropdims(sum, a, dims=(1,))) == + reshape(sum(a, dims=1), (1, 8, 8, 1))) + @test (@inferred(dropdims(sum, a, dims=3)) == + @inferred(dropdims(sum, a, dims=(3,))) == + reshape(sum(a, dims=3), (1, 1, 8, 1))) + @test (@inferred(dropdims(sum, a, dims=4)) == + @inferred(dropdims(sum, a, dims=(4,))) == + reshape(sum(a, dims=4), (1, 1, 8, 1))) + @test (@inferred(dropdims(sum, a, dims=(1, 5))) == + dropdims(sum, a, dims=(5, 1)) == + reshape(sum(a, (5, 1)), (1, 8, 8))) + @test (@inferred(dropdims(sum, a, dims=(1, 2, 5))) == + dropdims(sum, a, dims=(5, 2, 1)) == + reshape(sum(a, dims=(5, 2, 1)), (8, 8))) + @test (@inferred(dropdims(sum, abs2, a, dims=1)) == + @inferred(dropdims(sum, abs2, a, dims=(1,))) == + reshape(sum(a, dims=1), (1, 8, 8, 1))) + _sumplus(x; dims, plus) = sum(x; dims=dims) .+ plus # reduction with keywords + @test (@inferred(dropdims(_sumplus, a, dims=4, plus=1)) == + @inferred(dropdims(_sumplus, a, dims=(4,), plus=1)) == + reshape(sum(a, dims=4) .+ 1, (1, 1, 8, 1))) + @test_throws UndefKeywordError dropdims(sum, a) + @test_throws UndefKeywordError dropdims(sum, a, 1) + @test_throws ArgumentError dropdims(sum, a, dims=0) + @test_throws ArgumentError dropdims(sum, a, dims=(1, 1)) + @test_throws ArgumentError dropdims(sum, a, dims=(1, 2, 1)) + @test_throws ArgumentError dropdims(sum, a, dims=(1, 1, 2)) + @test_throws ArgumentError dropdims(sum, a, dims=6) sz = (5,8,7) A = reshape(1:prod(sz),sz...)