Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use Tullio for pairwise distances #386

Draft
wants to merge 23 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Expand Down
7 changes: 7 additions & 0 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ using IrrationalConstants: logtwo, twoπ, invsqrt2
using LogExpFunctions: softplus
using StatsBase
using TensorCore
using Tullio
using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield

# Hack to work around Zygote type inference problems.
Expand All @@ -66,8 +67,14 @@ const Distances_pairwise = Distances.pairwise
abstract type Kernel end
abstract type SimpleKernel <: Kernel end

# A general binary op type not respecting Distances metric rules
abstract type AbstractBinaryOp end
const BinaryOp = Union{AbstractBinaryOp,Distances.PreMetric}

include("utils.jl")

include("distances/pairwise.jl")
include("distances/euclidean.jl")
include("distances/dotproduct.jl")
include("distances/delta.jl")
include("distances/sinus.jl")
Expand Down
6 changes: 2 additions & 4 deletions src/distances/delta.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Delta is not following the PreMetric rules since d(x, x) == 1
struct Delta <: Distances.UnionPreMetric end
struct Delta <: AbstractBinaryOp end

# Basic definitions
(dist::Delta)(a::Number, b::Number) = a == b
Base.@propagate_inbounds function (dist::Delta)(
a::AbstractArray{<:Number}, b::AbstractArray{<:Number}
Expand All @@ -14,5 +14,3 @@ Base.@propagate_inbounds function (dist::Delta)(
end
return a == b
end

Distances.result_type(::Delta, Ta::Type, Tb::Type) = Bool
31 changes: 16 additions & 15 deletions src/distances/dotproduct.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
## DotProduct is not following the PreMetric rules since d(x, x) != 0 and d(x, y) >= 0 for all x, y
struct DotProduct <: Distances.UnionPreMetric end
struct DotProduct <: AbstractBinaryOp end

@inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector)
@boundscheck if length(a) != length(b)
throw(
DimensionMismatch(
"first array has length $(length(a)) which does not match the length of the second, $(length(b)).",
),
)
end
return dot(a, b)
(::DotProduct)(a::AbstractVector, b::AbstractVector) = dot(a, b)

(::DotProduct)(a::Number, b::Number) = a * b

function pairwise(::DotProduct, x::ColVecs, y::ColVecs)
return @tullio out[i, j] := x.X[k, i] * y.X[k, j]
end

Distances.result_type(::DotProduct, Ta::Type, Tb::Type) = promote_type(Ta, Tb)
function pairwise(::DotProduct, x::RowVecs, y::RowVecs)
return @tullio out[i, j] := x.X[i, k] * y.X[j, k]
end

function colwise(::DotProduct, x::RowVecs, y::RowVecs=x)
theogf marked this conversation as resolved.
Show resolved Hide resolved
return @tullio out[i] := x.X[i, k] * y.X[i, k]
end

@inline Distances.eval_op(::DotProduct, a::Real, b::Real) = a * b
@inline function (dist::DotProduct)(a::AbstractArray, b::AbstractArray)
return Distances._evaluate(dist, a, b)
function colwise(::DotProduct, x::ColVecs, y::ColVecs=x)
theogf marked this conversation as resolved.
Show resolved Hide resolved
return @tullio out[i] := x.X[k, i] * y.X[k, i]
end
@inline (dist::DotProduct)(a::Number, b::Number) = a * b
19 changes: 19 additions & 0 deletions src/distances/euclidean.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Tullio specialization for Euclidean and SqEuclidean metrics

function pairwise(::Euclidean, x::ColVecs, y::ColVecs)
return @tullio out[i, j] :=
sqrt <| (x.X[k, i] - y.X[k, j])^2
theogf marked this conversation as resolved.
Show resolved Hide resolved
end

function pairwise(::Euclidean, x::RowVecs, y::RowVecs)
return @tullio out[i, j] :=
sqrt <| x.X[i, k]^2 - 2 * x.X[i, k] * y.X[j, k] + y.X[j, k]^2
theogf marked this conversation as resolved.
Show resolved Hide resolved
end

function pairwise(::SqEuclidean, x::ColVecs, y::ColVecs)
return @tullio out[i, j] := x.X[k, i]^2 - 2 * x.X[k, i] * y.X[k, j] + y.X[k, j]^2
theogf marked this conversation as resolved.
Show resolved Hide resolved
end

function pairwise(::SqEuclidean, x::RowVecs, y::RowVecs)
return @tullio out[i, j] := x.X[i, k]^2 - 2 * x.X[i, k] * y.X[j, k] + y.X[j, k]^2
theogf marked this conversation as resolved.
Show resolved Hide resolved
end
66 changes: 12 additions & 54 deletions src/distances/pairwise.jl
Original file line number Diff line number Diff line change
@@ -1,70 +1,28 @@
# Add our own pairwise function to be able to apply it on vectors

function pairwise(d::PreMetric, X::AbstractVector, Y::AbstractVector)
return broadcast(d, X, permutedims(Y))
function pairwise(d::BinaryOp, X::AbstractVector, Y::AbstractVector=X)
return @tullio out[i, j] := d(X[i], Y[j])
end

pairwise(d::PreMetric, X::AbstractVector) = pairwise(d, X, X)

function pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector, Y::AbstractVector)
return broadcast!(d, out, X, permutedims(Y))
end

pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector) = pairwise!(out, d, X, X)

function pairwise(d::PreMetric, x::AbstractVector{<:Real})
return Distances_pairwise(d, reshape(x, :, 1); dims=1)
end

function pairwise(d::PreMetric, x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
return Distances_pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
end

function pairwise!(out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real})
return Distances.pairwise!(out, d, reshape(x, :, 1); dims=1)
end

