From fc2eeee8e71afd6e15105d3bcd902e7a30cc97d8 Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Tue, 29 Aug 2017 15:11:16 -0400 Subject: [PATCH] Add squeeze(f, A, dims) for reductions to drop dims This simple definition makes it easier to write reductions that drops the dimensions over which they reduce. Fixes #16606, addresses part of the root issue in #22000. --- base/abstractarraymath.jl | 6 ++++++ test/arrayops.jl | 11 +++++++++++ 2 files changed, 17 insertions(+) diff --git a/base/abstractarraymath.jl b/base/abstractarraymath.jl index 437fa54204943..913b46d6710ee 100644 --- a/base/abstractarraymath.jl +++ b/base/abstractarraymath.jl @@ -83,6 +83,12 @@ end squeeze(A::AbstractArray, dim::Integer) = squeeze(A, (Int(dim),)) +""" + squeeze(f, A, dims) + +Compute reduction `f` over dimensions `dims` in array `A` and drop those dimensions from the result. +""" +squeeze(f, A::AbstractArray, dims::Union{Dims, Integer}) = squeeze(f(A, dims), dims) ## Unary operators ## diff --git a/test/arrayops.jl b/test/arrayops.jl index 98c1908301fa5..a6dbdea2543ed 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -219,6 +219,17 @@ end @test_throws ArgumentError squeeze(a, 4) @test_throws ArgumentError squeeze(a, 6) + @test @inferred(squeeze(sum, a, 1)) == @inferred(squeeze(sum, a, (1,))) == reshape(sum(a, 1), (1, 8, 8, 1)) + @test @inferred(squeeze(sum, a, 3)) == @inferred(squeeze(sum, a, (3,))) == reshape(sum(a, 3), (1, 1, 8, 1)) + @test @inferred(squeeze(sum, a, 4)) == @inferred(squeeze(sum, a, (4,))) == reshape(sum(a, 4), (1, 1, 8, 1)) + @test @inferred(squeeze(sum, a, (1, 5))) == squeeze(sum, a, (5, 1)) == reshape(sum(a, (5, 1)), (1, 8, 8)) + @test @inferred(squeeze(sum, a, (1, 2, 5))) == squeeze(sum, a, (5, 2, 1)) == reshape(sum(a, (5, 2, 1)), (8, 8)) + @test_throws ArgumentError squeeze(sum, a, 0) + @test_throws ArgumentError squeeze(sum, a, (1, 1)) + @test_throws ArgumentError squeeze(sum, a, (1, 2, 1)) + @test_throws ArgumentError squeeze(sum, a, (1, 1, 2)) + @test_throws ArgumentError squeeze(sum, a, 6) + sz = (5,8,7) A = reshape(1:prod(sz),sz...) @test A[2:6] == [2:6;]