Skip to content

Commit

Permalink
Start adding neural part
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Nov 28, 2023
1 parent 564cfc7 commit 1abf43a
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 10 deletions.
3 changes: 2 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Install ``OTT`` from `PyPI <https://pypi.org/project/ott-jax/>`_ as:
pip install ott-jax
or with the :mod:`neural OT solvers <ott.solvers.nn>` dependencies:
or with the :mod:`neural OT <ott.neural>` dependencies:

.. code-block:: bash
Expand Down Expand Up @@ -114,6 +114,7 @@ Packages
problems/index
solvers/index
initializers/index
neural/index
tools
math
utils
Expand Down
4 changes: 4 additions & 0 deletions docs/neural/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
ott.neural
==========
.. module:: ott.neural
.. currentmodule:: ott.neural
3 changes: 1 addition & 2 deletions docs/solvers/linear.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ their own, either the Sinkhorn
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` or Low-Rank
:class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn` solvers, to match two
datasets. They also appear as subroutines for more advanced solvers in the
:mod:`ott.solvers` module, notably :mod:`ott.solvers.quadratic` or
:mod:`ott.solvers.nn`.
:mod:`ott.solvers` module, notably :mod:`ott.solvers.quadratic`.

Sinkhorn Solvers
----------------
Expand Down
2 changes: 1 addition & 1 deletion src/ott/neural/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from typing import Any, Callable, Dict, Optional, Sequence, Tuple

import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from flax.core import frozen_dict
from flax.training import train_state
from jax import numpy as jnp
from jax.nn import initializers

from ott import utils
Expand Down
12 changes: 6 additions & 6 deletions src/ott/solvers/quadratic/lower_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any, Optional

import jax
import jax.tree_util as jtu

from ott.geometry import pointcloud
from ott.problems.quadratic import quadratic_problem
Expand All @@ -27,7 +27,7 @@
__all__ = ["LowerBoundSolver"]


@jax.tree_util.register_pytree_node_class
@jtu.register_pytree_node_class
class LowerBoundSolver:
"""Lower bound OT solver.
Expand Down Expand Up @@ -62,10 +62,10 @@ def __call__(
) -> sinkhorn.SinkhornOutput:
"""Compute a lower-bound for the GW problem using a simple linearization.
This solver handles a quadratic problem by computing first a proxy ``[n,m]``
cost-matrix, inject it into a linear OT solver, to output a first OT matrix
that can be used either to linearize/initialize the resolution of the GW
problem, or more simply as a simple GW solution.
This solver handles a quadratic problem by computing a proxy ``[n, m]``
cost-matrix, injecting it into a linear OT solver to output a first an OT
matrix that can be used either to linearize/initialize the resolution
ot the GW problem, or more simply as a simple GW solution.
Args:
prob: Quadratic OT problem.
Expand Down

0 comments on commit 1abf43a

Please sign in to comment.