Skip to content

Commit

Permalink
Merge pull request #6128 from rgsl888prabhu/6127_add_support_with_Ran…
Browse files Browse the repository at this point in the history
…domState

[REVIEW] Add support for numpy RandomState handling in `sample`
  • Loading branch information
rgsl888prabhu authored Sep 1, 2020
2 parents a2afaa1 + 2773167 commit 0d83d64
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
- PR #6113 Fix to_timestamp to initialize default year to 1970
- PR #6110 Handle `format` for other input types in `to_datetime`
- PR #6118 Fix Java build for ORC read args change and update package version
- PR #6128 Add support for numpy RandomState handling in `sample`


# cuDF 0.15.0 (26 Aug 2020)
Expand Down
17 changes: 11 additions & 6 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,9 +1684,10 @@ def sample(
replace == True is not yet supported for axis = 1/"columns"
weights : str or ndarray-like, optional
Only supported for axis=1/"columns"
random_state : int or None, default None
random_state : int, numpy RandomState or None, default None
Seed for the random number generator (if int), or None.
If None, a random seed will be chosen.
if RandomState, seed will be extracted from current state.
axis : {0 or ‘index’, 1 or ‘columns’, None}, default None
Axis to sample. Accepts axis number or name.
Default is stat axis for given data type
Expand Down Expand Up @@ -1761,11 +1762,15 @@ def sample(
"weights is not yet supported for axis=0/index"
)

seed = (
np.random.randint(np.iinfo(np.int64).max, dtype=np.int64)
if random_state is None
else np.int64(random_state)
)
if random_state is None:
seed = np.random.randint(
np.iinfo(np.int64).max, dtype=np.int64
)
elif isinstance(random_state, np.random.mtrand.RandomState):
_, keys, pos, _, _ = random_state.get_state()
seed = 0 if pos >= len(keys) else pos
else:
seed = np.int64(random_state)

result = self._from_table(
libcudf.copying.sample(
Expand Down
7 changes: 4 additions & 3 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6604,11 +6604,12 @@ def test_dataframe_sample_basic(n, frac, replace, axis):


@pytest.mark.parametrize("replace", [True, False])
def test_dataframe_reproducibility(replace):
@pytest.mark.parametrize("random_state", [1, np.random.mtrand.RandomState(10)])
def test_dataframe_reproducibility(replace, random_state):
df = DataFrame({"a": cupy.arange(0, 1024)})

expected = df.sample(1024, replace=replace, random_state=1)
out = df.sample(1024, replace=replace, random_state=1)
expected = df.sample(1024, replace=replace, random_state=random_state)
out = df.sample(1024, replace=replace, random_state=random_state)

assert_eq(expected, out)

Expand Down

0 comments on commit 0d83d64

Please sign in to comment.