Skip to content

Commit

Permalink
Parametrize series test into axis_0
Browse files Browse the repository at this point in the history
  • Loading branch information
isVoid committed Mar 4, 2022
1 parent a7588b3 commit f4e2686
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 47 deletions.
24 changes: 15 additions & 9 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7184,21 +7184,27 @@ def test_sample_axis_1(
checker(expected, got)


@pytest.mark.parametrize(
"pdf",
[
pd.DataFrame(
{
"a": [1, 2, 3, 4, 5],
"float": [0.05, 0.2, 0.3, 0.2, 0.25],
"int": [1, 3, 5, 4, 2],
},
),
pd.Series([1, 2, 3, 4, 5]),
],
)
@pytest.mark.parametrize("replace", [True, False])
def test_sample_axis_0(
sample_n_frac, replace, random_state_tuple_axis_0, make_weights_axis_0
pdf, sample_n_frac, replace, random_state_tuple_axis_0, make_weights_axis_0
):
n, frac = sample_n_frac
pd_random_state, gd_random_state, checker = random_state_tuple_axis_0

pdf = pd.DataFrame(
{
"a": [1, 2, 3, 4, 5],
"float": [0.05, 0.2, 0.3, 0.2, 0.25],
"int": [1, 3, 5, 4, 2],
},
)
df = cudf.DataFrame.from_pandas(pdf)
df = cudf.from_pandas(pdf)

pd_weights, gd_weights = make_weights_axis_0(
len(pdf), isinstance(gd_random_state, np.random.RandomState)
Expand Down
38 changes: 0 additions & 38 deletions python/cudf/cudf/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,44 +1590,6 @@ def test_fill_new_category():
gs[0:1] = "d"


@pytest.mark.parametrize("replace", [True, False])
def test_sample(
sample_n_frac, replace, random_state_tuple_axis_0, make_weights_axis_0
):
n, frac = sample_n_frac
pd_random_state, gd_random_state, checker = random_state_tuple_axis_0
psr = pd.Series([1, 2, 3, 4, 5])
sr = cudf.Series.from_pandas(psr)

pd_weights, gd_weights = make_weights_axis_0(len(psr))
if (
not replace
and not isinstance(gd_random_state, np.random.RandomState)
and gd_weights is not None
):
pytest.skip(
"`cupy.random.RandomState` doesn't support weighted sampling "
"without replacement."
)

expected = psr.sample(
n=n,
frac=frac,
replace=replace,
weights=pd_weights,
random_state=pd_random_state,
)

got = sr.sample(
n=n,
frac=frac,
replace=replace,
weights=gd_weights,
random_state=gd_random_state,
)
checker(expected, got)


@pytest.mark.parametrize(
"data",
[
Expand Down

0 comments on commit f4e2686

Please sign in to comment.