Skip to content

Commit

Permalink
Fix n_iters (#437)
Browse files Browse the repository at this point in the history
* Start cleaning `n_iters`

* Remove dead code

* Update SD

* Update GW tutorial

* Remove unnecessary hierarchy in tutorials

* Remove `n_iters` property from SD

* Fix not passing `inner_iterations`

* Skip testing `flax` on 3.8

* Fix SD test

* Use `.toarray()`

* Update `FenchelConjugateLBFGS` for `jaxopt>=0.8`

* Fix missing `importorskip`

* Split `MetaInitializer` test

* [ci skip] Fix typos in bary docs

* Skip tests that need `optax`

* Remove relative pip install from MG notebook
  • Loading branch information
michalk8 authored Sep 13, 2023
1 parent ceb8320 commit df96dae
Show file tree
Hide file tree
Showing 42 changed files with 16,498 additions and 16,428 deletions.
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Packages
:maxdepth: 1
:caption: Examples

Getting Started <tutorials/notebooks/basic_ot_between_datasets>
Getting Started <tutorials/basic_ot_between_datasets>
tutorials/index

.. toctree::
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -22,7 +21,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -38,7 +37,6 @@
"\n",
"from matplotlib import pyplot as plt\n",
"\n",
"%pip install -e ../../../\n",
"from ott.geometry import costs, pointcloud\n",
"from ott.problems.nn import dataset\n",
"from ott.solvers.linear import acceleration\n",
Expand All @@ -47,17 +45,6 @@
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -83,7 +70,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -199,7 +185,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -293,7 +278,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -340,7 +324,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -464,7 +447,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -531,7 +513,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -575,7 +556,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -649,9 +629,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "ott",
"language": "python",
"name": "python3"
"name": "ott"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -663,7 +643,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.6"
}
},
"nbformat": 4,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
16,285 changes: 16,285 additions & 0 deletions docs/tutorials/gromov_wasserstein.ipynb

Large diffs are not rendered by default.

File renamed without changes.
42 changes: 21 additions & 21 deletions docs/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,54 +6,54 @@ Geometry
.. toctree::
:maxdepth: 1

notebooks/introduction_grid
introduction_grid

Linear Optimal Transport
------------------------
.. toctree::
:maxdepth: 1

notebooks/point_clouds
notebooks/One_Sinkhorn
notebooks/OTT_&_POT
notebooks/Hessians
notebooks/LRSinkhorn
notebooks/sinkhorn_divergence_gradient_flow
notebooks/sparse_monge_displacements
point_clouds
One_Sinkhorn
OTT_&_POT
Hessians
LRSinkhorn
sinkhorn_divergence_gradient_flow
sparse_monge_displacements

Barycenters
^^^^^^^^^^^
.. toctree::
:maxdepth: 1

notebooks/Sinkhorn_Barycenters
notebooks/gmm_pair_demo
notebooks/wasserstein_barycenters_gmms
Sinkhorn_Barycenters
gmm_pair_demo
wasserstein_barycenters_gmms

Miscellaneous
^^^^^^^^^^^^^
.. toctree::
:maxdepth: 1

notebooks/tracking_progress
notebooks/soft_sort
notebooks/application_biology
tracking_progress
soft_sort
application_biology

Quadratic Optimal Transport
---------------------------
.. toctree::
:maxdepth: 1

notebooks/gromov_wasserstein
notebooks/GWLRSinkhorn
notebooks/gromov_wasserstein_multiomics
gromov_wasserstein
GWLRSinkhorn
gromov_wasserstein_multiomics

Neural Optimal Transport
------------------------
.. toctree::
:maxdepth: 1

notebooks/neural_dual
notebooks/icnn_inits
notebooks/MetaOT
notebooks/Monge_Gap
neural_dual
icnn_inits
MetaOT
Monge_Gap
File renamed without changes.
File renamed without changes.
16,279 changes: 0 additions & 16,279 deletions docs/tutorials/notebooks/gromov_wasserstein.ipynb

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@ legacy_tox_ini = """
skip_missing_interpreters = true
[testenv]
extras = test,neural
extras =
test
# https://github.com/google/flax/issues/3329
py{3.9,3.10,3.11},py3.9-jax-default: neural
pass_env = CUDA_*,PYTEST_*,CI
commands_pre =
gpu: python -I -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Expand Down Expand Up @@ -299,7 +302,8 @@ select = [
unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"]
target-version = "py38"
[tool.ruff.per-file-ignores]
"tests/*" = ["D", "PT004"] # TODO(michalk8): remove `self.initialize` in `tests/`
# TODO(michalk8): PO004 - remove `self.initialize`
"tests/*" = ["D", "PT004", "E402"]
"*/__init__.py" = ["F401"]
"docs/*" = ["D"]
"src/ott/types.py" = ["D102"]
Expand Down
8 changes: 4 additions & 4 deletions src/ott/problems/linear/barycenter_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from ott.geometry import costs, geometry, segment

__all__ = ["FreeBarycenterProblem"]
__all__ = ["FreeBarycenterProblem", "FixedBarycenterProblem"]


@jax.tree_util.register_pytree_node_class
Expand All @@ -44,8 +44,8 @@ class FreeBarycenterProblem:
Only used when ``y`` is not already segmented. When passing
``segment_ids``, 2 arguments must be specified for jitting to work:
- ``num_segments`` - the total number of measures.
- ``max_measure_size`` - maximum of support sizes of these measures.
- ``num_segments`` - the total number of measures.
- ``max_measure_size`` - maximum of support sizes of these measures.
"""

