Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing Distances adjoints to ChainRules syntax #923

Closed
wants to merge 17 commits into from
41 changes: 23 additions & 18 deletions src/lib/distances.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
using .Distances
import .ChainRules: NoTangent, rrule, rrule_via_ad

@adjoint function (::SqEuclidean)(x::AbstractVector, y::AbstractVector)
function rrule(::ZygoteRuleConfig, ::SqEuclidean, x::AbstractVector, y::AbstractVector)
δ = x .- y
function sqeuclidean(Δ::Real)
x̄ = (2 * Δ) .* δ
return x̄, -x̄
return NoTangent(), x̄, -x̄
end
return sum(abs2, δ), sqeuclidean
end

@adjoint function colwise(s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix)
function rrule(::ZygoteRuleConfig, ::typeof(colwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix)
return colwise(s, x, y), function (Δ::AbstractVector)
x̄ = 2 .* Δ' .* (x .- y)
return nothing, x̄, -x̄
return NoTangent(), NoTangent(), x̄, -x̄
end
end

@adjoint function pairwise(s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix; dims::Int=2)
function rrule(::ZygoteRuleConfig, ::typeof(pairwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix; dims::Int=2)
if dims==1
return pairwise(s, x, y; dims=1), ∇pairwise(s, transpose(x), transpose(y), transpose)
else
Expand All @@ -28,10 +29,10 @@ end
function(Δ)
x̄ = 2 .* (x * Diagonal(vec(sum(Δ; dims=2))) .- y * transpose(Δ))
ȳ = 2 .* (y * Diagonal(vec(sum(Δ; dims=1))) .- x * Δ)
return (nothing, f(x̄), f(ȳ))
return NoTangent(), NoTangent(), f(x̄), f(ȳ)
end

@adjoint function pairwise(s::SqEuclidean, x::AbstractMatrix; dims::Int=2)
function rrule(::ZygoteRuleConfig, ::typeof(pairwise), s::SqEuclidean, x::AbstractMatrix; dims::Int=2)
if dims==1
return pairwise(s, x; dims=1), ∇pairwise(s, transpose(x), transpose)
else
Expand All @@ -43,45 +44,49 @@ end
function(Δ)
d1 = Diagonal(vec(sum(Δ; dims=1)))
d2 = Diagonal(vec(sum(Δ; dims=2)))
return (nothing, x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f)
return NoTangent(), NoTangent(), x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f
end

@adjoint function (::Euclidean)(x::AbstractVector, y::AbstractVector)
function rrule(::ZygoteRuleConfig, ::Euclidean, x::AbstractVector, y::AbstractVector)
D = x .- y
δ = sqrt(sum(abs2, D))
function euclidean(Δ::Real)
x̄ = ifelse(iszero(δ), D, (Δ / δ) .* D)
return x̄, -x̄
return NoTangent(), x̄, -x̄
end
return δ, euclidean
end

@adjoint function colwise(s::Euclidean, x::AbstractMatrix, y::AbstractMatrix)
function rrule(::ZygoteRuleConfig, ::typeof(colwise), s::Euclidean, x::AbstractMatrix, y::AbstractMatrix)
d = colwise(s, x, y)
return d, function (Δ::AbstractVector)
x̄ = (Δ ./ max.(d, eps(eltype(d))))' .* (x .- y)
return nothing, x̄, -x̄
return NoTangent(), NoTangent(), x̄, -x̄
end
end

_sqrt_if_positive(d, δ) = d > δ ? sqrt(d) : zero(d)

@adjoint function pairwise(dist::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
function rrule(config::ZygoteRuleConfig, ::typeof(pairwise), dist::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
# Modify the forwards-pass slightly to ensure stability on the reverse.
function _pairwise_euclidean(sqdist::SqEuclidean, X, Y)
D2 = pairwise(sqdist, X, Y; dims=dims)
D2 = pairwise(sqdist, X, Y; dims)
δ = eps(eltype(D2))
return _sqrt_if_positive.(D2, δ)
end
return pullback(_pairwise_euclidean, SqEuclidean(dist.thresh), X, Y)
D, back = rrule_via_ad(config, _pairwise_euclidean, SqEuclidean(dist.thresh), X, Y)
pairwise_Euclidean_rrule = back
return D, pairwise_Euclidean_rrule
end

@adjoint function pairwise(dist::Euclidean, X::AbstractMatrix; dims=2)
function rrule(config::ZygoteRuleConfig, ::typeof(pairwise), dist::Euclidean, X::AbstractMatrix; dims=2)
# Modify the forwards-pass slightly to ensure stability on the reverse.
function _pairwise_euclidean(sqdist::SqEuclidean, X)
D2 = pairwise(sqdist, X; dims=dims)
D2 = pairwise(sqdist, X; dims)
δ = eps(eltype(D2))
return _sqrt_if_positive.(D2, δ)
end
return pullback(_pairwise_euclidean, SqEuclidean(dist.thresh), X)
D, back = rrule_via_ad(config, _pairwise_euclidean, SqEuclidean(dist.thresh), X)
pairwise_Euclidean_rrule = back
return D, pairwise_Euclidean_rrule
end