Skip to content

Commit

Permalink
Closes #3084: Add shuffle to rng (#3085)
Browse files Browse the repository at this point in the history
This PR (closes #3084) adds `shuffle` method to our random number generators

Co-authored-by: Tess Hayes <[email protected]>
  • Loading branch information
stress-tess and stress-tess authored Apr 11, 2024
1 parent 2cd6162 commit a308ec0
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 4 deletions.
35 changes: 33 additions & 2 deletions PROTO_tests/tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,33 @@ def test_integers(self):
assert all(bounded_arr.to_ndarray() >= -5)
assert all(bounded_arr.to_ndarray() < 5)

def test_shuffle(self):
# verify same seed gives reproducible arrays
rng = ak.random.default_rng(18)

int_pda = rng.integers(-(2**32), 2**32, 10)
pda_copy = int_pda[:]
# shuffle int_pda in place
rng.shuffle(int_pda)
# verify all the same elements are in permutation as the original
assert (ak.sort(int_pda) == ak.sort(pda_copy)).all()

float_pda = rng.uniform(-(2**32), 2**32, 10)
pda_copy = float_pda[:]
rng.shuffle(float_pda)
# verify all the same elements are in permutation as the original
assert (ak.sort(float_pda) == ak.sort(pda_copy)).all()

rng = ak.random.default_rng(18)

pda = rng.integers(-(2**32), 2**32, 10)
rng.shuffle(pda)
assert (pda == int_pda).all()

pda = rng.uniform(-(2**32), 2**32, 10)
rng.shuffle(pda)
assert np.allclose(pda.to_list(), float_pda.to_list())

def test_permutation(self):
# verify same seed gives reproducible arrays
rng = ak.random.default_rng(18)
Expand Down Expand Up @@ -204,10 +231,14 @@ def test_legacy_uniform(self):
assert ak.float64 == testArray.dtype

uArray = ak.random.uniform(size=3, low=0, high=5, seed=0)
assert np.allclose([0.30013431967121934, 0.47383036230759112, 1.0441791878997098], uArray.to_list())
assert np.allclose(
[0.30013431967121934, 0.47383036230759112, 1.0441791878997098], uArray.to_list()
)

uArray = ak.random.uniform(size=np.int64(3), low=np.int64(0), high=np.int64(5), seed=np.int64(0))
assert np.allclose([0.30013431967121934, 0.47383036230759112, 1.0441791878997098], uArray.to_list())
assert np.allclose(
[0.30013431967121934, 0.47383036230759112, 1.0441791878997098], uArray.to_list()
)

with pytest.raises(TypeError):
ak.random.uniform(low="0", high=5, size=100)
Expand Down
29 changes: 29 additions & 0 deletions arkouda/random/_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,35 @@ def standard_normal(self, size=None):
return self._np_generator.standard_normal()
return standard_normal(size=size, seed=self._seed)

def shuffle(self, x):
"""
Randomly shuffle a pdarray in place.
Parameters
----------
x: pdarray
shuffle the elements of x randomly in place
Returns
-------
None
"""
if not isinstance(x, pdarray):
raise TypeError("shuffle only accepts a pdarray.")
dtype = to_numpy_dtype(x.dtype)
name = self._name_dict[to_numpy_dtype(akint64)]
generic_msg(
cmd="shuffle",
args={
"name": name,
"x": x,
"size": x.size,
"dtype": dtype,
"state": self._state,
},
)
self._state += x.size

def permutation(self, x):
"""
Randomly permute a sequence, or return a permuted range.
Expand Down
53 changes: 52 additions & 1 deletion src/RandMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ module RandMsg
return new MsgTuple(repMsg, MsgType.NORMAL);
}


proc uniformGeneratorMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
const pn = Reflection.getRoutineName();
var rname = st.nextName();
Expand Down Expand Up @@ -369,9 +368,61 @@ module RandMsg
return new MsgTuple(repMsg, MsgType.NORMAL);
}

proc shuffleMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
const pn = Reflection.getRoutineName();
const name = msgArgs.getValueOf("name");
const xName = msgArgs.getValueOf("x");
const size = msgArgs.get("size").getIntValue();
const dtypeStr = msgArgs.getValueOf("dtype");
const dtype = str2dtype(dtypeStr);
const state = msgArgs.get("state").getIntValue();

randLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),
"name: %? size %i dtype: %? state %i".doFormat(name, size, dtypeStr, state));

st.checkTable(name);

proc shuffleHelper(type t) throws {
var generatorEntry: borrowed GeneratorSymEntry(int) = toGeneratorSymEntry(st.lookup(name), int);
ref rng = generatorEntry.generator;

if state != 1 {
// you have to skip to one before where you want to be
rng.skipTo(state-1);
}

ref myArr = toSymEntry(getGenericTypedArrayEntry(xName, st),t).a;
rng.shuffle(myArr);
}

select dtype {
when DType.Int64 {
shuffleHelper(int);
}
when DType.UInt64 {
shuffleHelper(uint);
}
when DType.Float64 {
shuffleHelper(real);
}
when DType.Bool {
shuffleHelper(bool);
}
otherwise {
var errorMsg = "Unhandled data type %s".doFormat(dtypeStr);
randLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return new MsgTuple(notImplementedError(pn, errorMsg), MsgType.ERROR);
}
}
var repMsg = "created " + st.attrib(xName);
randLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

use CommandMap;
registerFunction("randomNormal", randomNormalMsg, getModuleName());
registerFunction("createGenerator", createGeneratorMsg, getModuleName());
registerFunction("uniformGenerator", uniformGeneratorMsg, getModuleName());
registerFunction("permutation", permutationMsg, getModuleName());
registerFunction("shuffle", shuffleMsg, getModuleName());
}
31 changes: 30 additions & 1 deletion tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,33 @@ def test_integers(self):
self.assertTrue(all(bounded_arr.to_ndarray() >= -5))
self.assertTrue(all(bounded_arr.to_ndarray() < 5))

def test_shuffle(self):
# verify same seed gives reproducible arrays
rng = ak.random.default_rng(18)

int_pda = rng.integers(-(2**32), 2**32, 10)
pda_copy = int_pda[:]
# shuffle int_pda in place
rng.shuffle(int_pda)
# verify all the same elements are in permutation as the original
self.assertEqual(ak.sort(int_pda).to_list(), ak.sort(pda_copy).to_list())

float_pda = rng.uniform(-(2**32), 2**32, 10)
pda_copy = float_pda[:]
rng.shuffle(float_pda)
# verify all the same elements are in permutation as the original
self.assertEqual(ak.sort(float_pda).to_list(), ak.sort(pda_copy).to_list())

rng = ak.random.default_rng(18)

pda = rng.integers(-(2**32), 2**32, 10)
rng.shuffle(pda)
self.assertEqual(pda.to_list(), int_pda.to_list())

pda = rng.uniform(-(2**32), 2**32, 10)
rng.shuffle(pda)
self.assertTrue(np.allclose(pda.to_list(), float_pda.to_list()))

def test_permutation(self):
# verify same seed gives reproducible arrays
rng = ak.random.default_rng(18)
Expand Down Expand Up @@ -77,7 +104,9 @@ def test_permutation(self):
pda = rng.uniform(-(2**32), 2**32, 10)
same_seed_float_array_permute = rng.permutation(pda)
# verify all the same elements are in permutation as the original
self.assertTrue(np.allclose(float_array_permute.to_list(), same_seed_float_array_permute.to_list()))
self.assertTrue(
np.allclose(float_array_permute.to_list(), same_seed_float_array_permute.to_list())
)

def test_uniform(self):
# verify same seed gives different but reproducible arrays
Expand Down

0 comments on commit a308ec0

Please sign in to comment.