Skip to content

Commit

Permalink
CLN: Clean test moments for expanding (#30566)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesdong1991 authored and jreback committed Jan 1, 2020
1 parent ac3715b commit 7ecd9af
Showing 1 changed file with 45 additions and 38 deletions.
83 changes: 45 additions & 38 deletions pandas/tests/window/moments/test_moments_expanding.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,19 +173,24 @@ def test_expanding_corr_pairwise_diff_length(self):
tm.assert_frame_equal(result3, expected)
tm.assert_frame_equal(result4, expected)

@pytest.mark.parametrize("has_min_periods", [True, False])
@pytest.mark.parametrize(
"func,static_comp",
[("sum", np.sum), ("mean", np.mean), ("max", np.max), ("min", np.min)],
ids=["sum", "mean", "max", "min"],
)
def test_expanding_func(self, func, static_comp):
def test_expanding_func(self, func, static_comp, has_min_periods):
def expanding_func(x, min_periods=1, center=False, axis=0):
exp = x.expanding(min_periods=min_periods, center=center, axis=axis)
return getattr(exp, func)()

self._check_expanding(expanding_func, static_comp, preserve_nan=False)
self._check_expanding_has_min_periods(
expanding_func, static_comp, has_min_periods
)

def test_expanding_apply(self, raw):
@pytest.mark.parametrize("has_min_periods", [True, False])
def test_expanding_apply(self, raw, has_min_periods):
def expanding_mean(x, min_periods=1):

exp = x.expanding(min_periods=min_periods)
Expand All @@ -195,19 +200,20 @@ def expanding_mean(x, min_periods=1):
# TODO(jreback), needed to add preserve_nan=False
# here to make this pass
self._check_expanding(expanding_mean, np.mean, preserve_nan=False)
self._check_expanding_has_min_periods(expanding_mean, np.mean, has_min_periods)

def test_expanding_apply_empty_series(self, raw):
ser = Series([], dtype=np.float64)
tm.assert_series_equal(ser, ser.expanding().apply(lambda x: x.mean(), raw=raw))

def test_expanding_apply_min_periods_0(self, raw):
# GH 8080
s = Series([None, None, None])
result = s.expanding(min_periods=0).apply(lambda x: len(x), raw=raw)
expected = Series([1.0, 2.0, 3.0])
tm.assert_series_equal(result, expected)

def _check_expanding(
self, func, static_comp, has_min_periods=True, preserve_nan=True
):
def _check_expanding(self, func, static_comp, preserve_nan=True):

series_result = func(self.series)
assert isinstance(series_result, Series)
Expand All @@ -220,6 +226,7 @@ def _check_expanding(
if preserve_nan:
assert result.iloc[self._nan_locs].isna().all()

def _check_expanding_has_min_periods(self, func, static_comp, has_min_periods):
ser = Series(randn(50))

if has_min_periods:
Expand All @@ -245,17 +252,9 @@ def _check_expanding(
result = func(ser)
tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50]))

def test_moment_functions_zero_length(self):
# GH 8056
s = Series(dtype=np.float64)
s_expected = s
df1 = DataFrame()
df1_expected = df1
df2 = DataFrame(columns=["a"])
df2["a"] = df2["a"].astype("float64")
df2_expected = df2

functions = [
@pytest.mark.parametrize(
"f",
[
lambda x: x.expanding().count(),
lambda x: x.expanding(min_periods=5).cov(x, pairwise=False),
lambda x: x.expanding(min_periods=5).corr(x, pairwise=False),
Expand All @@ -271,23 +270,35 @@ def test_moment_functions_zero_length(self):
lambda x: x.expanding(min_periods=5).median(),
lambda x: x.expanding(min_periods=5).apply(sum, raw=False),
lambda x: x.expanding(min_periods=5).apply(sum, raw=True),
]
for f in functions:
try:
s_result = f(s)
tm.assert_series_equal(s_result, s_expected)
],
)
def test_moment_functions_zero_length(self, f):
# GH 8056
s = Series(dtype=np.float64)
s_expected = s
df1 = DataFrame()
df1_expected = df1
df2 = DataFrame(columns=["a"])
df2["a"] = df2["a"].astype("float64")
df2_expected = df2

df1_result = f(df1)
tm.assert_frame_equal(df1_result, df1_expected)
s_result = f(s)
tm.assert_series_equal(s_result, s_expected)

df2_result = f(df2)
tm.assert_frame_equal(df2_result, df2_expected)
except (ImportError):
df1_result = f(df1)
tm.assert_frame_equal(df1_result, df1_expected)

# scipy needed for rolling_window
continue
df2_result = f(df2)
tm.assert_frame_equal(df2_result, df2_expected)

def test_moment_functions_zero_length_pairwise(self):
@pytest.mark.parametrize(
"f",
[
lambda x: (x.expanding(min_periods=5).cov(x, pairwise=True)),
lambda x: (x.expanding(min_periods=5).corr(x, pairwise=True)),
],
)
def test_moment_functions_zero_length_pairwise(self, f):

df1 = DataFrame()
df2 = DataFrame(columns=Index(["a"], name="foo"), index=Index([], name="bar"))
Expand All @@ -303,16 +314,12 @@ def test_moment_functions_zero_length_pairwise(self):
columns=Index(["a"], name="foo"),
dtype="float64",
)
functions = [
lambda x: (x.expanding(min_periods=5).cov(x, pairwise=True)),
lambda x: (x.expanding(min_periods=5).corr(x, pairwise=True)),
]
for f in functions:
df1_result = f(df1)
tm.assert_frame_equal(df1_result, df1_expected)

df2_result = f(df2)
tm.assert_frame_equal(df2_result, df2_expected)
df1_result = f(df1)
tm.assert_frame_equal(df1_result, df1_expected)

df2_result = f(df2)
tm.assert_frame_equal(df2_result, df2_expected)

@pytest.mark.slow
@pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4])
Expand Down

0 comments on commit 7ecd9af

Please sign in to comment.