Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing Distances adjoints to ChainRules syntax #923

Closed
wants to merge 17 commits into from
Closed
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.4"
version = "0.6.7"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
33 changes: 17 additions & 16 deletions src/lib/distances.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
using .Distances
import .ChainRules: NO_FIELDS, rrule

@adjoint function (::SqEuclidean)(x::AbstractVector, y::AbstractVector)
function rrule(::SqEuclidean, x::AbstractVector, y::AbstractVector)
δ = x .- y
function sqeuclidean(Δ::Real)
x̄ = (2 * Δ) .* δ
return x̄, -x̄
return NO_FIELDS, x̄, -x̄
end
return sum(abs2, δ), sqeuclidean
end

@adjoint function colwise(s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix)
function rrule(::typeof(colwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix)
return colwise(s, x, y), function (Δ::AbstractVector)
x̄ = 2 .* Δ' .* (x .- y)
return nothing, x̄, -x̄
return NO_FIELDS, NO_FIELDS, x̄, -x̄
end
end

@adjoint function pairwise(s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix; dims::Int=2)
function rrule(::typeof(pairwise), s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix; dims::Int=2)
if dims==1
return pairwise(s, x, y; dims=1), ∇pairwise(s, transpose(x), transpose(y), transpose)
else
Expand All @@ -28,10 +29,10 @@ end
function(Δ)
x̄ = 2 .* (x * Diagonal(vec(sum(Δ; dims=2))) .- y * transpose(Δ))
ȳ = 2 .* (y * Diagonal(vec(sum(Δ; dims=1))) .- x * Δ)
return (nothing, f(x̄), f(ȳ))
return NO_FIELDS, NO_FIELDS, f(x̄), f(ȳ)
end

@adjoint function pairwise(s::SqEuclidean, x::AbstractMatrix; dims::Int=2)
function rrule(::typeof(pairwise), s::SqEuclidean, x::AbstractMatrix; dims::Int=2)
if dims==1
return pairwise(s, x; dims=1), ∇pairwise(s, transpose(x), transpose)
else
Expand All @@ -43,28 +44,28 @@ end
function(Δ)
d1 = Diagonal(vec(sum(Δ; dims=1)))
d2 = Diagonal(vec(sum(Δ; dims=2)))
return (nothing, x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f)
return NO_FIELDS, NO_FIELDS, x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f
end

@adjoint function (::Euclidean)(x::AbstractVector, y::AbstractVector)
function rrule(::Euclidean, x::AbstractVector, y::AbstractVector)
D = x .- y
δ = sqrt(sum(abs2, D))
function euclidean(Δ::Real)
x̄ = ifelse(iszero(δ), D, (Δ / δ) .* D)
return x̄, -x̄
return NO_FIELDS, x̄, -x̄
end
return δ, euclidean
end

@adjoint function colwise(s::Euclidean, x::AbstractMatrix, y::AbstractMatrix)
function rrule(::typeof(colwise), s::Euclidean, x::AbstractMatrix, y::AbstractMatrix)
d = colwise(s, x, y)
return d, function (Δ::AbstractVector)
x̄ = (Δ ./ max.(d, eps(eltype(d))))' .* (x .- y)
return nothing, x̄, -x̄
return NO_FIELDS, NO_FIELDS, x̄, -x̄
end
end

@adjoint function pairwise(::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
function rrule(::typeof(pairwise), ::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2)

# Modify the forwards-pass slightly to ensure stability on the reverse.
function _pairwise_euclidean(X, Y)
Expand All @@ -74,16 +75,16 @@ end
D, back = pullback(_pairwise_euclidean, X, Y)

return D, function(Δ)
return (nothing, back(Δ)...)
return (NO_FIELDS, NO_FIELDS, back(Δ)...)
end
end

@adjoint function pairwise(::Euclidean, X::AbstractMatrix; dims=2)
function rrule(::typeof(pairwise), ::Euclidean, X::AbstractMatrix; dims=2)
D, back = pullback(X -> pairwise(SqEuclidean(), X; dims = dims), X)
D .= sqrt.(D)
return D, function(Δ)
Δ = Δ ./ (2 .* max.(D, eps(eltype(D))))
Δ[diagind(Δ)] .= 0
return (nothing, first(back(Δ)))
return (NO_FIELDS, NO_FIELDS, first(back(Δ)))
end
end