Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/optional dependencies #412

Merged
merged 6 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ python:
install:
- method: pip
path: .
extra_requirements: [docs]
extra_requirements: [docs, neural]
10 changes: 8 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@ Install ``OTT`` from `PyPI <https://pypi.org/project/ott-jax/>`_ as:

pip install ott-jax

or with ``conda`` via `conda-forge`_ as:
or with the :mod:`neural OT solvers <ott.solvers.nn>` dependencies:

.. code-block:: bash

pip install 'ott-jax[neural]'

or using `conda`_ as:

.. code-block:: bash

Expand Down Expand Up @@ -139,4 +145,4 @@ Packages
.. _auto-vectorization (VMAP): https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap
.. _automatic: https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation
.. _implicit: https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html#jax.custom_jvp
.. _conda-forge: https://anaconda.org/conda-forge/ott-jax
.. _conda: https://anaconda.org/conda-forge/ott-jax
1 change: 1 addition & 0 deletions docs/spelling/misc.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
alg
arg
args
cond
Expand Down
4 changes: 4 additions & 0 deletions docs/spelling/technical.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Kullback
Leibler
Mahalanobis
Monge
Moreau
SGD
Schur
Seidel
Expand Down Expand Up @@ -73,6 +74,7 @@ jax
jit
jitting
linearization
linearized
logit
macOS
methylation
Expand All @@ -95,6 +97,7 @@ preconditioner
preprocess
preprocessing
proteome
prox
quantile
quantiles
quantizes
Expand All @@ -105,6 +108,7 @@ renormalize
reproducibility
rescale
rescaled
rescaling
reweighted
reweighting
reweightings
Expand Down
18 changes: 10 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@ authors = [
dependencies = [
"jax>=0.1.67",
"jaxlib>=0.1.47",
"jaxopt>=0.5.5",
# numba/numpy compatibility issue in JAXOPT.
"jaxopt>=0.8",
"lineax>=0.0.1; python_version >= '3.9'",
# https://github.com/numba/numba/issues/9130
"numpy>=1.18.4, <1.25.0",
"flax>=0.6.6",
"optax>=0.1.1",
"lineax>=0.0.1; python_version >= '3.9'"
]
keywords = [
"optimal transport",
Expand Down Expand Up @@ -59,6 +57,10 @@ Documentation = "https://ott-jax.readthedocs.io"
Changelog = "https://github.com/ott-jax/ott/releases"

[project.optional-dependencies]
neural = [
"flax>=0.6.6",
"optax>=0.1.1",
]
dev = [
"pre-commit>=2.16.0",
"tox>=4",
Expand Down Expand Up @@ -180,7 +182,7 @@ legacy_tox_ini = """
skip_missing_interpreters = true

[testenv]
extras = test
extras = test,neural
pass_env = CUDA_*,PYTEST_*,CI
commands_pre =
gpu: python -I -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Expand All @@ -200,7 +202,7 @@ legacy_tox_ini = """
[testenv:lint-docs]
description = Lint the documentation.
deps =
extras = docs
extras = docs,neural
ignore_errors = true
allowlist_externals = make
pass_env = PYENCHANT_LIBRARY_PATH
Expand All @@ -214,7 +216,7 @@ legacy_tox_ini = """
description = Build the documentation.
use_develop = true
deps =
extras = docs
extras = docs,neural
allowlist_externals = make
changedir = {tox_root}{/}docs
commands =
Expand Down
8 changes: 7 additions & 1 deletion src/ott/initializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import linear, nn, quadratic
import contextlib

from . import linear, quadratic

with contextlib.suppress(ImportError):
from . import nn
del contextlib
8 changes: 7 additions & 1 deletion src/ott/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import linear, nn, quadratic
import contextlib

from . import linear, quadratic, was_solver

with contextlib.suppress(ImportError):
from . import nn
del contextlib
2 changes: 1 addition & 1 deletion src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ class SinkhornOutput(NamedTuple):
``max_iterations // inner_iterations`` where those were the parameters
passed on to the :class:`ott.solvers.linear.sinkhorn.Sinkhorn` solver.
For each entry indexed at ``i``, ``errors[i]`` can be either a real
nonnegative value (meaning the algorithm recorded that error at the
non-negative value (meaning the algorithm recorded that error at the
``i * inner_iterations`` iteration), a ``jnp.inf`` value (meaning the
algorithm computed that iteration but did not compute its error, because,
for instance, ``i < min_iterations // inner_iterations``), or a ``-1``,
Expand Down
7 changes: 6 additions & 1 deletion src/ott/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib

from . import (
gaussian_mixture,
k_means,
map_estimator,
plot,
segment_sinkhorn,
sinkhorn_divergence,
soft_sort,
)

with contextlib.suppress(ImportError):
from . import map_estimator
del contextlib
2 changes: 1 addition & 1 deletion src/ott/tools/gaussian_mixture/fit_gmm_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@

import jax
import jax.numpy as jnp
import optax

from ott.tools.gaussian_mixture import (
fit_gmm,
Expand Down Expand Up @@ -235,6 +234,7 @@ def get_m_step_fn(learning_rate: float, objective_fn, jit: bool):
Returns:
A function that performs the M-step of EM.
"""
import optax

def _m_step_fn(
pair: gaussian_mixture_pair.GaussianMixturePair,
Expand Down
7 changes: 3 additions & 4 deletions src/ott/tools/map_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import functools
from typing import Any, Callable, Dict, Iterator, Optional, Tuple
Expand All @@ -22,7 +21,7 @@
from flax.core import frozen_dict
from flax.training import train_state

from ott.solvers.nn.models import ModelBase
from ott.solvers.nn import models

__all__ = ["MapEstimator"]

Expand Down Expand Up @@ -65,7 +64,7 @@ class MapEstimator:
def __init__(
self,
dim_data: int,
model: ModelBase,
model: models.ModelBase,
optimizer: Optional[optax.OptState] = None,
fitting_loss: Optional[Callable[[jnp.ndarray, jnp.ndarray],
float]] = None,
Expand Down Expand Up @@ -94,7 +93,7 @@ def __init__(
def setup(
self,
dim_data: int,
neural_net: ModelBase,
neural_net: models.ModelBase,
optimizer: optax.OptState,
):
"""Setup all components required to train the network."""
Expand Down