From 607b1f3e738da50b6939673ba3bd38e705d56f9c Mon Sep 17 00:00:00 2001 From: Pontus Stenetorp Date: Fri, 18 May 2018 13:40:33 -0400 Subject: [PATCH] Added argmin and argmax over given dimensions --- base/reducedim.jl | 50 +++++++++++++++++++++++++++++++++++++++++++++++ test/reducedim.jl | 6 ++++++ 2 files changed, 56 insertions(+) diff --git a/base/reducedim.jl b/base/reducedim.jl index 19acea13f8292..09cba3c847a37 100644 --- a/base/reducedim.jl +++ b/base/reducedim.jl @@ -805,3 +805,53 @@ function _findmax(A, region) end reducedim1(R, A) = _length(indices1(R)) == 1 + +""" + argmin(A; dims) -> indices + +For an array input, return the indices of the minimum elements over the given dimensions. +`NaN` is treated as less than all other values. + +# Examples +```jldoctest +julia> A = [1.0 2; 3 4] +2×2 Array{Float64,2}: + 1.0 2.0 + 3.0 4.0 + +julia> argmin(A, dims=1) +1×2 Array{CartesianIndex{2},2}: + CartesianIndex(1, 1) CartesianIndex(1, 2) + +julia> argmin(A, dims=2) +2×1 Array{CartesianIndex{2},2}: + CartesianIndex(1, 1) + CartesianIndex(2, 1) +``` +""" +argmin(A::AbstractArray; dims=:) = findmin(A; dims=dims)[2] + +""" + argmax(A; dims) -> indices + +For an array input, return the indices of the maximum elements over the given dimensions. +`NaN` is treated as greater than all other values. + +# Examples +```jldoctest +julia> A = [1.0 2; 3 4] +2×2 Array{Float64,2}: + 1.0 2.0 + 3.0 4.0 + +julia> argmax(A, dims=1) +1×2 Array{CartesianIndex{2},2}: + CartesianIndex(2, 1) CartesianIndex(2, 2) + +julia> argmax(A, dims=2) +2×1 Array{CartesianIndex{2},2}: + CartesianIndex(1, 2) + CartesianIndex(2, 2) +``` +""" +argmax(A::AbstractArray; dims=:) = findmax(A; dims=dims)[2] diff --git a/test/reducedim.jl b/test/reducedim.jl index 8f167e073821b..2c9f622cb4905 100644 --- a/test/reducedim.jl +++ b/test/reducedim.jl @@ -352,3 +352,9 @@ end T <: Base.SmallUnsigned ? UInt : T) end + +@testset "argmin/argmax" begin + B = reshape(3^3:-1:1, (3, 3, 3)) + @test B[argmax(B, dims=[2, 3])] == maximum(B, dims=[2, 3]) + @test B[argmin(B, dims=[2, 3])] == minimum(B, dims=[2, 3]) +end