Skip to content

Commit

Permalink
Feature/faster gw init (#213)
Browse files Browse the repository at this point in the history
* Add faster GW initialization for balanced case

* Use `jax.scipy.special.entr`

* Update the unbalanced case

* Fix wrong `xlog`

* Parenthesize

* Fix imports, change how `reg` is implemented

* Transport mass change
  • Loading branch information
michalk8 authored Dec 17, 2022
1 parent 61f9e80 commit b70f1d0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 42 deletions.
36 changes: 17 additions & 19 deletions src/ott/initializers/quadratic/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple

import jax
import jax.numpy as jnp

from ott.geometry import geometry

Expand Down Expand Up @@ -123,33 +124,30 @@ def _create_geometry(
from ott.problems.quadratic import quadratic_problem
del kwargs

unbalanced_correction = 0.0
tmp = quad_prob.init_transport()
marginal_1 = tmp.sum(1)
marginal_2 = tmp.sum(0)
marginal_cost = quad_prob.marginal_dependent_cost(quad_prob.a, quad_prob.b)
geom_xx, geom_yy = quad_prob.geom_xx, quad_prob.geom_yy

# Initialises cost.
marginal_cost = quad_prob.marginal_dependent_cost(marginal_1, marginal_2)
h1, h2 = quad_prob.quad_loss
tmp1 = quadratic_problem.apply_cost(geom_xx, quad_prob.a, axis=1, fn=h1)
tmp2 = quadratic_problem.apply_cost(geom_yy, quad_prob.b, axis=1, fn=h2)
tmp = jnp.outer(tmp1, tmp2)

if quad_prob.is_balanced:
cost_matrix = marginal_cost.cost_matrix - tmp
else:
# initialize epsilon for Unbalanced GW according to Sejourne et. al (2021)
init_transport = jnp.outer(quad_prob.a, quad_prob.b)
marginal_1, marginal_2 = init_transport.sum(1), init_transport.sum(0)

if not quad_prob.is_balanced:
transport_mass = marginal_1.sum()
# Initialises epsilon for Unbalanced GW according to Sejourne et al (2021)
epsilon = quadratic_problem.update_epsilon_unbalanced(
epsilon=epsilon, transport_mass=transport_mass
epsilon=epsilon, transport_mass=marginal_1.sum()
)
unbalanced_correction = quad_prob.cost_unbalanced_correction(
tmp, marginal_1, marginal_2, epsilon=epsilon
init_transport, marginal_1, marginal_2, epsilon=epsilon
)

h1, h2 = quad_prob.quad_loss
tmp = quadratic_problem.apply_cost(quad_prob.geom_xx, tmp, axis=1, fn=h1)
tmp = quadratic_problem.apply_cost(
quad_prob.geom_yy, tmp.T, axis=1, fn=h2
).T
cost_matrix = (marginal_cost.cost_matrix - tmp + unbalanced_correction)
cost_matrix = marginal_cost.cost_matrix - tmp + unbalanced_correction

cost_matrix += quad_prob.fused_penalty * quad_prob._fused_cost_matrix

return geometry.Geometry(cost_matrix=cost_matrix, epsilon=epsilon)


Expand Down
42 changes: 19 additions & 23 deletions src/ott/problems/quadratic/quadratic_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import jax
import jax.numpy as jnp
import jax.scipy as jsp
from typing_extensions import Literal

from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud
Expand Down Expand Up @@ -173,7 +174,7 @@ def cost_unbalanced_correction(
transport_matrix: jnp.ndarray,
marginal_1: jnp.ndarray,
marginal_2: jnp.ndarray,
epsilon: float,
epsilon: epsilon_scheduler.Epsilon,
) -> float:
r"""Calculate cost term from the quadratic divergence when unbalanced.
Expand All @@ -200,35 +201,30 @@ def cost_unbalanced_correction(
marginal_2: jnp.ndarray<float>[num_b,], marginal of the transport matrix
for samples from :attr:`geom_yy`.
epsilon: regulariser.
delta: small quantity to avoid diverging KLs.
Returns:
The cost term.
"""

def regulariser(tau: float) -> float:
return epsilon._target_init * tau / (1.0 - tau) if tau != 1.0 else 0

marginal_1loga = jax.scipy.special.xlogy(marginal_1, self.a).sum()
marginal_2logb = jax.scipy.special.xlogy(marginal_2, self.b).sum()
cost = regulariser(
self.tau_a
) * (-jax.scipy.special.entr(marginal_1).sum() - marginal_1loga)
cost += regulariser(
self.tau_b
) * (-jax.scipy.special.entr(marginal_2).sum() - marginal_2logb)
cost += epsilon._target_init * jax.scipy.special.xlogy(
transport_matrix, transport_matrix
).sum()
def regularizer(tau: float) -> float:
return eps * tau / (1.0 - tau)

eps = epsilon._target_init
marginal_1loga = jsp.special.xlogy(marginal_1, self.a).sum()
marginal_2logb = jsp.special.xlogy(marginal_2, self.b).sum()

cost = eps * jsp.special.xlogy(transport_matrix, transport_matrix).sum()
if self.tau_a != 1.0:
cost += regularizer(
self.tau_a
) * (-jsp.special.entr(marginal_1).sum() - marginal_1loga)
if self.tau_b != 1.0:
cost += regularizer(
self.tau_b
) * (-jsp.special.entr(marginal_2).sum() - marginal_2logb)
return cost

def init_transport(self) -> jnp.ndarray:
"""Initialise the transport matrix."""
# TODO(oliviert, cuturi): consider passing a custom initialization.
a = jax.lax.stop_gradient(self.a)
b = jax.lax.stop_gradient(self.b)
return a[:, None] * b[None, :]

# TODO(michalk8): highly coupled to the pre-defined initializer, refactor
def init_transport_mass(self) -> float:
"""Initialise the transport mass.
Expand Down

0 comments on commit b70f1d0

Please sign in to comment.