Skip to content

Commit

Permalink
split the random generalized normal test and skip its K-S half
Browse files Browse the repository at this point in the history
It is key-sensitive and sometimes slow.

PiperOrigin-RevId: 590756597
  • Loading branch information
froystig authored and jax authors committed Dec 14, 2023
1 parent 9198174 commit 3380b9f
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions tests/random_lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,21 +651,32 @@ def testOrthogonal(self, n, shape, dtype):
)
def testGeneralizedNormal(self, p, shape, dtype):
key = self.make_key(2)
rand = lambda key, p, shape: random.generalized_normal(key, p, shape, dtype)
crand = jax.jit(rand, static_argnums=2)
rand = lambda key, p: random.generalized_normal(key, p, shape, dtype)
crand = jax.jit(rand)

uncompiled_samples = rand(key, p, shape)
compiled_samples = crand(key, p, shape)
uncompiled_samples = rand(key, p)
compiled_samples = crand(key, p)
for samples in [uncompiled_samples, compiled_samples]:
self.assertEqual(samples.shape, shape)
self.assertEqual(samples.dtype, dtype)

uncompiled_samples = rand(key, p, (300, *shape))
compiled_samples = crand(key, p, (300, *shape))
@jtu.sample_product(
p=[.5, 1., 1.5, 2., 2.5],
shape=[(), (5,), (10, 5)],
dtype=jtu.dtypes.floating,
)
def testGeneralizedNormalKS(self, p, shape, dtype):
self.skipTest( # test is also sometimes slow, with (300, ...)-shape draws
"sensitive to random key - https://github.com/google/jax/issues/18941")
key = self.make_key(2)
rand = lambda key, p: random.generalized_normal(key, p, (300, *shape), dtype)
crand = jax.jit(rand)

uncompiled_samples = rand(key, p)
compiled_samples = crand(key, p)
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples.ravel(), scipy.stats.gennorm(p).cdf)


@jtu.sample_product(
d=range(1, 5),
p=[.5, 1., 1.5, 2., 2.5],
Expand Down

0 comments on commit 3380b9f

Please sign in to comment.