diff --git a/src/signalcorr.jl b/src/signalcorr.jl index 06c83ba1d..1553a5c37 100644 --- a/src/signalcorr.jl +++ b/src/signalcorr.jl @@ -629,3 +629,77 @@ end function pacf(x::AbstractVector{<:Real}, lags::AbstractVector{<:Integer}; method::Symbol=:regression) vec(pacf(reshape(x, length(x), 1), lags, method=method)) end + + +""" + yule_walker(x::Vector{<:Real}; + order::Int64=1, + method="adjusted", + df::Union{Nothing,Int64}=nothing, + inv=false, + demean=true, + ) + +Estimate AutoRegressive (AR) parameters using the Yule-Walker equations. This function +estimates AR parameters using the Yule-Walker equations. It supports different ACF +estimation methods (adjusted or maximum likelihood) and can optionally return the inverse +of the Toeplitz matrix. + +Example: +```julia +data = [0.1, 0.2, 0.3, 0.4, 0.5] +order = 2 +rho, sigma = yule_walker(data, order=order, method="mle") +``` +""" +function yule_walker( + x::Vector{<:Real}; + order::Int64=1, + method="adjusted", + df::Union{Nothing,Int64}=nothing, + inv=false, + demean=true, +) + method in ("adjusted", "mle") || + throw(ArgumentError("ACF estimation method must be 'adjusted' or 'MLE'")) + + x = copy(x) + if demean + x .-= mean(x) + end + n = isnothing(df) ? length(x) : df + + adj_needed = method == "adjusted" + + if ndims(x) > 1 || size(x, 2) != 1 + throw(ArgumentError("Expecting a vector to estimate AR parameters")) + end + + r = zeros(Float64, order + 1) + r[1] = sum(x .^ 2) / n + for k in 1:order + r[k + 1] = sum(x[1:(end - k)] .* x[(k + 1):end]) / (n - k * adj_needed) + end + R = Toeplitz(r[1:(end - 1)], conj(r[1:(end - 1)])) + + rho = 0 + try + rho = R \ r[2:end] + catch err + if occursin("Singular matrix", string(err)) + @warn "Matrix is singular. Using pinv." + rho = pinv(R) * r[2:end] + else + throw(err) + end + end + + sigmasq = r[1] - dot(r[2:end], rho) + sigma = isnan(sigmasq) || sigmasq <= 0 ? NaN : sqrt(sigmasq) + + if inv + return rho, sigma, inv(R) + else + return rho, sigma + end +end diff --git a/test/signalcorr.jl b/test/signalcorr.jl index bfbe90fed..8660df4c3 100644 --- a/test/signalcorr.jl +++ b/test/signalcorr.jl @@ -144,3 +144,86 @@ rpacfy = [-0.221173011668873, -0.175020669835420] @test pacf(x[:,1], 1:4, method=:yulewalker) ≈ rpacfy + +rho, sigma = yule_walker([1.0, 2, 3]; order=1) +@test rho == [0.0] + +rho, sigma = yule_walker([1.0, 2, 3]; order=2) +@test rho == [0.0, -1.5] + +x = [0.9901178, -0.74795127, 0.44612542, 1.1362954, -0.04040932] +rho, sigma = yule_walker(x; order=3, method="mle") +@test rho ≈ [-0.9418963, -0.90335955, -0.33267884] +@test sigma ≈ 0.44006365345695164 + +rho, sigma = yule_walker(x; order=3) +@test isapprox(rho, [0.10959317, 0.05242324, 1.06587676], atol=tol) +@test isapprox(sigma, 0.15860522671108127, atol=tol) + +rho, sigma = yule_walker(x; order=5, method="mle") +@test isapprox( + rho, [-1.24209771, -1.56893346, -1.16951484, -0.79844781, -0.27598787], atol=tol +) +@test isapprox(sigma, 0.3679474002175471, atol=tol) + +x = [ + 0.9901178, + -0.74795127, + 0.44612542, + 1.1362954, + -0.04040932, + 0.28625813, + 0.88901716, + -0.1079814, + -0.33231995, + 0.4607741, +] + +rho, sigma = yule_walker(x; order=3, method="mle") +@test isapprox( + rho, [-0.4896151627237206, -0.5724647370433921, 0.09083516892540627], atol=tol +) +@test isapprox(sigma, 0.4249693094713215, atol=tol) + +x = [ + 0.9901178, + -0.74795127, + 0.44612542, + 1.1362954, + -0.04040932, + 0.28625813, + 0.88901716, + -0.1079814, + -0.33231995, + 0.4607741, + 0.7729643, + -1.0998684, + 1.098167, + 1.0105597, + -1.3370227, + 1.239718, + -0.01393661, + -0.4790918, + 1.5009186, + -1.1647809, +] + +rho, sigma = yule_walker(x; order=3, method="mle") +@test isapprox(rho, [-0.82245705, -0.57029742, 0.12166898], atol=tol) +@test isapprox(sigma, 0.5203501608988023, atol=tol) + +rho, sigma = yule_walker(x; order=3) +@test isapprox(rho, [-0.93458149, -0.68653741, 0.10161722], atol=tol) +@test isapprox(sigma, 0.4269012058667671, atol=tol) + +rho, sigma = yule_walker(x; order=5, method="mle") +@test isapprox( + rho, [-0.83107755, -0.56407764, 0.20950143, 0.1232321, 0.10249279], atol=tol +) +@test isapprox(sigma, 0.5172269743102993, atol=tol) + +rho, sigma = yule_walker(x; order=5) +@test isapprox( + rho, [-0.96481241, -0.65359486, 0.31587079, 0.28403115, 0.1913565], atol=tol +) +@test isapprox(sigma, 0.41677565377507053, atol=tol)