Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF: remove axis keyword from Manager/Block.shift #53845

Merged
merged 6 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ def searchsorted(
return self._ndarray.searchsorted(npvalue, side=side, sorter=sorter)

@doc(ExtensionArray.shift)
def shift(self, periods: int = 1, fill_value=None, axis: AxisInt = 0):
def shift(self, periods: int = 1, fill_value=None):
# NB: shift is always along axis=0
axis = 0
fill_value = self._validate_scalar(fill_value)
new_values = shift(self._ndarray, periods, axis, fill_value)

Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,8 @@ def shift(self, periods: int = 1, fill_value: object = None) -> ExtensionArray:
If ``periods > len(self)``, then an array of size
len(self) is returned, with all values filled with
``self.dtype.na_value``.

For 2-dimensional ExtensionArrays, we are always shifting along axis=0.
"""
# Note: this implementation assumes that `self.dtype.na_value` can be
# stored in an instance of your ExtensionArray with `self.dtype`.
Expand Down
5 changes: 2 additions & 3 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10590,9 +10590,8 @@ def shift(
if freq is None:
# when freq is None, data is shifted, index is not
axis = self._get_axis_number(axis)
new_data = self._mgr.shift(
periods=periods, axis=axis, fill_value=fill_value
)
assert axis == 0 # axis == 1 cases handled in DataFrame.shift
new_data = self._mgr.shift(periods=periods, fill_value=fill_value)
return self._constructor_from_mgr(
new_data, axes=new_data.axes
).__finalize__(self, method="shift")
Expand Down
10 changes: 2 additions & 8 deletions pandas/core/internals/array_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,17 +324,11 @@ def diff(self, n: int) -> Self:
assert self.ndim == 2 # caller ensures
return self.apply(algos.diff, n=n)

def shift(self, periods: int, axis: AxisInt, fill_value) -> Self:
def shift(self, periods: int, fill_value) -> Self:
if fill_value is lib.no_default:
fill_value = None

if axis == 1 and self.ndim == 2:
# TODO column-wise shift
raise NotImplementedError

return self.apply_with_block(
"shift", periods=periods, axis=axis, fill_value=fill_value
)
return self.apply_with_block("shift", periods=periods, fill_value=fill_value)

def astype(self, dtype, copy: bool | None = False, errors: str = "raise") -> Self:
if copy is None:
Expand Down
40 changes: 16 additions & 24 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,12 +1478,11 @@ def diff(self, n: int) -> list[Block]:
new_values = algos.diff(self.values.T, n, axis=0).T
return [self.make_block(values=new_values)]

def shift(
self, periods: int, axis: AxisInt = 0, fill_value: Any = None
) -> list[Block]:
def shift(self, periods: int, fill_value: Any = None) -> list[Block]:
"""shift the block by periods, possibly upcast"""
# convert integer to float if necessary. need to do a lot more than
# that, handle boolean etc also
axis = self.ndim - 1

# Note: periods is never 0 here, as that is handled at the top of
# NDFrame.shift. If that ever changes, we can do a check for periods=0
Expand All @@ -1505,12 +1504,12 @@ def shift(
)
except LossySetitemError:
nb = self.coerce_to_target_dtype(fill_value)
return nb.shift(periods, axis=axis, fill_value=fill_value)
return nb.shift(periods, fill_value=fill_value)

else:
values = cast(np.ndarray, self.values)
new_values = shift(values, periods, axis, casted)
return [self.make_block(new_values)]
return [self.make_block_same_class(new_values)]

@final
def quantile(
Expand Down Expand Up @@ -1661,6 +1660,18 @@ class EABackedBlock(Block):

values: ExtensionArray

def shift(self, periods: int, fill_value: Any = None) -> list[Block]:
"""
Shift the block by `periods`.

Dispatches to underlying ExtensionArray and re-boxes in an
ExtensionBlock.
"""
# Transpose since EA.shift is always along axis=0, while we want to shift
# along rows.
new_values = self.values.T.shift(periods=periods, fill_value=fill_value).T
return [self.make_block_same_class(new_values)]

def setitem(self, indexer, value, using_cow: bool = False):
"""
Attempt self.values[indexer] = value, possibly creating a new array.
Expand Down Expand Up @@ -2113,18 +2124,6 @@ def slice_block_rows(self, slicer: slice) -> Self:
new_values = self.values[slicer]
return type(self)(new_values, self._mgr_locs, ndim=self.ndim, refs=self.refs)

def shift(
self, periods: int, axis: AxisInt = 0, fill_value: Any = None
) -> list[Block]:
"""
Shift the block by `periods`.

Dispatches to underlying ExtensionArray and re-boxes in an
ExtensionBlock.
"""
new_values = self.values.shift(periods=periods, fill_value=fill_value)
return [self.make_block_same_class(new_values)]

def _unstack(
self,
unstacker,
Expand Down Expand Up @@ -2231,13 +2230,6 @@ def is_view(self) -> bool:
# check the ndarray values of the DatetimeIndex values
return self.values._ndarray.base is not None

def shift(
self, periods: int, axis: AxisInt = 0, fill_value: Any = None
) -> list[Block]:
values = self.values
new_values = values.shift(periods, fill_value=fill_value, axis=axis)
return [self.make_block_same_class(new_values)]


def _catch_deprecated_value_error(err: Exception) -> None:
"""
Expand Down
5 changes: 2 additions & 3 deletions pandas/core/internals/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,12 +378,11 @@ def diff(self, n: int) -> Self:
# only reached with self.ndim == 2
return self.apply("diff", n=n)

def shift(self, periods: int, axis: AxisInt, fill_value) -> Self:
axis = self._normalize_axis(axis)
def shift(self, periods: int, fill_value) -> Self:
if fill_value is lib.no_default:
fill_value = None

return self.apply("shift", periods=periods, axis=axis, fill_value=fill_value)
return self.apply("shift", periods=periods, fill_value=fill_value)

def astype(self, dtype, copy: bool | None = False, errors: str = "raise") -> Self:
if copy is None:
Expand Down