Skip to content

Commit

Permalink
[WIP] Fix some AD issues with various kernels (#154)
Browse files Browse the repository at this point in the history
- Defines kernelmatrix function for NeuralNetworkKernel.
- Defines Zygote adjoints for Mahalanobis distance metric.
- Zygote tests pass for Exponential, FBM, NN and Gabor kernels.

* Zygote passes for Exponential and FBM kernel

* Zygote passes NN kernel

* Zygote passes Gabor kernel

* Address code review

* Fix mutating arrays problem for maha kernel

* Add adjoint for maha distance metric

* Fix zygote adjoint

* Fix adjoint typo

* Fix buggy version of pairwise adjoint

* Fix typo

* Forgot to add adjoint macro

* Add pairwise sqmahalanobis adjoint and test of sqmahalanobis

* Maha kernel tests

* Fix zygote adjoint for mahalanobis

* Fix docs for matern

* Make maha tests more readable

* Address style issues

* Fix bugs in tests and adjoints

* Fix maha tests

* Remove pairwise maha adjoints for now.

* Fix style issues

* Update maha.jl

* Fix style in zygote_adjoints.jl
  • Loading branch information
sharanry authored Sep 8, 2020
1 parent 8c99314 commit 5c24f1c
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/basekernels/maha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Mahalanobis distance-based kernel given by
```math
κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'*inv(P)*(x-y)
κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'* P *(x-y)
```
where the matrix P is the metric.
Expand Down
4 changes: 3 additions & 1 deletion src/basekernels/matern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ The matern kernel is a Mercer kernel given by the formula:
```
κ(x,y) = 2^{1-ν}/Γ(ν)*(√(2ν)‖x-y‖)^ν K_ν(√(2ν)‖x-y‖)
```
For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use [`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`, [`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`.
For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use
[`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`,
[`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`.
"""
struct MaternKernel{Tν<:Real} <: SimpleKernel
ν::Vector{Tν}
Expand Down
28 changes: 28 additions & 0 deletions src/basekernels/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,32 @@ function (κ::NeuralNetworkKernel)(x, y)
return asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y))))
end

function kernelmatrix(::NeuralNetworkKernel, x::ColVecs, y::ColVecs)
validate_inputs(x, y)
X_2 = sum(x.X .* x.X; dims=1)
Y_2 = sum(y.X .* y.X; dims=1)
XY = x.X' * y.X
return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1)))
end

function kernelmatrix(::NeuralNetworkKernel, x::ColVecs)
X_2_1 = sum(x.X .* x.X; dims=1) .+ 1
XX = x.X' * x.X
return asin.(XX ./ sqrt.(X_2_1' * X_2_1))
end

function kernelmatrix(::NeuralNetworkKernel, x::RowVecs, y::RowVecs)
validate_inputs(x, y)
X_2 = sum(x.X .* x.X; dims=2)
Y_2 = sum(y.X .* y.X; dims=2)
XY = x.X * y.X'
return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1)))
end

function kernelmatrix(::NeuralNetworkKernel, x::RowVecs)
X_2_1 = sum(x.X .* x.X; dims=2) .+ 1
XX = x.X * x.X'
return asin.(XX ./ sqrt.(X_2_1' * X_2_1))
end

Base.show(io::IO, κ::NeuralNetworkKernel) = print(io, "Neural Network Kernel")
26 changes: 18 additions & 8 deletions src/zygote_adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,21 @@ end
end

@adjoint function ColVecs(X::AbstractMatrix)
back::NamedTuple) =.X,)
back::AbstractMatrix) = (Δ,)
function back::AbstractVector{<:AbstractVector{<:Real}})
ColVecs_pullback::NamedTuple) =.X,)
ColVecs_pullback::AbstractMatrix) = (Δ,)
function ColVecs_pullback::AbstractVector{<:AbstractVector{<:Real}})
throw(error("In slow method"))
end
return ColVecs(X), back
return ColVecs(X), ColVecs_pullback
end

@adjoint function RowVecs(X::AbstractMatrix)
back::NamedTuple) =.X,)
back::AbstractMatrix) = (Δ,)
function back::AbstractVector{<:AbstractVector{<:Real}})
RowVecs_pullback::NamedTuple) =.X,)
RowVecs_pullback::AbstractMatrix) = (Δ,)
function RowVecs_pullback::AbstractVector{<:AbstractVector{<:Real}})
throw(error("In slow method"))
end
return RowVecs(X), back
return RowVecs(X), RowVecs_pullback
end

@adjoint function Base.map(t::Transform, X::ColVecs)
Expand All @@ -84,3 +84,13 @@ end
@adjoint function Base.map(t::Transform, X::RowVecs)
pullback(_map, t, X)
end

@adjoint function (dist::Distances.SqMahalanobis)(a, b)
function SqMahalanobis_pullback::Real)
B_Bᵀ = dist.qmat + transpose(dist.qmat)
a_b = a - b
δa = (B_Bᵀ * a_b) * Δ
return (qmat = (a_b * a_b') * Δ,), δa, -δa
end
return evaluate(dist, a, b), SqMahalanobis_pullback
end
3 changes: 1 addition & 2 deletions test/basekernels/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
@test metric(GammaExponentialKernel=2.0)) == SqEuclidean()
@test repr(k) == "Gamma Exponential Kernel (γ = $(γ))"
@test KernelFunctions.iskroncompatible(k) == true
test_ADs-> GammaExponentialKernel(gamma=first(γ)), [γ], ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Zygote gradient given γ"
test_ADs-> GammaExponentialKernel(gamma=first(γ)), [γ])
test_params(k, ([γ],))
#Coherence :
@test GammaExponentialKernel=1.0)(v1,v2) SqExponentialKernel()(v1,v2)
Expand Down
5 changes: 2 additions & 3 deletions test/basekernels/fbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
@test kernelmatrix(k, x1*ones(1,1), x2*ones(1,1))[1] k(x1, x2) atol=1e-5

@test repr(k) == "Fractional Brownian Motion Kernel (h = $(h))"
test_ADs(FBMKernel, ADs = [:ReverseDiff])
@test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff and Zygote"

test_ADs(FBMKernel, ADs = [:ReverseDiff, :Zygote])
@test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff"
test_params(k, ([h],))
end
3 changes: 1 addition & 2 deletions test/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
@test k.ell 1.0 atol=1e-5
@test k.p 1.0 atol=1e-5
@test repr(k) == "Gabor Kernel (ell = 1.0, p = 1.0)"
#test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p])#, ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Tests failing for Zygote on differentiating through ell and p"
test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p], ADs = [:Zygote])
# Tests are also failing randomly for ForwardDiff and ReverseDiff but randomly
end
30 changes: 28 additions & 2 deletions test/basekernels/maha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,40 @@
v1 = rand(rng, 3)
v2 = rand(rng, 3)

P = rand(rng, 3, 3)
U = UpperTriangular(rand(rng, 3,3))
P = Matrix(Cholesky(U, 'U', 0))
@assert isposdef(P)
k = MahalanobisKernel(P=P)

@test kappa(k, x) == exp(-x)
@test k(v1, v2) exp(-sqmahalanobis(v1, v2, P))
@test kappa(ExponentialKernel(), x) == kappa(k, x)
@test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))"
# test_ADs(P -> MahalanobisKernel(P=P), P)

M1, M2 = rand(rng,3,2), rand(rng,3,2)
fdm = FiniteDifferences.Central(5, 1);


function FiniteDifferences.to_vec(dist::SqMahalanobis{Float64})
return vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...))
end
a = rand()

function test_mahakernel(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector)
return MahalanobisKernel(P=Array(U'*U))(v1, v2)
end

@test all(FiniteDifferences.j′vp(fdm, test_mahakernel, a, U, v1, v2)[1] .≈
UpperTriangular(Zygote.pullback(test_mahakernel, U, v1, v2)[2](a)[1]))

function test_sqmaha(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector)
return SqMahalanobis(Array(U'*U))(v1, v2)
end

@test all(FiniteDifferences.j′vp(fdm, test_sqmaha, a, U, v1, v2)[1] .≈
UpperTriangular(Zygote.pullback(test_sqmaha, U, v1, v2)[2](a)[1]))

# test_ADs(U -> MahalanobisKernel(P=Array(U' * U)), U, ADs=[:Zygote])
@test_broken "Nothing passes (problem with Mahalanobis distance in Distances)"

test_params(k, (P,))
Expand Down
7 changes: 3 additions & 4 deletions test/basekernels/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@
@test kerneldiagmatrix(k, m1) A4 atol=1e-5

A5 = ones(4,4)
@test_throws AssertionError kernelmatrix!(A5, k, m1, m2, obsdim=3)
@test_throws AssertionError kernelmatrix!(A5, k, m1, obsdim=3)
@test_throws AssertionError kernelmatrix!(A5, k, m1, m2; obsdim=3)
@test_throws AssertionError kernelmatrix!(A5, k, m1; obsdim=3)
@test_throws DimensionMismatch kernelmatrix!(A5, k, ones(4,3), ones(3,4))

@test k([x1], [x2]) k(x1, x2) atol=1e-5
test_ADs(NeuralNetworkKernel, ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Zygote uncompatible with BaseKernel"
test_ADs(NeuralNetworkKernel)
end
30 changes: 20 additions & 10 deletions test/zygote_adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,53 @@
x = rand(rng, 5)
y = rand(rng, 5)
r = rand(rng, 5)
Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0))
@assert isposdef(Q)

gzeucl = gradient(:Zygote, [x,y]) do xy

gzeucl = gradient(:Zygote, [x, y]) do xy
evaluate(Euclidean(), xy[1], xy[2])
end
gzsqeucl = gradient(:Zygote, [x,y]) do xy
gzsqeucl = gradient(:Zygote, [x, y]) do xy
evaluate(SqEuclidean(), xy[1], xy[2])
end
gzdotprod = gradient(:Zygote, [x,y]) do xy
gzdotprod = gradient(:Zygote, [x, y]) do xy
evaluate(KernelFunctions.DotProduct(), xy[1], xy[2])
end
gzdelta = gradient(:Zygote, [x,y]) do xy
gzdelta = gradient(:Zygote, [x, y]) do xy
evaluate(KernelFunctions.Delta(), xy[1], xy[2])
end
gzsinus = gradient(:Zygote, [x,y]) do xy
gzsinus = gradient(:Zygote, [x, y]) do xy
evaluate(KernelFunctions.Sinus(r), xy[1], xy[2])
end
gzsqmaha = gradient(:Zygote, [Q, x, y]) do xy
evaluate(SqMahalanobis(xy[1]), xy[2], xy[3])
end

gfeucl = gradient(:FiniteDiff, [x,y]) do xy
gfeucl = gradient(:FiniteDiff, [x, y]) do xy
evaluate(Euclidean(), xy[1], xy[2])
end
gfsqeucl = gradient(:FiniteDiff, [x,y]) do xy
gfsqeucl = gradient(:FiniteDiff, [x, y]) do xy
evaluate(SqEuclidean(), xy[1], xy[2])
end
gfdotprod = gradient(:FiniteDiff, [x,y]) do xy
gfdotprod = gradient(:FiniteDiff, [x, y]) do xy
evaluate(KernelFunctions.DotProduct(), xy[1], xy[2])
end
gfdelta = gradient(:FiniteDiff, [x,y]) do xy
gfdelta = gradient(:FiniteDiff, [x, y]) do xy
evaluate(KernelFunctions.Delta(), xy[1], xy[2])
end
gfsinus = gradient(:FiniteDiff, [x,y]) do xy
gfsinus = gradient(:FiniteDiff, [x, y]) do xy
evaluate(KernelFunctions.Sinus(r), xy[1], xy[2])
end
gfsqmaha = gradient(:FiniteDiff, [Q, x, y]) do xy
evaluate(SqMahalanobis(xy[1]), xy[2], xy[3])
end


@test all(gzeucl .≈ gfeucl)
@test all(gzsqeucl .≈ gfsqeucl)
@test all(gzdotprod .≈ gfdotprod)
@test all(gzdelta .≈ gfdelta)
@test all(gzsinus .≈ gfsinus)
@test all(gzsqmaha .≈ gfsqmaha)
end

0 comments on commit 5c24f1c

Please sign in to comment.