diff --git a/Project.toml b/Project.toml index 6f622134d..805cde43e 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" version = "0.6.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -14,6 +15,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] +ChainRulesCore = "0.9" Compat = "2.2, 3" Distances = "0.9" Requires = "1.0.1" diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index d1906428b..d74e441b9 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -39,10 +39,12 @@ using Requires using Distances, LinearAlgebra using SpecialFunctions: loggamma, besselk, polygamma using ZygoteRules: @adjoint, pullback +using ChainRulesCore using StatsFuns: logtwo using InteractiveUtils: subtypes using StatsBase + """ Abstract type defining a slice-wise transformation on an input matrix """ @@ -74,7 +76,7 @@ include("generic.jl") include("mokernels/moinput.jl") include("mokernels/independent.jl") -include("zygote_adjoints.jl") +include("chainrules.jl") function __init__() @require Kronecker="2c470bb0-bcc8-11e8-3dad-c9649493f05e" include("matrix/kernelkroneckermat.jl") diff --git a/src/chainrules.jl b/src/chainrules.jl new file mode 100644 index 000000000..166996918 --- /dev/null +++ b/src/chainrules.jl @@ -0,0 +1,91 @@ +## rules for Delta +function ChainRulesCore.rrule(::typeof(evaluate), s::Delta, x::AbstractVector, y::AbstractVector) + function back(Δ) + return (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist()) + end + evaluate(s, x, y), back +end + +function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2) + D = Distances.pairwise(d, X, Y; dims = dims) + function back(Δ) + return (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist()) + end + return D, back +end + +function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2) + D = Distances.pairwise(d, X; dims = dims) + back(Δ) = (NO_FIELDS, DoesNotExist(), DoesNotExist()) + return D, back +end + +## rules for DotProduct +function ChainRulesCore.rrule(::typeof(evaluate), s::DotProduct, x::AbstractVector, y::AbstractVector) + back(Δ) = (NO_FIELDS, nothing, @thunk(Δ .* y), @thunk(Δ .* x)) + return dot(x, y), back +end + +function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2) + D = Distances.pairwise(d, X, Y; dims = dims) + + function back(Δ) + if dims == 1 + return (NO_FIELDS, nothing, @thunk(Δ * Y), @thunk((X' * Δ)')) + else + return (NO_FIELDS, nothing, @thunk((Δ * Y')'), @thunk(X * Δ)) + end + end + return D, back +end + +function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2) + D = Distances.pairwise(d, X; dims = dims) + + function back(Δ) + if dims == 1 + return (NO_FIELDS, nothing, @thunk(2 * Δ * X)) + else + return (NO_FIELDS, nothing, @thunk(2 * X * Δ)) + end + end + return D, back +end + +## rules for Sinus +function ChainRulesCore.rrule(::typeof(evaluate), s::Sinus, x::AbstractVector, y::AbstractVector) + d = @thunk((x - y)) + sind = @thunk(sinpi.(d)) + val = @thunk(sum(abs2, sind ./ s.r)) + gradx = @thunk(2π .* cospi.(d) .* sind ./ (s.r .^ 2)) + function back(Δ) + return (NO_FIELDS, (r = @thunk(-2Δ .* abs2.(sind) ./ s.r),), @thunk(Δ * gradx), @thunk(- Δ * gradx)) + end + val, back +end + + +# rules for ColVecs and RowVecs +vecs_pullback(Δ::NamedTuple) = (NO_FIELDS, Δ.X,) +vecs_pullback(Δ::AbstractMatrix) = (NO_FIELDS, Δ,) +function vecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) + throw(error("In slow method")) +end + +function ChainRulesCore.rrule(::Type{ColVecs}, X::AbstractMatrix) + return ColVecs(X), vecs_pullback +end + +function ChainRulesCore.rrule(::Type{RowVecs}, X::AbstractMatrix) + return RowVecs(X), vecs_pullback +end + + +# rules for transforms +@adjoint function Base.map(t::Transform, X::ColVecs) + return pullback(_map, t, X) +end + +@adjoint function Base.map(t::Transform, X::RowVecs) + return pullback(_map, t, X) +end diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl deleted file mode 100644 index f51466fb6..000000000 --- a/src/zygote_adjoints.jl +++ /dev/null @@ -1,86 +0,0 @@ -## Adjoints Delta -@adjoint function evaluate(s::Delta, x::AbstractVector, y::AbstractVector) - evaluate(s, x, y), Δ -> begin - (nothing, nothing, nothing) - end -end - -@adjoint function Distances.pairwise(d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X, Y; dims = dims) - if dims == 1 - return D, Δ -> (nothing, nothing, nothing) - else - return D, Δ -> (nothing, nothing, nothing) - end -end - -@adjoint function Distances.pairwise(d::Delta, X::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X; dims = dims) - if dims == 1 - return D, Δ -> (nothing, nothing) - else - return D, Δ -> (nothing, nothing) - end -end - -## Adjoints DotProduct -@adjoint function evaluate(s::DotProduct, x::AbstractVector, y::AbstractVector) - dot(x, y), Δ -> begin - (nothing, Δ .* y, Δ .* x) - end -end - -@adjoint function Distances.pairwise(d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X, Y; dims = dims) - if dims == 1 - return D, Δ -> (nothing, Δ * Y, (X' * Δ)') - else - return D, Δ -> (nothing, (Δ * Y')', X * Δ) - end -end - -@adjoint function Distances.pairwise(d::DotProduct, X::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X; dims = dims) - if dims == 1 - return D, Δ -> (nothing, 2 * Δ * X) - else - return D, Δ -> (nothing, 2 * X * Δ) - end -end - -## Adjoints Sinus -@adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector) - d = (x - y) - sind = sinpi.(d) - val = sum(abs2, sind ./ s.r) - gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2) - val, Δ -> begin - ((r = -2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, - Δ * gradx) - end -end - -@adjoint function ColVecs(X::AbstractMatrix) - back(Δ::NamedTuple) = (Δ.X,) - back(Δ::AbstractMatrix) = (Δ,) - function back(Δ::AbstractVector{<:AbstractVector{<:Real}}) - throw(error("In slow method")) - end - return ColVecs(X), back -end - -@adjoint function RowVecs(X::AbstractMatrix) - back(Δ::NamedTuple) = (Δ.X,) - back(Δ::AbstractMatrix) = (Δ,) - function back(Δ::AbstractVector{<:AbstractVector{<:Real}}) - throw(error("In slow method")) - end - return RowVecs(X), back -end - -@adjoint function Base.map(t::Transform, X::ColVecs) - pullback(_map, t, X) -end - -@adjoint function Base.map(t::Transform, X::RowVecs) - pullback(_map, t, X) -end diff --git a/test/Project.toml b/test/Project.toml index fdff36ab0..e3ed24758 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,6 @@ [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" @@ -13,6 +15,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +ChainRulesCore = "0.9" +ChainRulesTestUtils = "0.5" Distances = "0.9" FiniteDifferences = "0.10.8" Flux = "0.10, 0.11" diff --git a/test/runtests.jl b/test/runtests.jl index b33115fc0..04c88d6b4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,7 @@ using Random using SpecialFunctions using Test using Flux +using ChainRulesTestUtils import Zygote, ForwardDiff, ReverseDiff, FiniteDifferences using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs