-
Notifications
You must be signed in to change notification settings - Fork 84
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Histogram Transport Implementation (#444)
* test commit * added HistogramTransport distance * added tests * removed `test_file` * renamed `ArbitraryTransportInitializer` to `FixedCouplingInitializer` * softness -> epsilon_1d * epsilon_1d=0.0 for hard sorting * epsilon_1d=0.0 for hard sorting * + cost_fn for 1d_wasserstein * extracted `wasserstein_1d` * removed `p` argument * removed `match` statement * removed `HTOutput` and `HTState` classes * updated `ht_test` * fixed `ht_test` * `wasserstein_1d` -> `univariate` * fixed indentation issues * removed `FixedCouplingInitializer` class * changed `QuadraticInitializer` documentation * added `solvers.univariate` to documentation * minor edits to `univariate.py` * fixed `UnivariateSolver` docstring * many updates to `univariate.py` * docstring edits to `histogram_transport` * added missing type of `univariate`'s `__call__` * added pytree class to HT and Univariate solvers * doc changes, code refactoring * added memoli citation * parametrized `ht_test` * fixed spelling * readded min/max iterations to `univariate` * fixed underline * fixed indentations * added `init_coupling` as a child * type ascription for `**kwargs` * fixed `warning::`, I think? * docstring edits of `univariate.py` * `self.cost_fn` to oneliner * fixed `univariate` children * fixing `.rst` stuff * editing `univariate.py` docs * slightly more documentation * fixed `ht_test` error * Use `sort_fn` * Fewer tests * Add shape checks * Add diff tests * Re-scale when subsampling * Update grad test * Rename solver * Fix indentation * Refer to the definition in the LowerBoundSolver --------- Co-authored-by: Michal Klein <[email protected]>
- Loading branch information
1 parent
6285372
commit 54d3b63
Showing
14 changed files
with
413 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright OTT-JAX | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Callable, Literal, Optional | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
from ott.geometry import costs | ||
|
||
__all__ = ["UnivariateSolver"] | ||
|
||
|
||
@jax.tree_util.register_pytree_node_class | ||
class UnivariateSolver: | ||
r"""1-D OT solver. | ||
.. warning:: | ||
This solver assumes uniform marginals, a non-uniform marginal solver | ||
is coming soon. | ||
Computes the 1-Dimensional optimal transport distance between two histograms. | ||
Args: | ||
sort_fn: The sorting function. If :obj:`None`, | ||
use :func:`hard-sorting <jax.numpy.sort>`. | ||
cost_fn: The cost function for transport. If :obj:`None`, defaults to | ||
:class:`PNormP(2) <ott.geometry.costs.PNormP>`. | ||
method: The method used for computing the distance on the line. Options | ||
currently supported are: | ||
- `'subsample'` - Take a stratified sub-sample of the distances. | ||
- `'quantile'` - Take equally spaced quantiles of the distances. | ||
- `'equal'` - No subsampling is performed, requires distributions to have | ||
the same number of points. | ||
n_subsamples: The number of samples to draw for the "quantile" or | ||
"subsample" methods. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
sort_fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, | ||
cost_fn: Optional[costs.CostFn] = None, | ||
method: Literal["subsample", "quantile", "equal"] = "subsample", | ||
n_subsamples: int = 100, | ||
): | ||
self.sort_fn = jnp.sort if sort_fn is None else sort_fn | ||
self.cost_fn = costs.PNormP(2) if cost_fn is None else cost_fn | ||
self.method = method | ||
self.n_subsamples = n_subsamples | ||
|
||
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float: | ||
"""Computes the Univariate OT Distance between `x` and `y`. | ||
Args: | ||
x: The first distribution of shape ``[n,]`` or ``[n, 1]``. | ||
y: The second distribution of shape ``[m,]`` or ``[m, 1]``. | ||
Returns: | ||
The OT distance. | ||
""" | ||
x = x.squeeze(-1) if x.ndim == 2 else x | ||
y = y.squeeze(-1) if y.ndim == 2 else y | ||
assert x.ndim == 1, x.ndim | ||
assert y.ndim == 1, y.ndim | ||
|
||
n, m = x.shape[0], y.shape[0] | ||
|
||
if self.method == "equal": | ||
xx, yy = self.sort_fn(x), self.sort_fn(y) | ||
elif self.method == "subsample": | ||
assert self.n_subsamples <= n, (self.n_subsamples, x) | ||
assert self.n_subsamples <= m, (self.n_subsamples, y) | ||
|
||
sorted_x, sorted_y = self.sort_fn(x), self.sort_fn(y) | ||
xx = sorted_x[jnp.linspace(0, n, num=self.n_subsamples).astype(int)] | ||
yy = sorted_y[jnp.linspace(0, m, num=self.n_subsamples).astype(int)] | ||
elif self.method == "quantile": | ||
sorted_x, sorted_y = self.sort_fn(x), self.sort_fn(y) | ||
xx = jnp.quantile(sorted_x, q=jnp.linspace(0, 1, self.n_subsamples)) | ||
yy = jnp.quantile(sorted_y, q=jnp.linspace(0, 1, self.n_subsamples)) | ||
else: | ||
raise NotImplementedError(f"Method `{self.method}` not implemented.") | ||
|
||
# re-scale when subsampling | ||
return self.cost_fn.pairwise(xx, yy) * (n / xx.shape[0]) | ||
|
||
def tree_flatten(self): # noqa: D102 | ||
aux_data = vars(self).copy() | ||
return [], aux_data | ||
|
||
@classmethod | ||
def tree_unflatten(cls, aux_data, children): # noqa: D102 | ||
return cls(*children, **aux_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright OTT-JAX | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any, Optional | ||
|
||
import jax | ||
|
||
from ott.geometry import geometry | ||
from ott.problems.quadratic import quadratic_problem | ||
from ott.solvers import linear | ||
from ott.solvers.linear import sinkhorn, univariate | ||
|
||
__all__ = ["LowerBoundSolver"] | ||
|
||
|
||
@jax.tree_util.register_pytree_node_class | ||
class LowerBoundSolver: | ||
"""Lower bound OT solver :cite:`memoli:11`. | ||
.. warning:: | ||
As implemented, this solver assumes uniform marginals, | ||
non-uniform marginal solver coming soon! | ||
Computes the first lower bound distance from :cite:`memoli:11`, def. 6.1. | ||
there is an uneven number of points in the distributions, then we perform a | ||
stratified subsample of the distribution of distances to approximate | ||
the Wasserstein distance between the local distributions of distances. | ||
Args: | ||
epsilon: Entropy regularization for the resulting linear problem. | ||
kwargs: Keyword arguments for | ||
:class:`~ott.solvers.linear.univariate.UnivariateSolver`. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
epsilon: Optional[float] = None, | ||
**kwargs: Any, | ||
): | ||
self.epsilon = epsilon | ||
self.univariate_solver = univariate.UnivariateSolver(**kwargs) | ||
|
||
def __call__( | ||
self, | ||
prob: quadratic_problem.QuadraticProblem, | ||
**kwargs: Any, | ||
) -> sinkhorn.SinkhornOutput: | ||
"""Run the Histogram transport solver. | ||
Args: | ||
prob: Quadratic OT problem. | ||
kwargs: Keyword arguments for :func:`~ott.solvers.linear.solve`. | ||
Returns: | ||
The Histogram transport output. | ||
""" | ||
dists_xx = prob.geom_xx.cost_matrix | ||
dists_yy = prob.geom_yy.cost_matrix | ||
cost_xy = jax.vmap( | ||
jax.vmap(self.univariate_solver, in_axes=(0, None), out_axes=-1), | ||
in_axes=(None, 0), | ||
out_axes=-1, | ||
)(dists_xx, dists_yy) | ||
|
||
geom_xy = geometry.Geometry(cost_matrix=cost_xy, epsilon=self.epsilon) | ||
|
||
return linear.solve(geom_xy, **kwargs) | ||
|
||
def tree_flatten(self): # noqa: D102 | ||
return [self.epsilon, self.univariate_solver], {} | ||
|
||
@classmethod | ||
def tree_unflatten(cls, aux_data, children): # noqa: D102 | ||
epsilon, solver = children | ||
obj = cls(epsilon, **aux_data) | ||
obj.univariate_solver = solver | ||
return obj |
Oops, something went wrong.