From db536c92176c6ed41adf3262e2d5351093d88e3f Mon Sep 17 00:00:00 2001
From: michalk8 <46717574+michalk8@users.noreply.github.com>
Date: Fri, 18 Aug 2023 12:26:19 +0200
Subject: [PATCH] Feature/optional dependencies (#412)
* Add `neural` extra requirements
* Add import wrappers for neural modules
* Update `tox`
* Fix linter
* Update docs
* Pin numpy from above, fix RTD
---
.readthedocs.yaml | 2 +-
docs/index.rst | 10 ++++++++--
docs/spelling/misc.txt | 1 +
docs/spelling/technical.txt | 4 ++++
pyproject.toml | 18 ++++++++++--------
src/ott/initializers/__init__.py | 8 +++++++-
src/ott/solvers/__init__.py | 8 +++++++-
src/ott/solvers/linear/sinkhorn.py | 2 +-
src/ott/tools/__init__.py | 7 ++++++-
src/ott/tools/gaussian_mixture/fit_gmm_pair.py | 2 +-
src/ott/tools/map_estimator.py | 7 +++----
11 files changed, 49 insertions(+), 20 deletions(-)
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index d06b13894..bb16868e5 100644
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -14,4 +14,4 @@ python:
install:
- method: pip
path: .
- extra_requirements: [docs]
+ extra_requirements: [docs, neural]
diff --git a/docs/index.rst b/docs/index.rst
index 8e9728256..170b1b6ab 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -32,7 +32,13 @@ Install ``OTT`` from `PyPI `_ as:
pip install ott-jax
-or with ``conda`` via `conda-forge`_ as:
+or with the :mod:`neural OT solvers ` dependencies:
+
+.. code-block:: bash
+
+ pip install 'ott-jax[neural]'
+
+or using `conda`_ as:
.. code-block:: bash
@@ -139,4 +145,4 @@ Packages
.. _auto-vectorization (VMAP): https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap
.. _automatic: https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation
.. _implicit: https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html#jax.custom_jvp
-.. _conda-forge: https://anaconda.org/conda-forge/ott-jax
+.. _conda: https://anaconda.org/conda-forge/ott-jax
diff --git a/docs/spelling/misc.txt b/docs/spelling/misc.txt
index 695a18c72..e9ddc55d0 100644
--- a/docs/spelling/misc.txt
+++ b/docs/spelling/misc.txt
@@ -1,3 +1,4 @@
+alg
arg
args
cond
diff --git a/docs/spelling/technical.txt b/docs/spelling/technical.txt
index cf621cab6..008dac7b4 100644
--- a/docs/spelling/technical.txt
+++ b/docs/spelling/technical.txt
@@ -19,6 +19,7 @@ Kullback
Leibler
Mahalanobis
Monge
+Moreau
SGD
Schur
Seidel
@@ -73,6 +74,7 @@ jax
jit
jitting
linearization
+linearized
logit
macOS
methylation
@@ -95,6 +97,7 @@ preconditioner
preprocess
preprocessing
proteome
+prox
quantile
quantiles
quantizes
@@ -105,6 +108,7 @@ renormalize
reproducibility
rescale
rescaled
+rescaling
reweighted
reweighting
reweightings
diff --git a/pyproject.toml b/pyproject.toml
index 3bb7dd45c..bb7a86614 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,12 +15,10 @@ authors = [
dependencies = [
"jax>=0.1.67",
"jaxlib>=0.1.47",
- "jaxopt>=0.5.5",
- # numba/numpy compatibility issue in JAXOPT.
+ "jaxopt>=0.8",
+ "lineax>=0.0.1; python_version >= '3.9'",
+ # https://github.com/numba/numba/issues/9130
"numpy>=1.18.4, <1.25.0",
- "flax>=0.6.6",
- "optax>=0.1.1",
- "lineax>=0.0.1; python_version >= '3.9'"
]
keywords = [
"optimal transport",
@@ -59,6 +57,10 @@ Documentation = "https://ott-jax.readthedocs.io"
Changelog = "https://github.com/ott-jax/ott/releases"
[project.optional-dependencies]
+neural = [
+ "flax>=0.6.6",
+ "optax>=0.1.1",
+]
dev = [
"pre-commit>=2.16.0",
"tox>=4",
@@ -181,7 +183,7 @@ legacy_tox_ini = """
skip_missing_interpreters = true
[testenv]
- extras = test
+ extras = test,neural
pass_env = CUDA_*,PYTEST_*,CI
commands_pre =
gpu: python -I -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
@@ -201,7 +203,7 @@ legacy_tox_ini = """
[testenv:lint-docs]
description = Lint the documentation.
deps =
- extras = docs
+ extras = docs,neural
ignore_errors = true
allowlist_externals = make
pass_env = PYENCHANT_LIBRARY_PATH
@@ -215,7 +217,7 @@ legacy_tox_ini = """
description = Build the documentation.
use_develop = true
deps =
- extras = docs
+ extras = docs,neural
allowlist_externals = make
changedir = {tox_root}{/}docs
commands =
diff --git a/src/ott/initializers/__init__.py b/src/ott/initializers/__init__.py
index f72551448..c3e4a78a1 100644
--- a/src/ott/initializers/__init__.py
+++ b/src/ott/initializers/__init__.py
@@ -11,4 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from . import linear, nn, quadratic
+import contextlib
+
+from . import linear, quadratic
+
+with contextlib.suppress(ImportError):
+ from . import nn
+del contextlib
diff --git a/src/ott/solvers/__init__.py b/src/ott/solvers/__init__.py
index f72551448..834fc4d22 100644
--- a/src/ott/solvers/__init__.py
+++ b/src/ott/solvers/__init__.py
@@ -11,4 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from . import linear, nn, quadratic
+import contextlib
+
+from . import linear, quadratic, was_solver
+
+with contextlib.suppress(ImportError):
+ from . import nn
+del contextlib
diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py
index 8521e200e..59ef5a5a9 100644
--- a/src/ott/solvers/linear/sinkhorn.py
+++ b/src/ott/solvers/linear/sinkhorn.py
@@ -298,7 +298,7 @@ class SinkhornOutput(NamedTuple):
``max_iterations // inner_iterations`` where those were the parameters
passed on to the :class:`ott.solvers.linear.sinkhorn.Sinkhorn` solver.
For each entry indexed at ``i``, ``errors[i]`` can be either a real
- nonnegative value (meaning the algorithm recorded that error at the
+ non-negative value (meaning the algorithm recorded that error at the
``i * inner_iterations`` iteration), a ``jnp.inf`` value (meaning the
algorithm computed that iteration but did not compute its error, because,
for instance, ``i < min_iterations // inner_iterations``), or a ``-1``,
diff --git a/src/ott/tools/__init__.py b/src/ott/tools/__init__.py
index 0e7747a9f..edc722614 100644
--- a/src/ott/tools/__init__.py
+++ b/src/ott/tools/__init__.py
@@ -11,12 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import contextlib
+
from . import (
gaussian_mixture,
k_means,
- map_estimator,
plot,
segment_sinkhorn,
sinkhorn_divergence,
soft_sort,
)
+
+with contextlib.suppress(ImportError):
+ from . import map_estimator
+del contextlib
diff --git a/src/ott/tools/gaussian_mixture/fit_gmm_pair.py b/src/ott/tools/gaussian_mixture/fit_gmm_pair.py
index ccf02fbab..7ecde263c 100644
--- a/src/ott/tools/gaussian_mixture/fit_gmm_pair.py
+++ b/src/ott/tools/gaussian_mixture/fit_gmm_pair.py
@@ -83,7 +83,6 @@
import jax
import jax.numpy as jnp
-import optax
from ott.tools.gaussian_mixture import (
fit_gmm,
@@ -235,6 +234,7 @@ def get_m_step_fn(learning_rate: float, objective_fn, jit: bool):
Returns:
A function that performs the M-step of EM.
"""
+ import optax
def _m_step_fn(
pair: gaussian_mixture_pair.GaussianMixturePair,
diff --git a/src/ott/tools/map_estimator.py b/src/ott/tools/map_estimator.py
index 1101a561b..f0aba78a4 100644
--- a/src/ott/tools/map_estimator.py
+++ b/src/ott/tools/map_estimator.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import collections
import functools
from typing import Any, Callable, Dict, Iterator, Optional, Tuple
@@ -22,7 +21,7 @@
from flax.core import frozen_dict
from flax.training import train_state
-from ott.solvers.nn.models import ModelBase
+from ott.solvers.nn import models
__all__ = ["MapEstimator"]
@@ -65,7 +64,7 @@ class MapEstimator:
def __init__(
self,
dim_data: int,
- model: ModelBase,
+ model: models.ModelBase,
optimizer: Optional[optax.OptState] = None,
fitting_loss: Optional[Callable[[jnp.ndarray, jnp.ndarray],
float]] = None,
@@ -94,7 +93,7 @@ def __init__(
def setup(
self,
dim_data: int,
- neural_net: ModelBase,
+ neural_net: models.ModelBase,
optimizer: optax.OptState,
):
"""Setup all components required to train the network."""