Skip to content

Commit

Permalink
add median tests from Statistics.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
stev47 committed Nov 23, 2024
1 parent 2461b2f commit e4ee876
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 3 deletions.
15 changes: 13 additions & 2 deletions ext/StaticArraysStatisticsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ _mean_denom(a, ::Val{D}) where {D} = size(a, D)
@inline mean(a::StaticArray; dims=:) = _reduce(+, a, dims) / _mean_denom(a, dims)
@inline mean(f::Function, a::StaticArray; dims=:) = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) / _mean_denom(a, dims)

@inline function median(a::StaticArray; dims = :)
if dims == Colon()
median(vec(a))
else
# FIXME: Implement `mapslices` correctly on `StaticArray` to remove
# this fallback.
median(Array(a); dims)
end
end

@inline function median(a::StaticVector)
(isimmutable(a) && length(a) <= _bitonic_sort_limit) ||
return median!(Base.copymutable(a))
Expand All @@ -24,8 +34,9 @@ _mean_denom(a, ::Val{D}) where {D} = size(a, D)
throw(ArgumentError("median of empty vector is undefined, $(repr(a))"))
eltype(a) >: Missing && any(ismissing, a) &&
return missing
any(x -> x isa Number && isnan(x), a) &&
return convert(eltype(a), NaN)
nanix = findfirst(x -> x isa Number && isnan(x), a)
isnothing(nanix) ||
return a[nanix]

order = ord(isless, identity, nothing, Forward)
sa = _sort(Tuple(a), BitonicSort, order)
Expand Down
63 changes: 62 additions & 1 deletion test/sort.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using StaticArrays, Test
using Statistics: median
using Statistics: Statistics, median, median!, middle

@testset "sort" begin

Expand Down Expand Up @@ -44,6 +44,67 @@ using Statistics: median
@test @inferred(median(v) == mref)
end
end

# Tests based on upstream `Statistics.jl`.
# https://github.com/JuliaStats/Statistics.jl/blob/d49c2bf4f81e1efb4980a35fe39c815ef8396297/test/runtests.jl#L31-L92
@test median(SA[1.]) === 1.
@test median(SA[1.,3]) === 2.
@test median(SA[1.,3,2]) === 2.

@test median(SA[1,3,2]) === 2.0
@test median(SA[1,3,2,4]) === 2.5

@test median(SA[0.0,Inf]) == Inf
@test median(SA[0.0,-Inf]) == -Inf
@test median(SA[0.,Inf,-Inf]) == 0.0
@test median(SA[1.,-1.,Inf,-Inf]) == 0.0
@test isnan(median(SA[-Inf,Inf]))

X = SA[2 3 1 -1; 7 4 5 -4]
@test all(median(X, dims=2) .== SA[1.5, 4.5])
@test all(median(X, dims=1) .== SA[4.5 3.5 3.0 -2.5])
@test X == SA[2 3 1 -1; 7 4 5 -4] # issue #17153

@test_throws ArgumentError median(SA[])
@test isnan(median(SA[NaN]))
@test isnan(median(SA[0.0,NaN]))
@test isnan(median(SA[NaN,0.0]))
@test isnan(median(SA[NaN,0.0,1.0]))
@test isnan(median(SA{Any}[NaN,0.0,1.0]))
@test isequal(median(SA[NaN 0.0; 1.2 4.5], dims=2), reshape(SA[NaN; 2.85], 2, 1))

# the specific NaN value is propagated from the input
@test median(SA[NaN]) === NaN
@test median(SA[0.0,NaN]) === NaN
@test median(SA[0.0,NaN,NaN]) === NaN
@test median(SA[-NaN]) === -NaN
@test median(SA[0.0,-NaN]) === -NaN
@test median(SA[0.0,-NaN,-NaN]) === -NaN

@test ismissing(median(SA[1, missing]))
@test ismissing(median(SA[1, 2, missing]))
@test ismissing(median(SA[NaN, 2.0, missing]))
@test ismissing(median(SA[NaN, missing]))
@test ismissing(median(SA[missing, NaN]))
@test ismissing(median(SA{Any}[missing, 2.0, 3.0, 4.0, NaN]))
@test median(skipmissing(SA[1, missing, 2])) === 1.5

@test median!(Base.copymutable(SA[1 2 3 4])) == 2.5
@test median!(Base.copymutable(SA[1 2; 3 4])) == 2.5

@test @inferred(median(SA{Float16}[1, 2, NaN])) === Float16(NaN)
@test @inferred(median(SA{Float16}[1, 2, 3])) === Float16(2)
@test @inferred(median(SA{Float32}[1, 2, NaN])) === NaN32
@test @inferred(median(SA{Float32}[1, 2, 3])) === 2.0f0

# custom type implementing minimal interface
struct A
x
end
Statistics.middle(x::A, y::A) = A(middle(x.x, y.x))
Base.isless(x::A, y::A) = isless(x.x, y.x)
@test median(SA[A(1), A(2)]) === A(1.5)
@test median(SA{Any}[A(1), A(2)]) === A(1.5)
end

end

0 comments on commit e4ee876

Please sign in to comment.