Skip to content

Commit

Permalink
Feature/optional dependencies (#412)
Browse files Browse the repository at this point in the history
* Add `neural` extra requirements

* Add import wrappers for neural modules

* Update `tox`

* Fix linter

* Update docs

* Pin numpy from above, fix RTD
  • Loading branch information
michalk8 authored Aug 18, 2023
1 parent 2f453df commit db536c9
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ python:
install:
- method: pip
path: .
extra_requirements: [docs]
extra_requirements: [docs, neural]
10 changes: 8 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@ Install ``OTT`` from `PyPI <https://pypi.org/project/ott-jax/>`_ as:
pip install ott-jax
or with ``conda`` via `conda-forge`_ as:
or with the :mod:`neural OT solvers <ott.solvers.nn>` dependencies:

.. code-block:: bash
pip install 'ott-jax[neural]'
or using `conda`_ as:

.. code-block:: bash
Expand Down Expand Up @@ -139,4 +145,4 @@ Packages
.. _auto-vectorization (VMAP): https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap
.. _automatic: https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation
.. _implicit: https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html#jax.custom_jvp
.. _conda-forge: https://anaconda.org/conda-forge/ott-jax
.. _conda: https://anaconda.org/conda-forge/ott-jax
1 change: 1 addition & 0 deletions docs/spelling/misc.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
alg
arg
args
cond
Expand Down
4 changes: 4 additions & 0 deletions docs/spelling/technical.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Kullback
Leibler
Mahalanobis
Monge
Moreau
SGD
Schur
Seidel
Expand Down Expand Up @@ -73,6 +74,7 @@ jax
jit
jitting
linearization
linearized
logit
macOS
methylation
Expand All @@ -95,6 +97,7 @@ preconditioner
preprocess
preprocessing
proteome
prox
quantile
quantiles
quantizes
Expand All @@ -105,6 +108,7 @@ renormalize
reproducibility
rescale
rescaled
rescaling
reweighted
reweighting
reweightings
Expand Down
18 changes: 10 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@ authors = [
dependencies = [
"jax>=0.1.67",
"jaxlib>=0.1.47",
"jaxopt>=0.5.5",
# numba/numpy compatibility issue in JAXOPT.
"jaxopt>=0.8",
"lineax>=0.0.1; python_version >= '3.9'",
# https://github.com/numba/numba/issues/9130
"numpy>=1.18.4, <1.25.0",
"flax>=0.6.6",
"optax>=0.1.1",
"lineax>=0.0.1; python_version >= '3.9'"
]
keywords = [
"optimal transport",
Expand Down Expand Up @@ -59,6 +57,10 @@ Documentation = "https://ott-jax.readthedocs.io"
Changelog = "https://github.com/ott-jax/ott/releases"

[project.optional-dependencies]
neural = [
"flax>=0.6.6",
"optax>=0.1.1",
]
dev = [
"pre-commit>=2.16.0",
"tox>=4",
Expand Down Expand Up @@ -181,7 +183,7 @@ legacy_tox_ini = """
skip_missing_interpreters = true
[testenv]
extras = test
extras = test,neural
pass_env = CUDA_*,PYTEST_*,CI
commands_pre =
gpu: python -I -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Expand All @@ -201,7 +203,7 @@ legacy_tox_ini = """
[testenv:lint-docs]
description = Lint the documentation.
deps =
extras = docs
extras = docs,neural
ignore_errors = true
allowlist_externals = make
pass_env = PYENCHANT_LIBRARY_PATH
Expand All @@ -215,7 +217,7 @@ legacy_tox_ini = """
description = Build the documentation.
use_develop = true
deps =
extras = docs
extras = docs,neural
allowlist_externals = make
changedir = {tox_root}{/}docs
commands =
Expand Down
8 changes: 7 additions & 1 deletion src/ott/initializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import linear, nn, quadratic
import contextlib

from . import linear, quadratic

with contextlib.suppress(ImportError):
from . import nn
del contextlib
8 changes: 7 additions & 1 deletion src/ott/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import linear, nn, quadratic
import contextlib

from . import linear, quadratic, was_solver

with contextlib.suppress(ImportError):
from . import nn
del contextlib
2 changes: 1 addition & 1 deletion src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ class SinkhornOutput(NamedTuple):
``max_iterations // inner_iterations`` where those were the parameters
passed on to the :class:`ott.solvers.linear.sinkhorn.Sinkhorn` solver.
For each entry indexed at ``i``, ``errors[i]`` can be either a real
nonnegative value (meaning the algorithm recorded that error at the
non-negative value (meaning the algorithm recorded that error at the
``i * inner_iterations`` iteration), a ``jnp.inf`` value (meaning the
algorithm computed that iteration but did not compute its error, because,
for instance, ``i < min_iterations // inner_iterations``), or a ``-1``,
Expand Down
7 changes: 6 additions & 1 deletion src/ott/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib

from . import (
gaussian_mixture,
k_means,
map_estimator,
plot,
segment_sinkhorn,
sinkhorn_divergence,
soft_sort,
)

with contextlib.suppress(ImportError):
from . import map_estimator
del contextlib
2 changes: 1 addition & 1 deletion src/ott/tools/gaussian_mixture/fit_gmm_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@

import jax
import jax.numpy as jnp
import optax

from ott.tools.gaussian_mixture import (
fit_gmm,
Expand Down Expand Up @@ -235,6 +234,7 @@ def get_m_step_fn(learning_rate: float, objective_fn, jit: bool):
Returns:
A function that performs the M-step of EM.
"""
import optax

def _m_step_fn(
pair: gaussian_mixture_pair.GaussianMixturePair,
Expand Down
7 changes: 3 additions & 4 deletions src/ott/tools/map_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import functools
from typing import Any, Callable, Dict, Iterator, Optional, Tuple
Expand All @@ -22,7 +21,7 @@
from flax.core import frozen_dict
from flax.training import train_state

from ott.solvers.nn.models import ModelBase
from ott.solvers.nn import models

__all__ = ["MapEstimator"]

Expand Down Expand Up @@ -65,7 +64,7 @@ class MapEstimator:
def __init__(
self,
dim_data: int,
model: ModelBase,
model: models.ModelBase,
optimizer: Optional[optax.OptState] = None,
fitting_loss: Optional[Callable[[jnp.ndarray, jnp.ndarray],
float]] = None,
Expand Down Expand Up @@ -94,7 +93,7 @@ def __init__(
def setup(
self,
dim_data: int,
neural_net: ModelBase,
neural_net: models.ModelBase,
optimizer: optax.OptState,
):
"""Setup all components required to train the network."""
Expand Down

0 comments on commit db536c9

Please sign in to comment.