Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/online gw #80

Merged
merged 3 commits into from
Jun 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ott/core/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def gromov_wasserstein(
geom_xx: geometry.Geometry,
geom_yy: geometry.Geometry,
geom_xy: Optional[geometry.Geometry] = None,
fused_penalty: Optional[float] = None,
fused_penalty: float = 1.0,
scale_cost: Optional[Union[bool, float, str]] = False,
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
Expand All @@ -350,8 +350,8 @@ def gromov_wasserstein(
geom_yy: a second Geometry object for the second view.
geom_xy: a Geometry object representing the linear cost in FGW.
fused_penalty: multiplier of the linear term in Fused Gromov Wasserstein,
i.e. loss = quadratic_loss + fused_penalty * linear_loss. If geom_xy is
None fused_penalty will be ignored, i.e. fused_penalty = 0.
i.e. loss = quadratic_loss + fused_penalty * linear_loss.
Ignored if ``geom_xy`` is not specified.
scale_cost: option to rescale the cost matrices:

- if `True`, use the default for each geometry.
Expand Down
42 changes: 19 additions & 23 deletions ott/core/quad_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,8 @@ class QuadraticProblem:
for Fused Gromov Wasserstein. If None, the problem reduces to a plain
Gromov Wasserstein problem.
fused_penalty: multiplier of the linear term in Fused Gromov Wasserstein,
i.e. problem = purely quadratic + fused_penalty * linear problem. If
fused_penalty is None but geom_xy is passed, fused_penalty is set by
default to 1.0, equal to 0.0 otherwise.
i.e. problem = purely quadratic + fused_penalty * linear problem.
Ignored if ``geom_xy`` is not specified.
scale_cost: option to rescale the cost matrices:

- if `True`, use the default for each geometry.
Expand Down Expand Up @@ -119,7 +118,7 @@ def __init__(
geom_xx: geometry.Geometry,
geom_yy: geometry.Geometry,
geom_xy: Optional[geometry.Geometry] = None,
fused_penalty: Optional[float] = None,
fused_penalty: float = 1.0,
scale_cost: Optional[Union[bool, float, str]] = False,
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
Expand All @@ -128,14 +127,12 @@ def __init__(
tau_b: Optional[float] = 1.0,
gw_unbalanced_correction: Optional[bool] = True
):

assert fused_penalty > 0, fused_penalty
self.geom_xx = geom_xx._set_scale_cost(scale_cost)
self.geom_yy = geom_yy._set_scale_cost(scale_cost)
self.geom_xy = (
None if geom_xy is None else geom_xy._set_scale_cost(scale_cost)
)
if fused_penalty is None:
fused_penalty = jnp.where(self.geom_xy is None, 0.0, 1.0)
self.fused_penalty = fused_penalty
self.scale_cost = scale_cost
self._a = a
Expand All @@ -152,14 +149,14 @@ def __init__(

@property
def is_fused(self) -> bool:
return self.geom_xy is not None and self.fused_penalty > 0.0
return self.geom_xy is not None

@property
def is_all_geoms_lr(self) -> bool:
return (
isinstance(self.geom_xx, low_rank.LRCGeometry) and
isinstance(self.geom_yy, low_rank.LRCGeometry) and
(not self.is_fused or isinstance(self.geom_xy, low_rank.LRCGeometry))
isinstance(self.geom_xy, (low_rank.LRCGeometry, type(None)))
)

@property
Expand Down Expand Up @@ -380,10 +377,7 @@ def init_linearization(
transport_mass = marginal_1.sum()
epsilon = update_epsilon_unbalanced(epsilon, transport_mass)

cost_matrix += self.fused_penalty * jnp.where(
self.is_fused,
0.0 if self.geom_xy is None else self.geom_xy.cost_matrix, 0.0
)
cost_matrix += self.fused_penalty * self._fused_cost_matrix

geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=epsilon)
return problems.LinearProblem(
Expand Down Expand Up @@ -433,10 +427,7 @@ def update_lr_geom(
geom = low_rank.add_lrc_geom(geom, self.geom_xy)
else:
cost_matrix = marginal_cost.cost_matrix - jnp.dot(tmp1, tmp2.T)
cost_matrix += self.fused_penalty * jnp.where(
self.is_fused,
0.0 if self.geom_xy is None else self.geom_xy.cost_matrix, 0.0
)
cost_matrix += self.fused_penalty * self._fused_cost_matrix
geom = geometry.Geometry(cost_matrix=cost_matrix)
return geom

Expand Down Expand Up @@ -488,12 +479,7 @@ def update_linearization(
tmp = self.geom_yy.apply_cost(tmp.T, axis=1, fn=self.quad_loss[1]).T

cost_matrix = marginal_cost.cost_matrix - tmp + unbalanced_correction

cost_matrix += self.fused_penalty * jnp.where(
self.is_fused,
0.0 if self.geom_xy is None else self.geom_xy.cost_matrix, 0.0
)

cost_matrix += self.fused_penalty * self._fused_cost_matrix
cost_matrix *= rescale_factor

geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=epsilon)
Expand All @@ -513,6 +499,16 @@ def update_lr_linearization(
tau_b=self.tau_b
)

@property
def _fused_cost_matrix(self) -> Union[float, jnp.ndarray]:
if not self.is_fused:
return 0
if isinstance(
self.geom_xy, pointcloud.PointCloud
) and self.geom_xy.is_online:
return self.geom_xy.compute_cost_matrix() * self.geom_xy.inv_scale_cost
return self.geom_xy.cost_matrix


def update_epsilon_unbalanced(epsilon, transport_mass):
updated_epsilon = epsilon_scheduler.Epsilon.make(epsilon)
Expand Down
8 changes: 5 additions & 3 deletions ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,10 @@ def apply_cost(self, arr: jnp.ndarray, axis: int = 0, fn=None) -> jnp.ndarray:
if fn is None:
return self._apply_cost(arr, axis, fn=fn)
# Switch to efficient computation for the squared euclidean case.
return jnp.where(
return jax.lax.cond(
jnp.logical_and(self.is_squared_euclidean, geometry.is_affine(fn)),
self.vec_apply_cost(arr, axis, fn=fn),
self._apply_cost(arr, axis, fn=fn)
lambda: self.vec_apply_cost(arr, axis, fn=fn),
lambda: self._apply_cost(arr, axis, fn=fn)
)

def _apply_cost(
Expand All @@ -390,6 +390,8 @@ def _apply_cost(
None, 0, None, self._axis_norm, None, None, None, None, None
]
)
if arr.ndim == 1:
arr = arr.reshape(-1, 1)
if axis == 0:
return app(
self.x, self.y, self._norm_x, self._norm_y, arr, self._cost_fn,
Expand Down