Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make nystrom work with AbstractVector #427

Merged
merged 14 commits into from
Jan 28, 2022
54 changes: 34 additions & 20 deletions src/approximations/nystrom.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
# Following the algorithm by William and Seeger, 2001
# Cs is equivalent to X_mm and C to X_mn

function sampleindex(X::AbstractMatrix, r::Real; obsdim::Integer=defaultobs)
function sampleindex(X::AbstractVector, r::Real)
0 < r <= 1 || throw(ArgumentError("Sample rate `r` must be in range (0,1]"))
n = size(X, obsdim)
n = length(X)
m = ceil(Int, n * r)
S = StatsBase.sample(1:n, m; replace=false, ordered=true)
return S
end

function sampleindex(X::AbstractMatrix, r::Real; obsdim::Integer=defaultobs)
return sampleindex(vec_of_vecs(X; obsdim), r)
st-- marked this conversation as resolved.
Show resolved Hide resolved
end

function nystrom_sample(k::Kernel, X::AbstractVector, S::Vector{<:Integer})
st-- marked this conversation as resolved.
Show resolved Hide resolved
Xₘ = X[S]
st-- marked this conversation as resolved.
Show resolved Hide resolved
C = kernelmatrix(k, Xₘ, X)
Cs = C[:, S]
return (C, Cs)
end

function nystrom_sample(
k::Kernel, X::AbstractMatrix, S::Vector{<:Integer}; obsdim::Integer=defaultobs
)
obsdim ∈ [1, 2] ||
throw(ArgumentError("`obsdim` should be 1 or 2 (see docs of kernelmatrix))"))
Xₘ = obsdim == 1 ? X[S, :] : X[:, S]
C = kernelmatrix(k, Xₘ, X; obsdim=obsdim)
Cs = C[:, S]
return (C, Cs)
return nystrom_sample(k, vec_of_vecs(X; obsdim), S)
st-- marked this conversation as resolved.
Show resolved Hide resolved
end

function nystrom_pinv!(Cs::Matrix{T}, tol::T=eps(T) * size(Cs, 1)) where {T<:Real}
Expand Down Expand Up @@ -63,38 +69,46 @@ function NystromFact(W::Matrix{<:Real}, C::Matrix{<:Real})
end
theogf marked this conversation as resolved.
Show resolved Hide resolved

@doc raw"""
nystrom(k::Kernel, X::Matrix, S::Vector; obsdim::Int=defaultobs)
nystrom(k::Kernel, X::Vector, S::Vector)
st-- marked this conversation as resolved.
Show resolved Hide resolved

Computes a factorization of Nystrom approximation of the square kernel matrix of data
matrix `X` with respect to kernel `k`. Returns a `NystromFact` struct which stores a
Nystrom factorization satisfying:
Computes a factorization of Nystrom approximation of the square kernel matrix
st-- marked this conversation as resolved.
Show resolved Hide resolved
of data vector `X` with respect to kernel `k`, using indices `S`.
Returns a `NystromFact` struct which stores a Nystrom factorization satisfying:
```math
\mathbf{K} \approx \mathbf{C}^{\intercal}\mathbf{W}\mathbf{C}
```
"""
function nystrom(k::Kernel, X::AbstractMatrix, S::Vector{<:Integer}; obsdim::Int=defaultobs)
C, Cs = nystrom_sample(k, X, S; obsdim=obsdim)
function nystrom(k::Kernel, X::AbstractVector, S::Vector{<:Integer})
st-- marked this conversation as resolved.
Show resolved Hide resolved
C, Cs = nystrom_sample(k, X, S)
W = nystrom_pinv!(Cs)
return NystromFact(W, C)
end

@doc raw"""
nystrom(k::Kernel, X::Matrix, r::Real; obsdim::Int=defaultobs)
nystrom(k::Kernel, X::Vector, r::Real)
st-- marked this conversation as resolved.
Show resolved Hide resolved

Computes a factorization of Nystrom approximation of the square kernel matrix of data
matrix `X` with respect to kernel `k` using a sample ratio of `r`.
Computes a factorization of Nystrom approximation of the square kernel matrix
st-- marked this conversation as resolved.
Show resolved Hide resolved
of data vector `X` with respect to kernel `k` using a sample ratio of `r`.
Returns a `NystromFact` struct which stores a Nystrom factorization satisfying:
```math
\mathbf{K} \approx \mathbf{C}^{\intercal}\mathbf{W}\mathbf{C}
```
"""
function nystrom(k::Kernel, X::AbstractVector, r::Real)
S = sampleindex(X, r)
return nystrom(k, X, S)
end

function nystrom(k::Kernel, X::AbstractMatrix, S::Vector{<:Integer}; obsdim::Int=defaultobs)
st-- marked this conversation as resolved.
Show resolved Hide resolved
return nystrom(k, vec_of_vecs(X; obsdim), S)
st-- marked this conversation as resolved.
Show resolved Hide resolved
end

function nystrom(k::Kernel, X::AbstractMatrix, r::Real; obsdim::Int=defaultobs)
S = sampleindex(X, r; obsdim=obsdim)
return nystrom(k, X, S; obsdim=obsdim)
return nystrom(k, vec_of_vecs(X; obsdim), r)
st-- marked this conversation as resolved.
Show resolved Hide resolved
end

"""
nystrom(CᵀWC::NystromFact)
kernelmatrix(CᵀWC::NystromFact)

Compute the approximate kernel matrix based on the Nystrom factorization.
"""
Expand Down
8 changes: 7 additions & 1 deletion test/approximations/nystrom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
dims = [10, 5]
X = rand(dims...)
k = SqExponentialKernel()
for obsdim in [1, 2]
Xv = vec_of_vecs(X; obsdim)
st-- marked this conversation as resolved.
Show resolved Hide resolved
@assert Xv isa Union{ColVecs,RowVecs}
@test kernelmatrix(k, Xv) ≈ kernelmatrix(nystrom(k, Xv, 1.0))
@test kernelmatrix(k, Xv) ≈ kernelmatrix(nystrom(k, Xv, collect(1:dims[obsdim])))
end
for obsdim in [1, 2]
@test kernelmatrix(k, X; obsdim=obsdim) ≈
kernelmatrix(nystrom(k, X, 1.0; obsdim=obsdim))
kernelmatrix(nystrom(k, X, 1.0; obsdim=obsdim))
st-- marked this conversation as resolved.
Show resolved Hide resolved
@test kernelmatrix(k, X; obsdim=obsdim) ≈
kernelmatrix(nystrom(k, X, collect(1:dims[obsdim]); obsdim=obsdim))
st-- marked this conversation as resolved.
Show resolved Hide resolved
end
Expand Down