Skip to content

Commit

Permalink
Cache apply function
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt Roeschke committed Sep 19, 2019
1 parent 9b9ea7a commit aa9644c
Showing 1 changed file with 36 additions and 6 deletions.
42 changes: 36 additions & 6 deletions pandas/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
self.win_freq = None
self.axis = obj._get_axis_number(axis) if axis is not None else None
self.validate()
self._apply_func_cache = dict()

@property
def _constructor(self):
Expand Down Expand Up @@ -493,7 +494,7 @@ def _apply(
minimum_periods = _check_min_periods(
self.min_periods or 1, self.min_periods, len(values) + offset
)
func = partial( # type: ignore
func_partial = partial( # type: ignore
func, begin=start, end=end, minimum_periods=minimum_periods
)

Expand All @@ -511,7 +512,7 @@ def _apply(
cfunc, check_minp, index_as_array, **kwargs
)

func = partial( # type: ignore
func_partial = partial( # type: ignore
func,
window=window,
min_periods=self.min_periods,
Expand All @@ -521,12 +522,12 @@ def _apply(
if additional_nans is not None:

def calc(x):
return func(np.concatenate((x, additional_nans)))
return func_partial(np.concatenate((x, additional_nans)))

else:

def calc(x):
return func(x)
return func_partial(x)

with np.errstate(all="ignore"):
if values.ndim > 1:
Expand All @@ -535,6 +536,9 @@ def calc(x):
result = calc(values)
result = np.asarray(result)

if use_numba:
self._apply_func_cache[name] = func

if center:
result = self._center_window(result, window)

Expand Down Expand Up @@ -1147,8 +1151,34 @@ def f(arg, window, min_periods, closed):

# Numba doesn't support kwargs in nopython mode
# https://github.com/numba/numba/issues/2916
numba_func = numba.njit(func)
rolling_apply = partial(methods.rolling_apply, numba_func=numba_func, args=args)
if func not in self._apply_func_cache:

def make_rolling_apply(func):

numba_func = numba.njit(func)

@numba.njit
def roll_apply(
values: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
minimum_periods: int,
):
result = np.empty(len(begin))
for i, (start, stop) in enumerate(zip(begin, end)):
window = values[start:stop]
count_nan = np.sum(np.isnan(window))
if len(window) - count_nan >= minimum_periods:
result[i] = numba_func(window, *args)
else:
result[i] = np.nan
return result

return roll_apply

rolling_apply = make_rolling_apply(func)
else:
rolling_apply = self._apply_func_cache[func]

return self._apply(
rolling_apply,
Expand Down

0 comments on commit aa9644c

Please sign in to comment.