Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Add support for numpy RandomState handling in sample #6128

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
- PR #6100 Fix issue in `Series.factorize` to correctly pick `na_sentinel` value
- PR #6110 Handle `format` for other input types in `to_datetime`
- PR #6118 Fix Java build for ORC read args change and update package version
- PR #6128 Add support for numpy RandomState handling in `sample`


# cuDF 0.15.0 (26 Aug 2020)
Expand Down
17 changes: 11 additions & 6 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,9 +1684,10 @@ def sample(
replace == True is not yet supported for axis = 1/"columns"
weights : str or ndarray-like, optional
Only supported for axis=1/"columns"
random_state : int or None, default None
random_state : int, numpy RandomState or None, default None
Seed for the random number generator (if int), or None.
If None, a random seed will be chosen.
if RandomState, seed will be extracted from current state.
axis : {0 or ‘index’, 1 or ‘columns’, None}, default None
Axis to sample. Accepts axis number or name.
Default is stat axis for given data type
Expand Down Expand Up @@ -1761,11 +1762,15 @@ def sample(
"weights is not yet supported for axis=0/index"
)

seed = (
np.random.randint(np.iinfo(np.int64).max, dtype=np.int64)
if random_state is None
else np.int64(random_state)
)
if random_state is None:
seed = np.random.randint(
np.iinfo(np.int64).max, dtype=np.int64
)
elif isinstance(random_state, np.random.mtrand.RandomState):
_, keys, pos, _, _ = random_state.get_state()
seed = 0 if pos >= len(keys) else pos
else:
seed = np.int64(random_state)

result = self._from_table(
libcudf.copying.sample(
Expand Down
7 changes: 4 additions & 3 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6604,11 +6604,12 @@ def test_dataframe_sample_basic(n, frac, replace, axis):


@pytest.mark.parametrize("replace", [True, False])
def test_dataframe_reproducibility(replace):
@pytest.mark.parametrize("random_state", [1, np.random.mtrand.RandomState(10)])
def test_dataframe_reproducibility(replace, random_state):
df = DataFrame({"a": cupy.arange(0, 1024)})

expected = df.sample(1024, replace=replace, random_state=1)
out = df.sample(1024, replace=replace, random_state=1)
expected = df.sample(1024, replace=replace, random_state=random_state)
out = df.sample(1024, replace=replace, random_state=random_state)

assert_eq(expected, out)

Expand Down