Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make random seeding consistent with PRNG keys everywhere instead of seeds #290

Merged
merged 20 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def to_LRCGeometry(
self,
rank: int = 0,
tol: float = 1e-2,
seed: int = 0,
rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0),
scale: float = 1.
) -> 'low_rank.LRCGeometry':
r"""Factorize the cost matrix using either SVD (full) or :cite:`indyk:19`.
Expand All @@ -648,7 +648,7 @@ def to_LRCGeometry(
rank: Target rank of the :attr:`cost_matrix`.
tol: Tolerance of the error. The total number of sampled points is
:math:`min(n, m,\frac{rank}{tol})`.
seed: Random seed.
rng: The PRNG key to use for initializing the model.
scale: Value used to rescale the factors of the low-rank geometry.
Useful when this geometry is used in the linear term of fused GW.

Expand All @@ -670,7 +670,6 @@ def to_LRCGeometry(
cost_1 = u
cost_2 = (s[:, None] * vh).T
else:
rng = jax.random.PRNGKey(seed)
key1, key2, key3, key4, key5 = jax.random.split(rng, 5)
n_subset = min(int(rank / tol), n, m)

Expand Down
3 changes: 2 additions & 1 deletion src/ott/geometry/low_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,10 @@ def finalize(carry):
return max_value + self._bias

def to_LRCGeometry(
self, rank: int = 0, tol: float = 1e-2, seed: int = 0
self, rank: int = 0, tol: float = 1e-2, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0),
) -> 'LRCGeometry':
"""Return self."""
del rank, tol, rng
return self

@property
Expand Down
64 changes: 31 additions & 33 deletions src/ott/initializers/linear/initializers_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, rank: int, **kwargs: Any):
def init_q(
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jnp.ndarray,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the type hint should rather be jax.random.PRNGKeyArray

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same everywhere, it should always be rng: jax.random.PRNGKeyArray instead of rng: jnp.ndarray

*,
init_g: jnp.ndarray,
**kwargs: Any,
Expand All @@ -75,7 +75,7 @@ def init_q(

Args:
ot_prob: OT problem.
key: Random key for seeding.
rng: Random key for seeding.
init_g: Initial value for :math:`g` factor.
kwargs: Additional keyword arguments.

Expand All @@ -87,7 +87,7 @@ def init_q(
def init_r(
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jnp.ndarray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
Expand All @@ -96,7 +96,7 @@ def init_r(

Args:
ot_prob: Linear OT problem.
key: Random key for seeding.
rng: Random key for seeding.
init_g: Initial value for :math:`g` factor.
kwargs: Additional keyword arguments.

Expand All @@ -108,14 +108,14 @@ def init_r(
def init_g(
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
"""Initialize the low-rank factor :math:`g`.

Args:
ot_prob: OT problem.
key: Random key for seeding.
rng: Random key for seeding.
kwargs: Additional keyword arguments.

