From d808a9adaf2134d14427bcbbe2c51ce4c5bc841a Mon Sep 17 00:00:00 2001 From: Colin Caine Date: Mon, 30 Mar 2020 18:14:26 +0100 Subject: [PATCH] Add 2-arg versions of findmax/min, argmax/min Fixes #27613. Related: #27639, #27612, #34674. Thanks to @tkf, @StefanKarpinski and @drewrobson for their assistance with this PR. --- NEWS.md | 1 + base/array.jl | 134 --------------------------------- base/reduce.jl | 196 +++++++++++++++++++++++++++++++++++++++++++++++++ test/reduce.jl | 32 ++++++++ 4 files changed, 229 insertions(+), 134 deletions(-) diff --git a/NEWS.md b/NEWS.md index 71e3e6d4ef247c..4677e553fefe0b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -87,6 +87,7 @@ New library functions * New function `bitrotate(x, k)` for rotating the bits in a fixed-width integer ([#33937]). * One argument methods `startswith(x)` and `endswith(x)` have been added, returning partially-applied versions of the functions, similar to existing methods like `isequal(x)` ([#33193]). * New function `isgreater(a, b)` defines a descending total order where unorderable values and missing are ordered smaller than any regular value. +* Two argument methods `findmax(f, domain)`, `argmax(f, domain)` and the corresponding `min` versions ([#27613]). New library features -------------------- diff --git a/base/array.jl b/base/array.jl index 5dd482215d1925..e2952a28a42038 100644 --- a/base/array.jl +++ b/base/array.jl @@ -2096,140 +2096,6 @@ findall(x::Bool) = x ? [1] : Vector{Int}() findall(testf::Function, x::Number) = testf(x) ? [1] : Vector{Int}() findall(p::Fix2{typeof(in)}, x::Number) = x in p.x ? [1] : Vector{Int}() -""" - findmax(itr) -> (x, index) - -Return the maximum element of the collection `itr` and its index. If there are multiple -maximal elements, then the first one will be returned. -If any data element is `NaN`, this element is returned. -The result is in line with `max`. - -The collection must not be empty. - -# Examples -```jldoctest -julia> findmax([8,0.1,-9,pi]) -(8.0, 1) - -julia> findmax([1,7,7,6]) -(7, 2) - -julia> findmax([1,7,7,NaN]) -(NaN, 4) -``` -""" -findmax(a) = _findmax(a, :) - -function _findmax(a, ::Colon) - p = pairs(a) - y = iterate(p) - if y === nothing - throw(ArgumentError("collection must be non-empty")) - end - (mi, m), s = y - i = mi - while true - y = iterate(p, s) - y === nothing && break - m != m && break - (i, ai), s = y - if ai != ai || isless(m, ai) - m = ai - mi = i - end - end - return (m, mi) -end - -""" - findmin(itr) -> (x, index) - -Return the minimum element of the collection `itr` and its index. If there are multiple -minimal elements, then the first one will be returned. -If any data element is `NaN`, this element is returned. -The result is in line with `min`. - -The collection must not be empty. - -# Examples -```jldoctest -julia> findmin([8,0.1,-9,pi]) -(-9.0, 3) - -julia> findmin([7,1,1,6]) -(1, 2) - -julia> findmin([7,1,1,NaN]) -(NaN, 4) -``` -""" -findmin(a) = _findmin(a, :) - -function _findmin(a, ::Colon) - p = pairs(a) - y = iterate(p) - if y === nothing - throw(ArgumentError("collection must be non-empty")) - end - (mi, m), s = y - i = mi - while true - y = iterate(p, s) - y === nothing && break - m != m && break - (i, ai), s = y - if ai != ai || isless(ai, m) - m = ai - mi = i - end - end - return (m, mi) -end - -""" - argmax(itr) -> Integer - -Return the index of the maximum element in a collection. If there are multiple maximal -elements, then the first one will be returned. - -The collection must not be empty. - -# Examples -```jldoctest -julia> argmax([8,0.1,-9,pi]) -1 - -julia> argmax([1,7,7,6]) -2 - -julia> argmax([1,7,7,NaN]) -4 -``` -""" -argmax(a) = findmax(a)[2] - -""" - argmin(itr) -> Integer - -Return the index of the minimum element in a collection. If there are multiple minimal -elements, then the first one will be returned. - -The collection must not be empty. - -# Examples -```jldoctest -julia> argmin([8,0.1,-9,pi]) -3 - -julia> argmin([7,1,1,6]) -2 - -julia> argmin([7,1,1,NaN]) -4 -``` -""" -argmin(a) = findmin(a)[2] - # similar to Matlab's ismember """ indexin(a, b) diff --git a/base/reduce.jl b/base/reduce.jl index 73b200eb772880..3e865b04e9d0ed 100644 --- a/base/reduce.jl +++ b/base/reduce.jl @@ -659,6 +659,202 @@ julia> minimum([1,2,3]) """ minimum(a) = mapreduce(identity, min, a) +## findmax, findmin, argmax & argmin + +""" + findmax(f, domain) -> (f(x), x) + findmax(f) + +Returns a pair of a value in the codomain (outputs of `f`) and the corresponding +value in the `domain` (inputs to `f`) such that `f(x)` is maximised. If there +are multiple maximal points, then the first one will be returned. + +When `domain` is provided it may be any iterable and must not be empty. + +When `domain` is omitted, `f` must have an implicit domain. In particular, if +`f` is an indexable collection, it is interpreted as a function mapping keys +(domain) to values (codomain), i.e. `findmax(itr)` returns the maximal element +of the collection `itr` and its index. + +Values are compared with `isless`. + +# Examples + +```jldoctest +julia> findmax(identity, 5:9) +(9, 9) + +julia> findmax(-, 1:10) +(-1, 1) + +julia> findmax(cos, 0:π/2:2π) +(1.0, 0.0) + +julia> findmax([8,0.1,-9,pi]) +(8.0, 1) + +julia> findmax([1,7,7,6]) +(7, 2) + +julia> findmax([1,7,7,NaN]) +(NaN, 4) +``` + +""" +findmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain) +_rf_findmax((fm, m), (fx, x)) = isless(fm, fx) ? (fx, x) : (fm, m) + +""" + findmin(f, domain) -> (f(x), x) + findmin(f) + +Returns a pair of a value in the codomain (outputs of `f`) and the corresponding +value in the `domain` (inputs to `f`) such that `f(x)` is minimised. If there +are multiple minimal points, then the first one will be returned. + +When `domain` is provided it may be any iterable and must not be empty. + +When `domain` is omitted, `f` must have an implicit domain. In particular, if +`f` is an indexable collection, it is interpreted as a function mapping keys +(domain) to values (codomain), i.e. `findmin(itr)` returns the minimal element +of the collection `itr` and its index. + +Values are compared with `isgreater`. + +# Examples + +```jldoctest +julia> findmin(identity, 5:9) +(5, 5) + +julia> findmin(-, 1:10) +(-10, 10) + +julia> findmin(cos, 0:π/2:2π) +(-1.0, 3.141592653589793) + +julia> findmin([8,0.1,-9,pi]) +(-9, 3) + +julia> findmin([1,7,7,6]) +(1, 1) + +julia> findmin([1,7,7,NaN]) +(NaN, 4) +``` + +""" +findmin(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmin, domain) +_rf_findmin((fm, m), (fx, x)) = isgreater(fm, fx) ? (fx, x) : (fm, m) + +findmax(a) = _findmax(a, :) + +function _findmax(a, ::Colon) + p = pairs(a) + y = iterate(p) + if y === nothing + throw(ArgumentError("collection must be non-empty")) + end + (mi, m), s = y + i = mi + while true + y = iterate(p, s) + y === nothing && break + m != m && break + (i, ai), s = y + if ai != ai || isless(m, ai) + m = ai + mi = i + end + end + return (m, mi) +end + +findmin(a) = _findmin(a, :) + +function _findmin(a, ::Colon) + p = pairs(a) + y = iterate(p) + if y === nothing + throw(ArgumentError("collection must be non-empty")) + end + (mi, m), s = y + i = mi + while true + y = iterate(p, s) + y === nothing && break + m != m && break + (i, ai), s = y + if ai != ai || isless(ai, m) + m = ai + mi = i + end + end + return (m, mi) +end + +""" + argmax(f, domain) + argmax(f) + +Return a value `x` in the domain of `f` for which `f(x)` is maximised. +If there are multiple maximal values for `f(x)` then the first one will be found. + +When `domain` is provided it may be any iterable and must not be empty. + +When `domain` is omitted, `f` must have an implicit domain. In particular, if +`f` is an indexable collection, it is interpreted as a function mapping keys +(domain) to values (codomain), i.e. `argmax(itr)` returns the index of the +maximal element in `itr`. + +Values are compared with `isless`. + +# Examples +```jldoctest +julia> argmax([8,0.1,-9,pi]) +1 + +julia> argmax([1,7,7,6]) +2 + +julia> argmax([1,7,7,NaN]) +4 +``` +""" +argmax(f, domain) = findmax(f, domain)[2] +argmax(f) = findmax(f)[2] + +""" + argmin(f, domain) + argmin(f) + +Return a value `x` in the domain of `f` for which `f(x)` is minimised. +If there are multiple minimal values for `f(x)` then the first one will be found. + +When `domain` is provided it may be any iterable and must not be empty. + +When `domain` is omitted, `f` must have an implicit domain. In particular, if +`f` is an indexable collection, it is interpreted as a function mapping keys +(domain) to values (codomain), i.e. `argmin(itr)` returns the index of the +minimal element in `itr`. + +Values are compared with `isgreater`. + +# Examples +```jldoctest +julia> argmin([8,0.1,-9,pi]) +3 + +julia> argmin([7,1,1,6]) +2 + +julia> argmin([7,1,1,NaN]) +4 +``` +""" +argmin(f, domain) = findmin(f, domain)[2] +argmin(f) = findmin(f)[2] + ## all & any """ diff --git a/test/reduce.jl b/test/reduce.jl index 69b8b1911e7ea1..b7527ca869fdf4 100644 --- a/test/reduce.jl +++ b/test/reduce.jl @@ -338,6 +338,38 @@ A = circshift(reshape(1:24,2,3,4), (0,1,1)) @test size(extrema(A,dims=(1,2,3))) == size(maximum(A,dims=(1,2,3))) @test extrema(x->div(x, 2), A, dims=(2,3)) == reshape([(0,11),(1,12)],2,1,1) +# findmin, findmax, argmin, argmax + +@testset "findmin(f, domain)" begin + @test findmin(-, 1:10) == (-10, 10) + @test findmin(identity, [1, 2, 3, missing]) === (missing, missing) + @test findmin(identity, [1, NaN, 3, missing]) === (missing, missing) + @test findmin(identity, [1, missing, NaN, 3]) === (missing, missing) + @test findmin(identity, [1, NaN, 3]) === (NaN, NaN) + @test findmin(identity, [1, 3, NaN]) === (NaN, NaN) + @test all(findmin(cos, 0:π/2:2π) .≈ (-1.0, π)) +end + +@testset "findmax(f, domain)" begin + @test findmax(-, 1:10) == (-1, 1) + @test findmax(identity, [1, 2, 3, missing]) === (missing, missing) + @test findmax(identity, [1, NaN, 3, missing]) === (missing, missing) + @test findmax(identity, [1, missing, NaN, 3]) === (missing, missing) + @test findmax(identity, [1, NaN, 3]) === (NaN, NaN) + @test findmax(identity, [1, 3, NaN]) === (NaN, NaN) + @test findmax(cos, 0:π/2:2π) == (1.0, 0.0) +end + +@testset "argmin(f, domain)" begin + @test argmin(-, 1:10) == 10 + @test argmin(sum, Iterators.product(1:5, 1:5)) == (1, 1) +end + +@testset "argmax(f, domain)" begin + @test argmax(-, 1:10) == 1 + @test argmax(sum, Iterators.product(1:5, 1:5)) == (5, 5) +end + # any & all @test @inferred any([]) == false