diff --git a/docs/solvers/linear.rst b/docs/solvers/linear.rst index 0996b328b..a9ecd8c9b 100644 --- a/docs/solvers/linear.rst +++ b/docs/solvers/linear.rst @@ -15,6 +15,7 @@ Sinkhorn Solvers sinkhorn.SinkhornState sinkhorn.SinkhornOutput sinkhorn_lr.LRSinkhorn + sinkhorn_lr.LRSinkhornState sinkhorn_lr.LRSinkhornOutput Barycenter Solvers diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index afc234ac8..a4b0de323 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -34,6 +34,7 @@ Miscellaneous .. toctree:: :maxdepth: 1 + notebooks/tracking_progress notebooks/soft_sort notebooks/application_biology diff --git a/docs/tutorials/notebooks/basic_ot_between_datasets.ipynb b/docs/tutorials/notebooks/basic_ot_between_datasets.ipynb index 9298c1dd6..6abcac121 100644 --- a/docs/tutorials/notebooks/basic_ot_between_datasets.ipynb +++ b/docs/tutorials/notebooks/basic_ot_between_datasets.ipynb @@ -334,7 +334,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "251c4917", "metadata": {}, @@ -344,6 +343,7 @@ "This tutorial gave you a glimpse of the most basic features of {mod}`ott` and how they integrate with `JAX`.\n", "{mod}`ott` implements many other functionalities, including improved execution and extensions of the basic optimal transport problem such as:\n", "- More general cost functions in {doc}`point_clouds`,\n", + "- How to use a progress bar in {doc}`tracking_progress`,\n", "- Regularization and acceleration of {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` solvers in {doc}`One_Sinkhorn`,\n", "- {doc}`gromov_wasserstein`, to compare distributions defined on incomparable spaces.\n", "- {doc}`LRSinkhorn` for faster solvers that exploit a low-rank factorization of coupling matrices,\n", diff --git a/docs/tutorials/notebooks/gmm_pair_demo.ipynb b/docs/tutorials/notebooks/gmm_pair_demo.ipynb index 6e47b69ba..0a2d68c54 100644 --- a/docs/tutorials/notebooks/gmm_pair_demo.ipynb +++ b/docs/tutorials/notebooks/gmm_pair_demo.ipynb @@ -817,9 +817,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.10.5" } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 4 } diff --git a/docs/tutorials/notebooks/tracking_progress.ipynb b/docs/tutorials/notebooks/tracking_progress.ipynb new file mode 100644 index 000000000..5985e4910 --- /dev/null +++ b/docs/tutorials/notebooks/tracking_progress.ipynb @@ -0,0 +1,461 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d9ee0ff0-e502-4809-a553-caf3d17ddb7e", + "metadata": { + "tags": [] + }, + "source": [ + "# Tracking progress in Sinkhorn and Low-Rank Sinkhorn\n", + "\n", + "This tutorial shows how to track progress and errors during iterations of {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` and {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn` algorithms.\n", + "\n", + "We use the same basic example as in the {doc}`basic_ot_between_datasets` notebook." + ] + }, + { + "cell_type": "markdown", + "id": "46ebff53-e738-44ea-9cc7-173d560f6a75", + "metadata": {}, + "source": [ + "## Without tracking (default behavior)\n", + "\n", + "Let's recap the basic example we use in this notebook:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d7ea0e9d-f518-4798-90d0-062776b4ac5d", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "if \"google.colab\" in sys.modules:\n", + " %pip install -q git+https://github.com/ott-jax/ott@main\n", + " %pip install -q tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9dca5e59-5c90-4cf2-ae19-7043f68ebc66", + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from ott import utils\n", + "from ott.geometry import pointcloud\n", + "from ott.problems.linear import linear_problem\n", + "from ott.solvers.linear import sinkhorn, sinkhorn_lr" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "334d6131-e52c-489b-9aa5-4a4c0d8c462c", + "metadata": {}, + "outputs": [], + "source": [ + "rngs = jax.random.split(jax.random.PRNGKey(0), 2)\n", + "d, n_x, n_y = 2, 7, 11\n", + "x = jax.random.normal(rngs[0], (n_x, d))\n", + "y = jax.random.normal(rngs[1], (n_y, d)) + 0.5" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d64f5e0c-9053-4364-bed3-aff2ca7c6018", + "metadata": {}, + "outputs": [], + "source": [ + "geom = pointcloud.PointCloud(x, y)" + ] + }, + { + "cell_type": "markdown", + "id": "a7583315-7340-4d0f-95b5-3a2c0398bce3", + "metadata": {}, + "source": [ + "This problem is very simple, so the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` solver converges after only 7 iterations. The solver would otherwise keep iterating for 200 steps (default value)." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "03fce59f-1435-400f-a463-6576f5979260", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "has converged: True, #iters: 7, cost: 1.2429015636444092\n" + ] + } + ], + "source": [ + "solve_fn = jax.jit(sinkhorn.solve)\n", + "ot = solve_fn(geom)\n", + "\n", + "print(\n", + " f\"has converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e0fb44c5-a0d2-43ce-8b69-a82e84768dab", + "metadata": {}, + "source": [ + "Obviously, not tracking progress (the default) is fine.However when tackling larger problems, we will probably want to track the various metrics that the Sinkhorn algorithm updates at each iteration.\n", + "\n", + "In the next sections, we show how to track progress for that same example." + ] + }, + { + "cell_type": "markdown", + "id": "8b551db9-f308-4540-a57f-0d4c74d48feb", + "metadata": {}, + "source": [ + "## How to track progress\n", + "\n", + "{mod}`ott` offers a simple and flexible mechanism that works well with {func}`~jax.jit`, and applies to both the functional interface and the class interface.\n", + "\n", + "The {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` and low-rank Sinkhorn {class}`ott.solvers.linear.sinkhorn_lr.LRSinkhorn` solver implementations only report progress if we pass a callback function (with some specific signature) to its initializer. This callback is called at each iteration." + ] + }, + { + "cell_type": "markdown", + "id": "f74f472b-12a3-49ad-9d81-8431674f512e", + "metadata": {}, + "source": [ + "### Callback function signature\n", + "\n", + "The required signature of the callback function is: `(status: Tuple[ndarray, ndarray, ndarray, NamedTuple], args: Any) -> None`.\n", + "\n", + "The arguments are:\n", + "\n", + "- status: a tuple of:\n", + " - the current iteration index (0-based)\n", + " - the number of inner iterations after which the error is computed\n", + " - the total number of iterations\n", + " - the current {class}`~ott.solvers.linear.sinkhorn.SinkhornState` or {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhornState`. For technical reasons, the type of this argument in the signature is simply {class}`~typing.NamedTuple` (the common super-type).\n", + "\n", + "- args: unused, see {mod}`jax.experimental.host_callback`.\n", + "\n", + "Note:\n", + "\n", + "- Above, the {class}`~numpy.ndarray` types are required by the underlying mechanism {mod}`~jax.experimental.host_callback`, but their arguments simply contain one integer value and can be safely cast." + ] + }, + { + "cell_type": "markdown", + "id": "cfaa93d7-728d-4995-b625-5bdd682af4de", + "metadata": {}, + "source": [ + "## Examples\n", + "\n", + "Here are a few examples of how to track progress for Sinkhorn and low-rank Sinkhorn." + ] + }, + { + "cell_type": "markdown", + "id": "dd179a46-7772-4656-a837-411667e8c26c", + "metadata": {}, + "source": [ + "### Tracking progress for Sinkhorn via the functional interface" + ] + }, + { + "cell_type": "markdown", + "id": "6732bf21-a0c3-47eb-bb5d-0111724340aa", + "metadata": {}, + "source": [ + "#### With the basic callback function\n", + "\n", + "\n", + "{mod}`ott` provides a basic callback function: {func}`~ott.utils.default_progress_fn` that we can use directly (it simply prints iteration and error to the console). It can also serve as a basis for customizations.\n", + "\n", + "Here, we simply pass that basic callback as a static argument:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b031225b-5c6a-4a00-a567-d336a80a66c9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 / 2000 -- 0.049124784767627716\n", + "20 / 2000 -- 0.019962385296821594\n", + "30 / 2000 -- 0.00910455733537674\n", + "40 / 2000 -- 0.004339158535003662\n", + "50 / 2000 -- 0.002111591398715973\n", + "60 / 2000 -- 0.001037590205669403\n", + "70 / 2000 -- 0.0005124583840370178\n", + "has converged: True, #iters: 7, cost: 1.2429015636444092\n" + ] + } + ], + "source": [ + "solve_fn = jax.jit(sinkhorn.solve, static_argnames=[\"progress_fn\"])\n", + "ot = solve_fn(geom, a=None, b=None, progress_fn=utils.default_progress_fn)\n", + "\n", + "print(\n", + " f\"has converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e7b51753-15c5-47d0-b7cb-f9f4cca84f63", + "metadata": {}, + "source": [ + "This reveals that the solver reports its metrics each 10 _inner_ iterations (default value)." + ] + }, + { + "cell_type": "markdown", + "id": "e1957d0f-7e53-4dcd-85f1-a61ddb7e99bf", + "metadata": {}, + "source": [ + "#### With `tqdm`\n", + "\n", + "Here, we first define a function that updates a `tqdm` progress bar." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f7ac907f-2327-4c79-88f4-b7fdae72df75", + "metadata": {}, + "outputs": [], + "source": [ + "def progress_fn(status, *args):\n", + " iteration, inner_iterations, total_iter, state = status\n", + " iteration = int(iteration) + 1 # from [0;n-1] to [1;n]\n", + " inner_iterations = int(inner_iterations)\n", + " total_iter = int(total_iter)\n", + " errors = np.asarray(state.errors).ravel()\n", + "\n", + " # Avoid reporting error on each iteration,\n", + " # because errors are only computed every `inner_iterations`.\n", + " if iteration % inner_iterations == 0:\n", + " error_idx = max(0, iteration // inner_iterations - 1)\n", + " error = errors[error_idx]\n", + "\n", + " pbar.set_postfix_str(f\"error: {error:0.6e}\")\n", + " pbar.total = total_iter // inner_iterations\n", + " pbar.update()" + ] + }, + { + "cell_type": "markdown", + "id": "0a412dc6-fada-4ba7-a49d-b29d42f0fc8f", + "metadata": {}, + "source": [ + "and then use it as previously, but in the context of an existing `tqdm` progress bar:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d3ab5bce-bb3f-49ae-827c-39548adf2f48", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 4%|███▍ | 7/200 [00:00<00:10, 18.58it/s, error: 5.124584e-04]\n" + ] + } + ], + "source": [ + "with tqdm() as pbar:\n", + " solve_fn = jax.jit(sinkhorn.solve, static_argnames=[\"progress_fn\"])\n", + " ot = solve_fn(geom, a=None, b=None, progress_fn=progress_fn)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "dee13ce4-344c-47e2-ad5e-599e14e4dce1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "has converged: True, #iters: 7, cost: 1.2429015636444092\n" + ] + } + ], + "source": [ + "print(\n", + " f\"has converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e3842473-76b6-401f-8c05-54e7244d8c95", + "metadata": {}, + "source": [ + "### Tracking progress for Sinkhorn via the class interface" + ] + }, + { + "cell_type": "markdown", + "id": "a52db5bb-a5cf-4484-953b-211f82a487a2", + "metadata": {}, + "source": [ + "Here, we provide the callback function to the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` class initializer and display progress with `tqdm`:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "d2d70a82-8203-45e6-917b-dd7fc59dafb8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 4%|███▍ | 7/200 [00:00<00:10, 18.95it/s, error: 5.124584e-04]\n" + ] + } + ], + "source": [ + "prob = linear_problem.LinearProblem(geom)\n", + "\n", + "with tqdm() as pbar:\n", + " solver = sinkhorn.Sinkhorn(progress_fn=progress_fn)\n", + " ot = jax.jit(solver)(prob)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "e020a49f-e9f9-4d3e-a5d5-98b08c7d243d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "has converged: True, #iters: 7, cost: 1.2429015636444092\n" + ] + } + ], + "source": [ + "print(\n", + " f\"has converged: {ot.converged}, #iters: {ot.n_iters}, cost: {ot.reg_ot_cost}\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b375412e-e374-445f-9330-0cf785d1965e", + "metadata": {}, + "source": [ + "### Tracking progress of Low-Rank Sinkhorn iterations via the class interface\n", + "\n", + "We can track progress of the Low-rank Sinkhorn solver, however because it currently doesn't have a functional interface, we can only use the class interface {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn`:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "a78ab13c-d9a2-4201-bb10-cb9bd4ac5f57", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 8%|███████▊ | 16/200 [00:00<00:07, 23.01it/s, error: 3.191826e-04]\n" + ] + } + ], + "source": [ + "prob = linear_problem.LinearProblem(geom)\n", + "rank = 2\n", + "\n", + "with tqdm() as pbar:\n", + " solver = sinkhorn_lr.LRSinkhorn(rank, progress_fn=progress_fn)\n", + " ot = jax.jit(solver)(prob)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "10563fb7-e982-4607-9f48-c7ac73516abc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "has converged: True, cost: 1.7340877056121826\n" + ] + } + ], + "source": [ + "print(f\"has converged: {ot.converged}, cost: {ot.reg_ot_cost}\")" + ] + }, + { + "cell_type": "markdown", + "id": "41dc6533-a707-4f3d-bfc8-bb446fda382c", + "metadata": {}, + "source": [ + "That's it, this is how to track progress and errors during Sinkhorn and Low-rank Sinkhorn iterations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a35db13-4672-4185-aa9d-7f4600fa91c7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index 3996f01db..2e0eae5b9 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -905,7 +905,14 @@ def one_iteration( ot_prob, ) errors = state.errors.at[iteration // self.inner_iterations, :].set(err) - return state.set(errors=errors) + state = state.set(errors=errors) + + if self.progress_fn is not None: + host_callback.id_tap( + self.progress_fn, + (iteration, self.inner_iterations, self.max_iterations, state) + ) + return state def _converged(self, state: SinkhornState, iteration: int) -> bool: err = state.errors[iteration // self.inner_iterations - 1, 0] @@ -1049,13 +1056,7 @@ def body_fn( state: SinkhornState, compute_error: bool ) -> SinkhornState: ot_prob, solver = const - state = solver.one_iteration(ot_prob, state, iteration, compute_error) - if solver.progress_fn is not None: - host_callback.id_tap( - solver.progress_fn, - (iteration, solver.inner_iterations, solver.max_iterations, state) - ) - return state + return solver.one_iteration(ot_prob, state, iteration, compute_error) # Run the Sinkhorn loop. Choose either a standard fixpoint_iter loop if # differentiation is implicit, otherwise switch to the backprop friendly diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index c3f13f53d..b763bd3eb 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import ( Any, + Callable, Literal, Mapping, NamedTuple, @@ -25,6 +26,8 @@ import jax import jax.numpy as jnp import jax.scipy as jsp +import numpy as np +from jax.experimental import host_callback from ott.geometry import geometry, low_rank, pointcloud from ott.initializers.linear import initializers_lr as init_lib @@ -35,6 +38,9 @@ __all__ = ["LRSinkhorn", "LRSinkhornOutput"] +ProgressCallbackFn_t = Callable[ + [Tuple[np.ndarray, np.ndarray, np.ndarray, "LRSinkhornState"]], None] + class LRSinkhornState(NamedTuple): """State of the Low Rank Sinkhorn algorithm.""" @@ -268,6 +274,10 @@ class LRSinkhorn(sinkhorn.Sinkhorn): input parameters. Only `True` handled at this moment. implicit_diff: Whether to use implicit differentiation. Currently, only ``implicit_diff = False`` is implemented. + progress_fn: callback function which gets called during the Sinkhorn + iterations, so the user can display the error at each iteration, + e.g., using a progress bar. See :func:`~ott.utils.default_progress_fn` + for a basic implementation. kwargs_dys: Keyword arguments passed to :meth:`dykstra_update`. kwargs_init: Keyword arguments for :class:`~ott.initializers.linear.initializers_lr.LRInitializer`. @@ -290,6 +300,7 @@ def __init__( implicit_diff: bool = False, kwargs_dys: Optional[Mapping[str, Any]] = None, kwargs_init: Optional[Mapping[str, Any]] = None, + progress_fn: Optional[ProgressCallbackFn_t] = None, **kwargs: Any, ): assert lse_mode, "Kernel mode not yet implemented." @@ -306,6 +317,7 @@ def __init__( self.gamma_rescale = gamma_rescale self.epsilon = epsilon self.initializer = initializer + self.progress_fn = progress_fn # can be `None` self.kwargs_dys = {} if kwargs_dys is None else kwargs_dys self.kwargs_init = {} if kwargs_init is None else kwargs_init @@ -550,12 +562,20 @@ def one_iteration( ) ) - return state.set( + state = state.set( costs=state.costs.at[it].set(cost), errors=state.errors.at[it].set(error), crossed_threshold=crossed_threshold, ) + if self.progress_fn is not None: + host_callback.id_tap( + self.progress_fn, + (iteration, self.inner_iterations, self.max_iterations, state) + ) + + return state + @property def norm_error(self) -> Tuple[int]: # noqa: D102 return self._norm_error, diff --git a/src/ott/utils.py b/src/ott/utils.py index 4d2a10766..e0dd355be 100644 --- a/src/ott/utils.py +++ b/src/ott/utils.py @@ -70,7 +70,7 @@ def default_progress_fn( """Callback function that reports progress of :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` by printing to the console. - It updates the progress bar only when the error is computed, that is every + It prints the progress only when the error is computed, that is every :attr:`~ott.solvers.linear.sinkhorn.Sinkhorn.inner_iterations`. Note: @@ -107,19 +107,19 @@ def default_progress_fn( def progress_fn(status, *args): iteration, inner_iterations, total_iter, state = status - iteration = int(iteration) + iteration = int(iteration) + 1 inner_iterations = int(inner_iterations) total_iter = int(total_iter) errors = np.asarray(state.errors).ravel() # Avoid reporting error on each iteration, # because errors are only computed every `inner_iterations`. - if (iteration + 1) % inner_iterations == 0: - error_idx = max((iteration + 1) // inner_iterations - 1, 0) + if iteration % inner_iterations == 0: + error_idx = max(0, iteration // inner_iterations - 1) error = errors[error_idx] pbar.set_postfix_str(f"error: {error:0.6e}") - pbar.total = total_iter + pbar.total = total_iter // inner_iterations pbar.update() prob = linear_problem.LinearProblem(...) @@ -130,15 +130,15 @@ def progress_fn(status, *args): """ # noqa: D205 # Convert arguments. iteration, inner_iterations, total_iter, state = status - iteration = int(iteration) + iteration = int(iteration) + 1 inner_iterations = int(inner_iterations) total_iter = int(total_iter) errors = np.array(state.errors).ravel() # Avoid reporting error on each iteration, # because errors are only computed every `inner_iterations`. - if (iteration + 1) % inner_iterations == 0: - error_idx = max((iteration + 1) // inner_iterations - 1, 0) + if iteration % inner_iterations == 0: + error_idx = max(0, iteration // inner_iterations - 1) error = errors[error_idx] print(f"{iteration} / {total_iter} -- {error}") # noqa: T201 diff --git a/tests/solvers/linear/sinkhorn_lr_test.py b/tests/solvers/linear/sinkhorn_lr_test.py index 7cda2bb62..d7a093d2d 100644 --- a/tests/solvers/linear/sinkhorn_lr_test.py +++ b/tests/solvers/linear/sinkhorn_lr_test.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Tuple + import jax import jax.numpy as jnp import numpy as np @@ -146,3 +148,55 @@ def test_output_apply_batch_size(self, axis: int): np.testing.assert_allclose( pred, jnp.stack([gt] * n_stack), rtol=1e-6, atol=1e-6 ) + + @pytest.mark.fast.with_args("num_iterations", [30, 60]) + def test_callback_fn(self, num_iterations: int): + """Check that the callback function is actually called.""" + + def progress_fn( + status: Tuple[np.ndarray, np.ndarray, np.ndarray, + sinkhorn_lr.LRSinkhornState], *args: Any + ) -> None: + # Convert arguments. + iteration, inner_iterations, total_iter, state = status + iteration = int(iteration) + inner_iterations = int(inner_iterations) + total_iter = int(total_iter) + errors = np.array(state.errors).ravel() + + # Avoid reporting error on each iteration, + # because errors are only computed every `inner_iterations`. + if (iteration + 1) % inner_iterations == 0: + error_idx = max((iteration + 1) // inner_iterations - 1, 0) + error = errors[error_idx] + + traced_values["iters"].append(iteration) + traced_values["error"].append(error) + traced_values["total"].append(total_iter) + + traced_values = {"iters": [], "error": [], "total": []} + + geom = pointcloud.PointCloud(self.x, self.y, epsilon=1e-3) + lin_prob = linear_problem.LinearProblem(geom, a=self.a, b=self.b) + + rank = 2 + inner_iterations = 10 + + _ = sinkhorn_lr.LRSinkhorn( + rank, + progress_fn=progress_fn, + max_iterations=num_iterations, + inner_iterations=inner_iterations + )( + lin_prob + ) + + # check that the function is called on the 10th iteration (iter #9), the + # 20th iteration (iter #19). + assert traced_values["iters"] == [9, 19] + + # check that error decreases + np.testing.assert_array_equal(np.diff(traced_values["error"]) < 0, True) + + # check that max iterations is provided each time: [30, 30] + assert traced_values["total"] == [num_iterations] * 2