Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ulrgw #410

Merged
merged 49 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
19756c1
Remove low-rank from GromovWasserstein solver
michalk8 Aug 2, 2023
a079761
First skeleton loop
michalk8 Aug 2, 2023
0d0c85f
Add LRGW implementation
michalk8 Aug 3, 2023
1f7f10d
Add ULFGW
michalk8 Aug 3, 2023
07aeedf
Revert change
michalk8 Aug 3, 2023
6d257ff
Add a TODO
michalk8 Aug 3, 2023
bf0a50d
Fix `grad_g` in the fused case
michalk8 Aug 4, 2023
7b8a0f4
Update docs
michalk8 Aug 17, 2023
5855a32
Remove duplicate citation
michalk8 Aug 17, 2023
01bd8e6
Merge branch 'main' into feature/ulrgw
michalk8 Aug 18, 2023
e4a70c3
Fix cost for the fused case
michalk8 Aug 18, 2023
4ecdca0
Fix bugs in TI
michalk8 Aug 18, 2023
04a3e8b
Remove unused import
michalk8 Aug 18, 2023
4fb9ee4
Change way array extraction in LR init works
michalk8 Aug 31, 2023
43d39eb
Disallow LR in the old GW solver
michalk8 Aug 31, 2023
0747c56
Disallow LR in old GW class
michalk8 Aug 31, 2023
83a1c36
Remove `is_entropic` property
michalk8 Aug 31, 2023
b783f3a
Use `jnp.linalg.norm`
michalk8 Aug 31, 2023
0f661cc
Simplify initializers in GW
michalk8 Aug 31, 2023
d79ce31
Simplify initializer creation for low-rank
michalk8 Aug 31, 2023
1be699f
Remove temporary name
michalk8 Aug 31, 2023
87284ae
Fix norms
michalk8 Aug 31, 2023
30f275e
Fix linkcheck
michalk8 Aug 31, 2023
f11e19f
Remove old initializers test
michalk8 Sep 1, 2023
6d4775e
Fix more initializer tests
michalk8 Sep 1, 2023
fdbafe9
Remove `LRQuadraticInitializer`, `reg_ot_cost -> reg_gw_cost`
michalk8 Sep 1, 2023
17b447a
`host_callback` -> `io_callback`
michalk8 Sep 1, 2023
3714057
Fix more initializers tests
michalk8 Sep 1, 2023
0fa6be2
Fix more tests
michalk8 Sep 1, 2023
3843609
Remove initializer mention from the docs
michalk8 Sep 1, 2023
62aad99
Remove mention of LR initializer
michalk8 Sep 1, 2023
4e57972
Start incorporating GWLoss
michalk8 Sep 1, 2023
b312126
Simplify reg GW cost computation
michalk8 Sep 1, 2023
a499102
Finish `primal_cost`
michalk8 Sep 1, 2023
4bcb71b
Don't calculate unbal. grads in balanced case
michalk8 Sep 1, 2023
eea362f
Fix `primal_cost` in balanced case
michalk8 Sep 1, 2023
3f14488
Update GW LR notebook
michalk8 Sep 1, 2023
8207696
Convert quad problem to LR if possible
michalk8 Sep 4, 2023
ba272f0
Convert quad problem to LR if possible
michalk8 Sep 4, 2023
c565ee3
Regenerate GWLR Sinkhorn
michalk8 Sep 5, 2023
f508d5d
Regenerate `LRSinkhorn`
michalk8 Sep 5, 2023
f7f440f
Merge remote-tracking branch 'upstream/main' into feature/ulrgw
michalk8 Sep 5, 2023
f703a55
Merge branch 'feature/ulrgw' of ssh://github.com/michalk8/ott into fe…
michalk8 Sep 5, 2023
ec45825
[ci skip] Fix linter
michalk8 Sep 5, 2023
68d3ce2
Fix convergence metric
michalk8 Sep 7, 2023
afd8ea6
Undo TODO
michalk8 Sep 7, 2023
e006eae
Fix factor
michalk8 Sep 7, 2023
e06f1b0
Regenerate notebooks
michalk8 Sep 7, 2023
f04757f
Add tests
michalk8 Sep 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@
"https://doi.org/10.1137/17M1140431",
"https://doi.org/10.1137/141000439",
"https://doi.org/10.1002/mana.19901470121",
"https://doi.org/10.1145/2516971.2516977",
"https://doi.org/10.1145/2766963",
]

# List of patterns, relative to source directory, that match files and
Expand Down
4 changes: 1 addition & 3 deletions docs/initializers/quadratic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@ ott.initializers.quadratic

