From 4fad710d4a6fef382c3d1809d7319f3c5527ef75 Mon Sep 17 00:00:00 2001 From: michalk8 <46717574+michalk8@users.noreply.github.com> Date: Fri, 3 Feb 2023 13:39:19 +0100 Subject: [PATCH] Re-add `jit` argument (#221) * Re-add `jit` argument * [ci skip] Add `environment.yml` for `binder` * Fix missing `static_argnums` * Adjust test to have `jit=False` * Fix tests * Fix `typing_extensions` in tests * Fix linter --- .pre-commit-config.yaml | 4 +- src/ott/problems/linear/barycenter_problem.py | 6 +- .../problems/quadratic/quadratic_problem.py | 1 - .../solvers/linear/continuous_barycenter.py | 3 +- src/ott/solvers/linear/sinkhorn.py | 6 +- src/ott/solvers/linear/sinkhorn_lr.py | 3 +- .../solvers/quadratic/gromov_wasserstein.py | 3 +- src/ott/solvers/quadratic/gw_barycenter.py | 11 ++- src/ott/solvers/was_solver.py | 3 + tests/geometry/graph_test.py | 3 +- tests/solvers/linear/sinkhorn_diff_test.py | 7 +- tests/solvers/nn/neuraldual_test.py | 3 +- .../solvers/quadratic/fgw_barycenter_test.py | 79 ------------------- tests/solvers/quadratic/gw_barycenter_test.py | 64 +++++++++++++++ tests/tools/k_means_test.py | 3 +- 15 files changed, 99 insertions(+), 100 deletions(-) delete mode 100644 tests/solvers/quadratic/fgw_barycenter_test.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e33d4a6fb..d4078928e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,7 +48,7 @@ repos: - id: trailing-whitespace - id: check-case-conflict - repo: https://github.com/myint/autoflake - rev: v2.0.0 + rev: v2.0.1 hooks: - id: autoflake args: @@ -71,4 +71,4 @@ repos: rev: v3.3.1 hooks: - id: pyupgrade - args: [--py3-plus, --py37-plus, --keep-runtime-typing] + args: [--py38-plus, --keep-runtime-typing] diff --git a/src/ott/problems/linear/barycenter_problem.py b/src/ott/problems/linear/barycenter_problem.py index c2f28860f..bcfbc30cd 100644 --- a/src/ott/problems/linear/barycenter_problem.py +++ b/src/ott/problems/linear/barycenter_problem.py @@ -109,10 +109,10 @@ def segmented_y_b(self) -> Tuple[jnp.ndarray, jnp.ndarray]: return self._add_slice_for_debiased(y, b) return y, b + @staticmethod def _add_slice_for_debiased( - self, y: jnp.ndarray, b: jnp.ndarray - ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - y, b = self._y, self._b + y: jnp.ndarray, b: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray]: _, n, ndim = y.shape # (num_measures, max_measure_size, ndim) # yapf: disable y = jnp.concatenate((y, jnp.zeros((1, n, ndim))), axis=0) diff --git a/src/ott/problems/quadratic/quadratic_problem.py b/src/ott/problems/quadratic/quadratic_problem.py index 67200893f..ea2a7eb56 100644 --- a/src/ott/problems/quadratic/quadratic_problem.py +++ b/src/ott/problems/quadratic/quadratic_problem.py @@ -108,7 +108,6 @@ def __init__( ranks: Union[int, Tuple[int, ...]] = -1, tolerances: Union[float, Tuple[float, ...]] = 1e-2, ): - assert fused_penalty > 0, fused_penalty self._geom_xx = geom_xx._set_scale_cost(scale_cost) self._geom_yy = geom_yy._set_scale_cost(scale_cost) self._geom_xy = ( diff --git a/src/ott/solvers/linear/continuous_barycenter.py b/src/ott/solvers/linear/continuous_barycenter.py index 45553fcce..556192ad5 100644 --- a/src/ott/solvers/linear/continuous_barycenter.py +++ b/src/ott/solvers/linear/continuous_barycenter.py @@ -128,7 +128,8 @@ def __call__( rng: int = 0 ) -> BarycenterState: # TODO(michalk8): no reason for iterations to be outside this class - return iterations(self, bar_size, bar_prob, x_init, rng) + run_fn = jax.jit(iterations, static_argnums=1) if self.jit else iterations + return run_fn(self, bar_size, bar_prob, x_init, rng) def init_state( self, diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index 44e880f68..844ae4c9f 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -680,6 +680,7 @@ class Sinkhorn: gradients have been stopped. This is useful when carrying out first order differentiation, and is only valid (as with ``implicit_differentiation``) when the algorithm has converged with a low tolerance. + jit: Whether to jit the iteration loop. initializer: how to compute the initial potentials/scalings. kwargs_init: keyword arguments when creating the initializer. """ @@ -697,6 +698,7 @@ def __init__( parallel_dual_updates: bool = False, recenter_potentials: bool = False, use_danskin: Optional[bool] = None, + jit: bool = True, implicit_diff: Optional[implicit_lib.ImplicitDiff ] = implicit_lib.ImplicitDiff(), # noqa: E124 initializer: Union[Literal["default", "gaussian", "sorting"], @@ -711,6 +713,7 @@ def __init__( self._norm_error = norm_error self.anderson = anderson self.implicit_diff = implicit_diff + self.jit = jit if momentum is not None: self.momentum = acceleration.Momentum( @@ -767,7 +770,8 @@ def __call__( init_dual_a, init_dual_b = initializer( ot_prob, *init, lse_mode=self.lse_mode ) - return run(ot_prob, self, (init_dual_a, init_dual_b)) + run_fn = jax.jit(run) if self.jit else run + return run_fn(ot_prob, self, (init_dual_a, init_dual_b)) def lse_step( self, ot_prob: linear_problem.LinearProblem, state: SinkhornState, diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index 12ca108b7..22e542b06 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -320,7 +320,8 @@ def __call__( assert ot_prob.is_balanced, "Unbalanced case is not implemented." initializer = self.create_initializer(ot_prob) init = initializer(ot_prob, *init, key=key, **kwargs) - return run(ot_prob, self, init) + run_fn = jax.jit(run) if self.jit else run + return run_fn(ot_prob, self, init) def _lr_costs( self, diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index d859b0b49..652d2fed3 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -223,7 +223,8 @@ def __call__( initializer = self.create_initializer(prob) init = initializer(prob, epsilon=self.epsilon, key=key1, **kwargs) - out = iterations(self, prob, init, key2) + run_fn = jax.jit(iterations) if self.jit else iterations + out = run_fn(self, prob, init, key2) # TODO(lpapaxanthos): remove stop_gradient when using backprop if self.is_low_rank: linearization = prob.update_lr_linearization( diff --git a/src/ott/solvers/quadratic/gw_barycenter.py b/src/ott/solvers/quadratic/gw_barycenter.py index 4360b3f6f..d382d0ebc 100644 --- a/src/ott/solvers/quadratic/gw_barycenter.py +++ b/src/ott/solvers/quadratic/gw_barycenter.py @@ -52,6 +52,7 @@ class GromovWassersteinBarycenter(was_solver.WassersteinSolver): min_iterations: Minimum number of iterations. max_iterations: Maximum number of outermost iterations. threshold: Convergence threshold. + jit: Whether to jit the iteration loop. store_inner_errors: Whether to store the errors of the GW solver, as well as its linear solver, at each iteration for each measure. quad_solver: The GW solver. @@ -66,6 +67,7 @@ def __init__( min_iterations: int = 5, max_iterations: int = 50, threshold: float = 1e-3, + jit: bool = True, store_inner_errors: bool = False, quad_solver: Optional[gromov_wasserstein.GromovWasserstein] = None, # TODO(michalk8): maintain the API compatibility with `was_solver` @@ -79,14 +81,16 @@ def __init__( min_iterations=min_iterations, max_iterations=max_iterations, threshold=threshold, - store_inner_errors=store_inner_errors + store_inner_errors=store_inner_errors, + jit=jit, ) - self._quad_solver = quad_solver if quad_solver is None: kwargs["epsilon"] = epsilon # TODO(michalk8): store only GW errors? kwargs["store_inner_errors"] = store_inner_errors self._quad_solver = gromov_wasserstein.GromovWasserstein(**kwargs) + else: + self._quad_solver = quad_solver def __call__( self, problem: gw_barycenter.GWBarycenterProblem, bar_size: int, @@ -103,7 +107,8 @@ def __call__( The solution. """ state = self.init_state(problem, bar_size, **kwargs) - state = iterations(solver=self, problem=problem, init_state=state) + run_fn = jax.jit(iterations) if self.jit else iterations + state = run_fn(self, problem, state) return self.output_from_state(state) def init_state( diff --git a/src/ott/solvers/was_solver.py b/src/ott/solvers/was_solver.py index 38b5fa7a7..e392f7bd5 100644 --- a/src/ott/solvers/was_solver.py +++ b/src/ott/solvers/was_solver.py @@ -40,6 +40,7 @@ def __init__( min_iterations: int = 5, max_iterations: int = 50, threshold: float = 1e-3, + jit: bool = True, store_inner_errors: bool = False, **kwargs: Any, ): @@ -73,6 +74,7 @@ def __init__( self.min_iterations = min_iterations self.max_iterations = max_iterations self.threshold = threshold + self.jit = jit self.store_inner_errors = store_inner_errors self._kwargs = kwargs @@ -86,6 +88,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: "min_iterations": self.min_iterations, "max_iterations": self.max_iterations, "rank": self.rank, + "jit": self.jit, "store_inner_errors": self.store_inner_errors, **self._kwargs }) diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index 268f7bd74..4de0ebdb2 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -1,11 +1,10 @@ import time -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Literal, Optional, Tuple, Union import networkx as nx import pytest from networkx.algorithms import shortest_paths from networkx.generators import balanced_tree, random_graphs -from typing_extensions import Literal import jax import jax.experimental.sparse as jesp diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index 78ca522cc..91c9038b9 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -160,7 +160,8 @@ def test_autograd_sinkhorn( def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float: geom = pointcloud.PointCloud(x, y, epsilon=1e-1) prob = linear_problem.LinearProblem(geom, a=a, b=b) - solver = sinkhorn.Sinkhorn(lse_mode=lse_mode) + # TODO: fails with `jit=True`, investigate + solver = sinkhorn.Sinkhorn(lse_mode=lse_mode, jit=False) return solver(prob).reg_ot_cost reg_ot_and_grad = jax.jit(jax.grad(reg_ot)) @@ -277,6 +278,8 @@ def loss_fn(x: jnp.ndarray, lse_mode=lse_mode, min_iterations=min_iter, max_iterations=max_iter, + # TODO(cuturi): figure out why implicit diff breaks when `jit=True` + jit=False, implicit_diff=implicit_diff, ) out = solver(prob) @@ -287,7 +290,7 @@ def loss_fn(x: jnp.ndarray, eps = 1e-5 # perturbation magnitude # first calculation of gradient - loss_and_grad = jax.value_and_grad(loss_fn, has_aux=True) + loss_and_grad = jax.jit(jax.value_and_grad(loss_fn, has_aux=True)) (loss_value, out), grad_loss = loss_and_grad(x, y) custom_grad = jnp.sum(delta * grad_loss) diff --git a/tests/solvers/nn/neuraldual_test.py b/tests/solvers/nn/neuraldual_test.py index 2ef0e9d57..84e1c7dc8 100644 --- a/tests/solvers/nn/neuraldual_test.py +++ b/tests/solvers/nn/neuraldual_test.py @@ -11,10 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for implementation of ICNN-based Kantorovich dual by Makkuva+(2020).""" -from typing import Iterator, Sequence, Tuple +from typing import Iterator, Literal, Sequence, Tuple import pytest -from typing_extensions import Literal import jax import jax.numpy as jnp diff --git a/tests/solvers/quadratic/fgw_barycenter_test.py b/tests/solvers/quadratic/fgw_barycenter_test.py deleted file mode 100644 index 63d6e5513..000000000 --- a/tests/solvers/quadratic/fgw_barycenter_test.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Tests for Fused Gromov-Wasserstein barycenter.""" -from typing import Tuple - -import pytest - -import jax -import jax.numpy as jnp -import numpy as np - -from ott.geometry import pointcloud -from ott.problems.quadratic import gw_barycenter as gwb -from ott.solvers.quadratic import gw_barycenter as gwb_solver - - -class FGWBarycenterTest: - - @pytest.mark.fast( - "jit,fused_penalty,scale_cost", [(False, 1.5, "mean"), - (True, 3.1, "max_cost")], - only_fast=0 - ) - def test_fgw_barycenter( - self, - rng: jnp.ndarray, - jit: bool, - fused_penalty: float, - scale_cost: str, - ): - - def barycenter( - y: jnp.ndim, y_fused: jnp.ndarray, num_per_segment: Tuple[int, ...] - ) -> gwb_solver.GWBarycenterState: - prob = gwb.GWBarycenterProblem( - y=y, - y_fused=y_fused, - num_per_segment=num_per_segment, - fused_penalty=fused_penalty, - scale_cost=scale_cost, - ) - assert prob.is_fused - assert prob.fused_penalty == fused_penalty - assert not prob._y_as_costs - assert prob.max_measure_size == max(num_per_segment) - assert prob.num_measures == len(num_per_segment) - assert prob.ndim == self.ndim - assert prob.ndim_fused == self.ndim_f - - solver = gwb_solver.GromovWassersteinBarycenter( - store_inner_errors=True, epsilon=epsilon - ) - - x_init = jax.random.normal(rng, (bar_size, self.ndim_f)) - cost_init = pointcloud.PointCloud(x_init).cost_matrix - - return solver(prob, bar_size=bar_size, bar_init=(cost_init, x_init)) - - bar_size, epsilon, = 10, 1e-1 - num_per_segment = (7, 12) - - key1, *rngs = jax.random.split(rng, len(num_per_segment) + 1) - y = jnp.concatenate([ - self.random_pc(n, d=self.ndim, rng=rng).x - for n, rng in zip(num_per_segment, rngs) - ]) - rngs = jax.random.split(key1, len(num_per_segment)) - y_fused = jnp.concatenate([ - self.random_pc(n, d=self.ndim_f, rng=rng).x - for n, rng in zip(num_per_segment, rngs) - ]) - - fn = jax.jit(barycenter, static_argnums=2) if jit else barycenter - out = fn(y, y_fused, num_per_segment) - - assert out.cost.shape == (bar_size, bar_size) - assert out.x.shape == (bar_size, self.ndim_f) - np.testing.assert_array_equal(jnp.isfinite(out.cost), True) - np.testing.assert_array_equal(jnp.isfinite(out.x), True) - np.testing.assert_array_equal(jnp.isfinite(out.costs), True) - np.testing.assert_array_equal(jnp.isfinite(out.errors), True) diff --git a/tests/solvers/quadratic/gw_barycenter_test.py b/tests/solvers/quadratic/gw_barycenter_test.py index dc6bf7242..5609bf41f 100644 --- a/tests/solvers/quadratic/gw_barycenter_test.py +++ b/tests/solvers/quadratic/gw_barycenter_test.py @@ -112,3 +112,67 @@ def test_gw_barycenter( assert out_pc.cost.shape == (bar_size, bar_size) np.testing.assert_allclose(out_pc.cost, out_cost.cost, rtol=tol, atol=tol) np.testing.assert_allclose(out_pc.costs, out_cost.costs, rtol=tol, atol=tol) + + @pytest.mark.fast( + "jit,fused_penalty,scale_cost", [(False, 1.5, "mean"), + (True, 3.1, "max_cost")], + only_fast=0 + ) + def test_fgw_barycenter( + self, + rng: jnp.ndarray, + jit: bool, + fused_penalty: float, + scale_cost: str, + ): + + def barycenter( + y: jnp.ndim, y_fused: jnp.ndarray, num_per_segment: Tuple[int, ...] + ) -> gwb_solver.GWBarycenterState: + prob = gwb.GWBarycenterProblem( + y=y, + y_fused=y_fused, + num_per_segment=num_per_segment, + fused_penalty=fused_penalty, + scale_cost=scale_cost, + ) + assert prob.is_fused + assert prob.fused_penalty == fused_penalty + assert not prob._y_as_costs + assert prob.max_measure_size == max(num_per_segment) + assert prob.num_measures == len(num_per_segment) + assert prob.ndim == self.ndim + assert prob.ndim_fused == self.ndim_f + + solver = gwb_solver.GromovWassersteinBarycenter( + store_inner_errors=True, epsilon=epsilon + ) + + x_init = jax.random.normal(rng, (bar_size, self.ndim_f)) + cost_init = pointcloud.PointCloud(x_init).cost_matrix + + return solver(prob, bar_size=bar_size, bar_init=(cost_init, x_init)) + + bar_size, epsilon, = 10, 1e-1 + num_per_segment = (7, 12) + + key1, *rngs = jax.random.split(rng, len(num_per_segment) + 1) + y = jnp.concatenate([ + self.random_pc(n, d=self.ndim, rng=rng).x + for n, rng in zip(num_per_segment, rngs) + ]) + rngs = jax.random.split(key1, len(num_per_segment)) + y_fused = jnp.concatenate([ + self.random_pc(n, d=self.ndim_f, rng=rng).x + for n, rng in zip(num_per_segment, rngs) + ]) + + fn = jax.jit(barycenter, static_argnums=2) if jit else barycenter + out = fn(y, y_fused, num_per_segment) + + assert out.cost.shape == (bar_size, bar_size) + assert out.x.shape == (bar_size, self.ndim_f) + np.testing.assert_array_equal(jnp.isfinite(out.cost), True) + np.testing.assert_array_equal(jnp.isfinite(out.x), True) + np.testing.assert_array_equal(jnp.isfinite(out.costs), True) + np.testing.assert_array_equal(jnp.isfinite(out.errors), True) diff --git a/tests/tools/k_means_test.py b/tests/tools/k_means_test.py index 62c18a857..3352a3f33 100644 --- a/tests/tools/k_means_test.py +++ b/tests/tools/k_means_test.py @@ -1,9 +1,8 @@ import os import sys -from typing import Any, Optional, Tuple, Union +from typing import Any, Literal, Optional, Tuple, Union import pytest -from typing_extensions import Literal import jax import jax.numpy as jnp