diff --git a/PROTO_tests/tests/random_test.py b/PROTO_tests/tests/random_test.py index c8d85dc70c..d1e3dc4c47 100644 --- a/PROTO_tests/tests/random_test.py +++ b/PROTO_tests/tests/random_test.py @@ -50,6 +50,33 @@ def test_integers(self): assert all(bounded_arr.to_ndarray() >= -5) assert all(bounded_arr.to_ndarray() < 5) + def test_shuffle(self): + # verify same seed gives reproducible arrays + rng = ak.random.default_rng(18) + + int_pda = rng.integers(-(2**32), 2**32, 10) + pda_copy = int_pda[:] + # shuffle int_pda in place + rng.shuffle(int_pda) + # verify all the same elements are in permutation as the original + assert (ak.sort(int_pda) == ak.sort(pda_copy)).all() + + float_pda = rng.uniform(-(2**32), 2**32, 10) + pda_copy = float_pda[:] + rng.shuffle(float_pda) + # verify all the same elements are in permutation as the original + assert (ak.sort(float_pda) == ak.sort(pda_copy)).all() + + rng = ak.random.default_rng(18) + + pda = rng.integers(-(2**32), 2**32, 10) + rng.shuffle(pda) + assert (pda == int_pda).all() + + pda = rng.uniform(-(2**32), 2**32, 10) + rng.shuffle(pda) + assert np.allclose(pda.to_list(), float_pda.to_list()) + def test_permutation(self): # verify same seed gives reproducible arrays rng = ak.random.default_rng(18) @@ -204,10 +231,14 @@ def test_legacy_uniform(self): assert ak.float64 == testArray.dtype uArray = ak.random.uniform(size=3, low=0, high=5, seed=0) - assert np.allclose([0.30013431967121934, 0.47383036230759112, 1.0441791878997098], uArray.to_list()) + assert np.allclose( + [0.30013431967121934, 0.47383036230759112, 1.0441791878997098], uArray.to_list() + ) uArray = ak.random.uniform(size=np.int64(3), low=np.int64(0), high=np.int64(5), seed=np.int64(0)) - assert np.allclose([0.30013431967121934, 0.47383036230759112, 1.0441791878997098], uArray.to_list()) + assert np.allclose( + [0.30013431967121934, 0.47383036230759112, 1.0441791878997098], uArray.to_list() + ) with pytest.raises(TypeError): ak.random.uniform(low="0", high=5, size=100) diff --git a/arkouda/random/_generator.py b/arkouda/random/_generator.py index f6d0197a21..ed40c6e101 100644 --- a/arkouda/random/_generator.py +++ b/arkouda/random/_generator.py @@ -214,6 +214,35 @@ def standard_normal(self, size=None): return self._np_generator.standard_normal() return standard_normal(size=size, seed=self._seed) + def shuffle(self, x): + """ + Randomly shuffle a pdarray in place. + + Parameters + ---------- + x: pdarray + shuffle the elements of x randomly in place + + Returns + ------- + None + """ + if not isinstance(x, pdarray): + raise TypeError("shuffle only accepts a pdarray.") + dtype = to_numpy_dtype(x.dtype) + name = self._name_dict[to_numpy_dtype(akint64)] + generic_msg( + cmd="shuffle", + args={ + "name": name, + "x": x, + "size": x.size, + "dtype": dtype, + "state": self._state, + }, + ) + self._state += x.size + def permutation(self, x): """ Randomly permute a sequence, or return a permuted range. diff --git a/src/RandMsg.chpl b/src/RandMsg.chpl index c29e1cd798..60cd7ce3a8 100644 --- a/src/RandMsg.chpl +++ b/src/RandMsg.chpl @@ -234,7 +234,6 @@ module RandMsg return new MsgTuple(repMsg, MsgType.NORMAL); } - proc uniformGeneratorMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws { const pn = Reflection.getRoutineName(); var rname = st.nextName(); @@ -369,9 +368,61 @@ module RandMsg return new MsgTuple(repMsg, MsgType.NORMAL); } + proc shuffleMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws { + const pn = Reflection.getRoutineName(); + const name = msgArgs.getValueOf("name"); + const xName = msgArgs.getValueOf("x"); + const size = msgArgs.get("size").getIntValue(); + const dtypeStr = msgArgs.getValueOf("dtype"); + const dtype = str2dtype(dtypeStr); + const state = msgArgs.get("state").getIntValue(); + + randLogger.debug(getModuleName(),getRoutineName(),getLineNumber(), + "name: %? size %i dtype: %? state %i".doFormat(name, size, dtypeStr, state)); + + st.checkTable(name); + + proc shuffleHelper(type t) throws { + var generatorEntry: borrowed GeneratorSymEntry(int) = toGeneratorSymEntry(st.lookup(name), int); + ref rng = generatorEntry.generator; + + if state != 1 { + // you have to skip to one before where you want to be + rng.skipTo(state-1); + } + + ref myArr = toSymEntry(getGenericTypedArrayEntry(xName, st),t).a; + rng.shuffle(myArr); + } + + select dtype { + when DType.Int64 { + shuffleHelper(int); + } + when DType.UInt64 { + shuffleHelper(uint); + } + when DType.Float64 { + shuffleHelper(real); + } + when DType.Bool { + shuffleHelper(bool); + } + otherwise { + var errorMsg = "Unhandled data type %s".doFormat(dtypeStr); + randLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); + return new MsgTuple(notImplementedError(pn, errorMsg), MsgType.ERROR); + } + } + var repMsg = "created " + st.attrib(xName); + randLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); + return new MsgTuple(repMsg, MsgType.NORMAL); + } + use CommandMap; registerFunction("randomNormal", randomNormalMsg, getModuleName()); registerFunction("createGenerator", createGeneratorMsg, getModuleName()); registerFunction("uniformGenerator", uniformGeneratorMsg, getModuleName()); registerFunction("permutation", permutationMsg, getModuleName()); + registerFunction("shuffle", shuffleMsg, getModuleName()); } diff --git a/tests/random_test.py b/tests/random_test.py index c992e4071b..aebb74d088 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -49,6 +49,33 @@ def test_integers(self): self.assertTrue(all(bounded_arr.to_ndarray() >= -5)) self.assertTrue(all(bounded_arr.to_ndarray() < 5)) + def test_shuffle(self): + # verify same seed gives reproducible arrays + rng = ak.random.default_rng(18) + + int_pda = rng.integers(-(2**32), 2**32, 10) + pda_copy = int_pda[:] + # shuffle int_pda in place + rng.shuffle(int_pda) + # verify all the same elements are in permutation as the original + self.assertEqual(ak.sort(int_pda).to_list(), ak.sort(pda_copy).to_list()) + + float_pda = rng.uniform(-(2**32), 2**32, 10) + pda_copy = float_pda[:] + rng.shuffle(float_pda) + # verify all the same elements are in permutation as the original + self.assertEqual(ak.sort(float_pda).to_list(), ak.sort(pda_copy).to_list()) + + rng = ak.random.default_rng(18) + + pda = rng.integers(-(2**32), 2**32, 10) + rng.shuffle(pda) + self.assertEqual(pda.to_list(), int_pda.to_list()) + + pda = rng.uniform(-(2**32), 2**32, 10) + rng.shuffle(pda) + self.assertTrue(np.allclose(pda.to_list(), float_pda.to_list())) + def test_permutation(self): # verify same seed gives reproducible arrays rng = ak.random.default_rng(18) @@ -77,7 +104,9 @@ def test_permutation(self): pda = rng.uniform(-(2**32), 2**32, 10) same_seed_float_array_permute = rng.permutation(pda) # verify all the same elements are in permutation as the original - self.assertTrue(np.allclose(float_array_permute.to_list(), same_seed_float_array_permute.to_list())) + self.assertTrue( + np.allclose(float_array_permute.to_list(), same_seed_float_array_permute.to_list()) + ) def test_uniform(self): # verify same seed gives different but reproducible arrays