def __init__(
Expand Down Expand Up @@ -158,7 +158,7 @@ class FixedBarycenterProblem:
a: batch of histograms of shape ``[batch, num_a]`` where ``num_a`` matches
the first value of the :attr:`~ott.geometry.Geometry.shape` attribute of
``geom``.
weights: ``[batch,]`` positive weights summing to :math`1`. Uniform by
weights: ``[batch,]`` positive weights summing to :math:`1`. Uniform by
default.
"""

Expand Down
14 changes: 1 addition & 13 deletions src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ class SinkhornOutput(NamedTuple):
below the convergence threshold.
inner_iterations: number of iterations that were run between two
computations of errors.
"""

f: Optional[jnp.ndarray] = None
Expand Down Expand Up @@ -424,10 +423,6 @@ def transport_cost_at_geom(
return jnp.sum(self.apply(geom.cost_1.T) * geom.cost_2.T)
return jnp.sum(self.matrix * other_geom.cost_matrix)

@property
def linear(self) -> bool: # noqa: D102
return isinstance(self.ot_prob, linear_problem.LinearProblem)

@property
def geom(self) -> geometry.Geometry: # noqa: D102
return self.ot_prob.geom
Expand All @@ -440,17 +435,10 @@ def a(self) -> jnp.ndarray: # noqa: D102
def b(self) -> jnp.ndarray: # noqa: D102
return self.ot_prob.b

@property
def linear_output(self) -> bool: # noqa: D102
return True

# TODO(michalk8): this should be always present
@property
def n_iters(self) -> int: # noqa: D102
"""Returns the total number of iterations that were needed to terminate."""
if self.errors is None:
return -1
return jnp.sum(self.errors > -1) * self.inner_iterations
return jnp.sum(self.errors != -1) * self.inner_iterations

@property
def scalings(self) -> Tuple[jnp.ndarray, jnp.ndarray]: # noqa: D102
Expand Down
10 changes: 4 additions & 6 deletions src/ott/solvers/linear/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class LRSinkhornOutput(NamedTuple):
errors: jnp.ndarray
ot_prob: linear_problem.LinearProblem
epsilon: float
inner_iterations: int
# TODO(michalk8): Optional is an artifact of the current impl., refactor
reg_ot_cost: Optional[float] = None

Expand Down Expand Up @@ -205,10 +206,6 @@ def compute_reg_ot_cost( # noqa: D102
use_danskin=use_danskin
)

@property
def linear(self) -> bool: # noqa: D102
return isinstance(self.ot_prob, linear_problem.LinearProblem)

@property
def geom(self) -> geometry.Geometry: # noqa: D102
return self.ot_prob.geom
Expand All @@ -222,8 +219,8 @@ def b(self) -> jnp.ndarray: # noqa: D102
return self.ot_prob.b

@property
def linear_output(self) -> bool: # noqa: D102
return True
def n_iters(self) -> int: # noqa: D102
return jnp.sum(self.errors != -1) * self.inner_iterations

@property
def converged(self) -> bool: # noqa: D102
Expand Down Expand Up @@ -773,6 +770,7 @@ def output_from_state(
costs=state.costs,
errors=state.errors,
epsilon=self.epsilon,
inner_iterations=self.inner_iterations,
)

def _converged(self, state: LRSinkhornState, iteration: int) -> bool:
Expand Down
16 changes: 9 additions & 7 deletions src/ott/solvers/nn/conjugate_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,18 @@ class FenchelConjugateLBFGS(FenchelConjugateSolver):
max_iter: maximum number of iterations
max_linesearch_iter: maximum number of line search iterations
linesearch_type: type of line search
decrease_factor: decrease factor for a backtracking line search
ls_method: the line search method
linesearch_init: strategy for line search initialization
increase_factor: factor by which to increase the step size during
the line search
"""

gtol: float = 1e-3
max_iter: int = 10
max_linesearch_iter: int = 10
linesearch_type: Literal["zoom", "backtracking"] = "backtracking"
decrease_factor: float = 0.66
ls_method: Literal["wolf", "strong-wolfe"] = "strong-wolfe"
linesearch_type: Literal["zoom", "backtracking",
"hager-zhang"] = "backtracking"
linesearch_init: Literal["increase", "max", "current"] = "increase"
increase_factor: float = 1.5

def solve( # noqa: D102
self,
Expand All @@ -98,9 +100,9 @@ def solve( # noqa: D102
fun=lambda x: f(x) - x.dot(y),
tol=self.gtol,
maxiter=self.max_iter,
decrease_factor=self.decrease_factor,
linesearch=self.linesearch_type,
condition=self.ls_method,
linesearch_init=self.linesearch_init,
increase_factor=self.increase_factor,
implicit_diff=False,
unroll=False
)
Expand Down
2 changes: 1 addition & 1 deletion src/ott/solvers/quadratic/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def primal_cost(self) -> float:
def n_iters(self) -> int: # noqa: D102
if self.errors is None:
return -1
return jnp.sum(self.errors > -1)
return jnp.sum(self.errors[:, 0] != -1)


class GWState(NamedTuple):
Expand Down
Loading

0 comments on commit df96dae

Please sign in to comment.