Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs/progress bar #347

Merged
merged 10 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Miscellaneous
.. toctree::
:maxdepth: 1

notebooks/tracking_progress
notebooks/soft_sort
notebooks/application_biology

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks/basic_ot_between_datasets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "251c4917",
"metadata": {},
Expand All @@ -344,6 +343,7 @@
"This tutorial gave you a glimpse of the most basic features of {mod}`ott` and how they integrate with `JAX`.\n",
"{mod}`ott` implements many other functionalities, including improved execution and extensions of the basic optimal transport problem such as:\n",
"- More general cost functions in {doc}`point_clouds`,\n",
"- How to use a progress bar in {doc}`tracking_progress`,\n",
"- Regularization and acceleration of {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` solvers in {doc}`One_Sinkhorn`,\n",
"- {doc}`gromov_wasserstein`, to compare distributions defined on incomparable spaces.\n",
"- {doc}`LRSinkhorn` for faster solvers that exploit a low-rank factorization of coupling matrices,\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/notebooks/gmm_pair_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -817,9 +817,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.10.5"
}
},
"nbformat": 4,
"nbformat_minor": 1
"nbformat_minor": 4
}
351 changes: 351 additions & 0 deletions docs/tutorials/notebooks/tracking_progress.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
{
bosr marked this conversation as resolved.
Show resolved Hide resolved
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
bosr marked this conversation as resolved.
Show resolved Hide resolved
bosr marked this conversation as resolved.
Show resolved Hide resolved
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
"cells": [
{
"cell_type": "markdown",
"id": "d9ee0ff0-e502-4809-a553-caf3d17ddb7e",
"metadata": {},
"source": [
"# Tracking progress and metrics in Sinkhorn and Low-rank Sinkhorn\n",
"\n",
"This tutorial shows how to track progress and errors during iterations of Sinkhorn and Low-rank Sinkhorn algorithms.\n",
"\n",
"We use a subset of the {doc}`basic_ot_between_datasets` notebook, and use the same example."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d7ea0e9d-f518-4798-90d0-062776b4ac5d",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"if \"google.colab\" in sys.modules:\n",
" %pip install -q git+https://github.com/ott-jax/ott@main"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9dca5e59-5c90-4cf2-ae19-7043f68ebc66",
"metadata": {},
"outputs": [],
"source": [
"from tqdm import tqdm\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from ott.geometry import pointcloud\n",
"from ott.problems.linear import linear_problem\n",
"from ott.solvers.linear import sinkhorn, sinkhorn_lr\n",
"from ott.utils import default_progress_fn"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "334d6131-e52c-489b-9aa5-4a4c0d8c462c",
"metadata": {},
"outputs": [],
"source": [
"rngs = jax.random.split(jax.random.PRNGKey(0), 2)\n",
"d, n_x, n_y = 2, 7, 11\n",
"x = jax.random.normal(rngs[0], (n_x, d))\n",
"y = jax.random.normal(rngs[1], (n_y, d)) + 0.5"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d64f5e0c-9053-4364-bed3-aff2ca7c6018",
"metadata": {},
"outputs": [],
"source": [
"geom = pointcloud.PointCloud(x, y)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "03fce59f-1435-400f-a463-6576f5979260",
"metadata": {},
"outputs": [],
"source": [
"solve_fn = jax.jit(sinkhorn.solve)\n",
"ot = solve_fn(geom, a=None, b=None)"
]
},
{
"cell_type": "markdown",
"id": "94bd8c71-607c-450a-a3d9-98b91035291c",
"metadata": {},
"source": [
"By default, no progress is reported:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7e964a4c-e5b2-49ff-9ae7-75775e35d269",
"metadata": {},
"outputs": [],
"source": [
"ot = solve_fn(geom)"
]
},
{
"cell_type": "markdown",
"id": "e0fb44c5-a0d2-43ce-8b69-a82e84768dab",
"metadata": {},
"source": [
"While the Sinkhorn algorithm iterates, various metrics are updated, and you will probably want to track them or simply track progress when tackling larger problems. \n",
"\n",
"By default, the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` implementation will not report progress, but if we pass a callback function with some specific signature, Sinkhorn will call this function each time it updates its internal metrics. \n",
"\n",
"The signature of the callback functions is: `(status: Tuple[ndarray, ndarray, ndarray, NamedTuple], args: Any) -> None`.\n",
"\n",
"The arguments are:\n",
"\n",
"- status: status consisting of:\n",
" - the current iteration number\n",
" - the number of inner iterations after which the error is computed\n",
" - the total number of iterations\n",
" - the current {class}`~ott.solvers.linear.sinkhorn.SinkhornState` or {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhornState`. For technical reasons the type of this argument in the signature is simply `NamedTuple` (the common super-type).\n",
"\n",
"- args: unused, see {mod}`jax.experimental.host_callback`.\n",
"\n",
"As we show later, the same discussion applies for {class}`ott.solvers.linear.sinkhorn_lr.LRSinkhorn`."
]
},
{
"cell_type": "markdown",
"id": "cfaa93d7-728d-4995-b625-5bdd682af4de",
"metadata": {},
"source": [
"## Tracking progress of Sinkhorn iterations\n",
"\n",
"We show two alternative ways:\n",
"\n",
"- using the functional interface: `sinkhorn.solve`\n",
"- using the class interface: `sinkhorn.Sinkhorn`"
]
},
{
"cell_type": "markdown",
"id": "dd179a46-7772-4656-a837-411667e8c26c",
"metadata": {},
"source": [
"### 1. Using the functional interface\n",
"\n",
"Here, as an example, we use a basic callback function, {func}`~ott.utils.default_progress_fn`, which simply prints iteration and error to the console."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b031225b-5c6a-4a00-a567-d336a80a66c9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10 / 2000 -- 0.049124784767627716\n",
"20 / 2000 -- 0.019962385296821594\n",
"30 / 2000 -- 0.00910455733537674\n",
"40 / 2000 -- 0.004339158535003662\n",
"50 / 2000 -- 0.002111591398715973\n",
"60 / 2000 -- 0.001037590205669403\n",
"70 / 2000 -- 0.0005124583840370178\n"
]
}
],
"source": [
"solve_fn = jax.jit(sinkhorn.solve, static_argnames=[\"progress_fn\"])\n",
"ot = solve_fn(geom, a=None, b=None, progress_fn=default_progress_fn)"
]
},
{
"cell_type": "markdown",
"id": "e7b51753-15c5-47d0-b7cb-f9f4cca84f63",
"metadata": {},
"source": [
"In the above case, the functional interface leverages the class {class}`~ott.solvers.linear.sinkhorn.Sinkhorn`. By default, this object will stop after 2000 iterations if convergence hasn't been reached before, and it reports its metrics each 10 inner iterations (default value). For this basic example, convergence is reached after 7 iterations."
]
},
{
"cell_type": "markdown",
"id": "e1957d0f-7e53-4dcd-85f1-a61ddb7e99bf",
"metadata": {},
"source": [
"We can also provide any function with the signature specified above:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f7ac907f-2327-4c79-88f4-b7fdae72df75",
"metadata": {},
"outputs": [],
"source": [
"def progress_fn(status, *args):\n",
" iteration, inner_iterations, total_iter, state = status\n",
" iteration = int(iteration) + 1\n",
" inner_iterations = int(inner_iterations)\n",
" total_iter = int(total_iter)\n",
" errors = np.asarray(state.errors).ravel()\n",
"\n",
" # Avoid reporting error on each iteration,\n",
" # because errors are only computed every `inner_iterations`.\n",
" if iteration % inner_iterations == 0:\n",
" error_idx = max(0, iteration // inner_iterations - 1)\n",
" error = errors[error_idx]\n",
"\n",
" pbar.set_postfix_str(f\"error: {error:0.6e}\")\n",
" pbar.total = total_iter // inner_iterations\n",
" pbar.update()"
]
},
{
"cell_type": "markdown",
"id": "0a412dc6-fada-4ba7-a49d-b29d42f0fc8f",
"metadata": {},
"source": [
"In this case, still with the functional interface, a `tqdm` progress bar is instantiated and the iteration and errors are displayed.\n",
"\n",
"Of course, as previously, Sinkhorn will converge after only a few iterations because the problem is simple."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d3ab5bce-bb3f-49ae-827c-39548adf2f48",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 4%|███▍ | 7/200 [00:00<00:12, 15.80it/s, error: 5.124584e-04]\n"
]
}
],
"source": [
"with tqdm() as pbar:\n",
" solve_fn = jax.jit(sinkhorn.solve, static_argnames=[\"progress_fn\"])\n",
" ot = solve_fn(geom, a=None, b=None, progress_fn=progress_fn)"
]
},
{
"cell_type": "markdown",
"id": "e3842473-76b6-401f-8c05-54e7244d8c95",
"metadata": {},
"source": [
"### 2. Using the class interface"
]
},
{
"cell_type": "markdown",
"id": "a52db5bb-a5cf-4484-953b-211f82a487a2",
"metadata": {},
"source": [
"Here, we adapt the previous example, but we provide the callback function to the class initializer."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d2d70a82-8203-45e6-917b-dd7fc59dafb8",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 4%|███▍ | 7/200 [00:00<00:12, 15.69it/s, error: 5.124584e-04]\n"
]
}
],
"source": [
"prob = linear_problem.LinearProblem(geom)\n",
"\n",
"with tqdm() as pbar:\n",
" solver = sinkhorn.Sinkhorn(progress_fn=progress_fn)\n",
" ot = jax.jit(solver)(prob)"
]
},
{
"cell_type": "markdown",
"id": "b375412e-e374-445f-9330-0cf785d1965e",
"metadata": {},
"source": [
"## Tracking progress of Low-rank Sinkhorn iterations"
]
},
{
"cell_type": "markdown",
"id": "e87c1bf9-59ba-4813-90ae-28bed17676f6",
"metadata": {},
"source": [
"Low-rank Sinkhorn currently doesn't have a functional interface, so we use only the class interface."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a78ab13c-d9a2-4201-bb10-cb9bd4ac5f57",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 8%|███████▊ | 16/200 [00:00<00:09, 18.65it/s, error: 3.191826e-04]\n"
]
}
],
"source": [
"prob = linear_problem.LinearProblem(geom)\n",
"rank = 2\n",
"\n",
"with tqdm() as pbar:\n",
" solver = sinkhorn_lr.LRSinkhorn(rank, progress_fn=progress_fn)\n",
" ot = jax.jit(solver)(prob)"
]
},
{
"cell_type": "markdown",
"id": "41dc6533-a707-4f3d-bfc8-bb446fda382c",
"metadata": {},
"source": [
"That's it, we know how to track progress and errors during Sinkhorn and Low-rank Sinkhorn iterations."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ test = [
docs = [
"sphinx>=4.0",
"ipython>=7.20.0",
"sphinx-book-theme>=1",
# https://github.com/executablebooks/sphinx-book-theme/issues/711
"pydata-sphinx-theme<0.13.2",
"sphinx_autodoc_typehints>=1.12.0",
"sphinx-book-theme>=0.3.3",
"sphinx-copybutton>=0.5.1",
"sphinxcontrib-bibtex>=2.5.0",
"sphinxcontrib-spelling>=7.7.0",
Expand Down
Loading