From 7ecd9af7cb300d83d55db85e6c559a65f6c940e1 Mon Sep 17 00:00:00 2001 From: Kaiqi Dong Date: Wed, 1 Jan 2020 03:36:19 +0100 Subject: [PATCH] CLN: Clean test moments for expanding (#30566) --- .../window/moments/test_moments_expanding.py | 83 ++++++++++--------- 1 file changed, 45 insertions(+), 38 deletions(-) diff --git a/pandas/tests/window/moments/test_moments_expanding.py b/pandas/tests/window/moments/test_moments_expanding.py index d311937e234d8..507fd2e2fb3ba 100644 --- a/pandas/tests/window/moments/test_moments_expanding.py +++ b/pandas/tests/window/moments/test_moments_expanding.py @@ -173,19 +173,24 @@ def test_expanding_corr_pairwise_diff_length(self): tm.assert_frame_equal(result3, expected) tm.assert_frame_equal(result4, expected) + @pytest.mark.parametrize("has_min_periods", [True, False]) @pytest.mark.parametrize( "func,static_comp", [("sum", np.sum), ("mean", np.mean), ("max", np.max), ("min", np.min)], ids=["sum", "mean", "max", "min"], ) - def test_expanding_func(self, func, static_comp): + def test_expanding_func(self, func, static_comp, has_min_periods): def expanding_func(x, min_periods=1, center=False, axis=0): exp = x.expanding(min_periods=min_periods, center=center, axis=axis) return getattr(exp, func)() self._check_expanding(expanding_func, static_comp, preserve_nan=False) + self._check_expanding_has_min_periods( + expanding_func, static_comp, has_min_periods + ) - def test_expanding_apply(self, raw): + @pytest.mark.parametrize("has_min_periods", [True, False]) + def test_expanding_apply(self, raw, has_min_periods): def expanding_mean(x, min_periods=1): exp = x.expanding(min_periods=min_periods) @@ -195,19 +200,20 @@ def expanding_mean(x, min_periods=1): # TODO(jreback), needed to add preserve_nan=False # here to make this pass self._check_expanding(expanding_mean, np.mean, preserve_nan=False) + self._check_expanding_has_min_periods(expanding_mean, np.mean, has_min_periods) + def test_expanding_apply_empty_series(self, raw): ser = Series([], dtype=np.float64) tm.assert_series_equal(ser, ser.expanding().apply(lambda x: x.mean(), raw=raw)) + def test_expanding_apply_min_periods_0(self, raw): # GH 8080 s = Series([None, None, None]) result = s.expanding(min_periods=0).apply(lambda x: len(x), raw=raw) expected = Series([1.0, 2.0, 3.0]) tm.assert_series_equal(result, expected) - def _check_expanding( - self, func, static_comp, has_min_periods=True, preserve_nan=True - ): + def _check_expanding(self, func, static_comp, preserve_nan=True): series_result = func(self.series) assert isinstance(series_result, Series) @@ -220,6 +226,7 @@ def _check_expanding( if preserve_nan: assert result.iloc[self._nan_locs].isna().all() + def _check_expanding_has_min_periods(self, func, static_comp, has_min_periods): ser = Series(randn(50)) if has_min_periods: @@ -245,17 +252,9 @@ def _check_expanding( result = func(ser) tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50])) - def test_moment_functions_zero_length(self): - # GH 8056 - s = Series(dtype=np.float64) - s_expected = s - df1 = DataFrame() - df1_expected = df1 - df2 = DataFrame(columns=["a"]) - df2["a"] = df2["a"].astype("float64") - df2_expected = df2 - - functions = [ + @pytest.mark.parametrize( + "f", + [ lambda x: x.expanding().count(), lambda x: x.expanding(min_periods=5).cov(x, pairwise=False), lambda x: x.expanding(min_periods=5).corr(x, pairwise=False), @@ -271,23 +270,35 @@ def test_moment_functions_zero_length(self): lambda x: x.expanding(min_periods=5).median(), lambda x: x.expanding(min_periods=5).apply(sum, raw=False), lambda x: x.expanding(min_periods=5).apply(sum, raw=True), - ] - for f in functions: - try: - s_result = f(s) - tm.assert_series_equal(s_result, s_expected) + ], + ) + def test_moment_functions_zero_length(self, f): + # GH 8056 + s = Series(dtype=np.float64) + s_expected = s + df1 = DataFrame() + df1_expected = df1 + df2 = DataFrame(columns=["a"]) + df2["a"] = df2["a"].astype("float64") + df2_expected = df2 - df1_result = f(df1) - tm.assert_frame_equal(df1_result, df1_expected) + s_result = f(s) + tm.assert_series_equal(s_result, s_expected) - df2_result = f(df2) - tm.assert_frame_equal(df2_result, df2_expected) - except (ImportError): + df1_result = f(df1) + tm.assert_frame_equal(df1_result, df1_expected) - # scipy needed for rolling_window - continue + df2_result = f(df2) + tm.assert_frame_equal(df2_result, df2_expected) - def test_moment_functions_zero_length_pairwise(self): + @pytest.mark.parametrize( + "f", + [ + lambda x: (x.expanding(min_periods=5).cov(x, pairwise=True)), + lambda x: (x.expanding(min_periods=5).corr(x, pairwise=True)), + ], + ) + def test_moment_functions_zero_length_pairwise(self, f): df1 = DataFrame() df2 = DataFrame(columns=Index(["a"], name="foo"), index=Index([], name="bar")) @@ -303,16 +314,12 @@ def test_moment_functions_zero_length_pairwise(self): columns=Index(["a"], name="foo"), dtype="float64", ) - functions = [ - lambda x: (x.expanding(min_periods=5).cov(x, pairwise=True)), - lambda x: (x.expanding(min_periods=5).corr(x, pairwise=True)), - ] - for f in functions: - df1_result = f(df1) - tm.assert_frame_equal(df1_result, df1_expected) - df2_result = f(df2) - tm.assert_frame_equal(df2_result, df2_expected) + df1_result = f(df1) + tm.assert_frame_equal(df1_result, df1_expected) + + df2_result = f(df2) + tm.assert_frame_equal(df2_result, df2_expected) @pytest.mark.slow @pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4])