diff --git a/CHANGELOG.md b/CHANGELOG.md index 9dc82ee832d..df0b63895a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ - PR #6100 Fix issue in `Series.factorize` to correctly pick `na_sentinel` value - PR #6110 Handle `format` for other input types in `to_datetime` - PR #6118 Fix Java build for ORC read args change and update package version +- PR #6128 Add support for numpy RandomState handling in `sample` # cuDF 0.15.0 (26 Aug 2020) diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 93ae900f853..6911cdc761f 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -1684,9 +1684,10 @@ def sample( replace == True is not yet supported for axis = 1/"columns" weights : str or ndarray-like, optional Only supported for axis=1/"columns" - random_state : int or None, default None + random_state : int, numpy RandomState or None, default None Seed for the random number generator (if int), or None. If None, a random seed will be chosen. + if RandomState, seed will be extracted from current state. axis : {0 or ‘index’, 1 or ‘columns’, None}, default None Axis to sample. Accepts axis number or name. Default is stat axis for given data type @@ -1761,11 +1762,15 @@ def sample( "weights is not yet supported for axis=0/index" ) - seed = ( - np.random.randint(np.iinfo(np.int64).max, dtype=np.int64) - if random_state is None - else np.int64(random_state) - ) + if random_state is None: + seed = np.random.randint( + np.iinfo(np.int64).max, dtype=np.int64 + ) + elif isinstance(random_state, np.random.mtrand.RandomState): + _, keys, pos, _, _ = random_state.get_state() + seed = 0 if pos >= len(keys) else pos + else: + seed = np.int64(random_state) result = self._from_table( libcudf.copying.sample( diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 1b1d1cf10aa..df5782ad5e7 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -6604,11 +6604,12 @@ def test_dataframe_sample_basic(n, frac, replace, axis): @pytest.mark.parametrize("replace", [True, False]) -def test_dataframe_reproducibility(replace): +@pytest.mark.parametrize("random_state", [1, np.random.mtrand.RandomState(10)]) +def test_dataframe_reproducibility(replace, random_state): df = DataFrame({"a": cupy.arange(0, 1024)}) - expected = df.sample(1024, replace=replace, random_state=1) - out = df.sample(1024, replace=replace, random_state=1) + expected = df.sample(1024, replace=replace, random_state=random_state) + out = df.sample(1024, replace=replace, random_state=random_state) assert_eq(expected, out)