From 37c8788a0ce8f689824bf008b5bef620ef09d570 Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 13 Jan 2022 10:04:14 +0200 Subject: [PATCH 01/14] test_ADs of MaternKernel again --- test/basekernels/matern.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index dedbd3847..3ec75b271 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -14,8 +14,6 @@ @test metric(MaternKernel()) == Euclidean() @test metric(MaternKernel(; ν=2.0)) == Euclidean() @test repr(k) == "Matern Kernel (ν = $(ν), metric = Euclidean(0.0))" - # test_ADs(x->MaternKernel(nu=first(x)),[ν]) - @test_broken "All fails (because of logabsgamma for ForwardDiff and ReverseDiff and because of nu for Zygote)" k2 = MaternKernel(; ν=ν, metric=WeightedEuclidean(ones(3))) @test metric(k2) isa WeightedEuclidean @@ -23,6 +21,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) + test_ADs(args -> MaternKernel(nu=only(args)), [ν]) test_params(k, ([ν],)) end @testset "Matern32Kernel" begin From f464abfa2bbd89c53d90f50349949dee834d4a12 Mon Sep 17 00:00:00 2001 From: st-- Date: Thu, 13 Jan 2022 09:28:57 +0100 Subject: [PATCH 02/14] Update test/basekernels/matern.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/basekernels/matern.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index 3ec75b271..3df815311 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -21,7 +21,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) - test_ADs(args -> MaternKernel(nu=only(args)), [ν]) + test_ADs(args -> MaternKernel(; nu=only(args)), [ν]) test_params(k, ([ν],)) end @testset "Matern32Kernel" begin From cd81d8d7a5941ec8445b4e257b6573c79c28c27d Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 13 Jan 2022 18:33:20 +0200 Subject: [PATCH 03/14] remove test for differentiation through nu of MaternKernel --- test/basekernels/matern.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index 3ec75b271..00a6f20a4 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -4,13 +4,12 @@ v1 = rand(rng, 3) v2 = rand(rng, 3) @testset "MaternKernel" begin - ν = 2.0 + ν = 2.1 k = MaternKernel(; ν=ν) matern(x, ν) = 2^(1 - ν) / gamma(ν) * (sqrt(2ν) * x)^ν * besselk(ν, sqrt(2ν) * x) @test MaternKernel(; nu=ν).ν == [ν] @test kappa(k, x) ≈ matern(x, ν) @test kappa(k, 0.0) == 1.0 - @test kappa(MaternKernel(; ν=ν), x) == kappa(k, x) @test metric(MaternKernel()) == Euclidean() @test metric(MaternKernel(; ν=2.0)) == Euclidean() @test repr(k) == "Matern Kernel (ν = $(ν), metric = Euclidean(0.0))" @@ -21,7 +20,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) - test_ADs(args -> MaternKernel(nu=only(args)), [ν]) + test_ADs(() -> MaternKernel(nu=ν)) test_params(k, ([ν],)) end @testset "Matern32Kernel" begin From 364a6cd3c1012c62ab44897928fbb2a114529a60 Mon Sep 17 00:00:00 2001 From: ST John Date: Thu, 13 Jan 2022 23:36:19 +0200 Subject: [PATCH 04/14] fix --- src/basekernels/matern.jl | 4 ++++ src/matrix/kernelkroneckermat.jl | 2 +- test/basekernels/matern.jl | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index a1ae4dfcc..b54140394 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -17,6 +17,10 @@ By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``. A Gaussian process with a Matérn kernel is ``\\lceil \\nu \\rceil - 1``-times differentiable in the mean-square sense. +!!! note + + Differentiation with respect to the ν parameter is not currently supported. + See also: [`Matern12Kernel`](@ref), [`Matern32Kernel`](@ref), [`Matern52Kernel`](@ref) """ struct MaternKernel{Tν<:Real,M} <: SimpleKernel diff --git a/src/matrix/kernelkroneckermat.jl b/src/matrix/kernelkroneckermat.jl index d83463efa..113d0f53d 100644 --- a/src/matrix/kernelkroneckermat.jl +++ b/src/matrix/kernelkroneckermat.jl @@ -14,7 +14,7 @@ where `D` is given by `dims`. !!! warning - Require `Kronecker.jl` and for `iskroncompatible(κ)` to return `true`. + Requires `Kronecker.jl` and for `iskroncompatible(κ)` to return `true`. """ function kernelkronmat(κ::Kernel, X::AbstractVector{<:Real}, dims::Int) checkkroncompatible(κ) diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index 00a6f20a4..e9807a745 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -20,7 +20,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) - test_ADs(() -> MaternKernel(nu=ν)) + test_ADs(() -> MaternKernel(; nu=ν)) test_params(k, ([ν],)) end @testset "Matern32Kernel" begin From 603ddd0a5166de916abb7bf2521c2afd7aff67dc Mon Sep 17 00:00:00 2001 From: st-- Date: Fri, 14 Jan 2022 11:18:21 +0100 Subject: [PATCH 05/14] Update src/basekernels/matern.jl Co-authored-by: David Widmann --- src/basekernels/matern.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index b54140394..47cd4b404 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -19,7 +19,7 @@ differentiable in the mean-square sense. !!! note - Differentiation with respect to the ν parameter is not currently supported. + Differentiation with respect to the order ν is not currently supported. See also: [`Matern12Kernel`](@ref), [`Matern32Kernel`](@ref), [`Matern52Kernel`](@ref) """ From 30385acdee781a48d5960fe190b87c2e6c43aad0 Mon Sep 17 00:00:00 2001 From: ST John Date: Wed, 13 Apr 2022 09:39:11 +0300 Subject: [PATCH 06/14] separate out Zygote test --- test/basekernels/matern.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index af9f3774b..9a0267cc7 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -20,7 +20,13 @@ # Standardised tests. TestUtils.test_interface(k, Float64) - test_ADs(() -> MaternKernel(; nu=ν)) + test_ADs(() -> MaternKernel(; nu=ν); ADs=[:ForwardDiff, :ReverseDiff]) + try + test_ADs(() -> MaternKernel(; nu=ν); ADs=[:Zygote]) + catch + @test_broken "MaternKernel <-> Zygote AD test is broken" + end + test_params(k, ([ν],)) end @testset "Matern32Kernel" begin From 77bac8f409631614e5c72a28434a45ac9b42bc8e Mon Sep 17 00:00:00 2001 From: ST John Date: Wed, 13 Apr 2022 12:04:32 +0300 Subject: [PATCH 07/14] revert Zygote test restriction attempt & fix Zygote AD --- src/basekernels/matern.jl | 5 +++-- test/basekernels/matern.jl | 7 +------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index 47cd4b404..5039868ab 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -37,8 +37,9 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, @functor MaternKernel -@inline function kappa(κ::MaternKernel, d::Real) - result = _matern(only(κ.ν), d) +@inline function kappa(k::MaternKernel, d::Real) + nu = ChainRulesCore.@ignore_derivatives only(k.ν) # work-around for Zygote AD + result = _matern(nu, d) return ifelse(iszero(d), one(result), result) end diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index 9a0267cc7..025cb141b 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -20,12 +20,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) - test_ADs(() -> MaternKernel(; nu=ν); ADs=[:ForwardDiff, :ReverseDiff]) - try - test_ADs(() -> MaternKernel(; nu=ν); ADs=[:Zygote]) - catch - @test_broken "MaternKernel <-> Zygote AD test is broken" - end + test_ADs(() -> MaternKernel(; nu=ν)) test_params(k, ([ν],)) end From a08d44eac21afd32f226a926881e50fec10f4eb6 Mon Sep 17 00:00:00 2001 From: st-- Date: Wed, 13 Apr 2022 14:48:53 +0300 Subject: [PATCH 08/14] Update src/basekernels/matern.jl Co-authored-by: willtebbutt --- src/basekernels/matern.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index 5039868ab..1349668de 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -37,8 +37,19 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, @functor MaternKernel +# Work-around for Zygote -- `NotImplemented` doesn't appear to play nicely with whatever +# rule currently exists for `only`. +_get_ν(k::MaternKernel) = only(k.ν) +function ChainRulesCore.rrule(::typeof(_get_ν), k::T) where {T<:MaternKernel} + function _get_ν_pullback(Δ) + dν = ChainRulesCore.@not_implemented("Derivatives w.r.t. ν are not implemented.") + return Tangent{T}(ν=dν, metric=NoTangent()) + end + return _get_ν(k), _get_ν_pullback +end + @inline function kappa(k::MaternKernel, d::Real) - nu = ChainRulesCore.@ignore_derivatives only(k.ν) # work-around for Zygote AD + nu = _get_ν(k) result = _matern(nu, d) return ifelse(iszero(d), one(result), result) end From 27f9ba445225a10140e23194e7e812ee448c7b42 Mon Sep 17 00:00:00 2001 From: st-- Date: Wed, 13 Apr 2022 14:49:18 +0300 Subject: [PATCH 09/14] Update src/basekernels/matern.jl --- src/basekernels/matern.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index 1349668de..93cc8521e 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -49,8 +49,7 @@ function ChainRulesCore.rrule(::typeof(_get_ν), k::T) where {T<:MaternKernel} end @inline function kappa(k::MaternKernel, d::Real) - nu = _get_ν(k) - result = _matern(nu, d) + result = _matern(_get_ν(k), d) return ifelse(iszero(d), one(result), result) end From ae7e646d1648aa6e103817cec65cf53c83a1935a Mon Sep 17 00:00:00 2001 From: st-- Date: Wed, 13 Apr 2022 15:01:16 +0300 Subject: [PATCH 10/14] Update src/basekernels/matern.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/basekernels/matern.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index 93cc8521e..e019d5983 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -43,7 +43,7 @@ _get_ν(k::MaternKernel) = only(k.ν) function ChainRulesCore.rrule(::typeof(_get_ν), k::T) where {T<:MaternKernel} function _get_ν_pullback(Δ) dν = ChainRulesCore.@not_implemented("Derivatives w.r.t. ν are not implemented.") - return Tangent{T}(ν=dν, metric=NoTangent()) + return Tangent{T}(; ν=dν, metric=NoTangent()) end return _get_ν(k), _get_ν_pullback end From bc0bcf5dc7e24f8f635dd36f21afa15e3b8260df Mon Sep 17 00:00:00 2001 From: ST John Date: Wed, 13 Apr 2022 15:05:02 +0300 Subject: [PATCH 11/14] revert to simpler workaround --- src/basekernels/matern.jl | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index 93cc8521e..c93eeff74 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -37,16 +37,7 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, @functor MaternKernel -# Work-around for Zygote -- `NotImplemented` doesn't appear to play nicely with whatever -# rule currently exists for `only`. -_get_ν(k::MaternKernel) = only(k.ν) -function ChainRulesCore.rrule(::typeof(_get_ν), k::T) where {T<:MaternKernel} - function _get_ν_pullback(Δ) - dν = ChainRulesCore.@not_implemented("Derivatives w.r.t. ν are not implemented.") - return Tangent{T}(ν=dν, metric=NoTangent()) - end - return _get_ν(k), _get_ν_pullback -end +@inline _get_ν(k::MaternKernel) = ChainRulesCore.@ignore_derivatives only(k.ν) # work-around for Zygote AD @inline function kappa(k::MaternKernel, d::Real) result = _matern(_get_ν(k), d) From 4fc2b570c4e3aef0d650e4742fe0ec964873379c Mon Sep 17 00:00:00 2001 From: ST John Date: Wed, 13 Apr 2022 15:10:39 +0300 Subject: [PATCH 12/14] revert accidental commit --- src/zygoterules.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/zygoterules.jl b/src/zygoterules.jl index 679757056..e405a4946 100644 --- a/src/zygoterules.jl +++ b/src/zygoterules.jl @@ -1,13 +1,13 @@ -#ZygoteRules.@adjoint function Base.map(t::Transform, X::ColVecs) -# return ZygoteRules.pullback(_map, t, X) -#end -# -#ZygoteRules.@adjoint function Base.map(t::Transform, X::RowVecs) -# return ZygoteRules.pullback(_map, t, X) -#end -# -#function ZygoteRules._pullback( -# cx::AContext, ::typeof(literal_getproperty), x::ColVecs, ::Val{f} -#) where {f} -# return ZygoteRules._pullback(cx, literal_getfield, x, Val{f}()) -#end +ZygoteRules.@adjoint function Base.map(t::Transform, X::ColVecs) + return ZygoteRules.pullback(_map, t, X) +end + +ZygoteRules.@adjoint function Base.map(t::Transform, X::RowVecs) + return ZygoteRules.pullback(_map, t, X) +end + +function ZygoteRules._pullback( + cx::AContext, ::typeof(literal_getproperty), x::ColVecs, ::Val{f} +) where {f} + return ZygoteRules._pullback(cx, literal_getfield, x, Val{f}()) +end From b7acfa29917dd686e93e060f147fe2a30683ef6e Mon Sep 17 00:00:00 2001 From: ST John Date: Wed, 13 Apr 2022 15:12:53 +0300 Subject: [PATCH 13/14] use non_differentiable instead --- src/basekernels/matern.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index c93eeff74..dcc35f6be 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -37,7 +37,8 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, @functor MaternKernel -@inline _get_ν(k::MaternKernel) = ChainRulesCore.@ignore_derivatives only(k.ν) # work-around for Zygote AD +@inline _get_ν(k::MaternKernel) = only(k.ν) +ChainRulesCore.@non_differentiable _get_ν(k) # work-around; should be "NotImplemented" rather than NoTangent @inline function kappa(k::MaternKernel, d::Real) result = _matern(_get_ν(k), d) From 8aee1a477010ee47929f40b721fcba20f7438b46 Mon Sep 17 00:00:00 2001 From: ST John Date: Wed, 13 Apr 2022 15:17:06 +0300 Subject: [PATCH 14/14] patch bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f3d633beb..47ee0d086 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.36" +version = "0.10.37" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"