Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tracking progress of gromov-wasserstein #351

Merged
merged 4 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/tutorials/notebooks/gromov_wasserstein.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16271,9 +16271,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.10.5"
}
},
"nbformat": 4,
"nbformat_minor": 1
"nbformat_minor": 4
}
267 changes: 207 additions & 60 deletions docs/tutorials/notebooks/tracking_progress.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,15 @@
"tags": []
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
},
"source": [
"# Tracking progress in Sinkhorn and Low-Rank Sinkhorn\n",
"# Tracking progress in OTT solvers.\n",
"\n",
"This tutorial shows how to track progress and errors during iterations of {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` and {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn` algorithms.\n",
"This tutorial shows how to track progress and errors during iterations of the following solvers:\n",
"\n",
"We use the same basic example as in the {doc}`basic_ot_between_datasets` notebook."
]
},
{
"cell_type": "markdown",
"id": "46ebff53-e738-44ea-9cc7-173d560f6a75",
"metadata": {},
"source": [
"## Without tracking (default behavior)\n",
"- {class}`~ott.solvers.linear.sinkhorn.Sinkhorn`\n",
"- {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn`\n",
"- {class}`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein`.\n",
"\n",
"Let's recap the basic example we use in this notebook:"
"We'll see that we simply need to provide a callback function to the solvers."
]
},
{
Expand Down Expand Up @@ -56,7 +50,55 @@
"from ott import utils\n",
"from ott.geometry import pointcloud\n",
"from ott.problems.linear import linear_problem\n",
"from ott.solvers.linear import sinkhorn, sinkhorn_lr"
"from ott.problems.quadratic import quadratic_problem\n",
"from ott.solvers.linear import sinkhorn, sinkhorn_lr\n",
"from ott.solvers.quadratic import gromov_wasserstein"
]
},
{
"cell_type": "markdown",
"id": "8b551db9-f308-4540-a57f-0d4c74d48feb",
"metadata": {},
"source": [
"## How to track progress\n",
"\n",
"{mod}`ott` offers a simple and flexible mechanism that works well with {func}`~jax.jit`, and applies to both the functional interface and the class interface.\n",
"\n",
"The solvers {class}`~ott.solvers.linear.sinkhorn.Sinkhorn`, low-rank Sinkhorn {class}`ott.solvers.linear.sinkhorn_lr.LRSinkhorn`, and {class}`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein` only report progress if we pass a callback function (with some specific signature) to its initializer. This callback is called at each iteration."
]
},
{
"cell_type": "markdown",
"id": "f74f472b-12a3-49ad-9d81-8431674f512e",
"metadata": {},
"source": [
"### Callback function signature\n",
"\n",
"The required signature of the callback function is: `(status: Tuple[ndarray, ndarray, ndarray, NamedTuple], args: Any) -> None`.\n",
"\n",
"The arguments are:\n",
"\n",
"- status: a tuple of:\n",
" - the current iteration index (0-based)\n",
" - the number of inner iterations after which the error is computed\n",
" - the total number of iterations\n",
" - the current solver state class: {class}`~ott.solvers.linear.sinkhorn.SinkhornState` or {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhornState`, or {class}`~ott.solvers.quadratic.gromov_wasserstein.GWState`. For technical reasons, the type of this argument in the signature is simply {class}`~typing.NamedTuple` (the common super-type).\n",
"\n",
"- args: unused, see {mod}`jax.experimental.host_callback`.\n",
"\n",
"Note:\n",
"\n",
"- Above, the {class}`~numpy.ndarray` types are passed by the underlying mechanism {mod}`~jax.experimental.host_callback`, but their arguments simply contain one integer value and can be safely cast."
]
},
{
"cell_type": "markdown",
"id": "46ebff53-e738-44ea-9cc7-173d560f6a75",
"metadata": {},
"source": [
"## Linear problem without tracking (default behavior)\n",
"\n",
"Let's start with the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` solver and setup a basic linear problem:"
]
},
{
Expand Down Expand Up @@ -118,47 +160,11 @@
"id": "e0fb44c5-a0d2-43ce-8b69-a82e84768dab",
"metadata": {},
"source": [
"Obviously, not tracking progress (the default) is fine.However when tackling larger problems, we will probably want to track the various metrics that the Sinkhorn algorithm updates at each iteration.\n",
"For cases as simple as this one, it fine to not track progress (the default behavior). However when tackling larger problems, we will probably want to track the various metrics that the Sinkhorn algorithm updates at each iteration.\n",
"\n",
"In the next sections, we show how to track progress for that same example."
]
},
{
"cell_type": "markdown",
"id": "8b551db9-f308-4540-a57f-0d4c74d48feb",
"metadata": {},
"source": [
"## How to track progress\n",
"\n",
"{mod}`ott` offers a simple and flexible mechanism that works well with {func}`~jax.jit`, and applies to both the functional interface and the class interface.\n",
"\n",
"The {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` and low-rank Sinkhorn {class}`ott.solvers.linear.sinkhorn_lr.LRSinkhorn` solver implementations only report progress if we pass a callback function (with some specific signature) to its initializer. This callback is called at each iteration."
]
},
{
"cell_type": "markdown",
"id": "f74f472b-12a3-49ad-9d81-8431674f512e",
"metadata": {},
"source": [
"### Callback function signature\n",
"\n",
"The required signature of the callback function is: `(status: Tuple[ndarray, ndarray, ndarray, NamedTuple], args: Any) -> None`.\n",
"\n",
"The arguments are:\n",
"\n",
"- status: a tuple of:\n",
" - the current iteration index (0-based)\n",
" - the number of inner iterations after which the error is computed\n",
" - the total number of iterations\n",
" - the current {class}`~ott.solvers.linear.sinkhorn.SinkhornState` or {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhornState`. For technical reasons, the type of this argument in the signature is simply {class}`~typing.NamedTuple` (the common super-type).\n",
"\n",
"- args: unused, see {mod}`jax.experimental.host_callback`.\n",
"\n",
"Note:\n",
"\n",
"- Above, the {class}`~numpy.ndarray` types are required by the underlying mechanism {mod}`~jax.experimental.host_callback`, but their arguments simply contain one integer value and can be safely cast."
]
},
{
"cell_type": "markdown",
"id": "cfaa93d7-728d-4995-b625-5bdd682af4de",
Expand All @@ -185,9 +191,9 @@
"#### With the basic callback function\n",
"\n",
"\n",
"{mod}`ott` provides a basic callback function: {func}`~ott.utils.default_progress_fn` that we can use directly (it simply prints iteration and error to the console). It can also serve as a basis for customizations.\n",
"{mod}`ott` provides a basic callback function: {func}`~ott.utils.default_progress_fn` that we can use directly: it simply prints iteration and error to the console. It can also serve as a basis for customizations.\n",
"\n",
"Here, we simply pass that basic callback as a static argument:"
"Let's simply pass that basic callback as a static argument:"
]
},
{
Expand Down Expand Up @@ -235,7 +241,7 @@
"source": [
"#### With `tqdm`\n",
"\n",
"Here, we first define a function that updates a `tqdm` progress bar."
"Let's first define a function that updates a `tqdm` progress bar."
]
},
{
Expand Down Expand Up @@ -268,7 +274,7 @@
"id": "0a412dc6-fada-4ba7-a49d-b29d42f0fc8f",
"metadata": {},
"source": [
"and then use it as previously, but in the context of an existing `tqdm` progress bar:"
"and let's use it in the context of an existing `tqdm` progress bar:"
]
},
{
Expand All @@ -281,7 +287,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 4%|███▍ | 7/200 [00:00<00:10, 18.58it/s, error: 5.124584e-04]\n"
" 4%|███▍ | 7/200 [00:00<00:12, 15.94it/s, error: 5.124584e-04]\n"
]
}
],
Expand Down Expand Up @@ -324,7 +330,7 @@
"id": "a52db5bb-a5cf-4484-953b-211f82a487a2",
"metadata": {},
"source": [
"Here, we provide the callback function to the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` class initializer and display progress with `tqdm`:"
"Let's reiterate, but this time we provide the callback function to the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` class initializer and display progress with `tqdm`:"
]
},
{
Expand All @@ -337,7 +343,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 4%|███▍ | 7/200 [00:00<00:10, 18.95it/s, error: 5.124584e-04]\n"
" 4%|███▍ | 7/200 [00:00<00:11, 16.10it/s, error: 5.124584e-04]\n"
]
}
],
Expand Down Expand Up @@ -376,7 +382,7 @@
"source": [
"### Tracking progress of Low-Rank Sinkhorn iterations via the class interface\n",
"\n",
"We can track progress of the Low-rank Sinkhorn solver, however because it currently doesn't have a functional interface, we can only use the class interface {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn`:"
"We can also track progress of the Low-rank Sinkhorn solver, however because it currently doesn't have a functional interface, we can only use the class interface {class}`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn`:"
]
},
{
Expand All @@ -389,7 +395,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 8%|███████▊ | 16/200 [00:00<00:07, 23.01it/s, error: 3.191826e-04]\n"
" 8%|███████▊ | 16/200 [00:00<00:09, 19.80it/s, error: 3.191826e-04]\n"
]
}
],
Expand Down Expand Up @@ -420,12 +426,153 @@
"print(f\"has converged: {ot.converged}, cost: {ot.reg_ot_cost}\")"
]
},
{
"cell_type": "markdown",
"id": "1c7a4594-2388-4c07-a670-26cc4f7670fa",
"metadata": {},
"source": [
"## Tracking progress of Gromov-Wasserstein iterations\n",
"\n",
"We can track progress in the same way as with the Sinkhorn solvers. Let's define a simple quadratic problem (the same as in the {doc}`gromov_wasserstein` notebook):"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "942ca588-4c3b-4cc6-b2d8-642992427c51",
"metadata": {},
"outputs": [],
"source": [
"# Samples spiral\n",
"def sample_spiral(\n",
" n, min_radius, max_radius, key, min_angle=0, max_angle=10, noise=1.0\n",
"):\n",
" radius = jnp.linspace(min_radius, max_radius, n)\n",
" angles = jnp.linspace(min_angle, max_angle, n)\n",
" data = []\n",
" noise = jax.random.normal(key, (2, n)) * noise\n",
" for i in range(n):\n",
" x = (radius[i] + noise[0, i]) * jnp.cos(angles[i])\n",
" y = (radius[i] + noise[1, i]) * jnp.sin(angles[i])\n",
" data.append([x, y])\n",
" data = jnp.array(data)\n",
" return data\n",
"\n",
"\n",
"# Samples Swiss roll\n",
"def sample_swiss_roll(\n",
" n, min_radius, max_radius, length, key, min_angle=0, max_angle=10, noise=0.1\n",
"):\n",
" spiral = sample_spiral(\n",
" n, min_radius, max_radius, key[0], min_angle, max_angle, noise\n",
" )\n",
" third_axis = jax.random.uniform(key[1], (n, 1)) * length\n",
" swiss_roll = jnp.hstack((spiral[:, 0:1], third_axis, spiral[:, 1:]))\n",
" return swiss_roll\n",
"\n",
"\n",
"# Data parameters\n",
"n_spiral = 400\n",
"n_swiss_roll = 500\n",
"length = 10\n",
"min_radius = 3\n",
"max_radius = 10\n",
"noise = 0.8\n",
"min_angle = 0\n",
"max_angle = 9\n",
"angle_shift = 3\n",
"\n",
"# Seed\n",
"seed = 14\n",
"key = jax.random.PRNGKey(seed)\n",
"key, *subkey = jax.random.split(key, 4)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "1abae6d0-588b-4ee4-afd9-98ae0ce2d871",
"metadata": {},
"outputs": [],
"source": [
"spiral = sample_spiral(\n",
" n_spiral,\n",
" min_radius,\n",
" max_radius,\n",
" key=subkey[0],\n",
" min_angle=min_angle + angle_shift,\n",
" max_angle=max_angle + angle_shift,\n",
" noise=noise,\n",
")\n",
"swiss_roll = sample_swiss_roll(\n",
" n_swiss_roll,\n",
" min_radius,\n",
" max_radius,\n",
" key=subkey[1:],\n",
" length=length,\n",
" min_angle=min_angle,\n",
" max_angle=max_angle,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "d77e9d5e-3203-4598-ba38-338a83790919",
"metadata": {},
"source": [
"and let's track progress while the solver iterates:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "2918b2b8-a2af-421b-9b73-27652858763d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 / 20 -- -1.0\n",
"2 / 20 -- 0.13043604791164398\n",
"3 / 20 -- 0.08981532603502274\n",
"4 / 20 -- 0.06759563088417053\n",
"5 / 20 -- 0.05465726554393768\n",
"5 outer iterations were needed.\n",
"The last Sinkhorn iteration has converged: True\n",
"The outer loop of Gromov Wasserstein has converged: True\n",
"The final regularized GW cost is: 1183.608\n"
]
}
],
"source": [
"# apply Gromov-Wasserstein\n",
"geom_xx = pointcloud.PointCloud(x=spiral, y=spiral)\n",
"geom_yy = pointcloud.PointCloud(x=swiss_roll, y=swiss_roll)\n",
"prob = quadratic_problem.QuadraticProblem(geom_xx, geom_yy)\n",
"\n",
"solver = gromov_wasserstein.GromovWasserstein(\n",
" epsilon=100.0,\n",
" max_iterations=20,\n",
" store_inner_errors=True, # needed for reporting errors\n",
" progress_fn=utils.default_progress_fn, # callback function\n",
")\n",
"out = solver(prob)\n",
"\n",
"n_outer_iterations = jnp.sum(out.costs != -1)\n",
"has_converged = bool(out.linear_convergence[n_outer_iterations - 1])\n",
"print(f\"{n_outer_iterations} outer iterations were needed.\")\n",
"print(f\"The last Sinkhorn iteration has converged: {has_converged}\")\n",
"print(f\"The outer loop of Gromov Wasserstein has converged: {out.converged}\")\n",
"print(f\"The final regularized GW cost is: {out.reg_gw_cost:.3f}\")"
]
},
{
"cell_type": "markdown",
"id": "41dc6533-a707-4f3d-bfc8-bb446fda382c",
"metadata": {},
"source": [
"That's it, this is how to track progress and errors during Sinkhorn and Low-rank Sinkhorn iterations."
"That's it, this is how to track progress of Sinkhorn, Low-rank Sinkhorn, and Gromov-Wasserstein solvers."
]
},
{
Expand Down
Loading