Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sink div for LR #568

Merged
merged 6 commits into from
Aug 9, 2024
Merged

sink div for LR #568

merged 6 commits into from
Aug 9, 2024

Conversation

marcocuturi
Copy link
Contributor

simple patch for issue described in #485, when user wishes to compute a sinkhorn divergence using LR sinkhorn as a primitive.

@marcocuturi marcocuturi marked this pull request as draft July 30, 2024 13:39
@marcocuturi
Copy link
Contributor Author

Checking if this seems ok, if it does, will write some tests to check non-negativitiy, in accordance with paper

Copy link

codecov bot commented Jul 30, 2024

Codecov Report

Attention: Patch coverage is 90.90909% with 3 lines in your changes missing coverage. Please review.

Project coverage is 88.41%. Comparing base (478f034) to head (b66005d).
Report is 26 commits behind head on main.

Files with missing lines Patch % Lines
src/ott/tools/sinkhorn_divergence.py 90.00% 2 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #568      +/-   ##
==========================================
+ Coverage   88.36%   88.41%   +0.05%     
==========================================
  Files          72       72              
  Lines        7699     7716      +17     
  Branches     1102     1107       +5     
==========================================
+ Hits         6803     6822      +19     
+ Misses        745      743       -2     
  Partials      151      151              
Files with missing lines Coverage Δ
src/ott/math/utils.py 91.30% <100.00%> (+0.19%) ⬆️
src/ott/solvers/linear/sinkhorn_lr.py 98.98% <100.00%> (+0.32%) ⬆️
src/ott/tools/progot.py 29.07% <ø> (ø)
src/ott/tools/sinkhorn_divergence.py 91.66% <90.00%> (+0.62%) ⬆️

... and 1 file with indirect coverage changes

@michalk8 michalk8 added the enhancement New feature or request label Aug 2, 2024
@michalk8 michalk8 marked this pull request as ready for review August 8, 2024 16:29
@michalk8
Copy link
Collaborator

michalk8 commented Aug 8, 2024

@marcocuturi I've modified one test, but differentiating w.r.t. the divergence is not yet working (problems when unrolling):

src/ott/math/fixed_point_loop.py:226: in fixpoint_iter_bwd
    _, g_state, g_constants = jax.lax.while_loop(
src/ott/math/fixed_point_loop.py:209: in unrolled_body_fn
    _, pullback = jax.vjp(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

iteration = Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=3/0)>
constants = (<ott.problems.linear.linear_problem.LinearProblem object at 0x3436e2d10>, <ott.solvers.linear.sinkhorn_lr.LRSinkhorn object at 0x3436e2d70>)
state = LRSinkhornState(q=Traced<ShapedArray(float32[13,2])>with<JVPTrace(level=5/0)> with
  primal = Traced<ShapedArray(float...), None)
    recipe = LambdaBinding(), crossed_threshold=Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=3/0)>)

    def unrolled_body_fn_no_errors(iteration, constants, state):
      compute_error_flags = jnp.zeros((inner_iterations,), dtype=bool)
    
      def one_iteration(iteration_state, compute_error):
        iteration, state = iteration_state
        state = body_fn(iteration, constants, state, compute_error)
        iteration += 1
        return (iteration, state), None
    
>     iteration_state, _ = jax.lax.scan(
          one_iteration, (iteration, state), compute_error_flags
      )
E     jax._src.interpreters.ad.CustomVJPException: Detected differentiation of a custom_vjp function with respect to a closed-over value. That isn't supported because the custom VJP rule only specifies how to differentiate the custom_vjp function with respect to explicit input parameters. Try pa

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@marcocuturi
Copy link
Contributor Author

marcocuturi commented Aug 9, 2024

Thanks!

The bug came from computing jax.grad(lambda x : div(x).divergence) rather than defining a proper closure, with a function that outputs the desired divergence directly, before taking gradient.

I think the syntax above computes the entire differentiation graph for all elements outputted in a LRSinkhornOutput (notably the q,r,g) and then only gathers the divergence gradient? If that's the case, it was expected we would run into trouble as we haven't properly defined differentiation rules for solutions in LRSinkhorn.

Computing the gradient of a LRSinkhornOutput.reg_ot_cost w.r.t. input location x was working before this PR, so defining a proper function outputting divergence works.

As a result, I have reinstated the very primitive differentiability test.

@marcocuturi marcocuturi merged commit 826b67f into main Aug 9, 2024
12 checks passed
@marcocuturi marcocuturi deleted the sinkdiv branch August 9, 2024 08:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants