From 87f372ca912e5da94cbd7531d2fa5724832df98a Mon Sep 17 00:00:00 2001 From: Alexander Plavin Date: Mon, 17 Jun 2024 14:11:16 -0400 Subject: [PATCH] functions that only support finite weights now throw errors for non-finites (#914) * throw errors when only finite weights are supported * remove extra calls to sum(wv) * typo * typo * add a minimal test for custom weights implementations * fix new test on 1.0 --- src/sampling.jl | 11 +++++++++-- src/scalarstats.jl | 2 ++ src/weights.jl | 1 + test/weights.jl | 16 ++++++++++++++++ 4 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/sampling.jl b/src/sampling.jl index 609c7d48b..4551476ae 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -591,7 +591,9 @@ Optionally specify a random number generator `rng` as the first argument function sample(rng::AbstractRNG, wv::AbstractWeights) 1 == firstindex(wv) || throw(ArgumentError("non 1-based arrays are not supported")) - t = rand(rng) * sum(wv) + wsum = sum(wv) + isfinite(wsum) || throw(ArgumentError("only finite weights are supported")) + t = rand(rng) * wsum n = length(wv) i = 1 cw = wv[1] @@ -654,6 +656,7 @@ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, throw(ArgumentError("output array x must not share memory with input array a")) 1 == firstindex(a) == firstindex(wv) || throw(ArgumentError("non 1-based arrays are not supported")) + isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) length(wv) == length(a) || throw(DimensionMismatch("Inconsistent lengths.")) # create alias table @@ -688,13 +691,14 @@ function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray, throw(ArgumentError("output array x must not share memory with weights array wv")) 1 == firstindex(a) == firstindex(wv) == firstindex(x) || throw(ArgumentError("non 1-based arrays are not supported")) + wsum = sum(wv) + isfinite(wsum) || throw(ArgumentError("only finite weights are supported")) n = length(a) length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) k = length(x) w = Vector{Float64}(undef, n) copyto!(w, wv) - wsum = sum(wv) for i = 1:k u = rand(rng) * wsum @@ -734,6 +738,7 @@ function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray, throw(ArgumentError("output array x must not share memory with weights array wv")) 1 == firstindex(a) == firstindex(wv) == firstindex(x) || throw(ArgumentError("non 1-based arrays are not supported")) + isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) n = length(a) length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) k = length(x) @@ -775,6 +780,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, throw(ArgumentError("output array x must not share memory with weights array wv")) 1 == firstindex(a) == firstindex(wv) == firstindex(x) || throw(ArgumentError("non 1-based arrays are not supported")) + isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) n = length(a) length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) k = length(x) @@ -848,6 +854,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, throw(ArgumentError("output array x must not share memory with weights array wv")) 1 == firstindex(a) == firstindex(wv) == firstindex(x) || throw(ArgumentError("non 1-based arrays are not supported")) + isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) n = length(a) length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) k = length(x) diff --git a/src/scalarstats.jl b/src/scalarstats.jl index 1671bd01c..2936f9276 100644 --- a/src/scalarstats.jl +++ b/src/scalarstats.jl @@ -163,6 +163,7 @@ end # Weighted mode of arbitrary vectors of values function mode(a::AbstractVector, wv::AbstractWeights{T}) where T <: Real isempty(a) && throw(ArgumentError("mode is not defined for empty collections")) + isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) length(a) == length(wv) || throw(ArgumentError("data and weight vectors must be the same size, got $(length(a)) and $(length(wv))")) @@ -184,6 +185,7 @@ end function modes(a::AbstractVector, wv::AbstractWeights{T}) where T <: Real isempty(a) && throw(ArgumentError("mode is not defined for empty collections")) + isfinite(sum(wv)) || throw(ArgumentError("only finite weights are supported")) length(a) == length(wv) || throw(ArgumentError("data and weight vectors must be the same size, got $(length(a)) and $(length(wv))")) diff --git a/src/weights.jl b/src/weights.jl index f5f515104..ec415d3d9 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -716,6 +716,7 @@ function quantile(v::AbstractVector{<:Real}{V}, w::AbstractWeights{W}, p::Abstra # checks isempty(v) && throw(ArgumentError("quantile of an empty array is undefined")) isempty(p) && throw(ArgumentError("empty quantile array")) + isfinite(sum(w)) || throw(ArgumentError("only finite weights are supported")) all(x -> 0 <= x <= 1, p) || throw(ArgumentError("input probability out of [0,1] range")) w.sum == 0 && throw(ArgumentError("weight vector cannot sum to zero")) diff --git a/test/weights.jl b/test/weights.jl index 2180c88a4..76277e02f 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -1,6 +1,15 @@ using StatsBase using LinearAlgebra, Random, SparseArrays, Test + +# minimal custom weights type for tests below +struct MyWeights <: AbstractWeights{Float64, Float64, Vector{Float64}} + values::Vector{Float64} + sum::Float64 +end +MyWeights(values) = MyWeights(values, sum(values)) + + @testset "StatsBase.Weights" begin weight_funcs = (weights, aweights, fweights, pweights) @@ -610,4 +619,11 @@ end end end +@testset "custom weight types" begin + @test mean([1, 2, 3], MyWeights([1, 4, 10])) ≈ 2.6 + @test mean([1, 2, 3], MyWeights([NaN, 4, 10])) |> isnan + @test mode([1, 2, 3], MyWeights([1, 4, 10])) == 3 + @test_throws ArgumentError mode([1, 2, 3], MyWeights([NaN, 4, 10])) +end + end # @testset StatsBase.Weights