Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LR Sinkhorn improvements #111

Merged
merged 40 commits into from
Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
82e03fb
Merge pull request #1 from ott-jax/main
meyerscetbon Jul 19, 2022
c68545d
test
meyerscetbon Jul 19, 2022
c3b4a46
Merge branch 'ott-jax:main' into master
meyerscetbon Jul 20, 2022
d6aa2c8
update lr-sinkhorn
meyerscetbon Jul 20, 2022
8ebb773
restored_branch
meyerscetbon Jul 21, 2022
d85df4d
check
meyerscetbon Jul 21, 2022
b58c5dd
Merge branch 'ott-jax:main' into master
meyerscetbon Jul 24, 2022
0174446
review
meyerscetbon Jul 25, 2022
b7d9faf
Merge branch 'master' into branch_meyer_recovered
meyerscetbon Jul 26, 2022
c2cd4e2
circular fixed
meyerscetbon Jul 26, 2022
c9eac41
update review
meyerscetbon Jul 27, 2022
c5d037e
Fix bugs in `LRSinkhorn`
michalk8 Aug 29, 2022
3365151
Merge branch 'main' into branch_meyer_recovered
michalk8 Aug 29, 2022
372ae3d
Use new `k-means` implementation
michalk8 Aug 29, 2022
47cd986
Fix linter
michalk8 Aug 29, 2022
1abfa6e
Merge branch 'main' into branch_meyer_recovered
michalk8 Aug 29, 2022
70f7e9f
Refactor `LRSinkhorn` initializers
michalk8 Aug 29, 2022
341ef42
Use `if` for `is_entropic`, remove dead variables
michalk8 Aug 29, 2022
e56be73
Slightly improve types
michalk8 Aug 29, 2022
cd6c630
Do not use stateful `gamma`
michalk8 Aug 29, 2022
a857d91
Fix typo in tests
michalk8 Aug 29, 2022
859091e
Fix using `state.gamma` instead of `self.gamma`
michalk8 Aug 29, 2022
dfe00d5
Fix point cloud size in notebook
michalk8 Aug 29, 2022
869cd07
Add assertion to k-means
michalk8 Aug 29, 2022
ed6af26
Use `jax.lax.cond` instead of `jax.numpy.where`
michalk8 Aug 29, 2022
e1b3987
Change convergence criterion
michalk8 Aug 30, 2022
461e517
Use safe log
michalk8 Aug 30, 2022
c3ec475
Fix more tests
michalk8 Aug 30, 2022
cb2d89c
Fix tests
michalk8 Aug 31, 2022
9393e29
Fix `tree_flatten` in `KMeansInitializer`
michalk8 Aug 31, 2022
527e2b5
Fix defaults, change `rank_2` -> `rank2`
michalk8 Aug 31, 2022
e9026ef
Simplify `apply`
michalk8 Aug 31, 2022
4033b3d
Update TODOs
michalk8 Aug 31, 2022
b53610e
Update docs, make `lr_costs` private
michalk8 Aug 31, 2022
be80591
Increate tolerance in failing test
michalk8 Aug 31, 2022
9d93a6e
Merge branch 'main' into branch_meyer_recovered
Oisin-M Aug 31, 2022
ae97a3d
Update LR notebook
Oisin-M Aug 31, 2022
ffee388
Address comments
michalk8 Sep 1, 2022
8695eab
Merge branch 'branch_meyer_recovered' of https://github.com/meyerscet…
michalk8 Sep 1, 2022
fb1a1df
Remove LR Sinkhorn notebook from testing, to slow
michalk8 Sep 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions docs/notebooks/LRSinkhorn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"metadata": {
"id": "q9wY2bCeUIB0"
},
Expand All @@ -37,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 11,
"metadata": {
"id": "PfiRNdhVW8hT"
},
Expand Down Expand Up @@ -67,16 +67,20 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 12,
"metadata": {
"id": "pN_f36ACALET"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
"ename": "AttributeError",
"evalue": "module 'ott' has no attribute 'geometry'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-12-c9395a034c52>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_points\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mgeom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mott\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgeometry\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpointcloud\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPointCloud\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepsilon\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mot_prob\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mott\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear_problems\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinearProblem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgeom\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAttributeError\u001b[0m: module 'ott' has no attribute 'geometry'"
]
}
],
Expand Down Expand Up @@ -367,7 +371,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.8.2"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion ott/core/fixed_point_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Lint as: python3
"""jheek@ backprop-friendly implementation of fixed point loop."""
from typing import Any, Callable

