Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add findmin, findmax, argmin, and argmax #53

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name = "NaNMath"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
repo = "https://github.com/mlubin/NaNMath.jl.git"
authors = ["Miles Lubin"]
version = "1.0.1"
version = "1.1.0"

[deps]
OpenLibm_jll = "05823500-19ac-5b8b-9628-191a04bc5112"
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ sum
maximum
minimum
extrema
argmax
argmin
findmax
findmin
mean
median
var
Expand Down
168 changes: 168 additions & 0 deletions src/NaNMath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -355,4 +355,172 @@ for f in (:min, :max)
@eval ($f)(a, b, c, xs...) = Base.afoldl($f, ($f)(($f)(a, b), c), xs...)
end

"""
NaNMath.findmin([f,] domain) -> (f(x), index)

##### Args:
* `f`: a function applied to the values in `domain` (defaulting to `identity`)
* `domain`: A non-empty iterable such that the codomain (outputs of `f` applied to `domain`)
are floating point numbers or `missing`

##### Returns:
* Returns a pair of a value in the codomain and the index of the corresponding value in the
`domain` (inputs to `f`) such that `f(x)` is minimized. If there are multiple minimal
points, then the first one will be returned. `NaN`s are treated as greater than all other
values, while `missing` is treated as less than all other values.

##### Examples:
```julia
julia> NaNMath.findmin([1., 1., 2., 2., NaN])
(1.0, 1)

julia> NaNMath.findmin(-, [1., 1., 2., 2., NaN])
(-2.0, 3)
```
"""
function findmin end
findmin(f, x) = _findminmax(Base.isgreater, f, x)
findmin(x) = findmin(identity, x)

"""
NaNMath.findmax([f,] domain) -> (f(x), index)

##### Args:
* `f`: a function applied to the values in `domain` (defaulting to `identity`)
* `domain`: A non-empty iterable such that the codomain (outputs of `f` applied to `domain`)
are floating point numbers or `missing`

##### Returns:
* Returns a pair of a value in the codomain and the index of the corresponding value in the
`domain` (inputs to `f`) such that `f(x)` is maximized. If there are multiple maximal
points, then the first one will be returned. `NaN`s are treated as less than all other
values, while `missing` is treated as greater than all other values.

##### Examples:
```julia
julia> NaNMath.findmax([1., 1., 2., 2., NaN])
(2.0, 3)

julia> NaNMath.findmax(-, [1., 1., 2., 2., NaN])
(-1.0, 1)
```
"""
function findmax end
findmax(f, x) = _findminmax(Base.isless, f, x)
findmax(x) = findmax(identity, x)

function _findminmax_op(cmp)
return (xleft_and_index, xright_and_index) -> begin
xleft = first(xleft_and_index)
xleft === missing && return xleft_and_index
xright = first(xright_and_index)
xright === missing && return xright_and_index
return ifelse(
(xleft isa Number && isnan(xright)) || !cmp(xleft, xright),
xleft_and_index,
xright_and_index,
)
end
end

function _findminmax(cmp, f, x)
return mapfoldl(_findminmax_op(cmp), pairs(x)) do (k, xk)
return f(xk), k
end
end

"""
NaNMath.argmin(f, domain) -> x

##### Args:
* `f`: A function applied to the values of `domain`
* `domain`: A non-empty iterable such that the codomain (outputs of `f` applied to `domain`)
are floating point numbers or `missing`

##### Returns:
* Returns a value `x` in the domain of `f` for which `f(x)` is minimized. If there are
multiple minimal values for `f(x)`, then the first one will be found. `NaN`s are treated as
greater than all other values, while `missing` is treated as less than all other values.

##### Examples:
```julia
julia> NaNMath.argmin(abs, [1., -1., -2., 2., NaN])
1.0

julia> NaNMath.argmin(identity, [7, 1, 1, NaN])
1.0
```

NaNMath.argmin(itr) -> key

##### Args:
* `itr`: A non-empty iterable of floating point numbers or `missing`.

##### Returns:
* Returns the index or key of the minimal element in `itr`. If there are multiple
minimal elements, then the first one will be returned

##### Examples:
```julia
julia> NaNMath.argmin([7, 1, 1, NaN])
2

julia> NaNMath.argmin([1.0 2; 3 NaN])
CartesianIndex(1, 1)

julia> NaNMath.argmin(Dict("x" => 1.0, "y" => -1, "z" => NaN))
"y"
```
"""
function argmin end
argmin(x) = findmin(identity, x)[2]
argmin(f, x) = mapfoldl(x -> (f(x), x), _findminmax_op(Base.isgreater), x)[2]

