From f9d5ba2a2d63cdf4ff6372e79c5e5e5dfb002d4d Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Fri, 14 Aug 2020 02:21:06 +0530 Subject: [PATCH 1/4] Replace Zygote adjoints with ChainRules' rrules --- Project.toml | 2 + src/KernelFunctions.jl | 8 +++- src/chainrules.jl | 90 ++++++++++++++++++++++++++++++++++++++++++ src/zygote_adjoints.jl | 86 ---------------------------------------- test/Project.toml | 4 ++ test/runtests.jl | 1 + test/utils_AD.jl | 1 + 7 files changed, 104 insertions(+), 88 deletions(-) create mode 100644 src/chainrules.jl delete mode 100644 src/zygote_adjoints.jl diff --git a/Project.toml b/Project.toml index 6f622134d..805cde43e 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" version = "0.6.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -14,6 +15,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] +ChainRulesCore = "0.9" Compat = "2.2, 3" Distances = "0.9" Requires = "1.0.1" diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index d1906428b..240b9310f 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -34,15 +34,19 @@ export spectral_mixture_kernel, spectral_mixture_product_kernel export MOInput export IndependentMOKernel +export rrule + using Compat using Requires using Distances, LinearAlgebra using SpecialFunctions: loggamma, besselk, polygamma -using ZygoteRules: @adjoint, pullback +# using ZygoteRules: @adjoint, pullback +using ChainRulesCore using StatsFuns: logtwo using InteractiveUtils: subtypes using StatsBase + """ Abstract type defining a slice-wise transformation on an input matrix """ @@ -74,7 +78,7 @@ include("generic.jl") include("mokernels/moinput.jl") include("mokernels/independent.jl") -include("zygote_adjoints.jl") +include("chainrules.jl") function __init__() @require Kronecker="2c470bb0-bcc8-11e8-3dad-c9649493f05e" include("matrix/kernelkroneckermat.jl") diff --git a/src/chainrules.jl b/src/chainrules.jl new file mode 100644 index 000000000..a02449d2b --- /dev/null +++ b/src/chainrules.jl @@ -0,0 +1,90 @@ +## rules for Delta +function ChainRulesCore.rrule(::typeof(evaluate), s::Delta, x::AbstractVector, y::AbstractVector) + evaluate(s, x, y), Δ -> begin + (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist()) + end +end + +function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2) + D = Distances.pairwise(d, X, Y; dims = dims) + if dims == 1 + return D, Δ -> (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist()) + else + return D, Δ -> (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist()) + end +end + +function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2) + D = Distances.pairwise(d, X; dims = dims) + if dims == 1 + return D, Δ -> (NO_FIELDS, DoesNotExist(), DoesNotExist()) + else + return D, Δ -> (NO_FIELDS, DoesNotExist(), DoesNotExist()) + end +end + +## rules for DotProduct +function ChainRulesCore.rrule(::typeof(evaluate), s::DotProduct, x::AbstractVector, y::AbstractVector) + dot(x, y), Δ -> begin + (NO_FIELDS, nothing, Δ .* y, Δ .* x) + end +end + +function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2) + D = Distances.pairwise(d, X, Y; dims = dims) + if dims == 1 + return D, Δ -> (NO_FIELDS, nothing, Δ * Y, (X' * Δ)') + else + return D, Δ -> (NO_FIELDS, nothing, (Δ * Y')', X * Δ) + end +end + +function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2) + D = Distances.pairwise(d, X; dims = dims) + if dims == 1 + return D, Δ -> (NO_FIELDS, nothing, 2 * Δ * X) + else + return D, Δ -> (NO_FIELDS, nothing, 2 * X * Δ) + end +end + +## rules for Sinus +function ChainRulesCore.rrule(::typeof(evaluate), s::Sinus, x::AbstractVector, y::AbstractVector) + d = (x - y) + sind = sinpi.(d) + val = sum(abs2, sind ./ s.r) + gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2) + val, Δ -> begin + (NO_FIELDS, (r = -2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, - Δ * gradx) + end +end + + +# rules for ColVecs and RowVecs +function ChainRulesCore.rrule(::typeof(ColVecs), X::AbstractMatrix) + back(Δ::NamedTuple) = (NO_FIELDS, Δ.X,) + back(Δ::AbstractMatrix) = (NO_FIELDS, Δ,) + function back(Δ::AbstractVector{<:AbstractVector{<:Real}}) + throw(error("In slow method")) + end + return ColVecs(X), back +end + +function ChainRulesCore.rrule(::typeof(RowVecs), X::AbstractMatrix) + back(Δ::NamedTuple) = (NO_FIELDS, Δ.X,) + back(Δ::AbstractMatrix) = (NO_FIELDS, Δ,) + function back(Δ::AbstractVector{<:AbstractVector{<:Real}}) + throw(error("In slow method")) + end + return RowVecs(X), back +end + + +# rules for transforms +function ChainRulesCore.rrule(::typeof(Base.map), t::Transform, X::ColVecs) + ChainRulesCore.rrule(_map, t, X) +end + +function ChainRulesCore.rrule(::typeof(Base.map), t::Transform, X::RowVecs) + ChainRulesCore.rrule(_map, t, X) +end diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl deleted file mode 100644 index f51466fb6..000000000 --- a/src/zygote_adjoints.jl +++ /dev/null @@ -1,86 +0,0 @@ -## Adjoints Delta -@adjoint function evaluate(s::Delta, x::AbstractVector, y::AbstractVector) - evaluate(s, x, y), Δ -> begin - (nothing, nothing, nothing) - end -end - -@adjoint function Distances.pairwise(d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X, Y; dims = dims) - if dims == 1 - return D, Δ -> (nothing, nothing, nothing) - else - return D, Δ -> (nothing, nothing, nothing) - end -end - -@adjoint function Distances.pairwise(d::Delta, X::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X; dims = dims) - if dims == 1 - return D, Δ -> (nothing, nothing) - else - return D, Δ -> (nothing, nothing) - end -end - -## Adjoints DotProduct -@adjoint function evaluate(s::DotProduct, x::AbstractVector, y::AbstractVector) - dot(x, y), Δ -> begin - (nothing, Δ .* y, Δ .* x) - end -end - -@adjoint function Distances.pairwise(d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X, Y; dims = dims) - if dims == 1 - return D, Δ -> (nothing, Δ * Y, (X' * Δ)') - else - return D, Δ -> (nothing, (Δ * Y')', X * Δ) - end -end - -@adjoint function Distances.pairwise(d::DotProduct, X::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X; dims = dims) - if dims == 1 - return D, Δ -> (nothing, 2 * Δ * X) - else - return D, Δ -> (nothing, 2 * X * Δ) - end -end - -## Adjoints Sinus -@adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector) - d = (x - y) - sind = sinpi.(d) - val = sum(abs2, sind ./ s.r) - gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2) - val, Δ -> begin - ((r = -2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, - Δ * gradx) - end -end - -@adjoint function ColVecs(X::AbstractMatrix) - back(Δ::NamedTuple) = (Δ.X,) - back(Δ::AbstractMatrix) = (Δ,) - function back(Δ::AbstractVector{<:AbstractVector{<:Real}}) - throw(error("In slow method")) - end - return ColVecs(X), back -end - -@adjoint function RowVecs(X::AbstractMatrix) - back(Δ::NamedTuple) = (Δ.X,) - back(Δ::AbstractMatrix) = (Δ,) - function back(Δ::AbstractVector{<:AbstractVector{<:Real}}) - throw(error("In slow method")) - end - return RowVecs(X), back -end - -@adjoint function Base.map(t::Transform, X::ColVecs) - pullback(_map, t, X) -end - -@adjoint function Base.map(t::Transform, X::RowVecs) - pullback(_map, t, X) -end diff --git a/test/Project.toml b/test/Project.toml index fdff36ab0..e3ed24758 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,6 @@ [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" @@ -13,6 +15,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +ChainRulesCore = "0.9" +ChainRulesTestUtils = "0.5" Distances = "0.9" FiniteDifferences = "0.10.8" Flux = "0.10, 0.11" diff --git a/test/runtests.jl b/test/runtests.jl index b33115fc0..04c88d6b4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,7 @@ using Random using SpecialFunctions using Test using Flux +using ChainRulesTestUtils import Zygote, ForwardDiff, ReverseDiff, FiniteDifferences using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs diff --git a/test/utils_AD.jl b/test/utils_AD.jl index 1354485f9..17f86c963 100644 --- a/test/utils_AD.jl +++ b/test/utils_AD.jl @@ -4,6 +4,7 @@ const FDM = FiniteDifferences.central_fdm(5, 1) gradient(f, s::Symbol, args) = gradient(f, Val(s), args) function gradient(f, ::Val{:Zygote}, args) + display(args) g = first(Zygote.gradient(f, args)) if isnothing(g) if args isa AbstractArray{<:Real} From 3567c079f73850a021c6cd2025fd95b858a040f5 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Fri, 14 Aug 2020 02:25:14 +0530 Subject: [PATCH 2/4] Small clean up --- src/KernelFunctions.jl | 3 --- test/utils_AD.jl | 1 - 2 files changed, 4 deletions(-) diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 240b9310f..f5c338dab 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -34,13 +34,10 @@ export spectral_mixture_kernel, spectral_mixture_product_kernel export MOInput export IndependentMOKernel -export rrule - using Compat using Requires using Distances, LinearAlgebra using SpecialFunctions: loggamma, besselk, polygamma -# using ZygoteRules: @adjoint, pullback using ChainRulesCore using StatsFuns: logtwo using InteractiveUtils: subtypes diff --git a/test/utils_AD.jl b/test/utils_AD.jl index 17f86c963..1354485f9 100644 --- a/test/utils_AD.jl +++ b/test/utils_AD.jl @@ -4,7 +4,6 @@ const FDM = FiniteDifferences.central_fdm(5, 1) gradient(f, s::Symbol, args) = gradient(f, Val(s), args) function gradient(f, ::Val{:Zygote}, args) - display(args) g = first(Zygote.gradient(f, args)) if isnothing(g) if args isa AbstractArray{<:Real} From bb642c07655e2d4ef5f8e66919a42e4a6809fd08 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Fri, 14 Aug 2020 15:01:50 +0530 Subject: [PATCH 3/4] Address code review --- src/KernelFunctions.jl | 1 + src/chainrules.jl | 106 ++++++++++++++++++++--------------------- 2 files changed, 54 insertions(+), 53 deletions(-) diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index f5c338dab..d74e441b9 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -38,6 +38,7 @@ using Compat using Requires using Distances, LinearAlgebra using SpecialFunctions: loggamma, besselk, polygamma +using ZygoteRules: @adjoint, pullback using ChainRulesCore using StatsFuns: logtwo using InteractiveUtils: subtypes diff --git a/src/chainrules.jl b/src/chainrules.jl index a02449d2b..53329b193 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,90 +1,90 @@ ## rules for Delta function ChainRulesCore.rrule(::typeof(evaluate), s::Delta, x::AbstractVector, y::AbstractVector) - evaluate(s, x, y), Δ -> begin - (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist()) - end + function back(Δ) + return (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist()) + end + evaluate(s, x, y), back end function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X, Y; dims = dims) - if dims == 1 - return D, Δ -> (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist()) - else - return D, Δ -> (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist()) - end + D = Distances.pairwise(d, X, Y; dims = dims) + function back(Δ) + return (NO_FIELDS, DoesNotExist(), DoesNotExist(), DoesNotExist()) + end + return D, back end function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X; dims = dims) - if dims == 1 - return D, Δ -> (NO_FIELDS, DoesNotExist(), DoesNotExist()) - else - return D, Δ -> (NO_FIELDS, DoesNotExist(), DoesNotExist()) - end + D = Distances.pairwise(d, X; dims = dims) + back(Δ) = (NO_FIELDS, DoesNotExist(), DoesNotExist()) + return D, back end ## rules for DotProduct function ChainRulesCore.rrule(::typeof(evaluate), s::DotProduct, x::AbstractVector, y::AbstractVector) - dot(x, y), Δ -> begin - (NO_FIELDS, nothing, Δ .* y, Δ .* x) - end + back(Δ) = (NO_FIELDS, nothing, @thunk(Δ .* y), @thunk(Δ .* x)) + return dot(x, y), back end function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X, Y; dims = dims) - if dims == 1 - return D, Δ -> (NO_FIELDS, nothing, Δ * Y, (X' * Δ)') - else - return D, Δ -> (NO_FIELDS, nothing, (Δ * Y')', X * Δ) - end + D = Distances.pairwise(d, X, Y; dims = dims) + + function back(Δ) + if dims == 1 + return (NO_FIELDS, nothing, @thunk(Δ * Y), @thunk((X' * Δ)')) + else + return (NO_FIELDS, nothing, @thunk((Δ * Y')'), @thunk(X * Δ)) + end + end + return D, back end function ChainRulesCore.rrule(::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2) - D = Distances.pairwise(d, X; dims = dims) - if dims == 1 - return D, Δ -> (NO_FIELDS, nothing, 2 * Δ * X) - else - return D, Δ -> (NO_FIELDS, nothing, 2 * X * Δ) - end + D = Distances.pairwise(d, X; dims = dims) + + function back(Δ) + if dims == 1 + return (NO_FIELDS, nothing, @thunk(2 * Δ * X)) + else + return (NO_FIELDS, nothing, @thunk(2 * X * Δ)) + end + end + return D, back end ## rules for Sinus function ChainRulesCore.rrule(::typeof(evaluate), s::Sinus, x::AbstractVector, y::AbstractVector) - d = (x - y) - sind = sinpi.(d) - val = sum(abs2, sind ./ s.r) - gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2) - val, Δ -> begin - (NO_FIELDS, (r = -2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, - Δ * gradx) - end + d = @thunk((x - y)) + sind = @thunk(sinpi.(d)) + val = @thunk(sum(abs2, sind ./ s.r)) + gradx = @thunk(2π .* cospi.(d) .* sind ./ (s.r .^ 2)) + function back(Δ) + return (NO_FIELDS, (r = @thunk(-2Δ .* abs2.(sind) ./ s.r),), @thunk(Δ * gradx), @thunk(- Δ * gradx)) + end + val, back end # rules for ColVecs and RowVecs +vecs_pullback(Δ::NamedTuple) = (NO_FIELDS, Δ.X,) +vecs_pullback(Δ::AbstractMatrix) = (NO_FIELDS, Δ,) +function vecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) + throw(error("In slow method")) +end function ChainRulesCore.rrule(::typeof(ColVecs), X::AbstractMatrix) - back(Δ::NamedTuple) = (NO_FIELDS, Δ.X,) - back(Δ::AbstractMatrix) = (NO_FIELDS, Δ,) - function back(Δ::AbstractVector{<:AbstractVector{<:Real}}) - throw(error("In slow method")) - end - return ColVecs(X), back + return ColVecs(X), vecs_pullback end function ChainRulesCore.rrule(::typeof(RowVecs), X::AbstractMatrix) - back(Δ::NamedTuple) = (NO_FIELDS, Δ.X,) - back(Δ::AbstractMatrix) = (NO_FIELDS, Δ,) - function back(Δ::AbstractVector{<:AbstractVector{<:Real}}) - throw(error("In slow method")) - end - return RowVecs(X), back + return RowVecs(X), vecs_pullback end # rules for transforms -function ChainRulesCore.rrule(::typeof(Base.map), t::Transform, X::ColVecs) - ChainRulesCore.rrule(_map, t, X) +@adjoint function Base.map(t::Transform, X::ColVecs) + return pullback(_map, t, X) end -function ChainRulesCore.rrule(::typeof(Base.map), t::Transform, X::RowVecs) - ChainRulesCore.rrule(_map, t, X) +@adjoint function Base.map(t::Transform, X::RowVecs) + return pullback(_map, t, X) end From 1d054b0c995d13ac9c729d51779ee9a4d81381be Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Fri, 14 Aug 2020 15:27:19 +0530 Subject: [PATCH 4/4] Fix rrules for ColVecs and RowVecs --- src/chainrules.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 53329b193..166996918 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -71,11 +71,12 @@ vecs_pullback(Δ::AbstractMatrix) = (NO_FIELDS, Δ,) function vecs_pullback(Δ::AbstractVector{<:AbstractVector{<:Real}}) throw(error("In slow method")) end -function ChainRulesCore.rrule(::typeof(ColVecs), X::AbstractMatrix) + +function ChainRulesCore.rrule(::Type{ColVecs}, X::AbstractMatrix) return ColVecs(X), vecs_pullback end -function ChainRulesCore.rrule(::typeof(RowVecs), X::AbstractMatrix) +function ChainRulesCore.rrule(::Type{RowVecs}, X::AbstractMatrix) return RowVecs(X), vecs_pullback end