import jax
import numpy as np
from jax import numpy as jnp
Expand Down
2 changes: 1 addition & 1 deletion ott/core/linear_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from typing import Callable, Optional, Tuple

import jax
import jax
import jax.numpy as jnp

from ott.geometry import geometry
Expand Down
2 changes: 1 addition & 1 deletion ott/core/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Lint as: python3
"""A Jax implementation of the Sinkhorn algorithm."""
from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple

import jax
import jax.numpy as jnp
import numpy as np
Expand Down
166 changes: 147 additions & 19 deletions ott/core/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from typing_extensions import Literal

from ott.core import fixed_point_loop, linear_problems, sinkhorn
from ott.geometry import geometry
from ott.geometry import geometry, pointcloud
from ott.tools import k_means


class LRSinkhornState(NamedTuple):
Expand All @@ -30,12 +31,24 @@ class LRSinkhornState(NamedTuple):
q: Optional[jnp.ndarray] = None
r: Optional[jnp.ndarray] = None
g: Optional[jnp.ndarray] = None
q_prev: Optional[jnp.ndarray] = None
meyerscetbon marked this conversation as resolved.
Show resolved Hide resolved
r_prev: Optional[jnp.ndarray] = None
g_prev: Optional[jnp.ndarray] = None
gamma: Optional[float] = None
costs: Optional[jnp.ndarray] = None
criterion: Optional[float] = None
count_escape: Optional[int] = None

def set(self, **kwargs: Any) -> 'LRSinkhornState':
"""Return a copy of self, with potential overwrites."""
return self._replace(**kwargs)

def compute_crit(self) -> float:
return compute_criterion(
self.q, self.r, self.g, self.q_prev, self.r_prev, self.g_prev,
self.gamma
)