function pairwise!(
out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}
)
return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
function pairwise!(out::AbstractMatrix, d::BinaryOp, X::AbstractVector, Y::AbstractVector=X)
return @tullio out[i, j] = d(X[i], Y[j])
end

# Also defines the colwise method for abstractvectors

function colwise(d::PreMetric, x::AbstractVector)
# We have different methods for PreMetric and AbstractBinaryOp
# Since colwise on AbstractBinaryOp is not guaranteed to be equal to 0
function colwise(d::Distances.PreMetric, x::AbstractVector)
return zeros(Distances.result_type(d, x, x), length(x)) # Valid since d(x,x) == 0 by definition
end

function colwise(d::PreMetric, x::ColVecs)
function colwise(d::Distances.PreMetric, x::Union{ColVecs,RowVecs})
return zeros(Distances.result_type(d, x.X, x.X), length(x)) # Valid since d(x,x) == 0 by definition
end

function colwise(d::PreMetric, x::RowVecs)
return zeros(Distances.result_type(d, x.X, x.X), length(x)) # Valid since d(x,x) == 0 by definition
end

## The following is a hack for DotProduct and Delta to still work
function colwise(d::Distances.UnionPreMetric, x::ColVecs)
return Distances.colwise(d, x.X, x.X)
end

function colwise(d::Distances.UnionPreMetric, x::RowVecs)
return Distances.colwise(d, x.X', x.X')
end

function colwise(d::Distances.UnionPreMetric, x::AbstractVector)
return map(d, x, x)
end

function colwise(d::PreMetric, x::ColVecs, y::ColVecs)
return Distances.colwise(d, x.X, y.X)
end

function colwise(d::PreMetric, x::RowVecs, y::RowVecs)
return Distances.colwise(d, x.X', y.X')
function colwise(d::AbstractBinaryOp, x::AbstractVector)
return @tullio out[i] := d(x[i], x[i])
end

function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector)
return map(d, x, y)
function colwise(d::BinaryOp, x::AbstractVector, y::AbstractVector)
return @tullio out[i] := d(x[i], y[i])
end
4 changes: 2 additions & 2 deletions src/distances/sinus.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
struct Sinus{T} <: Distances.UnionSemiMetric
r::Vector{T}
struct Sinus{T,V<:AbstractVector{T}} <: Distances.SemiMetric
r::V
end

Distances.parameters(d::Sinus) = d.r
Expand Down
38 changes: 7 additions & 31 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,6 @@ Base.vcat(a::ColVecs, b::ColVecs) = ColVecs(hcat(a.X, b.X))

dim(x::ColVecs) = size(x.X, 1)

pairwise(d::PreMetric, x::ColVecs) = Distances_pairwise(d, x.X; dims=2)
pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances_pairwise(d, x.X, y.X; dims=2)
function pairwise(d::PreMetric, x::AbstractVector, y::ColVecs)
return Distances_pairwise(d, reduce(hcat, x), y.X; dims=2)
end
function pairwise(d::PreMetric, x::ColVecs, y::AbstractVector)
return Distances_pairwise(d, x.X, reduce(hcat, y); dims=2)
end
function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs)
return Distances.pairwise!(out, d, x.X; dims=2)
end
function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs, y::ColVecs)
return Distances.pairwise!(out, d, x.X, y.X; dims=2)
end

"""
RowVecs(X::AbstractMatrix)

Expand Down Expand Up @@ -150,25 +135,16 @@ Base.vcat(a::RowVecs, b::RowVecs) = RowVecs(vcat(a.X, b.X))

dim(x::RowVecs) = size(x.X, 2)

pairwise(d::PreMetric, x::RowVecs) = Distances_pairwise(d, x.X; dims=1)
pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = Distances_pairwise(d, x.X, y.X; dims=1)
function pairwise(d::PreMetric, x::AbstractVector, y::RowVecs)
return Distances_pairwise(d, permutedims(reduce(hcat, x)), y.X; dims=1)
end
function pairwise(d::PreMetric, x::RowVecs, y::AbstractVector)
return Distances_pairwise(d, x.X, permutedims(reduce(hcat, y)); dims=1)
end
function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs)
return Distances.pairwise!(out, d, x.X; dims=1)
# Resolve ambiguity error for ColVecs vs RowVecs. #346
pairwise(d::BinaryOp, x::ColVecs, y::RowVecs) = pairwise(d, x, ColVecs(permutedims(y.X)))
pairwise(d::BinaryOp, x::RowVecs, y::ColVecs) = pairwise(d, ColVecs(permutedims(x.X)), y)
function pairwise!(out::AbstractMatrix, d::BinaryOp, x::ColVecs, y::RowVecs)
return pairwise!(out, d, x, ColVecs(permutedims(y.X)))
end
function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs, y::RowVecs)
return Distances.pairwise!(out, d, x.X, y.X; dims=1)
function pairwise!(out::AbstractMatrix, d::BinaryOp, x::RowVecs, y::ColVecs)
return pairwise!(out, d, ColVecs(permutedims(x.X)), y)
end

# Resolve ambiguity error for ColVecs vs RowVecs. #346
pairwise(d::PreMetric, x::ColVecs, y::RowVecs) = pairwise(d, x, ColVecs(permutedims(y.X)))
pairwise(d::PreMetric, x::RowVecs, y::ColVecs) = pairwise(d, ColVecs(permutedims(x.X)), y)

dim(x) = 0 # This is the passes-by-default choice. For a proper check, implement `KernelFunctions.dim` for your datatype.
dim(x::AbstractVector) = dim(first(x))
dim(x::AbstractVector{<:AbstractVector{<:Real}}) = length(first(x))
Expand Down