Skip to content

Commit

Permalink
Make Frame.astype return Self instead of a ColumnAccessor (#15861)
Browse files Browse the repository at this point in the history
Allows simplification for it's subclasses (`IndexFrame.astype`, `Index.astype`)

Also minor cleanups in the `equals` method

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Bradley Dice (https://github.com/bdice)

URL: #15861
  • Loading branch information
mroeschke authored Jun 4, 2024
1 parent 54d49fc commit faf3929
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 36 deletions.
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/_base_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def ndim(self) -> int: # noqa: D401
"""Number of dimensions of the underlying data, by definition 1."""
return 1

def equals(self, other):
def equals(self, other) -> bool:
"""
Determine if two Index objects contain the same elements.
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2590,7 +2590,7 @@ def items(self):
yield (k, self[k])

@_cudf_nvtx_annotate
def equals(self, other):
def equals(self, other) -> bool:
ret = super().equals(other)
# If all other checks matched, validate names.
if ret:
Expand Down
23 changes: 6 additions & 17 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,20 +273,13 @@ def __len__(self) -> int:
return self._num_rows

@_cudf_nvtx_annotate
def astype(self, dtype, copy: bool = False):
result_data = {
col_name: col.astype(dtype.get(col_name, col.dtype), copy=copy)
def astype(self, dtype: dict[Any, Dtype], copy: bool = False) -> Self:
casted = (
col.astype(dtype.get(col_name, col.dtype), copy=copy)
for col_name, col in self._data.items()
}

return ColumnAccessor(
data=result_data,
multiindex=self._data.multiindex,
level_names=self._data.level_names,
rangeindex=self._data.rangeindex,
label_dtype=self._data.label_dtype,
verify=False,
)
ca = self._data._from_columns_like_self(casted, verify=False)
return self._from_data_like_self(ca)

@_cudf_nvtx_annotate
def equals(self, other) -> bool:
Expand Down Expand Up @@ -349,11 +342,7 @@ def equals(self, other) -> bool:
"""
if self is other:
return True
if (
other is None
or not isinstance(other, type(self))
or len(self) != len(other)
):
if not isinstance(other, type(self)) or len(self) != len(other):
return False

return all(
Expand Down
22 changes: 14 additions & 8 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def __getitem__(self, index):
return self._as_int_index()[index]

@_cudf_nvtx_annotate
def equals(self, other):
def equals(self, other) -> bool:
if isinstance(other, RangeIndex):
return self._range == other._range
return self._as_int_index().equals(other)
Expand Down Expand Up @@ -1058,6 +1058,16 @@ def _from_data(cls, data: MutableMapping, name: Any = no_default) -> Self:
out.name = name
return out

@classmethod
@_cudf_nvtx_annotate
def _from_data_like_self(
cls, data: MutableMapping, name: Any = no_default
) -> Self:
out = _index_from_data(data, name)
if name is not no_default:
out.name = name
return out

@classmethod
@_cudf_nvtx_annotate
def from_arrow(cls, obj):
Expand Down Expand Up @@ -1180,12 +1190,8 @@ def is_unique(self):
return self._column.is_unique

@_cudf_nvtx_annotate
def equals(self, other):
if (
other is None
or not isinstance(other, BaseIndex)
or len(self) != len(other)
):
def equals(self, other) -> bool:
if not isinstance(other, BaseIndex) or len(self) != len(other):
return False

check_dtypes = False
Expand Down Expand Up @@ -1231,7 +1237,7 @@ def copy(self, name=None, deep=False):

@_cudf_nvtx_annotate
def astype(self, dtype, copy: bool = True):
return _index_from_data(super().astype({self.name: dtype}, copy))
return super().astype({self.name: dtype}, copy)

@_cudf_nvtx_annotate
def get_indexer(self, target, method=None, limit=None, tolerance=None):
Expand Down
14 changes: 5 additions & 9 deletions python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,10 +625,8 @@ def copy(self, deep: bool = True) -> Self:
)

@_cudf_nvtx_annotate
def equals(self, other): # noqa: D102
if not super().equals(other):
return False
return self.index.equals(other.index)
def equals(self, other) -> bool: # noqa: D102
return super().equals(other) and self.index.equals(other.index)

@property
def index(self):
Expand Down Expand Up @@ -4896,10 +4894,10 @@ def repeat(self, repeats, axis=None):

def astype(
self,
dtype,
dtype: dict[Any, Dtype],
copy: bool = False,
errors: Literal["raise", "ignore"] = "raise",
):
) -> Self:
"""Cast the object to the given dtype.
Parameters
Expand Down Expand Up @@ -5010,14 +5008,12 @@ def astype(
raise ValueError("invalid error value specified")

try:
data = super().astype(dtype, copy)
return super().astype(dtype, copy)
except Exception as e:
if errors == "raise":
raise e
return self

return self._from_data(data, index=self.index)

@_cudf_nvtx_annotate
def drop(
self,
Expand Down

0 comments on commit faf3929

Please sign in to comment.