Skip to content

Commit

Permalink
Make var and std work for Vector{Vector{T}} (#23897)
Browse files Browse the repository at this point in the history
* Make var and std work for Vector{Vector{T}} by removing Number restriction
from some signatures as well as using broadcasting in std. Fixes #23884

* Make cov work for Vector{Vector}
  • Loading branch information
andreasnoack authored Oct 2, 2017
1 parent f2fd1f8 commit b10833b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
38 changes: 20 additions & 18 deletions base/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ function var(iterable; corrected::Bool=true, mean=nothing)
end
end

centralizedabs2fun(m::Number) = x -> abs2(x - m)
centralize_sumabs2(A::AbstractArray, m::Number) =
centralizedabs2fun(m) = x -> abs2.(x - m)
centralize_sumabs2(A::AbstractArray, m) =
mapreduce(centralizedabs2fun(m), +, A)
centralize_sumabs2(A::AbstractArray, m::Number, ifirst::Int, ilast::Int) =
centralize_sumabs2(A::AbstractArray, m, ifirst::Int, ilast::Int) =
mapreduce_impl(centralizedabs2fun(m), +, A, ifirst, ilast)

function centralize_sumabs2!(R::AbstractArray{S}, A::AbstractArray, means::AbstractArray) where S
Expand Down Expand Up @@ -164,7 +164,7 @@ function centralize_sumabs2!(R::AbstractArray{S}, A::AbstractArray, means::Abstr
return R
end

function varm(A::AbstractArray{T}, m::Number; corrected::Bool=true) where T
function varm(A::AbstractArray{T}, m; corrected::Bool=true) where T
n = _length(A)
n == 0 && return typeof((abs2(zero(T)) + abs2(zero(T)))/2)(NaN)
return centralize_sumabs2(A, m) / (n - Int(corrected))
Expand Down Expand Up @@ -219,12 +219,12 @@ The mean `mean` over the region may be provided.
var(A::AbstractArray, region; corrected::Bool=true, mean=nothing) =
varm(A, mean === nothing ? Base.mean(A, region) : mean, region; corrected=corrected)

varm(iterable, m::Number; corrected::Bool=true) =
varm(iterable, m; corrected::Bool=true) =
var(iterable, corrected=corrected, mean=m)

## variances over ranges

function varm(v::AbstractRange, m::Number)
function varm(v::AbstractRange, m)
f = first(v) - m
s = step(v)
l = length(v)
Expand Down Expand Up @@ -255,11 +255,11 @@ function sqrt!(A::AbstractArray)
A
end

stdm(A::AbstractArray, m::Number; corrected::Bool=true) =
sqrt(varm(A, m; corrected=corrected))
stdm(A::AbstractArray, m; corrected::Bool=true) =
sqrt.(varm(A, m; corrected=corrected))

std(A::AbstractArray; corrected::Bool=true, mean=nothing) =
sqrt(var(A; corrected=corrected, mean=mean))
sqrt.(var(A; corrected=corrected, mean=mean))

"""
std(v[, region]; corrected::Bool=true, mean=nothing)
Expand All @@ -284,7 +284,7 @@ std(iterable; corrected::Bool=true, mean=nothing) =
sqrt(var(iterable, corrected=corrected, mean=mean))

"""
stdm(v, m::Number; corrected::Bool=true)
stdm(v, m; corrected::Bool=true)
Compute the sample standard deviation of a vector `v`
with known mean `m`. If `corrected` is `true`,
Expand All @@ -296,7 +296,7 @@ scaled with `n` if `corrected` is `false` where `n = length(x)`.
applications requiring the handling of missing data, the
`DataArrays.jl` package is recommended.
"""
stdm(iterable, m::Number; corrected::Bool=true) =
stdm(iterable, m; corrected::Bool=true) =
std(iterable, corrected=corrected, mean=m)


Expand All @@ -321,7 +321,8 @@ _vmean(x::AbstractMatrix, vardim::Int) = mean(x, vardim)

# core functions

unscaled_covzm(x::AbstractVector) = sum(abs2, x)
unscaled_covzm(x::AbstractVector{<:Number}) = sum(abs2, x)
unscaled_covzm(x::AbstractVector) = sum(t -> t*t', x)
unscaled_covzm(x::AbstractMatrix, vardim::Int) = (vardim == 1 ? _conj(x'x) : x * x')

unscaled_covzm(x::AbstractVector, y::AbstractVector) = dot(y, x)
Expand Down Expand Up @@ -349,13 +350,14 @@ function covzm(x::AbstractVecOrMat, y::AbstractVecOrMat, vardim::Int=1; correcte
end

# covm (with provided mean)

## Use map(t -> t - xmean, x) instead of x .- xmean to allow for Vector{Vector}
## which can't be handled by broadcast
covm(x::AbstractVector, xmean; corrected::Bool=true) =
covzm(x .- xmean; corrected=corrected)
covzm(map(t -> t - xmean, x); corrected=corrected)
covm(x::AbstractMatrix, xmean, vardim::Int=1; corrected::Bool=true) =
covzm(x .- xmean, vardim; corrected=corrected)
covm(x::AbstractVector, xmean, y::AbstractVector, ymean; corrected::Bool=true) =
covzm(x .- xmean, y .- ymean; corrected=corrected)
covzm(map(t -> t - xmean, x), map(t -> t - ymean, y); corrected=corrected)
covm(x::AbstractVecOrMat, xmean, y::AbstractVecOrMat, ymean, vardim::Int=1; corrected::Bool=true) =
covzm(x .- xmean, y .- ymean, vardim; corrected=corrected)

Expand Down Expand Up @@ -425,7 +427,7 @@ function cov2cor!(C::AbstractMatrix{T}, xsd::AbstractArray) where T
end
return C
end
function cov2cor!(C::AbstractMatrix, xsd::Number, ysd::AbstractArray)
function cov2cor!(C::AbstractMatrix, xsd, ysd::AbstractArray)
nx, ny = size(C)
length(ysd) == ny || throw(DimensionMismatch("inconsistent dimensions"))
for (j, y) in enumerate(ysd) # fixme (iter): here and in all `cov2cor!` we assume that `C` is efficiently indexed by integers
Expand All @@ -435,7 +437,7 @@ function cov2cor!(C::AbstractMatrix, xsd::Number, ysd::AbstractArray)
end
return C
end
function cov2cor!(C::AbstractMatrix, xsd::AbstractArray, ysd::Number)
function cov2cor!(C::AbstractMatrix, xsd::AbstractArray, ysd)
nx, ny = size(C)
length(xsd) == nx || throw(DimensionMismatch("inconsistent dimensions"))
for j in 1:ny
Expand Down Expand Up @@ -475,7 +477,7 @@ corzm(x::AbstractMatrix, y::AbstractMatrix, vardim::Int=1) =

corm(x::AbstractVector{T}, xmean) where {T} = one(real(T))
corm(x::AbstractMatrix, xmean, vardim::Int=1) = corzm(x .- xmean, vardim)
function corm(x::AbstractVector, mx::Number, y::AbstractVector, my::Number)
function corm(x::AbstractVector, mx, y::AbstractVector, my)
n = length(x)
length(y) == n || throw(DimensionMismatch("inconsistent lengths"))
n > 0 || throw(ArgumentError("correlation only defined for non-empty vectors"))
Expand Down
7 changes: 7 additions & 0 deletions test/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -454,3 +454,10 @@ end
@test isequal(mean(a01, 1) , fill(NaN, 1, 1))
@test isequal(mean(a10, 2) , fill(NaN, 1, 1))
end

@testset "cov/var/std of Vector{Vector}" begin
x = [[2,4,6],[4,6,8]]
@test var(x) vec(var([x[1] x[2]], 2))
@test std(x) vec(std([x[1] x[2]], 2))
@test cov(x) cov([x[1] x[2]], 2)
end

0 comments on commit b10833b

Please sign in to comment.