diff --git a/src/weights.jl b/src/weights.jl index 11625a723..29c74b2d3 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -200,21 +200,35 @@ end @weights ExponentialWeights false """ - ExponentialWeights + ExponentialWeights(vs) -# Fields +Construct an `ExponentialWeights` vector with weight values `vs`, which must sum to 1. -* `λ::Float64`: is a smoothing factor or rate paremeter between 0 .. 1. - As this value approaches 0 the resulting weights will be almost equal(), - while values closer to 1 will put higher weight on the end elements of the vector. +Exponential weights are a common form of temporal weights which assign exponentially +greater weight to past observations, which in this case corresponds to the tail end of +the vector. +""" +function ExponentialWeights(vs::V) where {T<:Real, V<:AbstractVector{T}} + s = sum(vs) + s ≈ one(T) || throw(ArgumentError("weight values do not sum to 1 (got $s)")) + ExponentialWeights{T, T, V}(vs, s) +end -When called with a desired length `n` (`Int`) a vector of length `n` will -be returned, where each element is set to `λ * (1 - λ)^(1 - i)`. +""" + eweights(n, λ) -# Usage +Construct an [`ExponentialWeights`](@ref) vector with length `n`, +where each element in position ``i`` is set to ``λ (1 - λ)^{1 - i}``. +The entire set of weights are then normalized to sum to 1. -```julia -w = ExponentialWeights(10, 0.3) +``λ`` is a smoothing factor or rate parameter such that ``0 < λ \\leq 1``. +As this value approaches 0, the resulting weights will be almost equal, +while values closer to 1 will put greater weight on the tail elements of the vector. + +# Examples + +```julia-repl +julia> eweights(10, 0.3) 10-element ExponentialWeights{Float64,Float64,Array{Float64,1}}: 0.012458 0.0177971 @@ -228,41 +242,35 @@ w = ExponentialWeights(10, 0.3) 0.308721 ``` """ -function ExponentialWeights(vs::V) where {T<:Real, V<:AbstractVector{T}} - s = sum(vs) - s ≈ one(T) || throw(ArgumentError("weight values do not sum to 1 (got $s)")) - ExponentialWeights{T, T, V}(vs, s) -end - -function ExponentialWeights(n::Integer, λ::Real) - n > 0 || throw(ArgumentError("cannot construct weights of length < 1")) +function eweights(n::Integer, λ::Real) + n > 0 || throw(ArgumentError("cannot construct exponential weights of length < 1")) 0 < λ <= 1 || throw(ArgumentError("smoothing factor must be between 0 and 1")) w0 = map(i -> λ * (1 - λ)^(1 - i), 1:n) s = sum(w0) - ExponentialWeights(w0 / s) + w0 ./= s + ExponentialWeights{typeof(s), eltype(w0), typeof(w0)}(w0, s) end """ - eweights(n, λ) - -Construct an `ExponentialWeights` vector with length `n`, -where each element in position ``i`` is set to ``λ * (1 - λ)^(1 - i)``. -The entire set of weights are then normalized so that they sum to 1.0 + eweights(vs) -``λ`` is a smoothing factor or rate parameter between 0 and 1. -As this value approaches 0 the resulting weights will be almost equal, -while values closer to 1 will put higher weight on the end elements of the vector. +Construct an [`ExponentialWeights`](@ref) vector using the given array. """ -eweights(n::Integer, λ::Real) = ExponentialWeights(n, λ) +eweights(v::RealVector) = ExponentialWeights(v) +eweights(v::RealArray) = ExponentialWeights(vec(v)) """ varcorrection(w::ExponentialWeights, corrected=false) * `corrected=true`: ``\\frac{1}{1 - \\sum {w^2}}`` -* `corrected=false`: ``1.0`` +* `corrected=false`: ``1`` """ @inline function varcorrection(w::ExponentialWeights, corrected::Bool=false) - corrected ? 1 / (1 - sum(x -> x^2, w)) : 1.0 + if corrected + 1 / (1 - sum(abs2, w.values)) + else + 1 / one(w.sum) # just 1 promoted to the same type as the other branch + end end ##### Equality tests ##### diff --git a/test/weights.jl b/test/weights.jl index a77dff2c3..7a6206784 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -504,7 +504,8 @@ end θ = 5.25 λ = 1 - exp(-1 / θ) # simple conversion for the more common/readable method - w = ExponentialWeights(4, λ) + v = [λ*(1-λ)^(1-i) for i = 1:4] + w = ExponentialWeights(v ./ sum(v)) @test round.(w, digits=4) == [0.1837, 0.2222, 0.2688, 0.3253] @test eweights(4, λ) ≈ w @@ -513,6 +514,8 @@ end @testset "Failure Conditions" begin @test_throws ArgumentError eweights(0, 0.3) @test_throws ArgumentError eweights(1, 1.1) + @test_throws ArgumentError eweights(rand(4)) + @test_throws ArgumentError eweights(rand(4, 4)) @test_throws ArgumentError ExponentialWeights(rand(4)) end