Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/jit inside jit #335

Merged
merged 5 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,16 @@
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
html_theme_options = {
"announcement":
"In 0.4.1, the jit argument in solvers will be "
"removed. Please jit the solvers explicitly.",
"repository_url": "https://github.com/ott-jax/ott",
"repository_branch": "main",
"path_to_docs": "docs/",
"use_repository_button": True,
"use_fullscreen_button": False,
"logo_only": True,
"pygment_light_style": "tango",
"pygment_dark_style": "monokai",
"launch_buttons": {
"colab_url": "https://colab.research.google.com",
"binderhub_url": "https://mybinder.org",
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Packages
between GMMs, or computing differentiable sort and quantile operations
:cite:`cuturi:19`.
- :mod:`ott.math` holds low-level mathematical primitives.
- :mod:`ott.utils` provides misc helper functions
- :mod:`ott.utils` provides misc helper functions.

.. toctree::
:maxdepth: 1
Expand Down
99 changes: 37 additions & 62 deletions docs/tutorials/notebooks/basic_ot_between_datasets.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions src/ott/solvers/linear/continuous_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def __call__( # noqa: D102
rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0)
) -> FreeBarycenterState:
# TODO(michalk8): no reason for iterations to be outside this class
run_fn = jax.jit(iterations, static_argnums=1) if self.jit else iterations
return run_fn(self, bar_size, bar_prob, x_init, rng)
return iterations(self, bar_size, bar_prob, x_init, rng)

def init_state(
self,
Expand Down
6 changes: 1 addition & 5 deletions src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,6 @@ class Sinkhorn:
gradients have been stopped. This is useful when carrying out first order
differentiation, and is only valid (as with ``implicit_differentiation``)
when the algorithm has converged with a low tolerance.
jit: Whether to jit the iteration loop.
initializer: how to compute the initial potentials/scalings.
progress_fn: callback function which gets called during the Sinkhorn
iterations, so the user can display the error at each iteration,
Expand All @@ -705,7 +704,6 @@ def __init__(
parallel_dual_updates: bool = False,
recenter_potentials: bool = False,
use_danskin: Optional[bool] = None,
jit: bool = True,
implicit_diff: Optional[implicit_lib.ImplicitDiff
] = implicit_lib.ImplicitDiff(), # noqa: B008
initializer: Union[Literal["default", "gaussian", "sorting", "subsample"],
Expand All @@ -721,7 +719,6 @@ def __init__(
self._norm_error = norm_error
self.anderson = anderson
self.implicit_diff = implicit_diff
self.jit = jit

if momentum is not None:
self.momentum = acceleration.Momentum(
Expand Down Expand Up @@ -781,8 +778,7 @@ def __call__(
init_dual_a, init_dual_b = initializer(
ot_prob, *init, lse_mode=self.lse_mode, rng=rng
)
run_fn = jax.jit(run) if self.jit else run
return run_fn(ot_prob, self, (init_dual_a, init_dual_b))
return run(ot_prob, self, (init_dual_a, init_dual_b))

def lse_step(
self, ot_prob: linear_problem.LinearProblem, state: SinkhornState,
Expand Down
3 changes: 1 addition & 2 deletions src/ott/solvers/linear/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,7 @@ def __call__(
assert ot_prob.is_balanced, "Unbalanced case is not implemented."
initializer = self.create_initializer(ot_prob)
init = initializer(ot_prob, *init, rng=rng, **kwargs)
run_fn = jax.jit(run) if self.jit else run
return run_fn(ot_prob, self, init)
return run(ot_prob, self, init)

def _lr_costs(
self,
Expand Down
3 changes: 1 addition & 2 deletions src/ott/solvers/quadratic/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ def __call__(
initializer = self.create_initializer(prob)
init = initializer(prob, epsilon=self.epsilon, rng=rng1, **kwargs)

run_fn = jax.jit(iterations) if self.jit else iterations
out = run_fn(self, prob, init, rng2)
out = iterations(self, prob, init, rng2)
# TODO(lpapaxanthos): remove stop_gradient when using backprop
if self.is_low_rank:
linearization = prob.update_lr_linearization(
Expand Down
6 changes: 1 addition & 5 deletions src/ott/solvers/quadratic/gw_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class GromovWassersteinBarycenter(was_solver.WassersteinSolver):
min_iterations: Minimum number of iterations.
max_iterations: Maximum number of outermost iterations.
threshold: Convergence threshold.
jit: Whether to jit the iteration loop.
store_inner_errors: Whether to store the errors of the GW solver, as well
as its linear solver, at each iteration for each measure.
quad_solver: The GW solver.
Expand All @@ -78,7 +77,6 @@ def __init__(
min_iterations: int = 5,
max_iterations: int = 50,
threshold: float = 1e-3,
jit: bool = True,
store_inner_errors: bool = False,
quad_solver: Optional[gromov_wasserstein.GromovWasserstein] = None,
# TODO(michalk8): maintain the API compatibility with `was_solver`
Expand All @@ -93,7 +91,6 @@ def __init__(
max_iterations=max_iterations,
threshold=threshold,
store_inner_errors=store_inner_errors,
jit=jit,
)
if quad_solver is None:
kwargs["epsilon"] = epsilon
Expand All @@ -118,8 +115,7 @@ def __call__(
The solution.
"""
state = self.init_state(problem, bar_size, **kwargs)
run_fn = jax.jit(iterations) if self.jit else iterations
state = run_fn(self, problem, state)
state = iterations(self, problem, state)
return self.output_from_state(state)

def init_state(
Expand Down
3 changes: 0 additions & 3 deletions src/ott/solvers/was_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(
min_iterations: int = 5,
max_iterations: int = 50,
threshold: float = 1e-3,
jit: bool = True,
store_inner_errors: bool = False,
**kwargs: Any,
):
Expand Down Expand Up @@ -73,7 +72,6 @@ def __init__(
self.min_iterations = min_iterations
self.max_iterations = max_iterations
self.threshold = threshold
self.jit = jit
self.store_inner_errors = store_inner_errors
self._kwargs = kwargs

Expand All @@ -87,7 +85,6 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
"min_iterations": self.min_iterations,
"max_iterations": self.max_iterations,
"rank": self.rank,
"jit": self.jit,
"store_inner_errors": self.store_inner_errors,
**self._kwargs
})
Expand Down
2 changes: 1 addition & 1 deletion tests/geometry/graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def laplacian(G: jnp.ndarray) -> jnp.ndarray:
def test_graph_sinkhorn(self, rng: jax.random.PRNGKeyArray, jit: bool):

def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput:
solver = sinkhorn.Sinkhorn(lse_mode=False, jit=False)
solver = sinkhorn.Sinkhorn(lse_mode=False)
problem = linear_problem.LinearProblem(geom)
return solver(problem)

Expand Down
5 changes: 1 addition & 4 deletions tests/solvers/linear/sinkhorn_diff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ def test_autograd_sinkhorn(
def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float:
geom = pointcloud.PointCloud(x, y, epsilon=1e-1)
prob = linear_problem.LinearProblem(geom, a=a, b=b)
# TODO: fails with `jit=True`, investigate
solver = sinkhorn.Sinkhorn(lse_mode=lse_mode, jit=False)
solver = sinkhorn.Sinkhorn(lse_mode=lse_mode)
return solver(prob).reg_ot_cost

reg_ot_and_grad = jax.jit(jax.grad(reg_ot))
Expand Down Expand Up @@ -275,8 +274,6 @@ def loss_fn(x: jnp.ndarray,
lse_mode=lse_mode,
min_iterations=min_iter,
max_iterations=max_iter,
# TODO(cuturi): figure out why implicit diff breaks when `jit=True`
jit=False,
implicit_diff=implicit_diff,
)
out = solver(prob)
Expand Down
3 changes: 1 addition & 2 deletions tests/solvers/linear/sinkhorn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,7 @@ def test_sinkhorn_online_memory_jit(self, batch_size: int):
y = jax.random.uniform(rngs[1], (m, 2))
geom = pointcloud.PointCloud(x, y, batch_size=batch_size, epsilon=1)
problem = linear_problem.LinearProblem(geom)
solver = sinkhorn.Sinkhorn(jit=False)
solver = jax.jit(solver)
solver = jax.jit(sinkhorn.Sinkhorn())

out = solver(problem)
assert out.converged
Expand Down