From e7943839b6de4c86e1040aeee447d6f7637d3ec5 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 26 Sep 2023 01:09:06 +0200 Subject: [PATCH] Implement `Base.allequal` and `Base.allunique` for weight vectors (#894) * Implement `Base.allequal` and `Base.allunique` for weight vectors * Combine definitions --- Project.toml | 2 +- src/weights.jl | 8 ++++++++ test/weights.jl | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6d8d48303..5f51f9c7f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "StatsBase" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" authors = ["JuliaStats"] -version = "0.34.1" +version = "0.34.2" [deps] DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" diff --git a/src/weights.jl b/src/weights.jl index cf535d408..f5f515104 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -385,6 +385,14 @@ Base.:(==)(x::UnitWeights, y::UnitWeights) = (x.len == y.len) Base.isequal(x::AbstractWeights, y::AbstractWeights) = false Base.:(==)(x::AbstractWeights, y::AbstractWeights) = false +# https://github.com/JuliaLang/julia/pull/43354 +if VERSION >= v"1.8.0-DEV.1494" # 98e60ffb11ee431e462b092b48a31a1204bd263d + Base.allequal(wv::AbstractWeights) = allequal(wv.values) + Base.allequal(::UnitWeights) = true +end +Base.allunique(wv::AbstractWeights) = allunique(wv.values) +Base.allunique(wv::UnitWeights) = length(wv) <= 1 + ##### Weighted sum ##### ## weighted sum over vectors diff --git a/test/weights.jl b/test/weights.jl index 52142efd8..2180c88a4 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -574,4 +574,40 @@ end end end +@testset "allequal and allunique" begin + # General weights + for f in (weights, aweights, fweights, pweights) + @test allunique(f(Float64[])) + @test allunique(f([0.4])) + @test allunique(f([0.4, 0.3])) + @test !allunique(f([0.4, 0.4])) + @test allunique(f([0.4, 0.3, 0.5])) + @test !allunique(f([0.4, 0.4, 0.5])) + @test allunique(f([0.4, 0.3, 0.5, 0.35])) + @test !allunique(f([0.4, 0.3, 0.5, 0.4])) + + if isdefined(Base, :allequal) + @test allequal(f(Float64[])) + @test allequal(f([0.4])) + @test allequal(f([0.4, 0.4])) + @test !allequal(f([0.4, 0.3])) + @test allequal(f([0.4, 0.4, 0.4, 0.4])) + @test !allunique(f([0.4, 0.4, 0.3, 0.4])) + end + end + + # Uniform weights + @test allunique(uweights(0)) + @test allunique(uweights(1)) + @test !allunique(uweights(2)) + @test !allunique(uweights(5)) + + if isdefined(Base, :allequal) + @test allequal(uweights(0)) + @test allequal(uweights(1)) + @test allequal(uweights(2)) + @test allequal(uweights(5)) + end +end + end # @testset StatsBase.Weights