Skip to content

Commit

Permalink
Fix select_dtypes to work when non-class dtypes present in dataframe (
Browse files Browse the repository at this point in the history
#8849)

Addresses bug #6919. `issubclass` expects a class, but a list is not a class, so we must add a check  before calling `issubclass` on it.

Authors:
  - Sarah Yurick (https://github.com/sarahyurick)

Approvers:
  - Ashwin Srinath (https://github.com/shwina)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #8849
  • Loading branch information
sarahyurick authored Jul 27, 2021
1 parent d942100 commit 5d5bb2c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
10 changes: 6 additions & 4 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7359,8 +7359,9 @@ def select_dtypes(self, include=None, exclude=None):
# category handling
if is_categorical_dtype(i_dtype):
include_subtypes.add(i_dtype)
elif issubclass(dtype.type, i_dtype):
include_subtypes.add(dtype.type)
elif inspect.isclass(dtype.type):
if issubclass(dtype.type, i_dtype):
include_subtypes.add(dtype.type)

# exclude all subtypes
exclude_subtypes = set()
Expand All @@ -7369,8 +7370,9 @@ def select_dtypes(self, include=None, exclude=None):
# category handling
if is_categorical_dtype(e_dtype):
exclude_subtypes.add(e_dtype)
elif issubclass(dtype.type, e_dtype):
exclude_subtypes.add(dtype.type)
elif inspect.isclass(dtype.type):
if issubclass(dtype.type, e_dtype):
exclude_subtypes.add(dtype.type)

include_all = set(
[cudf_dtype_from_pydata_dtype(d) for d in self.dtypes]
Expand Down
8 changes: 8 additions & 0 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3079,6 +3079,14 @@ def test_select_dtype():
gdf.select_dtypes(include=["int"], exclude=["object"]),
)

gdf = cudf.DataFrame(
{"int_col": [0, 1, 2], "list_col": [[1, 2], [3, 4], [5, 6]]}
)
pdf = gdf.to_pandas()
assert_eq(
pdf.select_dtypes("int64"), gdf.select_dtypes("int64"),
)


def test_select_dtype_datetime():
gdf = cudf.datasets.timeseries(
Expand Down

0 comments on commit 5d5bb2c

Please sign in to comment.