Skip to content

Commit

Permalink
Added argmin and argmax over given dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
Pontus Stenetorp committed May 31, 2018
1 parent 9ed7978 commit 607b1f3
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
50 changes: 50 additions & 0 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -805,3 +805,53 @@ function _findmax(A, region)
end

reducedim1(R, A) = _length(indices1(R)) == 1

"""
argmin(A; dims) -> indices
For an array input, return the indices of the minimum elements over the given dimensions.
`NaN` is treated as less than all other values.
# Examples
```jldoctest
julia> A = [1.0 2; 3 4]
2×2 Array{Float64,2}:
1.0 2.0
3.0 4.0
julia> argmin(A, dims=1)
1×2 Array{CartesianIndex{2},2}:
CartesianIndex(1, 1) CartesianIndex(1, 2)
julia> argmin(A, dims=2)
2×1 Array{CartesianIndex{2},2}:
CartesianIndex(1, 1)
CartesianIndex(2, 1)
```
"""
argmin(A::AbstractArray; dims=:) = findmin(A; dims=dims)[2]

"""
argmax(A; dims) -> indices
For an array input, return the indices of the maximum elements over the given dimensions.
`NaN` is treated as greater than all other values.
# Examples
```jldoctest
julia> A = [1.0 2; 3 4]
2×2 Array{Float64,2}:
1.0 2.0
3.0 4.0
julia> argmax(A, dims=1)
1×2 Array{CartesianIndex{2},2}:
CartesianIndex(2, 1) CartesianIndex(2, 2)
julia> argmax(A, dims=2)
2×1 Array{CartesianIndex{2},2}:
CartesianIndex(1, 2)
CartesianIndex(2, 2)
```
"""
argmax(A::AbstractArray; dims=:) = findmax(A; dims=dims)[2]
6 changes: 6 additions & 0 deletions test/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,9 @@ end
T <: Base.SmallUnsigned ? UInt :
T)
end

@testset "argmin/argmax" begin
B = reshape(3^3:-1:1, (3, 3, 3))
@test B[argmax(B, dims=[2, 3])] == maximum(B, dims=[2, 3])
@test B[argmin(B, dims=[2, 3])] == minimum(B, dims=[2, 3])
end

0 comments on commit 607b1f3

Please sign in to comment.