Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix AD issues with various kernels #154

Merged
merged 24 commits into from
Sep 8, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/zygote_adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,27 @@ end
end
return evaluate(dist, a, b), back
end


# FIXME
function Distances.pairwise(
dist::SqMahalanobis,
a::AbstractMatrix,
b::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing
)
sharanry marked this conversation as resolved.
Show resolved Hide resolved
function back(Δ::AbstractMatrix)
sharanry marked this conversation as resolved.
Show resolved Hide resolved
B_B_t = dist.qmat + transpose(dist.qmat)
a_b = map(
x -> (first(last(x)) - last(last(x)))*first(x),
zip(
Δ,
Iterators.product(eachslice(a, dims=dims), eachslice(b, dims=dims))
)
)
δa = reduce(hcat, sum(map(x -> B_B_t*x, a_b), dims=1))
δB = sum(map(x -> x*transpose(x), a_b))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would assume it should be possible to vectorize this code? What's the mathematical formula that you use here?

Copy link
Contributor Author

@sharanry sharanry Aug 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the same equations you mentioned earlier.
d((x-y)'*Q*(x-y))/dx = (Q + Q') * (x - y), d((x-y)'*Q*(x-y))/dy = - (Q + Q') * (x - y), and d((x-y)'*Q*(x-y))/dQ = (x - y)' * (x - y) .
But this is being done for all pairwise combinations together using map. It later sums these differences to get \deltaB and others.
Please note that the current implementation is not correct. I am still debugging it. (it is only partially matching the intended result) If you happen to find any obvious mistakes please let me know. I am facing trouble in reducing the results of individual pairwise pullbacks to the final pullback. The way I am summing them is probably wrong.

Copy link
Contributor Author

@sharanry sharanry Aug 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> using Distances, Random;

julia> rng = MersenneTwister(123);

julia> M1, M2 = rand(rng, 2,3), rand(rng, 2,3);

julia> dist = SqMahalanobis(rand(rng, 2,2))
SqMahalanobis{Float64}([0.8654121434083455 0.2856979003853177; 0.617491887982287 0.46384720826189474])

julia> pairwise(dist, M1, M2; dims=2)
3×3 Array{Float64,2}:
  0.371673   0.856348  0.742803
  0.0233992  0.274278  0.276694
 -0.036568   0.118487  0.0748149

julia> map(x -> evaluate(dist, first(x), last(x)), Iterators.product(eachslice(M1, dims=2), eachslice(M2, dims=2)))
3×3 Array{Float64,2}:
 0.541253   0.912421  0.673273
 0.0886328  0.285181  0.192394
 0.0868399  0.166227  0.0616321

@devmotion isn't this wrong or have I done something silly? They are equal in case of euclidean. I feel this is the root of the problem.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should work if dist.qmat is positive definite: JuliaStats/Distances.jl#174

Copy link
Contributor Author

@sharanry sharanry Aug 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still does not solve the differences in the computed adjoints for the covariance matrix Q. My current implementation matches the second adjoint.

julia> using Distances, LinearAlgebra, FiniteDifferences, Random

julia> FiniteDifferences.to_vec(dist::SqMahalanobis{Float64}) = vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...))

julia> rng = MersenneTwister(123);

julia> M1, M2 = rand(rng,3,1), rand(rng,3,1)
([0.7684476751965699; 0.940515000715187; 0.6739586945680673], [0.3954531123351086; 0.3132439558075186; 0.6625548164736534])

julia> Q = Matrix(Cholesky(rand(rng, 3, 3), 'U', 0))
3×3 Array{Float64,2}:
 0.343422   0.0638007  0.507151
 0.0638007  0.0386393  0.19528
 0.507151   0.19528    1.21186

julia> isposdef(Q)
true

julia> dist = SqMahalanobis(Q);

julia> fdm=FiniteDifferences.Central(5, 1);

julia> j′vp(fdm, pairwise, ones(1,1), dist, M1, M2)[1].qmat #A
3×3 Array{Float64,2}:
 0.139125  0.365187  -0.238366
 0.102751  0.393469  -0.404876
 0.246873  0.419183   0.000130048

julia> j′vp(fdm, evaluate, 1, dist, M1[:, 1], M2[:, 1])[1].qmat #B
3×3 Array{Float64,2}:
 0.139125    0.233969    0.00425358
 0.233969    0.393469    0.00715332
 0.00425358  0.00715332  0.000130048

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it is best if (Sq)Mahalanobis distance is actually parameterized by the decomposition of Q, i.e, the upper or lower triangular matrix which is not constrained.

Yes, that would be the most natural way to ensure that it is always positive semi-definite (if the diagonal is non-negative) and optimization is performed in the correct space. So I guess users would want to use this parameterization even if it is not enforced by KernelFunctions and not directly supported by SqMahalanobis by using something like

function mykernel(L)
    idxs = diagind(L)
    @inbounds for i in idxs
        L[i] = softplus(L[i])
    end
    return MahalanobisKernel(Array(L * L'))
end

Of course, it would be nice if (Sq)Mahalanobis would support specifying e.g. a Cholesky decomposition or PDMat directly (it could even be used for simplifying the computations since x'*Q*x = (L'*x)'*(L'*x) in this case), but can't we work around this by checking gradients of the mykernel setup instead of computing Q -> MahalanobisKernel(Q) directly? That's at least how we do it in DistributionsAD, e.g. in https://github.com/TuringLang/DistributionsAD.jl/blob/a96b159ab25aab67d1a2076726e8b9c392eb6fc7/test/ad/distributions.jl#L18-L34.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but can't we work around this by checking gradients of the mykernel setup instead of computing Q -> MahalanobisKernel(Q) directly?

Yeah that should work. Will try that out.

Regarding the issue with pairwise implementation which messes up FiniteDifferences results, do you suggest I override the implementation for the time being?

Copy link
Member

@devmotion devmotion Aug 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you test the suggested parameterization the implementation of pairwise shouldn't matter (since we do not test the intermediate step which might be affected by it).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. Could we also change our side of the parametrization? i.e, the way it is stored in the struct. We could continue to allow initialization using a full matrix. This should allow for seamless AD regardless of how the user decides to initialize them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we want to do that, I think this deserves some discussion first (and then a separate PR possibly). Ideally, Distances would just support arbitrary matrices and contain optimized implementations for specific array types. We just forward P to SqMahalanobis, so ideally we wouldn't perform any transformations or computations. I'm also a bit worried that focusing on a specific parameterization might make it difficult for users who would like to use a different one (but still no dense matrix) or might lead to confusing behaviour.

return (qmat=δB,), δa, -δa
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is som discrepancy between the simple case above and this pullback - intuitively, from the simple case above I would assume that δB = sum_{i, j} (a_i - b_j) * (a_i - b_j)^T * Δ_{i,j}. However, here you compute δB = sum_{i, j} (a_i - b_j) * (a_i - b_j)^T * Δ_{i,j}^2. Probably one of them is incorrect (table 7 in https://notendur.hi.is/jonasson/greinar/blas-rmd.pdf indicates that the pairwise one is incorrect). Can we add the derivation of the adjoints according to https://www.juliadiff.org/ChainRulesCore.jl/dev/arrays.html as docstrings or comments, or maybe even have a separate PR for the Mahalanobis fixes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I think a separate PR for mahalanobis fixes makes more sense.

end
return Distances.pairwise(dist, a, b, dims=dims), back
end