Skip to content

Commit

Permalink
make nystrom work with AbstractVector (#427)
Browse files Browse the repository at this point in the history
* make nystrom work with AbstractVector

* add test

* Update test/approximations/nystrom.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* patch bump

* Update test/approximations/nystrom.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Apply suggestions from code review

* deprecate

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Apply suggestions from code review

Co-authored-by: Théo Galy-Fajou <[email protected]>

* Update src/approximations/nystrom.jl

Co-authored-by: Théo Galy-Fajou <[email protected]>

* Update src/approximations/nystrom.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: Théo Galy-Fajou <[email protected]>
  • Loading branch information
4 people authored Jan 28, 2022
1 parent d1c68a9 commit f9bbd84
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.27"
version = "0.10.28"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
56 changes: 35 additions & 21 deletions src/approximations/nystrom.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
# Following the algorithm by William and Seeger, 2001
# Cs is equivalent to X_mm and C to X_mn

function sampleindex(X::AbstractMatrix, r::Real; obsdim::Integer=defaultobs)
function sampleindex(X::AbstractVector, r::Real)
0 < r <= 1 || throw(ArgumentError("Sample rate `r` must be in range (0,1]"))
n = size(X, obsdim)
n = length(X)
m = ceil(Int, n * r)
S = StatsBase.sample(1:n, m; replace=false, ordered=true)
return S
end

function nystrom_sample(
k::Kernel, X::AbstractMatrix, S::Vector{<:Integer}; obsdim::Integer=defaultobs
)
obsdim [1, 2] ||
throw(ArgumentError("`obsdim` should be 1 or 2 (see docs of kernelmatrix))"))
Xₘ = obsdim == 1 ? X[S, :] : X[:, S]
C = kernelmatrix(k, Xₘ, X; obsdim=obsdim)
@deprecate sampleindex(X::AbstractMatrix, r::Real; obsdim::Integer=defaultobs) sampleindex(
vec_of_vecs(X; obsdim=obsdim), r
) false

function nystrom_sample(k::Kernel, X::AbstractVector, S::AbstractVector{<:Integer})
Xₘ = @view X[S]
C = kernelmatrix(k, Xₘ, X)
Cs = C[:, S]
return (C, Cs)
end

@deprecate nystrom_sample(
k::Kernel, X::AbstractMatrix, S::Vector{<:Integer}; obsdim::Integer=defaultobs
) nystrom_sample(k, vec_of_vecs(X; obsdim=obsdim), S) false

function nystrom_pinv!(Cs::Matrix{T}, tol::T=eps(T) * size(Cs, 1)) where {T<:Real}
# Compute eigendecomposition of sampled component of K
QΛQᵀ = LinearAlgebra.eigen!(LinearAlgebra.Symmetric(Cs))
Expand Down Expand Up @@ -63,38 +67,48 @@ function NystromFact(W::Matrix{<:Real}, C::Matrix{<:Real})
end

@doc raw"""
nystrom(k::Kernel, X::Matrix, S::Vector; obsdim::Int=defaultobs)
nystrom(k::Kernel, X::AbstractVector, S::AbstractVector{<:Integer})
Computes a factorization of Nystrom approximation of the square kernel matrix of data
matrix `X` with respect to kernel `k`. Returns a `NystromFact` struct which stores a
Nystrom factorization satisfying:
Compute a factorization of a Nystrom approximation of the square kernel matrix
of data vector `X` with respect to kernel `k`, using indices `S`.
Returns a `NystromFact` struct which stores a Nystrom factorization satisfying:
```math
\mathbf{K} \approx \mathbf{C}^{\intercal}\mathbf{W}\mathbf{C}
```
"""
function nystrom(k::Kernel, X::AbstractMatrix, S::Vector{<:Integer}; obsdim::Int=defaultobs)
C, Cs = nystrom_sample(k, X, S; obsdim=obsdim)
function nystrom(k::Kernel, X::AbstractVector, S::AbstractVector{<:Integer})
C, Cs = nystrom_sample(k, X, S)
W = nystrom_pinv!(Cs)
return NystromFact(W, C)
end

@doc raw"""
nystrom(k::Kernel, X::Matrix, r::Real; obsdim::Int=defaultobs)
nystrom(k::Kernel, X::AbstractVector, r::Real)
Computes a factorization of Nystrom approximation of the square kernel matrix of data
matrix `X` with respect to kernel `k` using a sample ratio of `r`.
Compute a factorization of a Nystrom approximation of the square kernel matrix
of data vector `X` with respect to kernel `k` using a sample ratio of `r`.
Returns a `NystromFact` struct which stores a Nystrom factorization satisfying:
```math
\mathbf{K} \approx \mathbf{C}^{\intercal}\mathbf{W}\mathbf{C}
```
"""
function nystrom(k::Kernel, X::AbstractVector, r::Real)
S = sampleindex(X, r)
return nystrom(k, X, S)
end

function nystrom(
k::Kernel, X::AbstractMatrix, S::AbstractVector{<:Integer}; obsdim::Int=defaultobs
)
return nystrom(k, vec_of_vecs(X; obsdim=obsdim), S)
end

function nystrom(k::Kernel, X::AbstractMatrix, r::Real; obsdim::Int=defaultobs)
S = sampleindex(X, r; obsdim=obsdim)
return nystrom(k, X, S; obsdim=obsdim)
return nystrom(k, vec_of_vecs(X; obsdim=obsdim), r)
end

"""
nystrom(CᵀWC::NystromFact)
kernelmatrix(CᵀWC::NystromFact)
Compute the approximate kernel matrix based on the Nystrom factorization.
"""
Expand Down
10 changes: 8 additions & 2 deletions test/approximations/nystrom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@
dims = [10, 5]
X = rand(dims...)
k = SqExponentialKernel()
for obsdim in [1, 2]
Xv = vec_of_vecs(X; obsdim=obsdim)
@assert Xv isa Union{ColVecs,RowVecs}
@test kernelmatrix(k, Xv) kernelmatrix(nystrom(k, Xv, 1.0))
@test kernelmatrix(k, Xv) kernelmatrix(nystrom(k, Xv, collect(1:dims[obsdim])))
end
for obsdim in [1, 2]
@test kernelmatrix(k, X; obsdim=obsdim)
kernelmatrix(nystrom(k, X, 1.0; obsdim=obsdim))
kernelmatrix(nystrom(k, X, 1.0; obsdim=obsdim))
@test kernelmatrix(k, X; obsdim=obsdim)
kernelmatrix(nystrom(k, X, collect(1:dims[obsdim]); obsdim=obsdim))
kernelmatrix(nystrom(k, X, collect(1:dims[obsdim]); obsdim=obsdim))
end
end

4 comments on commit f9bbd84

@devmotion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/53379

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.28 -m "<description of version>" f9bbd84beb487c8335adf55034cd78a1279681c1
git push origin v0.10.28

Also, note the warning: Version 0.10.28 skips over 0.10.27
This can be safely ignored. However, if you want to fix this you can do so. Call register() again after making the fix. This will update the Pull request.

@devmotion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/53379

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.28 -m "<description of version>" f9bbd84beb487c8335adf55034cd78a1279681c1
git push origin v0.10.28

Please sign in to comment.