Returns:
Expand Down Expand Up @@ -176,7 +176,7 @@ def __call__(
r: Optional[jnp.ndarray] = None,
g: Optional[jnp.ndarray] = None,
*,
key: Optional[jnp.ndarray] = None,
rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0),
**kwargs: Any
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Initialize the factors :math:`Q`, :math:`R` and :math:`g`.
Expand All @@ -189,16 +189,14 @@ def __call__(
using :meth:`init_r`.
g: Factor of shape ``[rank,]``. If `None`, it will be initialized
using :meth:`init_g`.
key: Random key for seeding.
rng: Random key for seeding.
kwargs: Additional keyword arguments for :meth:`init_q`, :meth:`init_r`
and :meth:`init_g`.

Returns:
The factors :math:`Q`, :math:`R` and :math:`g`, respectively.
"""
if key is None:
key = jax.random.PRNGKey(0)
key1, key2, key3 = jax.random.split(key, 3)
key1, key2, key3 = jax.random.split(rng, 3)

if g is None:
g = self.init_g(ot_prob, key1, **kwargs)
Expand Down Expand Up @@ -240,37 +238,37 @@ class RandomInitializer(LRInitializer):
def init_q( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jnp.ndarray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del kwargs, init_g
a = ot_prob.a
init_q = jnp.abs(jax.random.normal(key, (a.shape[0], self.rank)))
init_q = jnp.abs(jax.random.normal(rng, (a.shape[0], self.rank)))
return a[:, None] * (init_q / jnp.sum(init_q, axis=1, keepdims=True))

def init_r( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jnp.ndarray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del kwargs, init_g
b = ot_prob.b
init_r = jnp.abs(jax.random.normal(key, (b.shape[0], self.rank)))
init_r = jnp.abs(jax.random.normal(rng, (b.shape[0], self.rank)))
return b[:, None] * (init_r / jnp.sum(init_r, axis=1, keepdims=True))

def init_g( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del kwargs
init_g = jnp.abs(jax.random.uniform(key, (self.rank,))) + 1.
init_g = jnp.abs(jax.random.uniform(rng, (self.rank,))) + 1.
return init_g / jnp.sum(init_g)


Expand Down Expand Up @@ -314,32 +312,32 @@ def _compute_factor(
def init_q( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jnp.ndarray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del key, kwargs
del rng, kwargs
return self._compute_factor(ot_prob, init_g, which="q")

def init_r( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jnp.ndarray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del key, kwargs
del rng, kwargs
return self._compute_factor(ot_prob, init_g, which="r")

def init_g( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del key, kwargs
del rng, kwargs
return jnp.ones((self.rank,)) / self.rank


Expand Down Expand Up @@ -387,7 +385,7 @@ def _extract_array(
def _compute_factor(
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jnp.ndarray,
*,
init_g: jnp.ndarray,
which: Literal["q", "r"],
Expand All @@ -413,7 +411,7 @@ def _compute_factor(
arr = self._extract_array(geom, first=which == "q")
marginals = ot_prob.a if which == "q" else ot_prob.b

centroids = fn(arr, self.rank, key=key).centroids
centroids = fn(arr, self.rank, rng=rng).centroids
geom = pointcloud.PointCloud(
arr, centroids, epsilon=0.1, scale_cost="max_cost"
)
Expand All @@ -425,34 +423,34 @@ def _compute_factor(
def init_q( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jnp.ndarray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
return self._compute_factor(
ot_prob, key, init_g=init_g, which="q", **kwargs
ot_prob, rng, init_g=init_g, which="q", **kwargs
)

def init_r( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jnp.ndarray,
*,
init_g: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
return self._compute_factor(
ot_prob, key, init_g=init_g, which="r", **kwargs
ot_prob, rng, init_g=init_g, which="r", **kwargs
)

def init_g( # noqa: D102
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jnp.ndarray,
**kwargs: Any,
) -> jnp.ndarray:
del key, kwargs
del rng, kwargs
return jnp.ones((self.rank,)) / self.rank

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
Expand Down Expand Up @@ -518,7 +516,7 @@ class State(NamedTuple): # noqa: D106
def _compute_factor(
self,
ot_prob: Problem_t,
key: jnp.ndarray,
rng: jnp.ndarray,
*,
init_g: jnp.ndarray,
which: Literal["q", "r"],
Expand All @@ -530,7 +528,7 @@ def _compute_factor(

def init_fn() -> GeneralizedKMeansInitializer.State:
n = geom.shape[0]
factor = jnp.abs(jax.random.normal(key, (n, self.rank))) + 1. # (n, r)
factor = jnp.abs(jax.random.normal(rng, (n, self.rank))) + 1. # (n, r)
factor *= consts.marginal[:, None] / jnp.sum(
factor, axis=1, keepdims=True
)
Expand Down
20 changes: 10 additions & 10 deletions src/ott/problems/nn/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ class GaussianMixture:
- ``square_four`` (two-dimensional Gaussians in the corners of a rectangle)

batch_size: batch size of the samples
init_key: initial PRNG key
init_rng: initial PRNG key
scale: scale of the individual Gaussian samples
variance: the variance of the individual Gaussian samples
"""
name: Name_t
batch_size: int
init_key: jax.random.PRNGKey
init_rng: jax.random.PRNGKey
scale: float = 5.0
variance: float = 0.5

Expand Down Expand Up @@ -95,7 +95,7 @@ def create_sample_generators(self) -> Iterator[jnp.array]:
Returns:
A generator of samples from the Gaussian mixture.
"""
key = self.init_key
key = self.init_rng
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
while True:
k1, k2, key = jax.random.split(key, 3)
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
means = jax.random.choice(k1, self.centers, [self.batch_size])
Expand All @@ -109,7 +109,7 @@ def create_gaussian_mixture_samplers(
name_target: Name_t,
train_batch_size: int = 2048,
valid_batch_size: int = 2048,
key: jax.random.PRNGKey = jax.random.PRNGKey(0),
rng: jax.random.PRNGKey = jax.random.PRNGKey(0),
) -> Tuple[Dataset, Dataset, int]:
"""Creates Gaussian samplers for :class:`~ott.solvers.nn.neuraldual.W2NeuralDual`.

Expand All @@ -118,33 +118,33 @@ def create_gaussian_mixture_samplers(
name_target: name of the target sampler
train_batch_size: the training batch size
valid_batch_size: the validation batch size
key: initial PRNG key
rng: initial PRNG key

Returns:
The dataset and dimension of the data.
"""
k1, k2, k3, k4 = jax.random.split(key, 4)
k1, k2, k3, k4 = jax.random.split(rng, 4)
train_dataset = Dataset(
source_iter=iter(
GaussianMixture(
name_source, batch_size=train_batch_size, init_key=k1
name_source, batch_size=train_batch_size, init_rng=k1
)
),
target_iter=iter(
GaussianMixture(
name_target, batch_size=train_batch_size, init_key=k2
name_target, batch_size=train_batch_size, init_rng=k2
)
)
)
valid_dataset = Dataset(
source_iter=iter(
GaussianMixture(
name_source, batch_size=valid_batch_size, init_key=k3
name_source, batch_size=valid_batch_size, init_rng=k3
)
),
target_iter=iter(
GaussianMixture(
name_target, batch_size=valid_batch_size, init_key=k4
name_target, batch_size=valid_batch_size, init_rng=k4
)
)
)
Expand Down
12 changes: 6 additions & 6 deletions src/ott/problems/quadratic/quadratic_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,11 @@ def convertible(geom: geometry.Geometry) -> bool:
(geom_xy is None or convertible(geom_xy))
)

def to_low_rank(self, seed: int = 0) -> "QuadraticProblem":
def to_low_rank(self, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0)) -> "QuadraticProblem":
"""Convert geometries to low-rank.

Args:
seed: Random seed.
rng: Random key for seeding.

Returns:
Quadratic problem with low-rank geometries.
Expand All @@ -383,19 +383,19 @@ def convert(
return self

(geom_xx, geom_yy, geom_xy, *children), aux_data = self.tree_flatten()
(s1, s2, s3) = jax.random.split(jax.random.PRNGKey(seed), 3)[:, 0]
(k1, k2, k3) = jax.random.split(rng, 3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(k1, k2, k3) = jax.random.split(rng, 3)
rng1, rng2, rng3 = jax.random.split(rng, 3)

(No parenthesis to be consistent, and k -> rng

(r1, r2, r3), (t1, t2, t3) = convert(self.ranks), convert(self.tolerances)

geom_xx = geom_xx.to_LRCGeometry(rank=r1, tol=t1, seed=s1)
geom_yy = geom_yy.to_LRCGeometry(rank=r2, tol=t2, seed=s2)
geom_xx = geom_xx.to_LRCGeometry(rank=r1, tol=t1, rng=k1)
geom_yy = geom_yy.to_LRCGeometry(rank=r2, tol=t2, rng=k2)
if self.is_fused:
if isinstance(
geom_xy, pointcloud.PointCloud
) and geom_xy.is_squared_euclidean:
geom_xy = geom_xy.to_LRCGeometry(scale=self.fused_penalty)
else:
geom_xy = geom_xy.to_LRCGeometry(
rank=r3, tol=t3, seed=s3, scale=self.fused_penalty
rank=r3, tol=t3, rng=k3, scale=self.fused_penalty
)

return type(self).tree_unflatten(
Expand Down
Loading