Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Add support for numpy RandomState handling in sample #6128

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
## Bug Fixes

- PR #6081 Fix issue where fsspec thinks it has a protocol string
- PR #6128 Add support for numpy RandomState handling in `sample`


# cuDF 0.15.0 (Date TBD)
Expand Down
19 changes: 13 additions & 6 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1676,9 +1676,10 @@ def sample(
replace == True is not yet supported for axis = 1/"columns"
weights : str or ndarray-like, optional
Only supported for axis=1/"columns"
random_state : int or None, default None
random_state : int, numpy RandomState or None, default None
Seed for the random number generator (if int), or None.
If None, a random seed will be chosen.
if RandomState, seed will be extracted from current state.
axis : {0 or ‘index’, 1 or ‘columns’, None}, default None
Axis to sample. Accepts axis number or name.
Default is stat axis for given data type
Expand Down Expand Up @@ -1753,11 +1754,17 @@ def sample(
"weights is not yet supported for axis=0/index"
)

seed = (
np.random.randint(np.iinfo(np.int64).max, dtype=np.int64)
if random_state is None
else np.int64(random_state)
)
if random_state is None:
seed = np.random.randint(
np.iinfo(np.int64).max, dtype=np.int64
)
elif isinstance(random_state, np.random.mtrand.RandomState):
pos = random_state._bit_generator.state["state"]["pos"]
seed = random_state._bit_generator.state["state"]["key"][
0 if pos == 624 else pos
rgsl888prabhu marked this conversation as resolved.
Show resolved Hide resolved
]
else:
seed = np.int64(random_state)

result = self._from_table(
libcudf.copying.sample(
Expand Down
7 changes: 4 additions & 3 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6604,11 +6604,12 @@ def test_dataframe_sample_basic(n, frac, replace, axis):


@pytest.mark.parametrize("replace", [True, False])
def test_dataframe_reproducibility(replace):
@pytest.mark.parametrize("random_state", [1, np.random.mtrand.RandomState(10)])
def test_dataframe_reproducibility(replace, random_state):
df = DataFrame({"a": cupy.arange(0, 1024)})

expected = df.sample(1024, replace=replace, random_state=1)
out = df.sample(1024, replace=replace, random_state=1)
expected = df.sample(1024, replace=replace, random_state=random_state)
out = df.sample(1024, replace=replace, random_state=random_state)

assert_eq(expected, out)

Expand Down