Skip to content

Commit

Permalink
Histogram Transport Implementation (#444)
Browse files Browse the repository at this point in the history
* test commit

* added HistogramTransport distance

* added tests

* removed `test_file`

* renamed `ArbitraryTransportInitializer`
to `FixedCouplingInitializer`

* softness -> epsilon_1d

* epsilon_1d=0.0 for hard sorting

* epsilon_1d=0.0 for hard sorting

* + cost_fn for 1d_wasserstein

* extracted `wasserstein_1d`

* removed `p` argument

* removed `match` statement

* removed `HTOutput` and `HTState` classes

* updated `ht_test`

* fixed `ht_test`

* `wasserstein_1d` -> `univariate`

* fixed indentation issues

* removed `FixedCouplingInitializer` class

* changed `QuadraticInitializer` documentation

* added `solvers.univariate` to documentation

* minor edits to `univariate.py`

* fixed `UnivariateSolver` docstring

* many updates to `univariate.py`

* docstring edits to `histogram_transport`

* added missing type of `univariate`'s `__call__`

* added pytree class to HT and Univariate solvers

* doc changes, code refactoring

* added memoli citation

* parametrized `ht_test`

* fixed spelling

* readded min/max iterations to `univariate`

* fixed  underline

* fixed indentations

* added `init_coupling` as a child

* type ascription for `**kwargs`

* fixed `warning::`, I think?

* docstring edits of `univariate.py`

* `self.cost_fn` to oneliner

* fixed `univariate` children

* fixing `.rst` stuff

* editing `univariate.py` docs

* slightly more documentation

* fixed `ht_test` error

* Use `sort_fn`

* Fewer tests

* Add shape checks

* Add diff tests

* Re-scale when subsampling

* Update grad test

* Rename solver

* Fix indentation

* Refer to the definition in the LowerBoundSolver

---------

Co-authored-by: Michal Klein <[email protected]>
  • Loading branch information
Daniel-Packer and michalk8 authored Oct 27, 2023
1 parent 6285372 commit 54d3b63
Show file tree
Hide file tree
Showing 14 changed files with 413 additions and 23 deletions.
2 changes: 1 addition & 1 deletion docs/problems/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ott.problems
.. module:: ott.problems

The :mod:`ott.problems` module describes the low level optimal transport
problems that are solved by ``OTT`` :mod:`ott.solvers`. These problems are
problems that are solved by :mod:`ott.solvers`. These problems are
loosely divided into two categories, first finite-sample based problems, as in
:mod:`ott.problems.linear` and :mod:`ott.problems.quadratic`, or relying on
iterators. In that latter category, :mod:`ott.problems.nn` contains synthetic
Expand Down
2 changes: 1 addition & 1 deletion docs/solvers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ linear solvers in :mod:`ott.solvers.linear`, designed to solve linear OT
problems. More advanced solvers, notably quadratic in
:mod:`ott.solvers.quadratic`, rely on calls to linear solvers as subroutines.
That property itself is implemented in the more abstract
:class:`ott.solvers.was_solver.WassersteinSolver` class, which provides a
:class:`~ott.solvers.was_solver.WassersteinSolver` class, which provides a
lower-level template at the interface between the two. Neural based solvers in
:mod:`ott.solvers.nn` live on a different category of their own, since they
typically solve the Monge formulation of OT.
Expand Down
11 changes: 9 additions & 2 deletions docs/solvers/linear.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ ott.solvers.linear

Linear solvers are the bread-and-butter of OT solvers. They can be called on
their own, either the Sinkhorn
:class:`ott.solvers.linear.sinkhorn.Sinkhorn` or Low-Rank
:class:`ott.solvers.linear.sinkhorn_lr.LRSinkhorn` solvers, to match two
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` or Low-Rank
:class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn` solvers, to match two
datasets. They also appear as subroutines for more advanced solvers in the
:mod:`ott.solvers` module, notably :mod:`ott.solvers.quadratic` or
:mod:`ott.solvers.nn`.
Expand Down Expand Up @@ -34,6 +34,13 @@ Barycenter Solvers
discrete_barycenter.FixedBarycenter
discrete_barycenter.SinkhornBarycenterOutput

Other Solvers
-------------
.. autosummary::
:toctree: _autosummary

univariate.UnivariateSolver

Sinkhorn Acceleration
---------------------
.. autosummary::
Expand Down
1 change: 1 addition & 0 deletions docs/solvers/quadratic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Gromov-Wasserstein Solvers
gromov_wasserstein.GWOutput
gromov_wasserstein_lr.LRGromovWasserstein
gromov_wasserstein_lr.LRGWOutput
lower_bound.LowerBoundSolver


Barycenter Solvers
Expand Down
1 change: 1 addition & 0 deletions docs/spelling/technical.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ subpopulation
subpopulations
subsample
subsampled
subsampling
thresholding
transcriptome
undirected
Expand Down
39 changes: 30 additions & 9 deletions src/ott/initializers/quadratic/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ def __call__(

n, m = quad_prob.geom_xx.shape[0], quad_prob.geom_yy.shape[0]
geom = self._create_geometry(quad_prob, **kwargs)
assert geom.shape == (n, m), f"Expected geometry of shape `{n, m}`, " \
f"found `{geom.shape}`."
assert geom.shape == (n, m), (
f"Expected geometry of shape `{n, m}`, "
f"found `{geom.shape}`."
)
return linear_problem.LinearProblem(
geom,
a=quad_prob.a,
b=quad_prob.b,
tau_a=quad_prob.tau_a,
tau_b=quad_prob.tau_b
tau_b=quad_prob.tau_b,
)

@abc.abstractmethod
Expand Down Expand Up @@ -88,7 +90,7 @@ def tree_unflatten( # noqa: D102


class QuadraticInitializer(BaseQuadraticInitializer):
r"""Initialize a linear problem locally around :math:`ab^T` initializer.
r"""Initialize a linear problem locally around a selected coupling.
If the problem is balanced (``tau_a = 1`` and ``tau_b = 1``),
the equation of the cost follows eq. 6, p. 1 of :cite:`peyre:16`.
Expand All @@ -113,19 +115,29 @@ class QuadraticInitializer(BaseQuadraticInitializer):
.. math::
\text{marginal_dep_term} + \text{left}_x(\text{cost_xx}) P
\text{right}_y(\text{cost_yy}) + \text{unbalanced_correction}
\text{right}_y(\text{cost_yy}) + \text{unbalanced_correction}
When working with the fused problem, a linear term is added to the cost
matrix: `cost_matrix` += `fused_penalty` * `geom_xy.cost_matrix`
Args:
init_coupling: The coupling to use for initialization. If :obj:`None`,
defaults to the product coupling :math:`ab^T`.
"""

def __init__(
self, init_coupling: Optional[jnp.ndarray] = None, **kwargs: Any
):
super().__init__(**kwargs)
self.init_coupling = init_coupling

def _create_geometry(
self,
quad_prob: "quadratic_problem.QuadraticProblem",
*,
epsilon: float,
relative_epsilon: Optional[bool] = None,
**kwargs: Any
**kwargs: Any,
) -> geometry.Geometry:
"""Compute initial geometry for linearization.
Expand All @@ -139,15 +151,21 @@ def _create_geometry(
The initial geometry used to initialize the linearized problem.
"""
from ott.problems.quadratic import quadratic_problem

del kwargs

marginal_cost = quad_prob.marginal_dependent_cost(quad_prob.a, quad_prob.b)
geom_xx, geom_yy = quad_prob.geom_xx, quad_prob.geom_yy

h1, h2 = quad_prob.quad_loss
tmp1 = quadratic_problem.apply_cost(geom_xx, quad_prob.a, axis=1, fn=h1)
tmp2 = quadratic_problem.apply_cost(geom_yy, quad_prob.b, axis=1, fn=h2)
tmp = jnp.outer(tmp1, tmp2)
if self.init_coupling is None:
tmp1 = quadratic_problem.apply_cost(geom_xx, quad_prob.a, axis=1, fn=h1)
tmp2 = quadratic_problem.apply_cost(geom_yy, quad_prob.b, axis=1, fn=h2)
tmp = jnp.outer(tmp1, tmp2)
else:
tmp1 = h1.func(geom_xx.cost_matrix)
tmp2 = h2.func(geom_yy.cost_matrix)
tmp = tmp1 @ self.init_coupling @ tmp2.T

if quad_prob.is_balanced:
cost_matrix = marginal_cost.cost_matrix - tmp
Expand All @@ -170,3 +188,6 @@ def _create_geometry(
epsilon=epsilon,
relative_epsilon=relative_epsilon
)

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return [self.init_coupling], self._kwargs
4 changes: 3 additions & 1 deletion src/ott/solvers/linear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
lr_utils,
sinkhorn,
sinkhorn_lr,
univariate,
)
from ._solve import solve

__all__ = [
"acceleration", "continuous_barycenter", "discrete_barycenter",
"implicit_differentiation", "lr_utils", "sinkhorn", "sinkhorn_lr", "solve"
"implicit_differentiation", "lr_utils", "sinkhorn", "sinkhorn_lr", "solve",
"univariate"
]
2 changes: 1 addition & 1 deletion src/ott/solvers/linear/implicit_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
Solver_t = Callable[[LinOp_t, jnp.ndarray, Optional[LinOp_t], bool],
jnp.ndarray]

__all__ = ["ImplicitDiff"]
__all__ = ["ImplicitDiff", "solve_jax_cg"]


@utils.register_pytree_node
Expand Down
2 changes: 1 addition & 1 deletion src/ott/solvers/linear/lineax_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_T = TypeVar("_T")
_FlatPyTree = tuple[list[_T], jtu.PyTreeDef]

__all__ = ["CustomTransposeLinearOperator"]
__all__ = ["CustomTransposeLinearOperator", "solve_lineax"]


class CustomTransposeLinearOperator(lx.FunctionLinearOperator):
Expand Down
106 changes: 106 additions & 0 deletions src/ott/solvers/linear/univariate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Literal, Optional

import jax
import jax.numpy as jnp

from ott.geometry import costs

__all__ = ["UnivariateSolver"]


@jax.tree_util.register_pytree_node_class
class UnivariateSolver:
r"""1-D OT solver.
.. warning::
This solver assumes uniform marginals, a non-uniform marginal solver
is coming soon.
Computes the 1-Dimensional optimal transport distance between two histograms.
Args:
sort_fn: The sorting function. If :obj:`None`,
use :func:`hard-sorting <jax.numpy.sort>`.
cost_fn: The cost function for transport. If :obj:`None`, defaults to
:class:`PNormP(2) <ott.geometry.costs.PNormP>`.
method: The method used for computing the distance on the line. Options
currently supported are:
- `'subsample'` - Take a stratified sub-sample of the distances.
- `'quantile'` - Take equally spaced quantiles of the distances.
- `'equal'` - No subsampling is performed, requires distributions to have
the same number of points.
n_subsamples: The number of samples to draw for the "quantile" or
"subsample" methods.
"""

def __init__(
self,
sort_fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None,
cost_fn: Optional[costs.CostFn] = None,
method: Literal["subsample", "quantile", "equal"] = "subsample",
n_subsamples: int = 100,
):
self.sort_fn = jnp.sort if sort_fn is None else sort_fn
self.cost_fn = costs.PNormP(2) if cost_fn is None else cost_fn
self.method = method
self.n_subsamples = n_subsamples

def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Computes the Univariate OT Distance between `x` and `y`.
Args:
x: The first distribution of shape ``[n,]`` or ``[n, 1]``.
y: The second distribution of shape ``[m,]`` or ``[m, 1]``.
Returns:
The OT distance.
"""
x = x.squeeze(-1) if x.ndim == 2 else x
y = y.squeeze(-1) if y.ndim == 2 else y
assert x.ndim == 1, x.ndim
assert y.ndim == 1, y.ndim

n, m = x.shape[0], y.shape[0]

if self.method == "equal":
xx, yy = self.sort_fn(x), self.sort_fn(y)
elif self.method == "subsample":
assert self.n_subsamples <= n, (self.n_subsamples, x)
assert self.n_subsamples <= m, (self.n_subsamples, y)

sorted_x, sorted_y = self.sort_fn(x), self.sort_fn(y)
xx = sorted_x[jnp.linspace(0, n, num=self.n_subsamples).astype(int)]
yy = sorted_y[jnp.linspace(0, m, num=self.n_subsamples).astype(int)]
elif self.method == "quantile":
sorted_x, sorted_y = self.sort_fn(x), self.sort_fn(y)
xx = jnp.quantile(sorted_x, q=jnp.linspace(0, 1, self.n_subsamples))
yy = jnp.quantile(sorted_y, q=jnp.linspace(0, 1, self.n_subsamples))
else:
raise NotImplementedError(f"Method `{self.method}` not implemented.")

# re-scale when subsampling
return self.cost_fn.pairwise(xx, yy) * (n / xx.shape[0])

def tree_flatten(self): # noqa: D102
aux_data = vars(self).copy()
return [], aux_data

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(*children, **aux_data)
7 changes: 6 additions & 1 deletion src/ott/solvers/quadratic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import gromov_wasserstein, gromov_wasserstein_lr, gw_barycenter
from . import (
gromov_wasserstein,
gromov_wasserstein_lr,
gw_barycenter,
lower_bound,
)
from ._solve import solve
88 changes: 88 additions & 0 deletions src/ott/solvers/quadratic/lower_bound.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional

import jax

from ott.geometry import geometry
from ott.problems.quadratic import quadratic_problem
from ott.solvers import linear
from ott.solvers.linear import sinkhorn, univariate

__all__ = ["LowerBoundSolver"]


@jax.tree_util.register_pytree_node_class
class LowerBoundSolver:
"""Lower bound OT solver :cite:`memoli:11`.
.. warning::
As implemented, this solver assumes uniform marginals,
non-uniform marginal solver coming soon!
Computes the first lower bound distance from :cite:`memoli:11`, def. 6.1.
there is an uneven number of points in the distributions, then we perform a
stratified subsample of the distribution of distances to approximate
the Wasserstein distance between the local distributions of distances.
Args:
epsilon: Entropy regularization for the resulting linear problem.
kwargs: Keyword arguments for
:class:`~ott.solvers.linear.univariate.UnivariateSolver`.
"""

def __init__(
self,
epsilon: Optional[float] = None,
**kwargs: Any,
):
self.epsilon = epsilon
self.univariate_solver = univariate.UnivariateSolver(**kwargs)

def __call__(
self,
prob: quadratic_problem.QuadraticProblem,
**kwargs: Any,
) -> sinkhorn.SinkhornOutput:
"""Run the Histogram transport solver.
Args:
prob: Quadratic OT problem.
kwargs: Keyword arguments for :func:`~ott.solvers.linear.solve`.
Returns:
The Histogram transport output.
"""
dists_xx = prob.geom_xx.cost_matrix
dists_yy = prob.geom_yy.cost_matrix
cost_xy = jax.vmap(
jax.vmap(self.univariate_solver, in_axes=(0, None), out_axes=-1),
in_axes=(None, 0),
out_axes=-1,
)(dists_xx, dists_yy)

geom_xy = geometry.Geometry(cost_matrix=cost_xy, epsilon=self.epsilon)

return linear.solve(geom_xy, **kwargs)

def tree_flatten(self): # noqa: D102
return [self.epsilon, self.univariate_solver], {}

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
epsilon, solver = children
obj = cls(epsilon, **aux_data)
obj.univariate_solver = solver
return obj
Loading

0 comments on commit 54d3b63

Please sign in to comment.