Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor GW initialization #133

Merged
merged 7 commits into from
Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions docs/notebooks/GWLRSinkhorn.ipynb

Large diffs are not rendered by default.

54 changes: 32 additions & 22 deletions ott/core/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

# Lint as: python3
"""A Jax version of the regularised GW Solver (Peyre et al. 2016)."""
import functools
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union

import jax
Expand Down Expand Up @@ -101,13 +100,12 @@ class GWState(NamedTuple):
old_transport_mass: Intermediary value of the mass of the transport matrix.
"""

costs: Optional[jnp.ndarray] = None
linear_convergence: Optional[jnp.ndarray] = None
costs: jnp.ndarray
linear_convergence: jnp.ndarray
linear_state: LinearOutput
linear_pb: linear_problems.LinearProblem
old_transport_mass: float
errors: Optional[jnp.ndarray] = None
linear_state: Optional[LinearOutput] = None
linear_pb: Optional[linear_problems.LinearProblem] = None
# Intermediate values.
old_transport_mass: float = 1.0

def set(self, **kwargs: Any) -> 'GWState':
"""Return a copy of self, possibly with overwrites."""
Expand All @@ -125,6 +123,7 @@ def update(
linear_convergence = self.linear_convergence.at[iteration].set(
linear_sol.converged
)

return self.set(
linear_state=linear_sol,
linear_pb=linear_pb,
Expand All @@ -146,8 +145,7 @@ def __call__(self, prob: quad_problems.QuadraticProblem) -> GWOutput:
# Possibly jit iteration functions and run. Closure on rank to
# avoid jitting issues, since rank value will be used to branch between
# a default entropic GW or a low-rank GW.
iterations_fn = functools.partial(iterations, rank=self.rank)
gromov_fn = jax.jit(iterations_fn) if self.jit else iterations_fn
gromov_fn = jax.jit(iterations) if self.jit else iterations
out = gromov_fn(self, prob)
# TODO(lpapaxanthos): remove stop_gradient when using backprop
if self.is_low_rank:
Expand All @@ -167,24 +165,31 @@ def __call__(self, prob: quad_problems.QuadraticProblem) -> GWOutput:
return out.set(linear_state=linear_state, convergence=convergence)

def init_state(
self, prob: quad_problems.QuadraticProblem, rank: int
self,
prob: quad_problems.QuadraticProblem,
) -> GWState:
"""Initialize the state of the Gromov-Wasserstein iterations."""
if rank > 0:
linearization = prob.init_lr_linearization(rank)
if self.is_low_rank:
linear_prob = prob.init_lr_linearization(self.linear_ot_solver)
else:
linearization = prob.init_linearization(self.epsilon)
linear_prob = prob.init_linearization(self.epsilon)

linear_state = self.linear_ot_solver(linearization)
linear_state = self.linear_ot_solver(linear_prob)
num_iter = self.max_iterations
transport_mass = prob.init_transport_mass()

if self.store_inner_errors:
errors = -jnp.ones((num_iter, self.linear_ot_solver.outer_iterations))
else:
errors = None

return GWState(
-jnp.ones((num_iter,)), -jnp.ones((num_iter,)), errors, linear_state,
linearization, transport_mass
costs=-jnp.ones((num_iter,)),
linear_convergence=-jnp.ones((num_iter,)),
linear_state=linear_state,
linear_pb=linear_prob,
old_transport_mass=transport_mass,
errors=errors
)

def output_from_state(self, state: GWState) -> GWOutput:
Expand All @@ -208,7 +213,8 @@ def output_from_state(self, state: GWState) -> GWOutput:


def iterations(
solver: GromovWasserstein, prob: quad_problems.QuadraticProblem, rank: int
solver: GromovWasserstein,
prob: quad_problems.QuadraticProblem,
) -> GWOutput:
"""Jittable Gromov-Wasserstein outer loop."""

Expand All @@ -219,19 +225,21 @@ def cond_fn(
return solver._continue(state, iteration)

def body_fn(
iteration: int, constants: GromovWasserstein, state: GWState,
iteration: int, solver: GromovWasserstein, state: GWState,
compute_error: bool
) -> GWState:
del compute_error # Always assumed True for outer loop of GW.
solver = constants
if rank > 0:

if solver.is_low_rank:
init = state.linear_state.q, state.linear_state.r, state.linear_state.g
linear_pb = prob.update_lr_linearization(state.linear_state)
else:
init = state.linear_state.f, state.linear_state.g
linear_pb = prob.update_linearization(
state.linear_state, solver.epsilon, state.old_transport_mass
)

out = solver.linear_ot_solver(linear_pb)
out = solver.linear_ot_solver(linear_pb, init=init)
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
old_transport_mass = jax.lax.stop_gradient(
state.linear_state.transport_mass()
)
Expand All @@ -246,7 +254,7 @@ def body_fn(
max_iterations=solver.max_iterations,
inner_iterations=1,
constants=solver,
state=solver.init_state(prob, rank)
state=solver.init_state(prob)
)

return solver.output_from_state(state)
Expand Down Expand Up @@ -300,6 +308,8 @@ def make(
sink = sinkhorn_lr.make(
rank=rank, epsilon=epsilon, **linear_ot_solver_kwargs
)
else:
raise ValueError(f"Invalid value for `rank={rank}`.")

return GromovWasserstein(
epsilon,
Expand Down
6 changes: 3 additions & 3 deletions ott/core/gw_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,17 +241,17 @@ def output_from_state(self, state: GWBarycenterState) -> GWBarycenterState:

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
children, aux = super().tree_flatten()
aux["quad_solver"] = self._quad_solver
return children, aux
return children + [self._quad_solver], aux

@classmethod
def tree_unflatten(
cls, aux_data: Dict[str, Any], children: Sequence[Any]
) -> "GromovWassersteinBarycenter":
epsilon, _, _, threshold = children
epsilon, _, threshold, quad_solver = children
return cls(
epsilon=epsilon,
threshold=threshold,
quad_solver=quad_solver,
**aux_data,
)

Expand Down
48 changes: 31 additions & 17 deletions ott/core/quad_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,18 @@ def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss:

@jax.tree_util.register_pytree_node_class
class QuadraticProblem:
"""Definition of the quadratic regularized OT problem.
r"""Definition of the quadratic regularized OT problem.

The quadratic loss of a single OT matrix is assumed to
have the form given in :cite:`peyre:16`, eq. 4.

The two geometries below parameterize matrices C and bar{C} in that equation.
The function L (of two real values) in that equation is assumed
to match the form given in Eq. 5., with our notations:
The two geometries below parameterize matrices :math:`C` and :math:`\bar{C}`
in that equation. The function :math:`L` (of two real values) in that equation
is assumed to match the form given in eq. 5., with our notations:

L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y)
.. math::

L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y)

Args:
geom_xx: the geometry.Geometry object defining the ground geometry / cost
Expand Down Expand Up @@ -175,10 +177,12 @@ def __init__(

@property
def is_fused(self) -> bool:
"""Whether the problem is fused."""
return self.geom_xy is not None

@property
def is_low_rank(self) -> bool:
"""Whether all geometries are low-rank."""
return (
isinstance(self.geom_xx, low_rank.LRCGeometry) and
isinstance(self.geom_yy, low_rank.LRCGeometry) and (
Expand All @@ -189,14 +193,17 @@ def is_low_rank(self) -> bool:

@property
def linear_loss(self) -> Tuple[Loss, Loss]:
"""Linear part of the GW loss."""
return self.loss.f1, self.loss.f2

@property
def quad_loss(self) -> Tuple[Loss, Loss]:
"""Quadratic part of the GW loss."""
return self.loss.h1, self.loss.h2

@property
def is_balanced(self) -> bool:
"""Whether the problem is balanced."""
return ((not self.gw_unbalanced_correction) or
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
(self.tau_a == 1.0 and self.tau_b == 1.0))

Expand All @@ -219,11 +226,13 @@ def tree_unflatten(cls, aux_data, children):

@property
def a(self) -> jnp.ndarray:
"""Source marginals."""
num_a = self.geom_xx.shape[0]
return jnp.ones((num_a,)) / num_a if self._a is None else self._a

@property
def b(self) -> jnp.ndarray:
"""Target marginals."""
num_b = self.geom_yy.shape[0]
return jnp.ones((num_b,)) / num_b if self._b is None else self._b

Expand Down Expand Up @@ -416,24 +425,29 @@ def init_linearization(
)

def init_lr_linearization(
self, rank: int, **kwargs: Any
self,
solver: sinkhorn_lr.LRSinkhorn,
**kwargs: Any,
) -> linear_problems.LinearProblem:
"""Linearizes a Quad problem with a predefined initializer."""
x_ = self.geom_xx.apply_square_cost(self.a)
y_ = self.geom_yy.apply_square_cost(self.b)
geom_ = pointcloud.PointCloud(x_, y_).to_LRCGeometry()
out = sinkhorn_lr.LRSinkhorn(
rank=rank, **kwargs
)(
linear_problems.LinearProblem(geom_, self.a, self.b)
"""Linearize a Quad problem with a predefined initializer."""
x = self.geom_xx.apply_square_cost(self.a)
y = self.geom_yy.apply_square_cost(self.b)
geom = pointcloud.PointCloud(x, y).to_LRCGeometry()

prob = linear_problems.LinearProblem(geom, self.a, self.b)
q, r, g = solver.initializer(prob, **kwargs)
dummy_out = sinkhorn_lr.LRSinkhornOutput(
q=q, r=r, g=g, costs=None, criterions=None, ot_prob=prob
)
return linear_problems.LinearProblem(
self.update_lr_geom(out),

prob = linear_problems.LinearProblem(
self.update_lr_geom(dummy_out),
self.a,
self.b,
tau_a=self.tau_a,
tau_b=self.tau_b
)
return prob

def update_lr_geom(
self, lr_sink: sinkhorn_lr.LRSinkhornOutput
Expand Down Expand Up @@ -554,7 +568,7 @@ def convertible(geom: geometry.Geometry) -> bool:

geom_xx, geom_yy, geom_xy = self.geom_xx, self.geom_yy, self.geom_xy
# either explicitly via cost factorization or implicitly (e.g., a PC)
return self.ranks != 1 or (
return self.ranks != -1 or (
convertible(geom_xx) and convertible(geom_yy) and
(geom_xy is None or convertible(geom_xy))
)
Expand Down
6 changes: 3 additions & 3 deletions ott/core/was_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,21 @@ def is_low_rank(self) -> bool:
return self.rank > 0

def tree_flatten(self):
return ([self.epsilon, self.rank, self.linear_ot_solver, self.threshold],
return ([self.epsilon, self.linear_ot_solver, self.threshold],
dict(
min_iterations=self.min_iterations,
max_iterations=self.max_iterations,
jit=self.jit,
rank=self.rank,
store_inner_errors=self.store_inner_errors,
**self._kwargs
))

@classmethod
def tree_unflatten(cls, aux_data, children):
epsilon, rank, linear_ot_solver, threshold = children
epsilon, linear_ot_solver, threshold = children
return cls(
epsilon=epsilon,
rank=rank,
linear_ot_solver=linear_ot_solver,
threshold=threshold,
**aux_data
Expand Down
4 changes: 1 addition & 3 deletions ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,7 @@ def inv_scale_cost(self) -> float:
return 1.0 / jnp.nanmedian(self._cost_matrix)
raise ValueError(f'Scaling {self._scale_cost} not implemented.')

def _set_scale_cost(
self, scale_cost: Optional[Union[bool, float, str]]
) -> "Geometry":
def _set_scale_cost(self, scale_cost: Union[bool, float, str]) -> "Geometry":
# case when `geom` doesn't have `scale_cost` or doesn't need to be modified
# `False` retains the original scale
if scale_cost is False or scale_cost == self._scale_cost:
Expand Down
10 changes: 5 additions & 5 deletions tests/core/gromov_wasserstein_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,18 +358,18 @@ def test_gw_lr_matches_fused(self, rng: jnp.ndarray):
ot_gw = solver(prob)

# Test solutions look alike
assert 0.1 > jnp.linalg.norm(ot_gwlr.matrix - ot_gw.matrix)
assert 0.13 > jnp.linalg.norm(ot_gwlr.matrix - ot_gwlreps.matrix)
assert 0.11 > jnp.linalg.norm(ot_gwlr.matrix - ot_gw.matrix)
assert 0.15 > jnp.linalg.norm(ot_gwlr.matrix - ot_gwlreps.matrix)
# Test at least some difference when adding bigger entropic regularization
assert jnp.linalg.norm(ot_gwlr.matrix - ot_gwlreps.matrix) > 1e-3

@pytest.mark.parametrize("scale_cost", [True, "mean", "max_cost"])
def test_gw_fused_scale_cost(self, scale_cost: Union[bool, str]):
epsilon = 0.1
fused_penalty = 1
geom_x = pointcloud.PointCloud(self.x, scale_cost=None)
geom_y = pointcloud.PointCloud(self.y, scale_cost=None)
geom_xy = pointcloud.PointCloud(self.xx, self.yy, scale_cost=None)
geom_x = pointcloud.PointCloud(self.x, scale_cost=1.)
geom_y = pointcloud.PointCloud(self.y, scale_cost=1.)
geom_xy = pointcloud.PointCloud(self.xx, self.yy, scale_cost=1.)
geom_x_scaled = pointcloud.PointCloud(self.x, scale_cost=scale_cost)
geom_y_scaled = pointcloud.PointCloud(self.y, scale_cost=scale_cost)
geom_xy_scaled = pointcloud.PointCloud(
Expand Down
4 changes: 2 additions & 2 deletions tests/core/sinkhorn_diff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ class TestSinkhornHessian:
tau_b=[1.0, .91],
shape=[(12, 15)],
arg=[0, 1],
only_fast=[-1]
only_fast=-1
)
def test_hessian_sinkhorn(
self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float,
Expand Down Expand Up @@ -764,7 +764,7 @@ def loss(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True):
lse_mode=lse_mode,
threshold=1e-4,
use_danskin=False,
implicit_diff=implicit_diff
implicit_diff=implicit_diff,
)
return solver(prob).reg_ot_cost

Expand Down