From 5e84bf065b211efc4c73cfe85ea33a3ab93ac297 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 19 Dec 2023 16:21:48 -0800 Subject: [PATCH] Explicitly pass .dtype into is_foo_dtype functions --- python/cudf/cudf/core/_internals/where.py | 2 +- python/cudf/cudf/core/column/numerical.py | 4 ++-- python/cudf/cudf/core/dataframe.py | 6 +++--- python/cudf/cudf/testing/testing.py | 10 +++++++--- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/python/cudf/cudf/core/_internals/where.py b/python/cudf/cudf/core/_internals/where.py index f76802c8b7d..ef6b10f66c1 100644 --- a/python/cudf/cudf/core/_internals/where.py +++ b/python/cudf/cudf/core/_internals/where.py @@ -103,7 +103,7 @@ def _check_and_cast_columns_with_other( other = cudf.Scalar(other) if is_mixed_with_object_dtype(other, source_col) or ( - is_bool_dtype(source_col) and not is_bool_dtype(common_dtype) + is_bool_dtype(source_dtype) and not is_bool_dtype(common_dtype) ): raise TypeError(mixed_err) diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py index f40886bf153..8980f9257fb 100644 --- a/python/cudf/cudf/core/column/numerical.py +++ b/python/cudf/cudf/core/column/numerical.py @@ -271,13 +271,13 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase: out_dtype = "bool" if op in {"__and__", "__or__", "__xor__"}: - if is_float_dtype(self.dtype) or is_float_dtype(other): + if is_float_dtype(self.dtype) or is_float_dtype(other.dtype): raise TypeError( f"Operation 'bitwise {op[2:-2]}' not supported between " f"{self.dtype.type.__name__} and " f"{other.dtype.type.__name__}" ) - if is_bool_dtype(self.dtype) or is_bool_dtype(other): + if is_bool_dtype(self.dtype) or is_bool_dtype(other.dtype): out_dtype = "bool" if ( diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 01935fec8c3..152b85d37f4 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -324,7 +324,7 @@ def _getitem_tuple_arg(self, arg): tmp_arg[1], ) - if is_bool_dtype(tmp_arg[0]): + if is_bool_dtype(tmp_arg[0].dtype): df = columns_df._apply_boolean_mask( BooleanMask(tmp_arg[0], len(columns_df)) ) @@ -6029,7 +6029,7 @@ def _reduce( numeric_cols = ( name for name in self._data.names - if is_numeric_dtype(self._data[name]) + if is_numeric_dtype(self._data[name].dtype) ) source = self._get_columns_by_label(numeric_cols) if source.empty: @@ -6075,7 +6075,7 @@ def _reduce( numeric_cols = ( name for name in self._data.names - if is_numeric_dtype(self._data[name]) + if is_numeric_dtype(self._data[name].dtype) ) source = self._get_columns_by_label(numeric_cols) if source.empty: diff --git a/python/cudf/cudf/testing/testing.py b/python/cudf/cudf/testing/testing.py index a45733a0f83..6c2f073b7ac 100644 --- a/python/cudf/cudf/testing/testing.py +++ b/python/cudf/cudf/testing/testing.py @@ -232,10 +232,10 @@ def assert_column_equal( elif not ( ( not dtype_can_compare_equal_to_other(left.dtype) - and is_numeric_dtype(right) + and is_numeric_dtype(right.dtype) ) or ( - is_numeric_dtype(left) + is_numeric_dtype(left.dtype) and not dtype_can_compare_equal_to_other(right.dtype) ) ): @@ -245,7 +245,11 @@ def assert_column_equal( left.isnull().values == right.isnull().values ) - if columns_equal and not check_exact and is_numeric_dtype(left): + if ( + columns_equal + and not check_exact + and is_numeric_dtype(left.dtype) + ): # non-null values must be the same columns_equal = cp.allclose( left.apply_boolean_mask(