Skip to content

Commit

Permalink
Make random seeding consistent with PRNG keys everywhere instead of s…
Browse files Browse the repository at this point in the history
…eeds (#290)

* add ENS

* removed ENS

* Test

* Test

* Change gw_barycenter and neural dual

* fix random number generators in initializers

* k-means and gromov

* fix rng on low rank geometry and datasets

* fix remaining seed changes to rng

* fixed rng reference and increased epsilon to make algo converge, should open issue

* fixed keys to rngs

* extended rng instead of keys to inside functions and added jax.random.PRNGKeyArray as default type

* fixed more keys to rng with correct default type

* ran pre-commit

* Fix typo

---------

Co-authored-by: Anastasiia <[email protected]>
Co-authored-by: michalk8 <[email protected]>
  • Loading branch information
3 people authored Feb 24, 2023
1 parent edf98d1 commit d4857ea
Show file tree
Hide file tree
Showing 52 changed files with 743 additions and 694 deletions.
17 changes: 8 additions & 9 deletions src/ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def to_LRCGeometry(
self,
rank: int = 0,
tol: float = 1e-2,
seed: int = 0,
rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0),
scale: float = 1.
) -> 'low_rank.LRCGeometry':
r"""Factorize the cost matrix using either SVD (full) or :cite:`indyk:19`.
Expand All @@ -642,7 +642,7 @@ def to_LRCGeometry(
rank: Target rank of the :attr:`cost_matrix`.
tol: Tolerance of the error. The total number of sampled points is
:math:`min(n, m,\frac{rank}{tol})`.
seed: Random seed.
rng: The PRNG key to use for initializing the model.
scale: Value used to rescale the factors of the low-rank geometry.
Useful when this geometry is used in the linear term of fused GW.
Expand All @@ -664,27 +664,26 @@ def to_LRCGeometry(
cost_1 = u
cost_2 = (s[:, None] * vh).T
else:
rng = jax.random.PRNGKey(seed)
key1, key2, key3, key4, key5 = jax.random.split(rng, 5)
rng1, rng2, rng3, rng4, rng5 = jax.random.split(rng, 5)
n_subset = min(int(rank / tol), n, m)

i_star = jax.random.randint(key1, shape=(), minval=0, maxval=n)
j_star = jax.random.randint(key2, shape=(), minval=0, maxval=m)
i_star = jax.random.randint(rng1, shape=(), minval=0, maxval=n)
j_star = jax.random.randint(rng2, shape=(), minval=0, maxval=m)

ci_star = self.subset(i_star, None).cost_matrix.ravel() ** 2 # (m,)
cj_star = self.subset(None, j_star).cost_matrix.ravel() ** 2 # (n,)

p_row = cj_star + ci_star[j_star] + jnp.mean(ci_star) # (n,)
p_row /= jnp.sum(p_row)
row_ixs = jax.random.choice(key3, n, shape=(n_subset,), p=p_row)
row_ixs = jax.random.choice(rng3, n, shape=(n_subset,), p=p_row)
# (n_subset, m)
s = self.subset(row_ixs, None).cost_matrix
s /= jnp.sqrt(n_subset * p_row[row_ixs][:, None])

p_col = jnp.sum(s ** 2, axis=0) # (m,)
p_col /= jnp.sum(p_col)
# (n_subset,)
col_ixs = jax.random.choice(key4, m, shape=(n_subset,), p=p_col)
col_ixs = jax.random.choice(rng4, m, shape=(n_subset,), p=p_col)
# (n_subset, n_subset)
w = s[:, col_ixs] / jnp.sqrt(n_subset * p_col[col_ixs][None, :])

Expand All @@ -696,7 +695,7 @@ def to_LRCGeometry(
v = v.T / jnp.sqrt(d)[None, :]

inv_scale = (1. / jnp.sqrt(n_subset))
col_ixs = jax.random.choice(key5, m, shape=(n_subset,)) # (n_subset,)
col_ixs = jax.random.choice(rng5, m, shape=(n_subset,)) # (n_subset,)

# (n, n_subset)
A_trans = self.subset(None, col_ixs).cost_matrix * inv_scale
Expand Down
6 changes: 5 additions & 1 deletion src/ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,13 @@ def finalize(carry):
return max_value + self._bias

def to_LRCGeometry(
self, rank: int = 0, tol: float = 1e-2, seed: int = 0
self,
rank: int = 0,
tol: float = 1e-2,
rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0),
) -> 'LRCGeometry':
"""Return self."""
del rank, tol, rng
return self

@property
Expand Down
70 changes: 34 additions & 36 deletions src/ott/initializers/linear/initializers_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, rank: int, **kwargs: Any):
def init_q(
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jax.random.PRNGKeyArray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
Expand All @@ -75,7 +75,7 @@ def init_q(
Args:
ot_prob: OT problem.
key: Random key for seeding.
rng: Random key for seeding.
init_g: Initial value for :math:`g` factor.
kwargs: Additional keyword arguments.
Expand All @@ -87,7 +87,7 @@ def init_q(
def init_r(
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jax.random.PRNGKeyArray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
Expand All @@ -96,7 +96,7 @@ def init_r(
Args:
ot_prob: Linear OT problem.
key: Random key for seeding.
rng: Random key for seeding.
init_g: Initial value for :math:`g` factor.
kwargs: Additional keyword arguments.
Expand All @@ -108,14 +108,14 @@ def init_r(
def init_g(
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jax.random.PRNGKeyArray,
**kwargs: Any,
) -> jnp.ndarray:
"""Initialize the low-rank factor :math:`g`.
Args:
ot_prob: OT problem.
key: Random key for seeding.
rng: Random key for seeding.
kwargs: Additional keyword arguments.
Returns:
Expand Down Expand Up @@ -176,7 +176,7 @@ def __call__(
r: Optional[jnp.ndarray] = None,
g: Optional[jnp.ndarray] = None,
*,
key: Optional[jnp.ndarray] = None,
rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0),
**kwargs: Any
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Initialize the factors :math:`Q`, :math:`R` and :math:`g`.
Expand All @@ -189,23 +189,21 @@ def __call__(
using :meth:`init_r`.
g: Factor of shape ``[rank,]``. If `None`, it will be initialized
using :meth:`init_g`.
key: Random key for seeding.
rng: Random key for seeding.
kwargs: Additional keyword arguments for :meth:`init_q`, :meth:`init_r`
and :meth:`init_g`.
Returns:
The factors :math:`Q`, :math:`R` and :math:`g`, respectively.
"""
if key is None:
key = jax.random.PRNGKey(0)
key1, key2, key3 = jax.random.split(key, 3)
rng1, rng2, rng3 = jax.random.split(rng, 3)

if g is None:
g = self.init_g(ot_prob, key1, **kwargs)
g = self.init_g(ot_prob, rng1, **kwargs)
if q is None:
q = self.init_q(ot_prob, key2, init_g=g, **kwargs)
q = self.init_q(ot_prob, rng2, init_g=g, **kwargs)
if r is None:
r = self.init_r(ot_prob, key3, init_g=g, **kwargs)
r = self.init_r(ot_prob, rng3, init_g=g, **kwargs)

assert g.shape == (self.rank,)
assert q.shape == (ot_prob.a.shape[0], self.rank)
Expand Down Expand Up @@ -240,37 +238,37 @@ class RandomInitializer(LRInitializer):
def init_q( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jax.random.PRNGKeyArray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del kwargs, init_g
a = ot_prob.a
init_q = jnp.abs(jax.random.normal(key, (a.shape[0], self.rank)))
init_q = jnp.abs(jax.random.normal(rng, (a.shape[0], self.rank)))
return a[:, None] * (init_q / jnp.sum(init_q, axis=1, keepdims=True))

def init_r( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jax.random.PRNGKeyArray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del kwargs, init_g
b = ot_prob.b
init_r = jnp.abs(jax.random.normal(key, (b.shape[0], self.rank)))
init_r = jnp.abs(jax.random.normal(rng, (b.shape[0], self.rank)))
return b[:, None] * (init_r / jnp.sum(init_r, axis=1, keepdims=True))

def init_g( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jax.random.PRNGKeyArray,
**kwargs: Any,
) -> jnp.ndarray:
del kwargs
init_g = jnp.abs(jax.random.uniform(key, (self.rank,))) + 1.
init_g = jnp.abs(jax.random.uniform(rng, (self.rank,))) + 1.
return init_g / jnp.sum(init_g)


Expand Down Expand Up @@ -314,32 +312,32 @@ def _compute_factor(
def init_q( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jax.random.PRNGKeyArray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del key, kwargs
del rng, kwargs
return self._compute_factor(ot_prob, init_g, which="q")

def init_r( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jax.random.PRNGKeyArray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del key, kwargs
del rng, kwargs
return self._compute_factor(ot_prob, init_g, which="r")

def init_g( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jax.random.PRNGKeyArray,
**kwargs: Any,
) -> jnp.ndarray:
del key, kwargs
del rng, kwargs
return jnp.ones((self.rank,)) / self.rank


Expand Down Expand Up @@ -387,7 +385,7 @@ def _extract_array(
def _compute_factor(
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jax.random.PRNGKeyArray,
*,
init_g: jnp.ndarray,
which: Literal["q", "r"],
Expand All @@ -413,7 +411,7 @@ def _compute_factor(
arr = self._extract_array(geom, first=which == "q")
marginals = ot_prob.a if which == "q" else ot_prob.b

centroids = fn(arr, self.rank, key=key).centroids
centroids = fn(arr, self.rank, rng=rng).centroids
geom = pointcloud.PointCloud(
arr, centroids, epsilon=0.1, scale_cost="max_cost"
)
Expand All @@ -425,34 +423,34 @@ def _compute_factor(
def init_q( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jax.random.PRNGKeyArray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
return self._compute_factor(
ot_prob, key, init_g=init_g, which="q", **kwargs
ot_prob, rng, init_g=init_g, which="q", **kwargs
)

def init_r( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jax.random.PRNGKeyArray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
return self._compute_factor(
ot_prob, key, init_g=init_g, which="r", **kwargs
ot_prob, rng, init_g=init_g, which="r", **kwargs
)

def init_g( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jax.random.PRNGKeyArray,
**kwargs: Any,
) -> jnp.ndarray:
del key, kwargs
del rng, kwargs
return jnp.ones((self.rank,)) / self.rank

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
Expand Down Expand Up @@ -518,7 +516,7 @@ class State(NamedTuple): # noqa: D106
def _compute_factor(
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jax.random.PRNGKeyArray,
*,
init_g: jnp.ndarray,
which: Literal["q", "r"],
Expand All @@ -530,7 +528,7 @@ def _compute_factor(

def init_fn() -> GeneralizedKMeansInitializer.State:
n = geom.shape[0]
factor = jnp.abs(jax.random.normal(key, (n, self.rank))) + 1. # (n, r)
factor = jnp.abs(jax.random.normal(rng, (n, self.rank))) + 1. # (n, r)
factor *= consts.marginal[:, None] / jnp.sum(
factor, axis=1, keepdims=True
)
Expand Down
Loading

0 comments on commit d4857ea

Please sign in to comment.