"""
NaNMath.argmax(f, domain) -> x

##### Args:
* `f`: A function applied to the values of `domain`
* `domain`: A non-empty iterable such that the codomain (outputs of `f` applied to `domain`)
are floating point numbers or `missing`

##### Returns:
* Returns a value `x` in the domain of `f` for which `f(x)` is maximized. If there are
multiple maximal values for `f(x)`, then the first one will be found. `NaN`s are treated as
less than all other values, while `missing` is treated as greater than all other values.

##### Examples:
```julia
julia> NaNMath.argmax(abs, [1., -1., -2., NaN])
2.0

julia> NaNMath.argmax(identity, [7, 1, 1, NaN])
7.0
```

NaNMath.argmax(itr) -> key

##### Args:
* `itr`: A non-empty iterable of floating point numbers or `missing`.

##### Returns:
* Returns the index or key of the maximal element in `itr`. If there are multiple
maximal elements, then the first one will be returned

##### Examples:
```julia
julia> NaNMath.argmax([7, 1, 1, NaN])
1

julia> NaNMath.argmax([1.0 2; 3 NaN])
CartesianIndex(2, 1)

julia> NaNMath.argmax(Dict("x" => 1.0, "y" => -1, "z" => NaN))
"x"
```
"""
function argmax end
argmax(x) = findmax(identity, x)[2]
argmax(f, x) = mapfoldl(x -> (f(x), x), _findminmax_op(Base.isless), x)[2]

end
118 changes: 118 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,121 @@ using Test
@test isnan(NaNMath.max(NaN, NaN))
@test isnan(NaNMath.max(NaN))
@test NaNMath.max(NaN, NaN, 0.0, 1.0) == 1.0

@testset "findmin/findmax" begin
if VERSION ≥ v"1.7"
xvals = [
[1., 2., 3., 3., 1.],
[missing, missing],
[missing, 1.0],
[1.0, missing],
(1., 2, 3., 3, 1),
(x=1, y=3, z=-4, w=-2),
Dict(:a => 1.0, :b => 1.0, :d => 3.0, :c => 2.0),
]
@testset for x in xvals
@test NaNMath.findmin(x) === findmin(x)
@test NaNMath.findmax(x) === findmax(x)
@test NaNMath.findmin(identity, x) === findmin(identity, x)
@test NaNMath.findmax(identity, x) === findmax(identity, x)
@test NaNMath.findmin(sin, x) === findmin(sin, x)
@test NaNMath.findmax(sin, x) === findmax(sin, x)
end
end
x = [7, 7, NaN, 1, 1, NaN]
@test NaNMath.findmin(x) === (1.0, 4)
@test NaNMath.findmax(x) === (7.0, 1)
@test NaNMath.findmin(identity, x) === (1.0, 4)
@test NaNMath.findmax(identity, x) === (7.0, 1)
@test NaNMath.findmin(-, x) === (-7.0, 1)
@test NaNMath.findmax(-, x) === (-1.0, 4)

x = [NaN, NaN]
@test NaNMath.findmin(x) === (NaN, 1)
@test NaNMath.findmax(x) === (NaN, 1)
@test NaNMath.findmin(identity, x) === (NaN, 1)
@test NaNMath.findmax(identity, x) === (NaN, 1)
@test NaNMath.findmin(sin, x) === (NaN, 1)
@test NaNMath.findmax(sin, x) === (NaN, 1)

