From 76952a88ea650ae8c6b6b1d010ef695ed9c8244d Mon Sep 17 00:00:00 2001 From: Colin Caine Date: Fri, 18 Dec 2020 17:32:19 +0000 Subject: [PATCH] Add 2-arg versions of findmax/min, argmax/min (#35316) Defines a descending total order, `isgreater` (not exported), where unordered values like NaNs and missing are last. This makes defining min, argmin, etc, simpler and more consistent. Also adds 2-arg versions of findmax/min, argmax/min. Defines and exports the `isunordered` predicate for testing whether a value is unordered like NaN and missing. Fixes #27613. Related: #27639, #27612, #34674. Thanks to @tkf, @StefanKarpinski and @drewrobson for their assistance with this PR. Co-authored-by: Jameson Nash Co-authored-by: Takafumi Arakaki --- NEWS.md | 2 + base/array.jl | 134 ----------------------------- base/exports.jl | 1 + base/operators.jl | 59 ++++++++++++- base/reduce.jl | 215 ++++++++++++++++++++++++++++++++++++++++++++++ base/reducedim.jl | 46 ++++++---- test/operators.jl | 22 +++++ test/reduce.jl | 32 +++++++ test/reducedim.jl | 34 ++++++++ 9 files changed, 392 insertions(+), 153 deletions(-) diff --git a/NEWS.md b/NEWS.md index ea50e8c3c9b93..eaebcb398bee7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -28,6 +28,8 @@ Build system changes New library functions --------------------- +* Two argument methods `findmax(f, domain)`, `argmax(f, domain)` and the corresponding `min` versions ([#27613]). +* `isunordered(x)` returns true if `x` is value that is normally unordered, such as `NaN` or `missing`. New library features -------------------- diff --git a/base/array.jl b/base/array.jl index f23be8acc862e..7a4bd705bae70 100644 --- a/base/array.jl +++ b/base/array.jl @@ -2205,140 +2205,6 @@ findall(x::Bool) = x ? [1] : Vector{Int}() findall(testf::Function, x::Number) = testf(x) ? [1] : Vector{Int}() findall(p::Fix2{typeof(in)}, x::Number) = x in p.x ? [1] : Vector{Int}() -""" - findmax(itr) -> (x, index) - -Return the maximum element of the collection `itr` and its index or key. -If there are multiple maximal elements, then the first one will be returned. -If any data element is `NaN`, this element is returned. -The result is in line with `max`. - -The collection must not be empty. - -# Examples -```jldoctest -julia> findmax([8,0.1,-9,pi]) -(8.0, 1) - -julia> findmax([1,7,7,6]) -(7, 2) - -julia> findmax([1,7,7,NaN]) -(NaN, 4) -``` -""" -findmax(a) = _findmax(a, :) - -function _findmax(a, ::Colon) - p = pairs(a) - y = iterate(p) - if y === nothing - throw(ArgumentError("collection must be non-empty")) - end - (mi, m), s = y - i = mi - while true - y = iterate(p, s) - y === nothing && break - m != m && break - (i, ai), s = y - if ai != ai || isless(m, ai) - m = ai - mi = i - end - end - return (m, mi) -end - -""" - findmin(itr) -> (x, index) - -Return the minimum element of the collection `itr` and its index or key. -If there are multiple minimal elements, then the first one will be returned. -If any data element is `NaN`, this element is returned. -The result is in line with `min`. - -The collection must not be empty. - -# Examples -```jldoctest -julia> findmin([8,0.1,-9,pi]) -(-9.0, 3) - -julia> findmin([7,1,1,6]) -(1, 2) - -julia> findmin([7,1,1,NaN]) -(NaN, 4) -``` -""" -findmin(a) = _findmin(a, :) - -function _findmin(a, ::Colon) - p = pairs(a) - y = iterate(p) - if y === nothing - throw(ArgumentError("collection must be non-empty")) - end - (mi, m), s = y - i = mi - while true - y = iterate(p, s) - y === nothing && break - m != m && break - (i, ai), s = y - if ai != ai || isless(ai, m) - m = ai - mi = i - end - end - return (m, mi) -end - -""" - argmax(itr) - -Return the index or key of the maximum element in a collection. -If there are multiple maximal elements, then the first one will be returned. - -The collection must not be empty. - -# Examples -```jldoctest -julia> argmax([8,0.1,-9,pi]) -1 - -julia> argmax([1,7,7,6]) -2 - -julia> argmax([1,7,7,NaN]) -4 -``` -""" -argmax(a) = findmax(a)[2] - -""" - argmin(itr) - -Return the index or key of the minimum element in a collection. -If there are multiple minimal elements, then the first one will be returned. - -The collection must not be empty. - -# Examples -```jldoctest -julia> argmin([8,0.1,-9,pi]) -3 - -julia> argmin([7,1,1,6]) -2 - -julia> argmin([7,1,1,NaN]) -4 -``` -""" -argmin(a) = findmin(a)[2] - # similar to Matlab's ismember """ indexin(a, b) diff --git a/base/exports.jl b/base/exports.jl index 0c157c45d2052..702c1bf485c3b 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -657,6 +657,7 @@ export isequal, ismutable, isless, + isunordered, ifelse, objectid, sizeof, diff --git a/base/operators.jl b/base/operators.jl index c2645910bf039..a35c44276a2e4 100644 --- a/base/operators.jl +++ b/base/operators.jl @@ -141,7 +141,7 @@ is defined, it is expected to satisfy the following: `isless(x, y) && isless(y, z)` implies `isless(x, z)`. Values that are normally unordered, such as `NaN`, -are ordered in an arbitrary but consistent fashion. +are ordered after regular values. [`missing`](@ref) values are ordered last. This is the default comparison used by [`sort`](@ref). @@ -168,6 +168,63 @@ isless(x::AbstractFloat, y::AbstractFloat) = (!isnan(x) & (isnan(y) | signless(x isless(x::Real, y::AbstractFloat) = (!isnan(x) & (isnan(y) | signless(x, y))) | (x < y) isless(x::AbstractFloat, y::Real ) = (!isnan(x) & (isnan(y) | signless(x, y))) | (x < y) +""" + isgreater(x, y) + +Not the inverse of `isless`! Test whether `x` is greater than `y`, according to +a fixed total order compatible with `min`. + +Defined with `isless`, this function is usually `isless(y, x)`, but `NaN` and +[`missing`](@ref) are ordered as smaller than any ordinary value with `missing` +smaller than `NaN`. + +So `isless` defines an ascending total order with `NaN` and `missing` as the +largest values and `isgreater` defines a descending total order with `NaN` and +`missing` as the smallest values. + +!!! note + + Like `min`, `isgreater` orders containers (tuples, vectors, etc) + lexigraphically with `isless(y, x)` rather than recursively with itself: + + ```jldoctest + julia> Base.isgreater(1, NaN) # 1 is greater than NaN + true + + julia> Base.isgreater((1,), (NaN,)) # But (1,) is not greater than (NaN,) + false + + julia> sort([1, 2, 3, NaN]; lt=Base.isgreater) + 4-element Vector{Float64}: + 3.0 + 2.0 + 1.0 + NaN + + julia> sort(tuple.([1, 2, 3, NaN]); lt=Base.isgreater) + 4-element Vector{Tuple{Float64}}: + (NaN,) + (3.0,) + (2.0,) + (1.0,) + ``` + +# Implementation +This is unexported. Types should not usually implement this function. Instead, implement `isless`. +""" +isgreater(x, y) = isunordered(x) || isunordered(y) ? isless(x, y) : isless(y, x) + +""" + isunordered(x) + +Return true if `x` is a value that is not normally orderable, such as `NaN` or `missing`. + +!!! compat "Julia 1.7" + This function requires Julia 1.7 or later. +""" +isunordered(x) = false +isunordered(x::AbstractFloat) = isnan(x) +isunordered(x::Missing) = true function ==(T::Type, S::Type) @_pure_meta diff --git a/base/reduce.jl b/base/reduce.jl index 8ea928669ab9e..a3c8099b979a5 100644 --- a/base/reduce.jl +++ b/base/reduce.jl @@ -762,6 +762,221 @@ Inf """ minimum(a; kw...) = mapreduce(identity, min, a; kw...) +## findmax, findmin, argmax & argmin + +""" + findmax(f, domain) -> (f(x), x) + +Returns a pair of a value in the codomain (outputs of `f`) and the corresponding +value in the `domain` (inputs to `f`) such that `f(x)` is maximised. If there +are multiple maximal points, then the first one will be returned. + +`domain` must be a non-empty iterable. + +Values are compared with `isless`. + +!!! compat "Julia 1.7" + This method requires Julia 1.7 or later. + +# Examples + +```jldoctest +julia> findmax(identity, 5:9) +(9, 9) + +julia> findmax(-, 1:10) +(-1, 1) + +julia> findmax(first, [(1, :a), (2, :b), (2, :c)]) +(2, (2, :b)) + +julia> findmax(cos, 0:π/2:2π) +(1.0, 0.0) +``` +""" +findmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain) +_rf_findmax((fm, m), (fx, x)) = isless(fm, fx) ? (fx, x) : (fm, m) + +""" + findmax(itr) -> (x, index) + +Return the maximal element of the collection `itr` and its index or key. +If there are multiple maximal elements, then the first one will be returned. +Values are compared with `isless`. + +# Examples + +```jldoctest +julia> findmax([8, 0.1, -9, pi]) +(8.0, 1) + +julia> findmax([1, 7, 7, 6]) +(7, 2) + +julia> findmax([1, 7, 7, NaN]) +(NaN, 4) +``` +""" +findmax(itr) = _findmax(itr, :) +_findmax(a, ::Colon) = mapfoldl( ((k, v),) -> (v, k), _rf_findmax, pairs(a) ) + +""" + findmin(f, domain) -> (f(x), x) + +Returns a pair of a value in the codomain (outputs of `f`) and the corresponding +value in the `domain` (inputs to `f`) such that `f(x)` is minimised. If there +are multiple minimal points, then the first one will be returned. + +`domain` must be a non-empty iterable. + +`NaN` is treated as less than all other values except `missing`. + +!!! compat "Julia 1.7" + This method requires Julia 1.7 or later. + +# Examples + +```jldoctest +julia> findmin(identity, 5:9) +(5, 5) + +julia> findmin(-, 1:10) +(-10, 10) + +julia> findmin(first, [(1, :a), (1, :b), (2, :c)]) +(1, (1, :a)) + +julia> findmin(cos, 0:π/2:2π) +(-1.0, 3.141592653589793) +``` + +""" +findmin(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmin, domain) +_rf_findmin((fm, m), (fx, x)) = isgreater(fm, fx) ? (fx, x) : (fm, m) + +""" + findmin(itr) -> (x, index) + +Return the minimal element of the collection `itr` and its index or key. +If there are multiple minimal elements, then the first one will be returned. +`NaN` is treated as less than all other values except `missing`. + +# Examples + +```jldoctest +julia> findmin([8, 0.1, -9, pi]) +(-9.0, 3) + +julia> findmin([1, 7, 7, 6]) +(1, 1) + +julia> findmin([1, 7, 7, NaN]) +(NaN, 4) +``` +""" +findmin(itr) = _findmin(itr, :) +_findmin(a, ::Colon) = mapfoldl( ((k, v),) -> (v, k), _rf_findmin, pairs(a) ) + +""" + argmax(f, domain) + +Return a value `x` in the domain of `f` for which `f(x)` is maximised. +If there are multiple maximal values for `f(x)` then the first one will be found. + +`domain` must be a non-empty iterable. + +Values are compared with `isless`. + +!!! compat "Julia 1.7" + This method requires Julia 1.7 or later. + +# Examples +```jldoctest +julia> argmax(abs, -10:5) +-10 + +julia> argmax(cos, 0:π/2:2π) +0.0 +``` +""" +argmax(f, domain) = findmax(f, domain)[2] + +""" + argmax(itr) + +Return the index or key of the maximal element in a collection. +If there are multiple maximal elements, then the first one will be returned. + +The collection must not be empty. + +Values are compared with `isless`. + +# Examples +```jldoctest +julia> argmax([8, 0.1, -9, pi]) +1 + +julia> argmax([1, 7, 7, 6]) +2 + +julia> argmax([1, 7, 7, NaN]) +4 +``` +""" +argmax(itr) = findmax(itr)[2] + +""" + argmin(f, domain) + +Return a value `x` in the domain of `f` for which `f(x)` is minimised. +If there are multiple minimal values for `f(x)` then the first one will be found. + +`domain` must be a non-empty iterable. + +`NaN` is treated as less than all other values except `missing`. + +!!! compat "Julia 1.7" + This method requires Julia 1.7 or later. + +# Examples +```jldoctest +julia> argmin(sign, -10:5) +-10 + +julia> argmin(x -> -x^3 + x^2 - 10, -5:5) +5 + +julia> argmin(acos, 0:0.1:1) +1.0 + +``` +""" +argmin(f, domain) = findmin(f, domain)[2] + +""" + argmin(itr) + +Return the index or key of the minimal element in a collection. +If there are multiple minimal elements, then the first one will be returned. + +The collection must not be empty. + +`NaN` is treated as less than all other values except `missing`. + +# Examples +```jldoctest +julia> argmin([8, 0.1, -9, pi]) +3 + +julia> argmin([7, 1, 1, 6]) +2 + +julia> argmin([7, 1, 1, NaN]) +4 +``` +""" +argmin(itr) = findmin(itr)[2] + ## all & any """ diff --git a/base/reducedim.jl b/base/reducedim.jl index c83c2c71cb3e6..85807851cd23d 100644 --- a/base/reducedim.jl +++ b/base/reducedim.jl @@ -125,7 +125,7 @@ function _reducedim_init(f, op, fv, fop, A, region) end # initialization when computing minima and maxima requires a little care -for (f1, f2, initval) in ((:min, :max, :Inf), (:max, :min, :(-Inf))) +for (f1, f2, initval, typeextreme) in ((:min, :max, :Inf, :typemax), (:max, :min, :(-Inf), :typemin)) @eval function reducedim_init(f, op::typeof($f1), A::AbstractArray, region) # First compute the reduce indices. This will throw an ArgumentError # if any region is invalid @@ -144,11 +144,23 @@ for (f1, f2, initval) in ((:min, :max, :Inf), (:max, :min, :(-Inf))) # otherwise use the min/max of the first slice as initial value v0 = mapreduce(f, $f2, A1) - # but NaNs need to be avoided as initial values - v0 = v0 != v0 ? typeof(v0)($initval) : v0 - T = _realtype(f, promote_union(eltype(A))) Tr = v0 isa T ? T : typeof(v0) + + # but NaNs and missing need to be avoided as initial values + if (v0 == v0) === false + # v0 is NaN + v0 = $initval + elseif isunordered(v0) + # v0 is missing or a third-party unordered value + Tnm = nonmissingtype(Tr) + # TODO: Some types, like BigInt, don't support typemin/typemax. + # So a Matrix{Union{BigInt, Missing}} can still error here. + v0 = $typeextreme(Tnm) + end + # v0 may have changed type. + Tr = v0 isa T ? T : typeof(v0) + return reducedim_initarray(A, region, v0, Tr) end end @@ -926,7 +938,7 @@ function findminmax!(f, Rval, Rind, A::AbstractArray{T,N}) where {T,N} for i in axes(A,1) k, kss = y::Tuple tmpAv = A[i,IA] - if tmpRi == zi || (tmpRv == tmpRv && (tmpAv != tmpAv || f(tmpAv, tmpRv))) + if tmpRi == zi || f(tmpRv, tmpAv) tmpRv = tmpAv tmpRi = k end @@ -943,7 +955,7 @@ function findminmax!(f, Rval, Rind, A::AbstractArray{T,N}) where {T,N} tmpAv = A[i,IA] tmpRv = Rval[i,IR] tmpRi = Rind[i,IR] - if tmpRi == zi || (tmpRv == tmpRv && (tmpAv != tmpAv || f(tmpAv, tmpRv))) + if tmpRi == zi || f(tmpRv, tmpAv) Rval[i,IR] = tmpAv Rind[i,IR] = k end @@ -959,18 +971,18 @@ end Find the minimum of `A` and the corresponding linear index along singleton dimensions of `rval` and `rind`, and store the results in `rval` and `rind`. -`NaN` is treated as less than all other values. +`NaN` is treated as less than all other values except `missing`. """ function findmin!(rval::AbstractArray, rind::AbstractArray, A::AbstractArray; init::Bool=true) - findminmax!(isless, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,zero(eltype(keys(A)))), A) + findminmax!(isgreater, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,zero(eltype(keys(A)))), A) end """ findmin(A; dims) -> (minval, index) For an array input, returns the value and index of the minimum over the given dimensions. -`NaN` is treated as less than all other values. +`NaN` is treated as less than all other values except `missing`. # Examples ```jldoctest @@ -996,30 +1008,28 @@ function _findmin(A, region) end (similar(A, ri), zeros(eltype(keys(A)), ri)) else - findminmax!(isless, fill!(similar(A, ri), first(A)), + findminmax!(isgreater, fill!(similar(A, ri), first(A)), zeros(eltype(keys(A)), ri), A) end end -isgreater(a, b) = isless(b,a) - """ findmax!(rval, rind, A) -> (maxval, index) Find the maximum of `A` and the corresponding linear index along singleton dimensions of `rval` and `rind`, and store the results in `rval` and `rind`. -`NaN` is treated as greater than all other values. +`NaN` is treated as greater than all other values except `missing`. """ function findmax!(rval::AbstractArray, rind::AbstractArray, A::AbstractArray; init::Bool=true) - findminmax!(isgreater, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,zero(eltype(keys(A)))), A) + findminmax!(isless, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,zero(eltype(keys(A)))), A) end """ findmax(A; dims) -> (maxval, index) For an array input, returns the value and index of the maximum over the given dimensions. -`NaN` is treated as greater than all other values. +`NaN` is treated as greater than all other values except `missing`. # Examples ```jldoctest @@ -1045,7 +1055,7 @@ function _findmax(A, region) end similar(A, ri), zeros(eltype(keys(A)), ri) else - findminmax!(isgreater, fill!(similar(A, ri), first(A)), + findminmax!(isless, fill!(similar(A, ri), first(A)), zeros(eltype(keys(A)), ri), A) end end @@ -1056,7 +1066,7 @@ reducedim1(R, A) = length(axes1(R)) == 1 argmin(A; dims) -> indices For an array input, return the indices of the minimum elements over the given dimensions. -`NaN` is treated as less than all other values. +`NaN` is treated as less than all other values except `missing`. # Examples ```jldoctest @@ -1081,7 +1091,7 @@ argmin(A::AbstractArray; dims=:) = findmin(A; dims=dims)[2] argmax(A; dims) -> indices For an array input, return the indices of the maximum elements over the given dimensions. -`NaN` is treated as greater than all other values. +`NaN` is treated as greater than all other values except `missing`. # Examples ```jldoctest diff --git a/test/operators.jl b/test/operators.jl index 08f0b0179c81c..c14858657ce3b 100644 --- a/test/operators.jl +++ b/test/operators.jl @@ -83,6 +83,28 @@ import Base.< @test isless('a','b') +@testset "isgreater" begin + # isgreater should be compatible with min. + min1(a, b) = Base.isgreater(a, b) ? b : a + # min promotes numerical arguments to the same type, but our quick min1 + # doesn't, so use float test values instead of ints. + values = (1.0, 5.0, NaN, missing, Inf) + for a in values, b in values + @test min(a, b) === min1(a, b) + @test min((a,), (b,)) === min1((a,), (b,)) + @test all(min([a], [b]) .=== min1([a], [b])) + end +end + +@testset "isunordered" begin + @test isunordered(NaN) + @test isunordered(NaN32) + @test isunordered(missing) + @test !isunordered(1) + @test !isunordered([NaN, 1]) + @test !isunordered([1.0, missing]) +end + @testset "vectorized comparisons between numbers" begin @test 1 .!= 2 @test 1 .== 1 diff --git a/test/reduce.jl b/test/reduce.jl index 4f9fc33403282..6cbe01a5de1cf 100644 --- a/test/reduce.jl +++ b/test/reduce.jl @@ -387,6 +387,38 @@ A = circshift(reshape(1:24,2,3,4), (0,1,1)) end end +# findmin, findmax, argmin, argmax + +@testset "findmin(f, domain)" begin + @test findmin(-, 1:10) == (-10, 10) + @test findmin(identity, [1, 2, 3, missing]) === (missing, missing) + @test findmin(identity, [1, NaN, 3, missing]) === (missing, missing) + @test findmin(identity, [1, missing, NaN, 3]) === (missing, missing) + @test findmin(identity, [1, NaN, 3]) === (NaN, NaN) + @test findmin(identity, [1, 3, NaN]) === (NaN, NaN) + @test all(findmin(cos, 0:π/2:2π) .≈ (-1.0, π)) +end + +@testset "findmax(f, domain)" begin + @test findmax(-, 1:10) == (-1, 1) + @test findmax(identity, [1, 2, 3, missing]) === (missing, missing) + @test findmax(identity, [1, NaN, 3, missing]) === (missing, missing) + @test findmax(identity, [1, missing, NaN, 3]) === (missing, missing) + @test findmax(identity, [1, NaN, 3]) === (NaN, NaN) + @test findmax(identity, [1, 3, NaN]) === (NaN, NaN) + @test findmax(cos, 0:π/2:2π) == (1.0, 0.0) +end + +@testset "argmin(f, domain)" begin + @test argmin(-, 1:10) == 10 + @test argmin(sum, Iterators.product(1:5, 1:5)) == (1, 1) +end + +@testset "argmax(f, domain)" begin + @test argmax(-, 1:10) == 1 + @test argmax(sum, Iterators.product(1:5, 1:5)) == (5, 5) +end + # any & all @test @inferred any([]) == false diff --git a/test/reducedim.jl b/test/reducedim.jl index b6229634b6006..cc07cfff1dad3 100644 --- a/test/reducedim.jl +++ b/test/reducedim.jl @@ -195,6 +195,7 @@ end end end + ## findmin/findmax/minimum/maximum A = [1.0 5.0 6.0; @@ -219,6 +220,39 @@ for (tup, rval, rind) in [((1,), [5.0 5.0 6.0], [CartesianIndex(2,1) CartesianIn @test isequal(maximum!(copy(rval), A, init=false), rval) end +@testset "missing in findmin/findmax" begin + B = [1.0 missing NaN; + 5.0 NaN missing] + for (tup, rval, rind) in [(1, [5.0 missing missing], [CartesianIndex(2, 1) CartesianIndex(1, 2) CartesianIndex(2, 3)]), + (2, [missing; missing], [CartesianIndex(1, 2) CartesianIndex(2, 3)] |> permutedims)] + (rval′, rind′) = findmax(B, dims=tup) + @test all(rval′ .=== rval) + @test all(rind′ .== rind) + @test all(maximum(B, dims=tup) .=== rval) + end + + for (tup, rval, rind) in [(1, [1.0 missing missing], [CartesianIndex(1, 1) CartesianIndex(1, 2) CartesianIndex(2, 3)]), + (2, [missing; missing], [CartesianIndex(1, 2) CartesianIndex(2, 3)] |> permutedims)] + (rval′, rind′) = findmin(B, dims=tup) + @test all(rval′ .=== rval) + @test all(rind′ .== rind) + @test all(minimum(B, dims=tup) .=== rval) + end +end + +@testset "reducedim_init min/max unorderable handling" begin + x = Any[1.0, NaN] + y = [1, missing] + for (v, rval1, rval2) in [(x, [NaN], x), + (y, [missing], y), + (Any[1. NaN; 1. 1.], Any[1. NaN], Any[NaN, 1.])] + for f in (minimum, maximum) + @test all(f(v, dims=1) .=== rval1) + @test all(f(v, dims=2) .=== rval2) + end + end +end + #issue #23209 A = [1.0 3.0 6.0;