From db536c92176c6ed41adf3262e2d5351093d88e3f Mon Sep 17 00:00:00 2001 From: michalk8 <46717574+michalk8@users.noreply.github.com> Date: Fri, 18 Aug 2023 12:26:19 +0200 Subject: [PATCH] Feature/optional dependencies (#412) * Add `neural` extra requirements * Add import wrappers for neural modules * Update `tox` * Fix linter * Update docs * Pin numpy from above, fix RTD --- .readthedocs.yaml | 2 +- docs/index.rst | 10 ++++++++-- docs/spelling/misc.txt | 1 + docs/spelling/technical.txt | 4 ++++ pyproject.toml | 18 ++++++++++-------- src/ott/initializers/__init__.py | 8 +++++++- src/ott/solvers/__init__.py | 8 +++++++- src/ott/solvers/linear/sinkhorn.py | 2 +- src/ott/tools/__init__.py | 7 ++++++- src/ott/tools/gaussian_mixture/fit_gmm_pair.py | 2 +- src/ott/tools/map_estimator.py | 7 +++---- 11 files changed, 49 insertions(+), 20 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index d06b13894..bb16868e5 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -14,4 +14,4 @@ python: install: - method: pip path: . - extra_requirements: [docs] + extra_requirements: [docs, neural] diff --git a/docs/index.rst b/docs/index.rst index 8e9728256..170b1b6ab 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,7 +32,13 @@ Install ``OTT`` from `PyPI `_ as: pip install ott-jax -or with ``conda`` via `conda-forge`_ as: +or with the :mod:`neural OT solvers ` dependencies: + +.. code-block:: bash + + pip install 'ott-jax[neural]' + +or using `conda`_ as: .. code-block:: bash @@ -139,4 +145,4 @@ Packages .. _auto-vectorization (VMAP): https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap .. _automatic: https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation .. _implicit: https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html#jax.custom_jvp -.. _conda-forge: https://anaconda.org/conda-forge/ott-jax +.. _conda: https://anaconda.org/conda-forge/ott-jax diff --git a/docs/spelling/misc.txt b/docs/spelling/misc.txt index 695a18c72..e9ddc55d0 100644 --- a/docs/spelling/misc.txt +++ b/docs/spelling/misc.txt @@ -1,3 +1,4 @@ +alg arg args cond diff --git a/docs/spelling/technical.txt b/docs/spelling/technical.txt index cf621cab6..008dac7b4 100644 --- a/docs/spelling/technical.txt +++ b/docs/spelling/technical.txt @@ -19,6 +19,7 @@ Kullback Leibler Mahalanobis Monge +Moreau SGD Schur Seidel @@ -73,6 +74,7 @@ jax jit jitting linearization +linearized logit macOS methylation @@ -95,6 +97,7 @@ preconditioner preprocess preprocessing proteome +prox quantile quantiles quantizes @@ -105,6 +108,7 @@ renormalize reproducibility rescale rescaled +rescaling reweighted reweighting reweightings diff --git a/pyproject.toml b/pyproject.toml index 3bb7dd45c..bb7a86614 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,12 +15,10 @@ authors = [ dependencies = [ "jax>=0.1.67", "jaxlib>=0.1.47", - "jaxopt>=0.5.5", - # numba/numpy compatibility issue in JAXOPT. + "jaxopt>=0.8", + "lineax>=0.0.1; python_version >= '3.9'", + # https://github.com/numba/numba/issues/9130 "numpy>=1.18.4, <1.25.0", - "flax>=0.6.6", - "optax>=0.1.1", - "lineax>=0.0.1; python_version >= '3.9'" ] keywords = [ "optimal transport", @@ -59,6 +57,10 @@ Documentation = "https://ott-jax.readthedocs.io" Changelog = "https://github.com/ott-jax/ott/releases" [project.optional-dependencies] +neural = [ + "flax>=0.6.6", + "optax>=0.1.1", +] dev = [ "pre-commit>=2.16.0", "tox>=4", @@ -181,7 +183,7 @@ legacy_tox_ini = """ skip_missing_interpreters = true [testenv] - extras = test + extras = test,neural pass_env = CUDA_*,PYTEST_*,CI commands_pre = gpu: python -I -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html @@ -201,7 +203,7 @@ legacy_tox_ini = """ [testenv:lint-docs] description = Lint the documentation. deps = - extras = docs + extras = docs,neural ignore_errors = true allowlist_externals = make pass_env = PYENCHANT_LIBRARY_PATH @@ -215,7 +217,7 @@ legacy_tox_ini = """ description = Build the documentation. use_develop = true deps = - extras = docs + extras = docs,neural allowlist_externals = make changedir = {tox_root}{/}docs commands = diff --git a/src/ott/initializers/__init__.py b/src/ott/initializers/__init__.py index f72551448..c3e4a78a1 100644 --- a/src/ott/initializers/__init__.py +++ b/src/ott/initializers/__init__.py @@ -11,4 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import linear, nn, quadratic +import contextlib + +from . import linear, quadratic + +with contextlib.suppress(ImportError): + from . import nn +del contextlib diff --git a/src/ott/solvers/__init__.py b/src/ott/solvers/__init__.py index f72551448..834fc4d22 100644 --- a/src/ott/solvers/__init__.py +++ b/src/ott/solvers/__init__.py @@ -11,4 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import linear, nn, quadratic +import contextlib + +from . import linear, quadratic, was_solver + +with contextlib.suppress(ImportError): + from . import nn +del contextlib diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index 8521e200e..59ef5a5a9 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -298,7 +298,7 @@ class SinkhornOutput(NamedTuple): ``max_iterations // inner_iterations`` where those were the parameters passed on to the :class:`ott.solvers.linear.sinkhorn.Sinkhorn` solver. For each entry indexed at ``i``, ``errors[i]`` can be either a real - nonnegative value (meaning the algorithm recorded that error at the + non-negative value (meaning the algorithm recorded that error at the ``i * inner_iterations`` iteration), a ``jnp.inf`` value (meaning the algorithm computed that iteration but did not compute its error, because, for instance, ``i < min_iterations // inner_iterations``), or a ``-1``, diff --git a/src/ott/tools/__init__.py b/src/ott/tools/__init__.py index 0e7747a9f..edc722614 100644 --- a/src/ott/tools/__init__.py +++ b/src/ott/tools/__init__.py @@ -11,12 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib + from . import ( gaussian_mixture, k_means, - map_estimator, plot, segment_sinkhorn, sinkhorn_divergence, soft_sort, ) + +with contextlib.suppress(ImportError): + from . import map_estimator +del contextlib diff --git a/src/ott/tools/gaussian_mixture/fit_gmm_pair.py b/src/ott/tools/gaussian_mixture/fit_gmm_pair.py index ccf02fbab..7ecde263c 100644 --- a/src/ott/tools/gaussian_mixture/fit_gmm_pair.py +++ b/src/ott/tools/gaussian_mixture/fit_gmm_pair.py @@ -83,7 +83,6 @@ import jax import jax.numpy as jnp -import optax from ott.tools.gaussian_mixture import ( fit_gmm, @@ -235,6 +234,7 @@ def get_m_step_fn(learning_rate: float, objective_fn, jit: bool): Returns: A function that performs the M-step of EM. """ + import optax def _m_step_fn( pair: gaussian_mixture_pair.GaussianMixturePair, diff --git a/src/ott/tools/map_estimator.py b/src/ott/tools/map_estimator.py index 1101a561b..f0aba78a4 100644 --- a/src/ott/tools/map_estimator.py +++ b/src/ott/tools/map_estimator.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import collections import functools from typing import Any, Callable, Dict, Iterator, Optional, Tuple @@ -22,7 +21,7 @@ from flax.core import frozen_dict from flax.training import train_state -from ott.solvers.nn.models import ModelBase +from ott.solvers.nn import models __all__ = ["MapEstimator"] @@ -65,7 +64,7 @@ class MapEstimator: def __init__( self, dim_data: int, - model: ModelBase, + model: models.ModelBase, optimizer: Optional[optax.OptState] = None, fitting_loss: Optional[Callable[[jnp.ndarray, jnp.ndarray], float]] = None, @@ -94,7 +93,7 @@ def __init__( def setup( self, dim_data: int, - neural_net: ModelBase, + neural_net: models.ModelBase, optimizer: optax.OptState, ): """Setup all components required to train the network."""