From 8329e3c024d13e0cc1b5a8d8d9a739834f095634 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Mon, 16 Oct 2023 11:40:15 +0200 Subject: [PATCH 01/17] Get tests for Zygote to pass --- src/KernelFunctions.jl | 2 +- src/chainrules.jl | 111 +++++++++++++++++++++++++++++++++-- test/basekernels/periodic.jl | 7 ++- 3 files changed, 113 insertions(+), 7 deletions(-) diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 63205b5bf..80711b4ec 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -48,7 +48,7 @@ export tensor, ⊗, compose using Compat using ChainRulesCore: ChainRulesCore, Tangent, ZeroTangent, NoTangent -using ChainRulesCore: @thunk, InplaceableThunk +using ChainRulesCore: @thunk, InplaceableThunk, ProjectTo, unthunk using CompositionsBase using Distances using FillArrays diff --git a/src/chainrules.jl b/src/chainrules.jl index eebdf95b5..d68d8dc29 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -112,15 +112,118 @@ end function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) d = x - y sind = sinpi.(d) - abs2_sind_r = abs2.(sind) ./ s.r - val = sum(abs2_sind_r) - gradx = twoπ .* cospi.(d) .* sind ./ (s.r .^ 2) + abs2_sind_r = abs2.(sind) ./ s.r .^ 2 + val = sum(abs2_sind_r) + gradx = twoπ .* cospi.(d) .* sind ./ s.r .^ 2 function evaluate_pullback(Δ::Any) - return (r=-2Δ .* abs2_sind_r,), Δ * gradx, -Δ * gradx + return (r=-2Δ .* abs2_sind_r ./ s.r,), Δ * gradx, -Δ * gradx end return val, evaluate_pullback end +function ChainRulesCore.rrule( + ::typeof(Distances.pairwise), + d::Sinus, + x::AbstractMatrix; + dims = 2 +) + project_x = ProjectTo(x) + function pairwise_pullback(z̄) + Δ = unthunk(z̄) + n = size(x, dims) + x̄ = zero(x) + r̄ = zero(d.r) + if dims == 1 + for j in 1:n, i in 1:n + xi = view(x, i, :) + xj = view(x, j, :) + ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2 + r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3 + x̄[i, :] += ds + x̄[j, :] -= ds + end + elseif dims == 2 + for j in 1:n, i in 1:n + xi = view(x, :, i) + xj = view(x, :, j) + ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2 + r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3 + x̄[:, i] += ds + x̄[:, j] -= ds + end + end + NoTangent(), (r=r̄,), @thunk(project_x(x̄)) + end + return Distances.pairwise(d, x; dims), pairwise_pullback +end + +function ChainRulesCore.rrule( + ::typeof(Distances.pairwise), + d::Sinus, + x::AbstractMatrix, + y::AbstractMatrix; + dims = 2 +) + project_x = ProjectTo(x) + project_y = ProjectTo(y) + function pairwise_pullback(z̄) + Δ = unthunk(z̄) + n = size(x, dims) + m = size(y, dims) + x̄ = zero(x) + ȳ = zero(y) + r̄ = zero(d.r) + if dims == 1 + for j in 1:m, i in 1:n + xi = view(x, i, :) + yj = view(y, j, :) + ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2 + r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3 + x̄[i, :] += ds + ȳ[j, :] -= ds + end + elseif dims == 2 + for j in 1:m, i in 1:n + xi = view(x, :, i) + yj = view(y, :, j) + ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2 + r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3 + x̄[:, i] += ds + ȳ[:, j] -= ds + end + end + NoTangent(), (r=r̄,), @thunk(project_x(x̄)), @thunk(project_y(ȳ)) + end + return Distances.pairwise(d, x, y; dims), pairwise_pullback +end + +function ChainRulesCore.rrule( + ::typeof(Distances.colwise), + d::Sinus, + x::AbstractMatrix, + y::AbstractMatrix +) + project_x = ProjectTo(x) + project_y = ProjectTo(y) + function colwise_pullback(z̄) + Δ = unthunk(z̄) + n = size(x, 2) + x̄ = zero(x) + ȳ = zero(y) + r̄ = zero(d.r) + for i in 1:n + xi = view(x, :, i) + yi = view(y, :, i) + ds = twoπ .* Δ[i] .* sinpi.(xi .- yi) .* cospi.(xi .- yi) ./ d.r .^ 2 + r̄ .-= 2 .* Δ[i] .* sinpi.(xi .- yi) .^ 2 ./ d.r .^ 3 + x̄[:, i] += ds + ȳ[:, i] -= ds + end + NoTangent(), (r=r̄,), @thunk(project_x(x̄)), @thunk(project_y(ȳ)) + end + return Distances.colwise(d, x, y), colwise_pullback +end + ## Reverse Rules SqMahalanobis function ChainRulesCore.rrule( diff --git a/test/basekernels/periodic.jl b/test/basekernels/periodic.jl index fb149dff5..34dbcd36d 100644 --- a/test/basekernels/periodic.jl +++ b/test/basekernels/periodic.jl @@ -15,7 +15,10 @@ TestUtils.test_interface(PeriodicKernel(; r=[0.9, 0.9]), ColVecs{Float64}) TestUtils.test_interface(PeriodicKernel(; r=[0.8, 0.7]), RowVecs{Float64}) - # test_ADs(r->PeriodicKernel(r =exp.(r)), log.(r), ADs = [:ForwardDiff, :ReverseDiff]) - @test_broken "Undefined adjoint for Sinus metric, and failing randomly for ForwardDiff and ReverseDiff" + test_ADs( + r->PeriodicKernel(r = exp.(r)), log.(r), + ADs = [:Zygote] + ) + # @test_broken "Undefined adjoint for Sinus metric, and failing randomly for ForwardDiff and ReverseDiff" test_params(k, (r,)) end From 57a24b455ba18e512857c040d1ff6847e6a4a835 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Mon, 16 Oct 2023 11:40:15 +0200 Subject: [PATCH 02/17] Also test other AD backends --- test/basekernels/periodic.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test/basekernels/periodic.jl b/test/basekernels/periodic.jl index 34dbcd36d..8c4c16cf9 100644 --- a/test/basekernels/periodic.jl +++ b/test/basekernels/periodic.jl @@ -15,10 +15,6 @@ TestUtils.test_interface(PeriodicKernel(; r=[0.9, 0.9]), ColVecs{Float64}) TestUtils.test_interface(PeriodicKernel(; r=[0.8, 0.7]), RowVecs{Float64}) - test_ADs( - r->PeriodicKernel(r = exp.(r)), log.(r), - ADs = [:Zygote] - ) - # @test_broken "Undefined adjoint for Sinus metric, and failing randomly for ForwardDiff and ReverseDiff" + test_ADs(r -> PeriodicKernel(r = exp.(r)), log.(r)) test_params(k, (r,)) end From 43d1f3a8a3daf3efc868e615d38cd9c24c24279d Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Mon, 16 Oct 2023 11:50:14 +0200 Subject: [PATCH 03/17] Add return keyword Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index d68d8dc29..a019012b9 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -192,7 +192,7 @@ function ChainRulesCore.rrule( ȳ[:, j] -= ds end end - NoTangent(), (r=r̄,), @thunk(project_x(x̄)), @thunk(project_y(ȳ)) + return NoTangent(), (r=r̄,), @thunk(project_x(x̄)), @thunk(project_y(ȳ)) end return Distances.pairwise(d, x, y; dims), pairwise_pullback end From ea5156d31c19171e88524ca3ad67f6ac22317e2b Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Mon, 16 Oct 2023 11:50:25 +0200 Subject: [PATCH 04/17] Remove whitespace Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index a019012b9..d44a5aa2e 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -113,7 +113,7 @@ function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) d = x - y sind = sinpi.(d) abs2_sind_r = abs2.(sind) ./ s.r .^ 2 - val = sum(abs2_sind_r) + val = sum(abs2_sind_r) gradx = twoπ .* cospi.(d) .* sind ./ s.r .^ 2 function evaluate_pullback(Δ::Any) return (r=-2Δ .* abs2_sind_r ./ s.r,), Δ * gradx, -Δ * gradx From a8baa31144c0222d423c2f884fdedce835b2871a Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Mon, 16 Oct 2023 11:50:52 +0200 Subject: [PATCH 05/17] Improve code formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/chainrules.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index d44a5aa2e..eaa9561b8 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -122,10 +122,7 @@ function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) end function ChainRulesCore.rrule( - ::typeof(Distances.pairwise), - d::Sinus, - x::AbstractMatrix; - dims = 2 + ::typeof(Distances.pairwise), d::Sinus, x::AbstractMatrix; dims=2 ) project_x = ProjectTo(x) function pairwise_pullback(z̄) From 498cdd8ce07c9b582f9565e9f8b39c20f5f90d62 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Mon, 16 Oct 2023 11:51:03 +0200 Subject: [PATCH 06/17] Add return keyword Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index eaa9561b8..aa0cd836f 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -149,7 +149,7 @@ function ChainRulesCore.rrule( x̄[:, j] -= ds end end - NoTangent(), (r=r̄,), @thunk(project_x(x̄)) + return NoTangent(), (r=r̄,), @thunk(project_x(x̄)) end return Distances.pairwise(d, x; dims), pairwise_pullback end From 9e598d5903a7b31cf132720f2623b39bdcc28cf3 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Mon, 16 Oct 2023 11:51:14 +0200 Subject: [PATCH 07/17] Improve code formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/chainrules.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index aa0cd836f..4bc4bc4c7 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -155,11 +155,7 @@ function ChainRulesCore.rrule( end function ChainRulesCore.rrule( - ::typeof(Distances.pairwise), - d::Sinus, - x::AbstractMatrix, - y::AbstractMatrix; - dims = 2 + ::typeof(Distances.pairwise), d::Sinus, x::AbstractMatrix, y::AbstractMatrix; dims=2 ) project_x = ProjectTo(x) project_y = ProjectTo(y) From 8f01fc4ad1f5040f26d8f72342c290a48c5b682a Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Mon, 16 Oct 2023 11:51:24 +0200 Subject: [PATCH 08/17] Improve code formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/chainrules.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 4bc4bc4c7..735c3877e 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -191,10 +191,7 @@ function ChainRulesCore.rrule( end function ChainRulesCore.rrule( - ::typeof(Distances.colwise), - d::Sinus, - x::AbstractMatrix, - y::AbstractMatrix + ::typeof(Distances.colwise), d::Sinus, x::AbstractMatrix, y::AbstractMatrix ) project_x = ProjectTo(x) project_y = ProjectTo(y) From 691921e416255d88a5179bbed94c29d661a38ee0 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Mon, 16 Oct 2023 11:51:33 +0200 Subject: [PATCH 09/17] Add return keyword Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 735c3877e..3a5a4ad43 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -209,7 +209,7 @@ function ChainRulesCore.rrule( x̄[:, i] += ds ȳ[:, i] -= ds end - NoTangent(), (r=r̄,), @thunk(project_x(x̄)), @thunk(project_y(ȳ)) + return NoTangent(), (r=r̄,), @thunk(project_x(x̄)), @thunk(project_y(ȳ)) end return Distances.colwise(d, x, y), colwise_pullback end From cb3a7f31462e53fde203a7b96cb96b1161860fe7 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Mon, 16 Oct 2023 11:51:46 +0200 Subject: [PATCH 10/17] Improve code formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/basekernels/periodic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/basekernels/periodic.jl b/test/basekernels/periodic.jl index 8c4c16cf9..540947b1b 100644 --- a/test/basekernels/periodic.jl +++ b/test/basekernels/periodic.jl @@ -15,6 +15,6 @@ TestUtils.test_interface(PeriodicKernel(; r=[0.9, 0.9]), ColVecs{Float64}) TestUtils.test_interface(PeriodicKernel(; r=[0.8, 0.7]), RowVecs{Float64}) - test_ADs(r -> PeriodicKernel(r = exp.(r)), log.(r)) + test_ADs(r -> PeriodicKernel(; r=exp.(r)), log.(r)) test_params(k, (r,)) end From afe209c505402789ab553fde808144e1221455f3 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Thu, 26 Oct 2023 15:48:00 +0200 Subject: [PATCH 11/17] Add basic tests for `rrule`s --- src/chainrules.jl | 13 +++++++++---- test/Project.toml | 2 ++ test/chainrules.jl | 8 ++++++++ test/runtests.jl | 2 ++ 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 3a5a4ad43..04cfbbd26 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -116,7 +116,9 @@ function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) val = sum(abs2_sind_r) gradx = twoπ .* cospi.(d) .* sind ./ s.r .^ 2 function evaluate_pullback(Δ::Any) - return (r=-2Δ .* abs2_sind_r ./ s.r,), Δ * gradx, -Δ * gradx + r̄ = -2Δ .* abs2_sind_r ./ s.r + s̄ = ChainRulesCore.Tangent{typeof(s)}(; r=r̄) + return s̄, Δ * gradx, -Δ * gradx end return val, evaluate_pullback end @@ -149,7 +151,8 @@ function ChainRulesCore.rrule( x̄[:, j] -= ds end end - return NoTangent(), (r=r̄,), @thunk(project_x(x̄)) + d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄) + return NoTangent(), d̄, @thunk(project_x(x̄)) end return Distances.pairwise(d, x; dims), pairwise_pullback end @@ -185,7 +188,8 @@ function ChainRulesCore.rrule( ȳ[:, j] -= ds end end - return NoTangent(), (r=r̄,), @thunk(project_x(x̄)), @thunk(project_y(ȳ)) + d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄) + return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ)) end return Distances.pairwise(d, x, y; dims), pairwise_pullback end @@ -209,7 +213,8 @@ function ChainRulesCore.rrule( x̄[:, i] += ds ȳ[:, i] -= ds end - return NoTangent(), (r=r̄,), @thunk(project_x(x̄)), @thunk(project_y(ȳ)) + d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄) + return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ)) end return Distances.colwise(d, x, y), colwise_pullback end diff --git a/test/Project.toml b/test/Project.toml index e16f39a6c..ebada5716 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,7 @@ [deps] AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/test/chainrules.jl b/test/chainrules.jl index 03c2c3b1f..dbd4223ad 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -28,4 +28,12 @@ SqMahalanobis(Qxy[1])(Qxy[2], Qxy[3]) end end + + @testset "rrules for Sinus(r=$r)" for r in (rand(3),) + dist = KernelFunctions.Sinus(r) + ddist = (r = ones(length(r)),) + test_rrule(dist, rand(3), rand(3)) + test_rrule(Distances.pairwise, dist, rand(3, 2); fkwargs=(dims=2,)) + test_rrule(Distances.pairwise, dist, rand(3, 2), rand(3, 3); fkwargs=(dims=2,)) + end end diff --git a/test/runtests.jl b/test/runtests.jl index caf43cb91..e054b992a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,7 @@ using KernelFunctions using AxisArrays +using ChainRulesCore +using ChainRulesTestUtils using Distances using Documenter using Functors: functor From fbc1fd6a1a8d0ff02e87ed8690628ef26c944b75 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Mon, 30 Oct 2023 10:56:35 +0100 Subject: [PATCH 12/17] Add test for `colwise` of `Sinus` --- test/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index dbd4223ad..8163c84ea 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -31,9 +31,9 @@ @testset "rrules for Sinus(r=$r)" for r in (rand(3),) dist = KernelFunctions.Sinus(r) - ddist = (r = ones(length(r)),) test_rrule(dist, rand(3), rand(3)) test_rrule(Distances.pairwise, dist, rand(3, 2); fkwargs=(dims=2,)) test_rrule(Distances.pairwise, dist, rand(3, 2), rand(3, 3); fkwargs=(dims=2,)) + test_rrule(Distances.colwise, dist, rand(3, 2), rand(3, 2)) end end From 73d940d57facc829e3efa6255aa1cbd8a47144e4 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Mon, 30 Oct 2023 12:33:01 +0100 Subject: [PATCH 13/17] Cover StaticArrays --- src/chainrules.jl | 10 +++++----- test/chainrules.jl | 19 +++++++++++++++---- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 04cfbbd26..551f7b0b6 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -130,7 +130,7 @@ function ChainRulesCore.rrule( function pairwise_pullback(z̄) Δ = unthunk(z̄) n = size(x, dims) - x̄ = zero(x) + x̄ = collect(zero(x)) r̄ = zero(d.r) if dims == 1 for j in 1:n, i in 1:n @@ -166,8 +166,8 @@ function ChainRulesCore.rrule( Δ = unthunk(z̄) n = size(x, dims) m = size(y, dims) - x̄ = zero(x) - ȳ = zero(y) + x̄ = collect(zero(x)) + ȳ = collect(zero(y)) r̄ = zero(d.r) if dims == 1 for j in 1:m, i in 1:n @@ -202,8 +202,8 @@ function ChainRulesCore.rrule( function colwise_pullback(z̄) Δ = unthunk(z̄) n = size(x, 2) - x̄ = zero(x) - ȳ = zero(y) + x̄ = collect(zero(x)) + ȳ = collect(zero(y)) r̄ = zero(d.r) for i in 1:n xi = view(x, :, i) diff --git a/test/chainrules.jl b/test/chainrules.jl index 8163c84ea..1cc59852d 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -31,9 +31,20 @@ @testset "rrules for Sinus(r=$r)" for r in (rand(3),) dist = KernelFunctions.Sinus(r) - test_rrule(dist, rand(3), rand(3)) - test_rrule(Distances.pairwise, dist, rand(3, 2); fkwargs=(dims=2,)) - test_rrule(Distances.pairwise, dist, rand(3, 2), rand(3, 3); fkwargs=(dims=2,)) - test_rrule(Distances.colwise, dist, rand(3, 2), rand(3, 2)) + @testset "$type" for type in (Vector, SVector{3}) + test_rrule(dist, type(rand(3)), type(rand(3))) + end + @testset "$type1, $type2" for type1 in (Matrix, SMatrix{3, 2}), + type2 in (Matrix, SMatrix{3, 4}) + test_rrule( + Distances.pairwise, dist, type1(rand(3, 2)); + fkwargs=(dims=2,) + ) + test_rrule( + Distances.pairwise, dist, type1(rand(3, 2)), type2(rand(3, 4)); + fkwargs=(dims=2,) + ) + test_rrule(Distances.colwise, dist, type1(rand(3, 2)), type1(rand(3, 2))) + end end end From d0c889c74dd6965395a4c76f12169086c67cd1ab Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Mon, 30 Oct 2023 12:49:30 +0100 Subject: [PATCH 14/17] Improve formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/chainrules.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index 1cc59852d..f43803c11 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -34,12 +34,10 @@ @testset "$type" for type in (Vector, SVector{3}) test_rrule(dist, type(rand(3)), type(rand(3))) end - @testset "$type1, $type2" for type1 in (Matrix, SMatrix{3, 2}), - type2 in (Matrix, SMatrix{3, 4}) - test_rrule( - Distances.pairwise, dist, type1(rand(3, 2)); - fkwargs=(dims=2,) - ) + @testset "$type1, $type2" for type1 in (Matrix, SMatrix{3,2}), + type2 in (Matrix, SMatrix{3,4}) + + test_rrule(Distances.pairwise, dist, type1(rand(3, 2)); fkwargs=(dims=2,)) test_rrule( Distances.pairwise, dist, type1(rand(3, 2)), type2(rand(3, 4)); fkwargs=(dims=2,) From 1118c90505c37a71935f19d7b0846409e83aedd5 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Mon, 30 Oct 2023 12:49:38 +0100 Subject: [PATCH 15/17] Improve formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/chainrules.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index f43803c11..b1614af9d 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -39,8 +39,11 @@ test_rrule(Distances.pairwise, dist, type1(rand(3, 2)); fkwargs=(dims=2,)) test_rrule( - Distances.pairwise, dist, type1(rand(3, 2)), type2(rand(3, 4)); - fkwargs=(dims=2,) + Distances.pairwise, + dist, + type1(rand(3, 2)), + type2(rand(3, 4)); + fkwargs=(dims=2,), ) test_rrule(Distances.colwise, dist, type1(rand(3, 2)), type1(rand(3, 2))) end From 6add335f99576646e98947e75400ce15b3546a3e Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Tue, 6 Feb 2024 19:42:47 +0100 Subject: [PATCH 16/17] Remove whitespace Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/chainrules.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 82d2875e2..1f1a9a81b 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -219,7 +219,6 @@ function ChainRulesCore.rrule( return Distances.colwise(d, x, y), colwise_pullback end - ## Reverse Rules for matrix wrappers function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) From 16a3d029c6a908a8374b302928274ba1ef6fc067 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 6 Feb 2024 23:33:46 +0100 Subject: [PATCH 17/17] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c46c5c334..16bd8514d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.60" +version = "0.10.61" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"