def reg_ot_cost(
self,
ot_prob: linear_problems.LinearProblem,
Expand Down Expand Up @@ -101,6 +114,22 @@ def solution_error(
return err


def compute_criterion(
q: jnp.ndarray, r: jnp.ndarray, g: jnp.ndarray, q_prev: jnp.ndarray,
r_prev: jnp.ndarray, g_prev: jnp.ndarray, gamma: float
):
err_1 = ((1 / gamma) ** 2) * (kl(q, q_prev) + kl(q_prev, q))
err_2 = ((1 / gamma) ** 2) * (kl(r, r_prev) + kl(r_prev, r))
err_3 = ((1 / gamma) ** 2) * (kl(g, g_prev) + kl(g_prev, g))
criterion = err_1 + err_2 + err_3
return criterion


def kl(q1, q2):
meyerscetbon marked this conversation as resolved.
Show resolved Hide resolved
ratio = jnp.log(q1) - jnp.log(q2)
return jnp.sum(q1 * ratio)


class LRSinkhornOutput(NamedTuple):
"""Implement the problems.Transport interface, for a LR Sinkhorn solution."""

Expand Down Expand Up @@ -209,6 +238,7 @@ class LRSinkhorn(sinkhorn.Sinkhorn):
Args:
rank: the rank constraint on the coupling to minimize the linear OT problem
gamma: the (inverse of) gradient stepsize used by mirror descent.
gamma_init: TODO.
epsilon: entropic regularization added on top of low-rank problem.
init_type: TODO.
lse_mode: whether to run computations in lse or kernel mode. At this moment,
Expand All @@ -232,9 +262,10 @@ class LRSinkhorn(sinkhorn.Sinkhorn):
def __init__(
self,
rank: int = 10,
gamma: float = 1.0,
epsilon: float = 1e-4,
init_type: Literal['random', 'rank_2'] = 'random',
gamma: float = 10.0,
gamma_init: Literal['rescale', 'not_recale'] = 'rescale',
meyerscetbon marked this conversation as resolved.
Show resolved Hide resolved
epsilon: float = 0.0,
init_type: Literal['random', 'rank_2', 'kmeans'] = 'kmeans',
lse_mode: bool = True,
threshold: float = 1e-3,
norm_error: int = 1,
Expand All @@ -250,6 +281,7 @@ def __init__(
# TODO(michalk8): this should call super
self.rank = rank
self.gamma = gamma
self.gamma_init = gamma_init
self.epsilon = epsilon
self.init_type = init_type
self.lse_mode = lse_mode
Expand All @@ -275,7 +307,7 @@ def __call__(
"""Main interface to run LR sinkhorn.""" # noqa: D401
init_q, init_r, init_g = (init if init is not None else (None, None, None))
# Random initialization for q, r, g using rng_key
rng = jax.random.split(jax.random.PRNGKey(self.rng_key), 3)
rng = jax.random.split(jax.random.PRNGKey(self.rng_key), 5)
a, b = ot_prob.a, ot_prob.b
if self.init_type == 'random':
if init_g is None:
Expand Down Expand Up @@ -306,6 +338,49 @@ def __call__(
if init_r is None:
init_r = lambda_1 * jnp.dot(b1[:, None], g1.reshape(1, -1))
init_r += (1 - lambda_1) * jnp.dot(b2[:, None], g2.reshape(1, -1))
elif self.init_type == 'kmeans':
x = ot_prob.geom.x
y = ot_prob.geom.y
if init_g is None:
init_g = jnp.ones((self.rank,)) / self.rank
if init_q is None:
kmeans_x = jax.jit(
k_means.kmeans, static_argnums=(2, 3, 4, 5)
) if self.jit else k_means.kmeans
kmeans_x = kmeans_x(rng[3], x, self.rank)
z_x = kmeans_x[0]
geom_x = pointcloud.PointCloud(
x, z_x, epsilon=0.1, scale_cost='max_cost'
)
ot_prob_x = linear_problems.LinearProblem(geom_x, a, init_g)
solver_x = sinkhorn.Sinkhorn(
norm_error=self.norm_error,
lse_mode=self.lse_mode,
jit=self.jit,
implicit_diff=self.implicit_diff,
use_danskin=self.use_danskin
)
ot_sink_x = solver_x(ot_prob_x)
init_q = ot_sink_x.matrix
if init_r is None:
kmeans_y = jax.jit(
k_means.kmeans, static_argnums=(2, 3, 4, 5)
) if self.jit else k_means.kmeans
kmeans_y = kmeans_y(rng[4], y, self.rank)
z_y = kmeans_y[0]
geom_y = pointcloud.PointCloud(
y, z_y, epsilon=0.1, scale_cost='max_cost'
)
ot_prob_y = linear_problems.LinearProblem(geom_y, b, init_g)
solver_y = sinkhorn.Sinkhorn(
norm_error=self.norm_error,
lse_mode=self.lse_mode,
jit=self.jit,
implicit_diff=self.implicit_diff,
use_danskin=self.use_danskin
)
ot_sink_y = solver_y(ot_prob_y)
init_r = ot_sink_y.matrix
else:
raise NotImplementedError(self.init_type)
run_fn = jax.jit(run) if self.jit else run
Expand All @@ -316,13 +391,26 @@ def norm_error(self) -> Tuple[int]:
return (self._norm_error,)

def _converged(self, state: LRSinkhornState, iteration: int) -> bool:
costs, i, tol = state.costs, iteration, self.threshold
return jnp.logical_and(
i >= 2, jnp.isclose(costs[i - 2], costs[i - 1], rtol=tol)
)
criterion, count_escape, i, tol = state.criterion, state.count_escape, iteration, self.threshold
meyerscetbon marked this conversation as resolved.
Show resolved Hide resolved
if i >= 2:
meyerscetbon marked this conversation as resolved.
Show resolved Hide resolved
if criterion > tol / 1e-1:
err = criterion
else:
count_escape = count_escape + 1
state.set(count_escape=count_escape)
meyerscetbon marked this conversation as resolved.
Show resolved Hide resolved
if count_escape != iteration:
err = criterion
else:
err = jnp.inf
else:
err = jnp.inf
return jnp.logical_and(i >= 2, err < tol)

def _diverged(self, state: LRSinkhornState, iteration: int) -> bool:
return jnp.logical_not(jnp.isfinite(state.costs[iteration - 1]))
return jnp.logical_or(
jnp.logical_not(jnp.isfinite(state.criterion)),
jnp.logical_not(jnp.isfinite(state.costs[iteration - 1]))
)

def _continue(self, state: LRSinkhornState, iteration: int) -> bool:
"""Continue while not(converged) and not(diverged)."""
Expand All @@ -338,15 +426,34 @@ def lr_costs(
self, ot_prob: linear_problems.LinearProblem, state: LRSinkhornState,
iteration: int
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
c_q = ot_prob.geom.apply_cost(state.r, axis=1) / state.g[None, :]
c_q += (self.epsilon - 1 / self.gamma) * jnp.log(state.q)
c_r = ot_prob.geom.apply_cost(state.q) / state.g[None, :]
c_r += (self.epsilon - 1 / self.gamma) * jnp.log(state.r)
grad_q = ot_prob.geom.apply_cost(state.r, axis=1) / state.g[None, :]
grad_q = jnp.where(
self.epsilon != 0., grad_q + self.epsilon * jnp.log(state.q), grad_q
meyerscetbon marked this conversation as resolved.
Show resolved Hide resolved
)
if self.gamma_init == "rescale":
norm_q = jnp.max(jnp.abs(grad_q)) ** 2
grad_r = ot_prob.geom.apply_cost(state.q) / state.g[None, :]
grad_r = jnp.where(
self.epsilon != 0., grad_r + self.epsilon * jnp.log(state.r), grad_r
)
if self.gamma_init == "rescale":
norm_r = jnp.max(jnp.abs(grad_r)) ** 2
diag_qcr = jnp.sum(
state.q * ot_prob.geom.apply_cost(state.r, axis=1), axis=0
)
grad_g = -diag_qcr / state.g ** 2
meyerscetbon marked this conversation as resolved.
Show resolved Hide resolved
grad_g = jnp.where(
self.epsilon != 0., grad_g + self.epsilon * jnp.log(state.g), grad_g
)
if self.gamma_init == "rescale":
norm_g = jnp.max(jnp.abs(grad_g)) ** 2
if self.gamma_init == "rescale":
self.gamma = self.gamma / max(norm_q, norm_r, norm_g)
meyerscetbon marked this conversation as resolved.
Show resolved Hide resolved
h = diag_qcr / state.g ** 2 - (self.epsilon -
meyerscetbon marked this conversation as resolved.
Show resolved Hide resolved
1 / self.gamma) * jnp.log(state.g)
c_q = grad_q - (1 / self.gamma) * jnp.log(state.q)
meyerscetbon marked this conversation as resolved.
Show resolved Hide resolved
c_r = grad_r - (1 / self.gamma) * jnp.log(state.r)
h = -grad_g + (1 / self.gamma) * jnp.log(state.g)
return c_q, c_r, h

def dysktra_update(
Expand All @@ -358,10 +465,10 @@ def dysktra_update(
state: LRSinkhornState,
iteration: int,
min_entry_value: float = 1e-6,
tolerance: float = 1e-4,
tolerance: float = 1e-3,
min_iter: int = 0,
inner_iter: int = 10,
max_iter: int = 200
max_iter: int = 10000
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
# shortcuts for problem's definition.
r = self.rank
Expand Down Expand Up @@ -458,11 +565,15 @@ def lse_step(
iteration: int
) -> LRSinkhornState:
"""LR Sinkhorn LSE update."""
q_prev, r_prev, g_prev = state.q, state.r, state.g
c_q, c_r, h = self.lr_costs(ot_prob, state, iteration)
gamma = self.gamma
q, r, g = self.dysktra_update(
c_q, c_r, h, ot_prob, state, iteration, **self.kwargs_dys
)
return state.set(q=q, g=g, r=r)
return state.set(
q=q, g=g, r=r, q_prev=q_prev, g_prev=g_prev, r_prev=r_prev, gamma=gamma
)

def kernel_step(
self, ot_prob: linear_problems.LinearProblem, state: LRSinkhornState,
Expand Down Expand Up @@ -496,22 +607,39 @@ def one_iteration(
else:
state = self.kernel_step(ot_prob, state, iteration)

# compute the criterion
criterion = state.compute_crit()
meyerscetbon marked this conversation as resolved.
Show resolved Hide resolved

# re-computes error if compute_error is True, else set it to inf.
cost = jnp.where(
jnp.logical_and(compute_error, iteration >= self.min_iterations),
state.reg_ot_cost(ot_prob), jnp.inf
)
costs = state.costs.at[iteration // self.inner_iterations].set(cost)
return state.set(costs=costs)
return state.set(costs=costs, criterion=criterion)

def init_state(
self, ot_prob: linear_problems.LinearProblem,
init: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
) -> LRSinkhornState:
"""Return the initial state of the loop."""
gamma = self.gamma
q, r, g = init
costs = -jnp.ones(self.outer_iterations)
return LRSinkhornState(q=q, r=r, g=g, costs=costs)
criterion = 0.0
Copy link
Collaborator

@michalk8 michalk8 Jul 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to assign to variables criteron/count_escape, would just hardcode in the init call.

count_escape = 1
return LRSinkhornState(
q=q,
r=r,
g=g,
q_prev=q,
r_prev=r,
g_prev=g,
gamma=gamma,
costs=costs,
criterion=criterion,
count_escape=count_escape
)

def output_from_state(
self, ot_prob: linear_problems.LinearProblem, state: LRSinkhornState
Expand Down
Loading