Skip to content

Commit

Permalink
Adds parameterization to test_shuffle and test_permutation (#3320)
Browse files Browse the repository at this point in the history
* Adds parameterization to test_shuffle and test_permutation

* Fixes several mistakes.

---------

Co-authored-by: drculhane <[email protected]>
  • Loading branch information
drculhane and drculhane authored Jul 11, 2024
1 parent 5be20f0 commit 13cdb62
Showing 1 changed file with 49 additions and 39 deletions.
88 changes: 49 additions & 39 deletions PROTO_tests/tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import arkouda as ak
from arkouda.scipy import chisquare as akchisquare

INT_FLOAT = [ak.int64, ak.float64]


class TestRandom:
def test_integers(self):
Expand Down Expand Up @@ -55,62 +57,70 @@ def test_integers(self):
assert all(bounded_arr.to_ndarray() >= -5)
assert all(bounded_arr.to_ndarray() < 5)

def test_shuffle(self):
# verify same seed gives reproducible arrays
rng = ak.random.default_rng(18)
@pytest.mark.parametrize("data_type", INT_FLOAT)
def test_shuffle(self, data_type):

int_pda = rng.integers(-(2**32), 2**32, 10)
pda_copy = int_pda[:]
# shuffle int_pda in place
rng.shuffle(int_pda)
# verify all the same elements are in permutation as the original
assert (ak.sort(int_pda) == ak.sort(pda_copy)).all()
# ints are checked for equality; floats are checked for closeness

float_pda = rng.uniform(-(2**32), 2**32, 10)
pda_copy = float_pda[:]
rng.shuffle(float_pda)
# verify all the same elements are in permutation as the original
assert (ak.sort(float_pda) == ak.sort(pda_copy)).all()
check = lambda a, b, t: (
(a == b).all() if t is ak.int64 else np.allclose(a.to_list(), b.to_list())
)

rng = ak.random.default_rng(18)
# verify all the same elements are in the shuffle as in the original

pda = rng.integers(-(2**32), 2**32, 10)
rng = ak.random.default_rng(18)
rnfunc = rng.integers if data_type is ak.int64 else rng.uniform
pda = rnfunc(-(2**32), 2**32, 10)
pda_copy = pda[:]
rng.shuffle(pda)
assert (pda == int_pda).all()

pda = rng.uniform(-(2**32), 2**32, 10)
rng.shuffle(pda)
assert np.allclose(pda.to_list(), float_pda.to_list())
assert check(ak.sort(pda), ak.sort(pda_copy), data_type)

def test_permutation(self):
# verify same seed gives reproducible arrays

rng = ak.random.default_rng(18)
rnfunc = rng.integers if data_type is ak.int64 else rng.uniform
pda_prime = rnfunc(-(2**32), 2**32, 10)
rng.shuffle(pda_prime)

assert check(pda, pda_prime, data_type)

@pytest.mark.parametrize("data_type", INT_FLOAT)
def test_permutation(self, data_type):

# ints are checked for equality; floats are checked for closeness

check = lambda a, b, t: (
(a == b).all() if t is ak.int64 else np.allclose(a.to_list(), b.to_list())
)

# verify all the same elements are in the permutation as in the original

rng = ak.random.default_rng(18)
# providing just a number permutes the range(num)
range_permute = rng.permutation(20)
assert (ak.arange(20) == ak.sort(range_permute)).all()
assert (ak.arange(20) == ak.sort(range_permute)).all() # range is always int

pda = rng.integers(-(2**32), 2**32, 10)
array_permute = rng.permutation(pda)
# verify all the same elements are in permutation as the original
assert (ak.sort(pda) == ak.sort(array_permute)).all()
# verify same seed gives reproducible arrays

pda = rng.uniform(-(2**32), 2**32, 10)
float_array_permute = rng.permutation(pda)
# verify all the same elements are in permutation as the original
assert np.allclose(ak.sort(pda).to_list(), ak.sort(float_array_permute).to_list())
rng = ak.random.default_rng(18)
rnfunc = rng.integers if data_type is ak.int64 else rng.uniform
pda = rnfunc(-(2**32), 2**32, 10)
permuted = rng.permutation(pda)
assert check(ak.sort(pda), ak.sort(permuted), data_type)

# verify same seed gives reproducible permutations

rng = ak.random.default_rng(18)
same_seed_range_permute = rng.permutation(20)
assert (range_permute == same_seed_range_permute).all()
assert check(range_permute, same_seed_range_permute, data_type)

pda = rng.integers(-(2**32), 2**32, 10)
same_seed_array_permute = rng.permutation(pda)
assert (array_permute == same_seed_array_permute).all()
# verify all the same elements are in permutation as in the original

pda = rng.uniform(-(2**32), 2**32, 10)
same_seed_float_array_permute = rng.permutation(pda)
# verify all the same elements are in permutation as the original
assert np.allclose(float_array_permute.to_list(), same_seed_float_array_permute.to_list())
rng = ak.random.default_rng(18)
rnfunc = rng.integers if data_type is ak.int64 else rng.uniform
pda_p = rnfunc(-(2**32), 2**32, 10)
permuted_p = rng.permutation(pda_p)
assert check(ak.sort(pda_p), ak.sort(permuted_p), data_type)

def test_uniform(self):
# verify same seed gives different but reproducible arrays
Expand Down

0 comments on commit 13cdb62

Please sign in to comment.