-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sink div for LR #568
sink div for LR #568
Conversation
Checking if this seems ok, if it does, will write some tests to check non-negativitiy, in accordance with paper |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #568 +/- ##
==========================================
+ Coverage 88.36% 88.41% +0.05%
==========================================
Files 72 72
Lines 7699 7716 +17
Branches 1102 1107 +5
==========================================
+ Hits 6803 6822 +19
+ Misses 745 743 -2
Partials 151 151
|
@marcocuturi I've modified one test, but differentiating w.r.t. the divergence is not yet working (problems when unrolling): src/ott/math/fixed_point_loop.py:226: in fixpoint_iter_bwd
_, g_state, g_constants = jax.lax.while_loop(
src/ott/math/fixed_point_loop.py:209: in unrolled_body_fn
_, pullback = jax.vjp(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
iteration = Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=3/0)>
constants = (<ott.problems.linear.linear_problem.LinearProblem object at 0x3436e2d10>, <ott.solvers.linear.sinkhorn_lr.LRSinkhorn object at 0x3436e2d70>)
state = LRSinkhornState(q=Traced<ShapedArray(float32[13,2])>with<JVPTrace(level=5/0)> with
primal = Traced<ShapedArray(float...), None)
recipe = LambdaBinding(), crossed_threshold=Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=3/0)>)
def unrolled_body_fn_no_errors(iteration, constants, state):
compute_error_flags = jnp.zeros((inner_iterations,), dtype=bool)
def one_iteration(iteration_state, compute_error):
iteration, state = iteration_state
state = body_fn(iteration, constants, state, compute_error)
iteration += 1
return (iteration, state), None
> iteration_state, _ = jax.lax.scan(
one_iteration, (iteration, state), compute_error_flags
)
E jax._src.interpreters.ad.CustomVJPException: Detected differentiation of a custom_vjp function with respect to a closed-over value. That isn't supported because the custom VJP rule only specifies how to differentiate the custom_vjp function with respect to explicit input parameters. Try pa |
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Thanks! The bug came from computing I think the syntax above computes the entire differentiation graph for all elements outputted in a Computing the gradient of a As a result, I have reinstated the very primitive differentiability test. |
simple patch for issue described in #485, when user wishes to compute a sinkhorn divergence using LR sinkhorn as a primitive.