Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix MaternKernel AD, but remove differentiation wrt \nu #425

Merged
merged 21 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/basekernels/matern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``.
A Gaussian process with a Matérn kernel is ``\\lceil \\nu \\rceil - 1``-times
differentiable in the mean-square sense.

!!! note

Differentiation with respect to the order ν is not currently supported.

See also: [`Matern12Kernel`](@ref), [`Matern32Kernel`](@ref), [`Matern52Kernel`](@ref)
"""
struct MaternKernel{Tν<:Real,M} <: SimpleKernel
Expand All @@ -33,8 +37,19 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν,

@functor MaternKernel

@inline function kappa(κ::MaternKernel, d::Real)
result = _matern(only(κ.ν), d)
# Work-around for Zygote -- `NotImplemented` doesn't appear to play nicely with whatever
# rule currently exists for `only`.
_get_ν(k::MaternKernel) = only(k.ν)
function ChainRulesCore.rrule(::typeof(_get_ν), k::T) where {T<:MaternKernel}
function _get_ν_pullback(Δ)
dν = ChainRulesCore.@not_implemented("Derivatives w.r.t. ν are not implemented.")
return Tangent{T}(ν=dν, metric=NoTangent())
st-- marked this conversation as resolved.
Show resolved Hide resolved
end
return _get_ν(k), _get_ν_pullback
end

@inline function kappa(k::MaternKernel, d::Real)
result = _matern(_get_ν(k), d)
return ifelse(iszero(d), one(result), result)
end

Expand Down
2 changes: 1 addition & 1 deletion src/matrix/kernelkroneckermat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ where `D` is given by `dims`.

!!! warning

Require `Kronecker.jl` and for `iskroncompatible(κ)` to return `true`.
Requires `Kronecker.jl` and for `iskroncompatible(κ)` to return `true`.
st-- marked this conversation as resolved.
Show resolved Hide resolved
"""
function kernelkronmat(κ::Kernel, X::AbstractVector{<:Real}, dims::Int)
checkkroncompatible(κ)
Expand Down
7 changes: 3 additions & 4 deletions test/basekernels/matern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,24 @@
v1 = rand(rng, 3)
v2 = rand(rng, 3)
@testset "MaternKernel" begin
ν = 2.0
ν = 2.1
st-- marked this conversation as resolved.
Show resolved Hide resolved
k = MaternKernel(; ν=ν)
matern(x, ν) = 2^(1 - ν) / gamma(ν) * (sqrt(2ν) * x)^ν * besselk(ν, sqrt(2ν) * x)
@test MaternKernel(; nu=ν).ν == [ν]
@test kappa(k, x) ≈ matern(x, ν)
@test kappa(k, 0.0) == 1.0
@test kappa(MaternKernel(; ν=ν), x) == kappa(k, x)
@test metric(MaternKernel()) == Euclidean()
@test metric(MaternKernel(; ν=2.0)) == Euclidean()
@test repr(k) == "Matern Kernel (ν = $(ν), metric = Euclidean(0.0))"
# test_ADs(x->MaternKernel(nu=first(x)),[ν])
@test_broken "All fails (because of logabsgamma for ForwardDiff and ReverseDiff and because of nu for Zygote)"

k2 = MaternKernel(; ν=ν, metric=WeightedEuclidean(ones(3)))
@test metric(k2) isa WeightedEuclidean
@test k2(v1, v2) ≈ k(v1, v2)

# Standardised tests.
TestUtils.test_interface(k, Float64)
test_ADs(() -> MaternKernel(; nu=ν))

test_params(k, ([ν],))
end
@testset "Matern32Kernel" begin
Expand Down