Two families of initializers are described in the following to provide the first
iteration of Gromov-Wasserstein solvers. They apply respectively to the simpler
GW entropic solver :cite:`peyre:16` and its low-rank formulation
:cite:`scetbon:22`.
GW entropic solver :cite:`peyre:16`.

Gromov-Wasserstein Initializers
-------------------------------
.. autosummary::
:toctree: _autosummary

initializers.QuadraticInitializer
initializers.LRQuadraticInitializer
3 changes: 3 additions & 0 deletions docs/solvers/quadratic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ Gromov-Wasserstein Solvers
gromov_wasserstein.solve
gromov_wasserstein.GromovWasserstein
gromov_wasserstein.GWOutput
gromov_wasserstein_lr.LRGromovWasserstein
gromov_wasserstein_lr.LRGWOutput


Barycenter Solvers
------------------
Expand Down
50 changes: 15 additions & 35 deletions docs/tutorials/notebooks/GWLRSinkhorn.ipynb

Large diffs are not rendered by default.

34 changes: 14 additions & 20 deletions src/ott/initializers/linear/initializers_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein
from ott.solvers.quadratic import gromov_wasserstein_lr

Problem_t = Union["linear_problem.LinearProblem",
"quadratic_problem.QuadraticProblem"]
Expand Down Expand Up @@ -127,7 +127,7 @@ def init_g(
def from_solver(
cls,
solver: Union["sinkhorn_lr.LRSinkhorn",
"gromov_wasserstein.GromovWasserstein"],
"gromov_wasserstein_lr.LRGromovWasserstein"],
*,
kind: Literal["random", "rank2", "k-means", "generalized-k-means"],
**kwargs: Any,
Expand All @@ -140,22 +140,14 @@ def from_solver(
kwargs: Keyword arguments when creating the initializer.

Returns:
The low-rank initializer.
Low-rank initializer.
"""
from ott.solvers.quadratic import gromov_wasserstein

if isinstance(solver, gromov_wasserstein.GromovWasserstein):
assert solver.is_low_rank, "GW solver is not low-rank."
lin_sol = solver.linear_ot_solver
else:
lin_sol = solver

rank = solver.rank
sinkhorn_kwargs = {
"norm_error": lin_sol._norm_error,
"lse_mode": lin_sol.lse_mode,
"implicit_diff": lin_sol.implicit_diff,
"use_danskin": lin_sol.use_danskin
"norm_error": solver._norm_error,
"lse_mode": solver.lse_mode,
"implicit_diff": solver.implicit_diff,
"use_danskin": solver.use_danskin
}

if kind == "random":
Expand Down Expand Up @@ -373,9 +365,7 @@ def __init__(
self._sinkhorn_kwargs = {} if sinkhorn_kwargs is None else sinkhorn_kwargs

@staticmethod
def _extract_array(
geom: Union[pointcloud.PointCloud, low_rank.LRCGeometry], *, first: bool
) -> jnp.ndarray:
def _extract_array(geom: geometry.Geometry, *, first: bool) -> jnp.ndarray:
if isinstance(geom, pointcloud.PointCloud):
return geom.x if first else geom.y
if isinstance(geom, low_rank.LRCGeometry):
Expand Down Expand Up @@ -407,15 +397,19 @@ def _compute_factor(
)

if isinstance(ot_prob, quadratic_problem.QuadraticProblem):
geom = ot_prob.geom_xx if which == "q" else ot_prob.geom_yy
if ot_prob.geom_xy is not None and ot_prob.fused_penalty >= 1.0:
# prefer the linear term if it has a higher weight
geom = ot_prob.geom_xy
else:
geom = ot_prob.geom_xx if which == "q" else ot_prob.geom_yy
else:
geom = ot_prob.geom
arr = self._extract_array(geom, first=which == "q")
marginals = ot_prob.a if which == "q" else ot_prob.b

centroids = fn(arr, self.rank, rng=rng).centroids
geom = pointcloud.PointCloud(
arr, centroids, epsilon=0.1, scale_cost="max_cost"
arr, centroids, epsilon=1e-1, scale_cost="max_cost"
)

prob = linear_problem.LinearProblem(geom, marginals, init_g)
Expand Down
56 changes: 1 addition & 55 deletions src/ott/initializers/quadratic/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
from ott.geometry import geometry

if TYPE_CHECKING:
from ott.initializers.linear import initializers_lr
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem

__all__ = ["QuadraticInitializer", "LRQuadraticInitializer"]
__all__ = ["BaseQuadraticInitializer", "QuadraticInitializer"]


@jax.tree_util.register_pytree_node_class
Expand Down Expand Up @@ -171,56 +170,3 @@ def _create_geometry(
epsilon=epsilon,
relative_epsilon=relative_epsilon
)


class LRQuadraticInitializer(BaseQuadraticInitializer):
"""Wrapper that wraps low-rank Sinkhorn initializers.

Args:
lr_linear_initializer: Low-rank linear initializer.
"""

def __init__(self, lr_linear_initializer: "initializers_lr.LRInitializer"):
super().__init__()
self._linear_lr_initializer = lr_linear_initializer

def _create_geometry(
self,
quad_prob: "quadratic_problem.QuadraticProblem",
relative_epsilon: Optional[bool] = False,
**kwargs: Any
) -> geometry.Geometry:
"""Compute initial geometry for linearization.

Args:
quad_prob: Quadratic OT problem.
relative_epsilon: Whether to use relative epsilon in the geometry.
kwargs: Keyword arguments for
:meth:`~ott.initializers.linear.initializers_lr.LRInitializer.__call__`.

Returns:
The initial geometry used to initialize a linear problem.
"""
from ott.solvers.linear import sinkhorn_lr

q, r, g = self._linear_lr_initializer(quad_prob, **kwargs)
tmp_out = sinkhorn_lr.LRSinkhornOutput(
q=q,
r=r,
g=g,
costs=None,
errors=None,
ot_prob=None,
epsilon=None,
)

return quad_prob.update_lr_geom(tmp_out, relative_epsilon=relative_epsilon)

@property
def rank(self) -> int:
"""Rank of the transport matrix factorization."""
return self._linear_lr_initializer.rank

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
children, aux_data = super().tree_flatten()
return children + [self._linear_lr_initializer], aux_data
70 changes: 37 additions & 33 deletions src/ott/solvers/linear/lr_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import jax.scipy as jsp

from ott.math import fixed_point_loop
from ott.math import unbalanced_functions as uf
from ott.problems.linear import linear_problem

__all__ = ["unbalanced_dykstra_lse", "unbalanced_dykstra_kernel"]
Expand All @@ -36,8 +35,8 @@ class State(NamedTuple): # noqa: D101
class Constants(NamedTuple): # noqa: D101
a: jnp.ndarray
b: jnp.ndarray
tau_a: float
tau_b: float
rho_a: float
rho_b: float
supp_a: Optional[jnp.ndarray] = None
supp_b: Optional[jnp.ndarray] = None

Expand Down Expand Up @@ -105,16 +104,17 @@ def body_fn(
iteration: int, const: Constants, state: State, compute_error: bool
) -> State:
log_a, log_b = jnp.log(const.a), jnp.log(const.b)
rho_a, rho_b = const.rho_a, const.rho_b

if translation_invariant:
rho_a = uf.rho(1.0 / gamma, const.tau_a)
rho_b = uf.rho(1.0 / gamma, const.tau_b)
c_a = _get_ratio(const.rho_a, gamma)
c_b = _get_ratio(const.rho_b, gamma)

if translation_invariant:
lam_a, lam_b = compute_lambdas(const, state, gamma, g=c_g, lse_mode=True)

u1 = const.tau_a * (log_a - _softm(state.v1, c_q, axis=1))
u1 = c_a * (log_a - _softm(state.v1, c_q, axis=1))
u1 = u1 - lam_a / ((1.0 / gamma) + rho_a)
u2 = const.tau_b * (log_b - _softm(state.v2, c_r, axis=1))
u2 = c_b * (log_b - _softm(state.v2, c_r, axis=1))
u2 = u2 - lam_b / ((1.0 / gamma) + rho_b)

state_lam = State(
Expand All @@ -129,8 +129,8 @@ def body_fn(

g_trans = gamma * (lam_a + lam_b) + c_g
else:
u1 = const.tau_a * (log_a - _softm(state.v1, c_q, axis=1))
u2 = const.tau_b * (log_b - _softm(state.v2, c_r, axis=1))
u1 = c_a * (log_a - _softm(state.v1, c_q, axis=1))
u2 = c_b * (log_b - _softm(state.v2, c_r, axis=1))

v1_trans = _softm(u1, c_q, axis=0)
v2_trans = _softm(u2, c_r, axis=0)
Expand All @@ -155,8 +155,8 @@ def body_fn(
constants = Constants(
a=ot_prob.a,
b=ot_prob.b,
tau_a=ot_prob.tau_a,
tau_b=ot_prob.tau_b,
rho_a=_rho(ot_prob.tau_a),
rho_b=_rho(ot_prob.tau_b),
supp_a=ot_prob.a > 0,
supp_b=ot_prob.b > 0,
)
Expand Down Expand Up @@ -242,18 +242,16 @@ def cond_fn(
def body_fn(
iteration: int, const: Constants, state: State, compute_error: bool
) -> State:
if translation_invariant:
rho_a = uf.rho(1.0 / gamma, const.tau_a)
rho_b = uf.rho(1.0 / gamma, const.tau_b)
c_a = const.tau_a
c_b = const.tau_b
c_a = _get_ratio(const.rho_a, gamma)
c_b = _get_ratio(const.rho_b, gamma)

if translation_invariant:
lam_a, lam_b = compute_lambdas(const, state, gamma, g=k_g, lse_mode=False)

u1 = jnp.where(const.supp_a, (const.a / (k_q @ state.v1)) ** c_a, 0.0)
u1 = u1 * jnp.exp(-lam_a / ((1.0 / gamma) + rho_a))
u1 = u1 * jnp.exp(-lam_a / ((1.0 / gamma) + const.rho_a))
u2 = jnp.where(const.supp_b, (const.b / (k_r @ state.v2)) ** c_b, 0.0)
u2 = u2 * jnp.exp(-lam_b / ((1.0 / gamma) + rho_b))
u2 = u2 * jnp.exp(-lam_b / ((1.0 / gamma) + const.rho_b))

state_lam = State(
v1=state.v1, v2=state.v2, u1=u1, u2=u2, g=state.g, err=state.err
Expand All @@ -268,12 +266,8 @@ def body_fn(
k_trans = jnp.exp(gamma * (lam_a + lam_b)) * k_g
g = (k_trans * v1_trans * v2_trans) ** (1.0 / 3.0)
else:
u1 = jnp.where(
const.supp_a, (const.a / (k_q @ state.v1)) ** const.tau_a, 0.0
)
u2 = jnp.where(
const.supp_b, (const.b / (k_r @ state.v2)) ** const.tau_b, 0.0
)
u1 = jnp.where(const.supp_a, (const.a / (k_q @ state.v1)) ** c_a, 0.0)
u2 = jnp.where(const.supp_b, (const.b / (k_r @ state.v2)) ** c_b, 0.0)

v1_trans = k_q.T @ u1
v2_trans = k_r.T @ u2
Expand All @@ -298,8 +292,8 @@ def body_fn(
constants = Constants(
a=ot_prob.a,
b=ot_prob.b,
tau_a=ot_prob.tau_a,
tau_b=ot_prob.tau_b,
rho_a=_rho(ot_prob.tau_a),
rho_b=_rho(ot_prob.tau_b),
supp_a=ot_prob.a > 0.0,
supp_b=ot_prob.b > 0.0,
)
Expand Down Expand Up @@ -328,8 +322,8 @@ def compute_lambdas(
) -> Tuple[float, float]:
"""TODO."""
gamma_inv = 1.0 / gamma
rho_a = uf.rho(gamma_inv, const.tau_a)
rho_b = uf.rho(gamma_inv, const.tau_b)
rho_a = const.rho_a
rho_b = const.rho_b

if lse_mode:
num_1 = jsp.special.logsumexp((-gamma_inv / rho_a) * state.u1, b=const.a)
Expand All @@ -338,8 +332,8 @@ def compute_lambdas(
const_1 = num_1 - den
const_2 = num_2 - den

ratio_1 = const.tau_a # rho_a / (rho_a + gamma_inv)
ratio_2 = const.tau_b # rho_b / (rho_b + gamma_inv)
ratio_1 = _get_ratio(rho_a, gamma)
ratio_2 = _get_ratio(rho_b, gamma)
harmonic = 1.0 / (1.0 - (ratio_1 * ratio_2))
lam_1 = harmonic * gamma_inv * ratio_1 * (const_1 - ratio_2 * const_2)
lam_2 = harmonic * gamma_inv * ratio_2 * (const_2 - ratio_1 * const_1)
Expand All @@ -359,9 +353,19 @@ def compute_lambdas(
const_1 = jnp.log(num_1 / den)
const_2 = jnp.log(num_2 / den)

ratio_1 = const.tau_a # rho_a / (rho_a + gamma_inv)
ratio_2 = const.tau_b # rho_b / (rho_b + gamma_inv)
ratio_1 = _get_ratio(rho_a, gamma)
ratio_2 = _get_ratio(rho_b, gamma)
harmonic = 1.0 / (1.0 - (ratio_1 * ratio_2))
lam_1 = harmonic * gamma_inv * ratio_1 * (const_1 - ratio_2 * const_2)
lam_2 = harmonic * gamma_inv * ratio_2 * (const_2 - ratio_1 * const_1)
return lam_1, lam_2


def _rho(tau: float) -> float:
tau = jnp.asarray(tau) # avoid division by 0 in Python, get NaN instead
return tau / (1.0 - tau)


def _get_ratio(rho: float, gamma: float) -> float:
gamma_inv = 1.0 / gamma
return jnp.where(jnp.isfinite(rho), rho / (rho + gamma_inv), 1.0)
Loading