Skip to content

Commit

Permalink
Feature/ulrgw (#410)
Browse files Browse the repository at this point in the history
* Remove low-rank from GromovWasserstein solver

* First skeleton loop

* Add LRGW implementation

* Add ULFGW

* Revert change

* Add a TODO

* Fix `grad_g` in the fused case

* Update docs

* Remove duplicate citation

* Fix cost for the fused case

* Fix bugs in TI

* Remove unused import

* Change way array extraction in LR init works

* Disallow LR in the old GW solver

* Disallow LR in old GW class

* Remove `is_entropic` property

* Use `jnp.linalg.norm`

* Simplify initializers in GW

* Simplify initializer creation for low-rank

* Remove temporary name

* Fix norms

* Fix linkcheck

* Remove old initializers test

* Fix more initializer tests

* Remove `LRQuadraticInitializer`, `reg_ot_cost -> reg_gw_cost`

* `host_callback` -> `io_callback`

* Fix more initializers tests

* Fix more tests

* Remove initializer mention from the docs

* Remove mention of LR initializer

* Start incorporating GWLoss

* Simplify reg GW cost computation

* Finish `primal_cost`

* Don't calculate unbal. grads in balanced case

* Fix `primal_cost` in balanced case

* Update GW LR notebook

* Convert quad problem to LR if possible

* Convert quad problem to LR if possible

* Regenerate GWLR Sinkhorn

* Regenerate `LRSinkhorn`

* [ci skip] Fix linter

* Fix convergence metric

* Undo TODO

* Fix factor

* Regenerate notebooks

* Add tests
  • Loading branch information
michalk8 authored Sep 8, 2023
1 parent 21d3627 commit a18c16c
Show file tree
Hide file tree
Showing 16 changed files with 1,143 additions and 394 deletions.
4 changes: 1 addition & 3 deletions docs/initializers/quadratic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@ ott.initializers.quadratic

Two families of initializers are described in the following to provide the first
iteration of Gromov-Wasserstein solvers. They apply respectively to the simpler
GW entropic solver :cite:`peyre:16` and its low-rank formulation
:cite:`scetbon:22`.
GW entropic solver :cite:`peyre:16`.

Gromov-Wasserstein Initializers
-------------------------------
.. autosummary::
:toctree: _autosummary

initializers.QuadraticInitializer
initializers.LRQuadraticInitializer
3 changes: 3 additions & 0 deletions docs/solvers/quadratic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ Gromov-Wasserstein Solvers
gromov_wasserstein.solve
gromov_wasserstein.GromovWasserstein
gromov_wasserstein.GWOutput
gromov_wasserstein_lr.LRGromovWasserstein
gromov_wasserstein_lr.LRGWOutput


Barycenter Solvers
------------------
Expand Down
52 changes: 16 additions & 36 deletions docs/tutorials/notebooks/GWLRSinkhorn.ipynb

Large diffs are not rendered by default.

29 changes: 13 additions & 16 deletions docs/tutorials/notebooks/LRSinkhorn.ipynb

Large diffs are not rendered by default.

34 changes: 14 additions & 20 deletions src/ott/initializers/linear/initializers_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein
from ott.solvers.quadratic import gromov_wasserstein_lr

Problem_t = Union["linear_problem.LinearProblem",
"quadratic_problem.QuadraticProblem"]
Expand Down Expand Up @@ -127,7 +127,7 @@ def init_g(
def from_solver(
cls,
solver: Union["sinkhorn_lr.LRSinkhorn",
"gromov_wasserstein.GromovWasserstein"],
"gromov_wasserstein_lr.LRGromovWasserstein"],
*,
kind: Literal["random", "rank2", "k-means", "generalized-k-means"],
**kwargs: Any,
Expand All @@ -140,22 +140,14 @@ def from_solver(
kwargs: Keyword arguments when creating the initializer.
Returns:
The low-rank initializer.
Low-rank initializer.
"""
from ott.solvers.quadratic import gromov_wasserstein

if isinstance(solver, gromov_wasserstein.GromovWasserstein):
assert solver.is_low_rank, "GW solver is not low-rank."
lin_sol = solver.linear_ot_solver
else:
lin_sol = solver

rank = solver.rank
sinkhorn_kwargs = {
"norm_error": lin_sol._norm_error,
"lse_mode": lin_sol.lse_mode,
"implicit_diff": lin_sol.implicit_diff,
"use_danskin": lin_sol.use_danskin
"norm_error": solver._norm_error,
"lse_mode": solver.lse_mode,
"implicit_diff": solver.implicit_diff,
"use_danskin": solver.use_danskin
}

if kind == "random":
Expand Down Expand Up @@ -373,9 +365,7 @@ def __init__(
self._sinkhorn_kwargs = {} if sinkhorn_kwargs is None else sinkhorn_kwargs

@staticmethod
def _extract_array(
geom: Union[pointcloud.PointCloud, low_rank.LRCGeometry], *, first: bool
) -> jnp.ndarray:
def _extract_array(geom: geometry.Geometry, *, first: bool) -> jnp.ndarray:
if isinstance(geom, pointcloud.PointCloud):
return geom.x if first else geom.y
if isinstance(geom, low_rank.LRCGeometry):
Expand Down Expand Up @@ -407,15 +397,19 @@ def _compute_factor(
)

if isinstance(ot_prob, quadratic_problem.QuadraticProblem):
geom = ot_prob.geom_xx if which == "q" else ot_prob.geom_yy
if ot_prob.geom_xy is not None and ot_prob.fused_penalty >= 1.0:
# prefer the linear term if it has a higher weight
geom = ot_prob.geom_xy
else:
geom = ot_prob.geom_xx if which == "q" else ot_prob.geom_yy
else:
geom = ot_prob.geom
arr = self._extract_array(geom, first=which == "q")
marginals = ot_prob.a if which == "q" else ot_prob.b

centroids = fn(arr, self.rank, rng=rng).centroids
geom = pointcloud.PointCloud(
arr, centroids, epsilon=0.1, scale_cost="max_cost"
arr, centroids, epsilon=1e-1, scale_cost="max_cost"
)

prob = linear_problem.LinearProblem(geom, marginals, init_g)
Expand Down
56 changes: 1 addition & 55 deletions src/ott/initializers/quadratic/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
from ott.geometry import geometry

if TYPE_CHECKING:
from ott.initializers.linear import initializers_lr
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem

__all__ = ["QuadraticInitializer", "LRQuadraticInitializer"]
__all__ = ["BaseQuadraticInitializer", "QuadraticInitializer"]


@jax.tree_util.register_pytree_node_class
Expand Down Expand Up @@ -171,56 +170,3 @@ def _create_geometry(
epsilon=epsilon,
relative_epsilon=relative_epsilon
)


class LRQuadraticInitializer(BaseQuadraticInitializer):
"""Wrapper that wraps low-rank Sinkhorn initializers.
Args:
lr_linear_initializer: Low-rank linear initializer.
"""

def __init__(self, lr_linear_initializer: "initializers_lr.LRInitializer"):
super().__init__()
self._linear_lr_initializer = lr_linear_initializer

def _create_geometry(
self,
quad_prob: "quadratic_problem.QuadraticProblem",
relative_epsilon: Optional[bool] = False,
**kwargs: Any
) -> geometry.Geometry:
"""Compute initial geometry for linearization.
Args:
quad_prob: Quadratic OT problem.
relative_epsilon: Whether to use relative epsilon in the geometry.
kwargs: Keyword arguments for
:meth:`~ott.initializers.linear.initializers_lr.LRInitializer.__call__`.
Returns:
The initial geometry used to initialize a linear problem.
"""
from ott.solvers.linear import sinkhorn_lr

q, r, g = self._linear_lr_initializer(quad_prob, **kwargs)
tmp_out = sinkhorn_lr.LRSinkhornOutput(
q=q,
r=r,
g=g,
costs=None,
errors=None,
ot_prob=None,
epsilon=None,
)

return quad_prob.update_lr_geom(tmp_out, relative_epsilon=relative_epsilon)

@property
def rank(self) -> int:
"""Rank of the transport matrix factorization."""
return self._linear_lr_initializer.rank

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
children, aux_data = super().tree_flatten()
return children + [self._linear_lr_initializer], aux_data
6 changes: 3 additions & 3 deletions src/ott/math/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"norm",
"kl",
"gen_kl",
"js",
"gen_js",
"logsumexp",
"softmin",
"barycentric_projection",
Expand Down Expand Up @@ -116,9 +116,9 @@ def gen_kl(p: jnp.ndarray, q: jnp.ndarray) -> float:


# TODO(michalk8): add axis argument
def js(p: jnp.ndarray, q: jnp.ndarray, c: float = 0.5) -> float:
def gen_js(p: jnp.ndarray, q: jnp.ndarray, c: float = 0.5) -> float:
"""Jensen-Shannon divergence."""
return c * (kl(p, q) + kl(q, p))
return c * (gen_kl(p, q) + gen_kl(q, p))


@functools.partial(jax.custom_jvp, nondiff_argnums=(1, 2, 4))
Expand Down
70 changes: 37 additions & 33 deletions src/ott/solvers/linear/lr_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import jax.scipy as jsp

from ott.math import fixed_point_loop
from ott.math import unbalanced_functions as uf
from ott.problems.linear import linear_problem

__all__ = ["unbalanced_dykstra_lse", "unbalanced_dykstra_kernel"]
Expand All @@ -36,8 +35,8 @@ class State(NamedTuple): # noqa: D101
class Constants(NamedTuple): # noqa: D101
a: jnp.ndarray
b: jnp.ndarray
tau_a: float
tau_b: float
rho_a: float
rho_b: float
supp_a: Optional[jnp.ndarray] = None
supp_b: Optional[jnp.ndarray] = None

Expand Down Expand Up @@ -105,16 +104,17 @@ def body_fn(
iteration: int, const: Constants, state: State, compute_error: bool
) -> State:
log_a, log_b = jnp.log(const.a), jnp.log(const.b)
rho_a, rho_b = const.rho_a, const.rho_b

if translation_invariant:
rho_a = uf.rho(1.0 / gamma, const.tau_a)
rho_b = uf.rho(1.0 / gamma, const.tau_b)
c_a = _get_ratio(const.rho_a, gamma)
c_b = _get_ratio(const.rho_b, gamma)

if translation_invariant:
lam_a, lam_b = compute_lambdas(const, state, gamma, g=c_g, lse_mode=True)

u1 = const.tau_a * (log_a - _softm(state.v1, c_q, axis=1))
u1 = c_a * (log_a - _softm(state.v1, c_q, axis=1))
u1 = u1 - lam_a / ((1.0 / gamma) + rho_a)
u2 = const.tau_b * (log_b - _softm(state.v2, c_r, axis=1))
u2 = c_b * (log_b - _softm(state.v2, c_r, axis=1))
u2 = u2 - lam_b / ((1.0 / gamma) + rho_b)

state_lam = State(
Expand All @@ -129,8 +129,8 @@ def body_fn(

g_trans = gamma * (lam_a + lam_b) + c_g
else:
u1 = const.tau_a * (log_a - _softm(state.v1, c_q, axis=1))
u2 = const.tau_b * (log_b - _softm(state.v2, c_r, axis=1))
u1 = c_a * (log_a - _softm(state.v1, c_q, axis=1))
u2 = c_b * (log_b - _softm(state.v2, c_r, axis=1))

v1_trans = _softm(u1, c_q, axis=0)
v2_trans = _softm(u2, c_r, axis=0)
Expand All @@ -155,8 +155,8 @@ def body_fn(
constants = Constants(
a=ot_prob.a,
b=ot_prob.b,
tau_a=ot_prob.tau_a,
tau_b=ot_prob.tau_b,
rho_a=_rho(ot_prob.tau_a),
rho_b=_rho(ot_prob.tau_b),
supp_a=ot_prob.a > 0,
supp_b=ot_prob.b > 0,
)
Expand Down Expand Up @@ -242,18 +242,16 @@ def cond_fn(
def body_fn(
iteration: int, const: Constants, state: State, compute_error: bool
) -> State:
if translation_invariant:
rho_a = uf.rho(1.0 / gamma, const.tau_a)
rho_b = uf.rho(1.0 / gamma, const.tau_b)
c_a = const.tau_a
c_b = const.tau_b
c_a = _get_ratio(const.rho_a, gamma)
c_b = _get_ratio(const.rho_b, gamma)

if translation_invariant:
lam_a, lam_b = compute_lambdas(const, state, gamma, g=k_g, lse_mode=False)

u1 = jnp.where(const.supp_a, (const.a / (k_q @ state.v1)) ** c_a, 0.0)
u1 = u1 * jnp.exp(-lam_a / ((1.0 / gamma) + rho_a))
u1 = u1 * jnp.exp(-lam_a / ((1.0 / gamma) + const.rho_a))
u2 = jnp.where(const.supp_b, (const.b / (k_r @ state.v2)) ** c_b, 0.0)
u2 = u2 * jnp.exp(-lam_b / ((1.0 / gamma) + rho_b))
u2 = u2 * jnp.exp(-lam_b / ((1.0 / gamma) + const.rho_b))

state_lam = State(
v1=state.v1, v2=state.v2, u1=u1, u2=u2, g=state.g, err=state.err
Expand All @@ -268,12 +266,8 @@ def body_fn(
k_trans = jnp.exp(gamma * (lam_a + lam_b)) * k_g
g = (k_trans * v1_trans * v2_trans) ** (1.0 / 3.0)
else:
u1 = jnp.where(
const.supp_a, (const.a / (k_q @ state.v1)) ** const.tau_a, 0.0
)
u2 = jnp.where(
const.supp_b, (const.b / (k_r @ state.v2)) ** const.tau_b, 0.0
)
u1 = jnp.where(const.supp_a, (const.a / (k_q @ state.v1)) ** c_a, 0.0)
u2 = jnp.where(const.supp_b, (const.b / (k_r @ state.v2)) ** c_b, 0.0)

v1_trans = k_q.T @ u1
v2_trans = k_r.T @ u2
Expand All @@ -298,8 +292,8 @@ def body_fn(
constants = Constants(
a=ot_prob.a,
b=ot_prob.b,
tau_a=ot_prob.tau_a,
tau_b=ot_prob.tau_b,
rho_a=_rho(ot_prob.tau_a),
rho_b=_rho(ot_prob.tau_b),
supp_a=ot_prob.a > 0.0,
supp_b=ot_prob.b > 0.0,
)
Expand Down Expand Up @@ -328,8 +322,8 @@ def compute_lambdas(
) -> Tuple[float, float]:
"""TODO."""
gamma_inv = 1.0 / gamma
rho_a = uf.rho(gamma_inv, const.tau_a)
rho_b = uf.rho(gamma_inv, const.tau_b)
rho_a = const.rho_a
rho_b = const.rho_b

if lse_mode:
num_1 = jsp.special.logsumexp((-gamma_inv / rho_a) * state.u1, b=const.a)
Expand All @@ -338,8 +332,8 @@ def compute_lambdas(
const_1 = num_1 - den
const_2 = num_2 - den

ratio_1 = const.tau_a # rho_a / (rho_a + gamma_inv)
ratio_2 = const.tau_b # rho_b / (rho_b + gamma_inv)
ratio_1 = _get_ratio(rho_a, gamma)
ratio_2 = _get_ratio(rho_b, gamma)
harmonic = 1.0 / (1.0 - (ratio_1 * ratio_2))
lam_1 = harmonic * gamma_inv * ratio_1 * (const_1 - ratio_2 * const_2)
lam_2 = harmonic * gamma_inv * ratio_2 * (const_2 - ratio_1 * const_1)
Expand All @@ -359,9 +353,19 @@ def compute_lambdas(
const_1 = jnp.log(num_1 / den)
const_2 = jnp.log(num_2 / den)

ratio_1 = const.tau_a # rho_a / (rho_a + gamma_inv)
ratio_2 = const.tau_b # rho_b / (rho_b + gamma_inv)
ratio_1 = _get_ratio(rho_a, gamma)
ratio_2 = _get_ratio(rho_b, gamma)
harmonic = 1.0 / (1.0 - (ratio_1 * ratio_2))
lam_1 = harmonic * gamma_inv * ratio_1 * (const_1 - ratio_2 * const_2)
lam_2 = harmonic * gamma_inv * ratio_2 * (const_2 - ratio_1 * const_1)
return lam_1, lam_2


def _rho(tau: float) -> float:
tau = jnp.asarray(tau) # avoid division by 0 in Python, get NaN instead
return tau / (1.0 - tau)


def _get_ratio(rho: float, gamma: float) -> float:
gamma_inv = 1.0 / gamma
return jnp.where(jnp.isfinite(rho), rho / (rho + gamma_inv), 1.0)
Loading

0 comments on commit a18c16c

Please sign in to comment.