From 0f68434898881cb6a0abb157a9233d468c1032ef Mon Sep 17 00:00:00 2001 From: Sheilah Date: Wed, 16 Feb 2022 16:17:27 -0800 Subject: [PATCH] avoid check_dtype, reduce test cases for periods --- python/cudf/cudf/tests/test_dataframe.py | 28 +++++++++++++++--------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 35cc3ba74d0..5b9c73fd827 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -3442,29 +3442,37 @@ def test_get_numeric_data(): @pytest.mark.parametrize("dtype", NUMERIC_TYPES) -@pytest.mark.parametrize("period", [-1, -5, -10, -20, 0, 1, 5, 10, 20]) +@pytest.mark.parametrize("period", [-15, -1, 0, 1, 15]) @pytest.mark.parametrize("data_empty", [False, True]) def test_shift(dtype, period, data_empty): - + # TODO : this function currently tests for series.shift() + # but should instead test for dataframe.shift() if data_empty: data = None else: if dtype == np.int8: # to keep data in range - data = gen_rand(dtype, 100000, low=-2, high=2) + data = gen_rand(dtype, 10, low=-2, high=2) else: - data = gen_rand(dtype, 100000) + data = gen_rand(dtype, 10) - gdf = cudf.DataFrame({"a": cudf.Series(data, dtype=dtype)}) - pdf = pd.DataFrame({"a": pd.Series(data, dtype=dtype)}) + gs = cudf.DataFrame({"a": cudf.Series(data, dtype=dtype)}) + ps = pd.DataFrame({"a": pd.Series(data, dtype=dtype)}) - shifted_outcome = gdf.a.shift(period).fillna(0) - expected_outcome = pdf.a.shift(period).fillna(0).astype(dtype) + shifted_outcome = gs.a.shift(period) + expected_outcome = ps.a.shift(period) + # pandas uses NaNs to signal missing value and force converts the + # results columns to float types if data_empty: - assert_eq(shifted_outcome, expected_outcome, check_index_type=False) + assert_eq( + shifted_outcome, + expected_outcome, + check_index_type=False, + check_dtype=False, + ) else: - assert_eq(shifted_outcome, expected_outcome) + assert_eq(shifted_outcome, expected_outcome, check_dtype=False) @pytest.mark.parametrize("dtype", NUMERIC_TYPES)