x = [3, missing, NaN, -1]
@test NaNMath.findmin(x) === (missing, 2)
@test NaNMath.findmax(x) === (missing, 2)
@test NaNMath.findmin(identity, x) === (missing, 2)
@test NaNMath.findmax(identity, x) === (missing, 2)
@test NaNMath.findmin(sin, x) === (missing, 2)
@test NaNMath.findmax(sin, x) === (missing, 2)

x = Dict(:x => 3, :w => 2, :y => -1.0, :z => NaN)
@test NaNMath.findmin(x) === (-1.0, :y)
@test NaNMath.findmax(x) === (3, :x)
@test NaNMath.findmin(identity, x) === (-1.0, :y)
@test NaNMath.findmax(identity, x) === (3, :x)
@test NaNMath.findmin(-, x) === (-3, :x)
@test NaNMath.findmax(-, x) === (1.0, :y)

x = Dict(:x => :a, :w => :b, :y => :c, :z => :d)
y = Dict(:a => 3, :b => 2, :c => -1.0, :d => NaN)
f = k -> y[k]
@test NaNMath.findmin(f, x) === (-1.0, :y)
@test NaNMath.findmax(f, x) === (3, :x)
end

@testset "argmin/argmax" begin
if VERSION ≥ v"1.7"
xvals = [
[1., 2., 3., 3., 1.],
[missing, missing],
[missing, 1.0],
[1.0, missing],
(1., 2, 3., 3, 1),
(x=1, y=3, z=-4, w=-2),
Dict(:a => 1.0, :b => 1.0, :d => 3.0, :c => 2.0),
]
@testset for x in xvals
@test NaNMath.argmin(x) === argmin(x)
@test NaNMath.argmax(x) === argmax(x)
@test NaNMath.argmin(identity, x) === argmin(identity, x)
@test NaNMath.argmax(identity, x) === argmax(identity, x)
x isa Dict || @test NaNMath.argmin(sin, x) === argmin(sin, x)
x isa Dict || @test NaNMath.argmax(sin, x) === argmax(sin, x)
end
end
x = [7, 7, NaN, 1, 1, NaN]
@test NaNMath.argmin(x) === 4
@test NaNMath.argmax(x) === 1
@test NaNMath.argmin(identity, x) === 1.0
@test NaNMath.argmax(identity, x) === 7.0
@test NaNMath.argmin(-, x) === 7.0
@test NaNMath.argmax(-, x) === 1.0

x = [NaN, NaN]
@test NaNMath.argmin(x) === 1
@test NaNMath.argmax(x) === 1
@test NaNMath.argmin(identity, x) === NaN
@test NaNMath.argmax(identity, x) === NaN
@test NaNMath.argmin(-, x) === NaN
@test NaNMath.argmax(-, x) === NaN

x = [3, missing, NaN, -1]
@test NaNMath.argmin(x) === 2
@test NaNMath.argmax(x) === 2
@test NaNMath.argmin(identity, x) === missing
@test NaNMath.argmax(identity, x) === missing
@test NaNMath.argmin(-, x) === missing
@test NaNMath.argmax(-, x) === missing

x = Dict(:x => 3, :w => 2, :z => -1.0, :y => NaN)
@test NaNMath.argmin(x) === :z
@test NaNMath.argmax(x) === :x
if VERSION ≥ v"1.7"
@test NaNMath.argmin(identity, x) === argmin(identity, x)
@test NaNMath.argmax(identity, x) === argmax(identity, x)
end

x = (:a, :b, :c, :d)
y = Dict(:a => 3, :b => 2, :c => -1.0, :d => NaN)
f = k -> y[k]
@test NaNMath.argmin(f, x) === :c
@test NaNMath.argmax(f, x) === :a
end