diff --git a/python/cudf/cudf/pandas/_wrappers/pandas.py b/python/cudf/cudf/pandas/_wrappers/pandas.py index 29aaaac245d..2e3880e14f6 100644 --- a/python/cudf/cudf/pandas/_wrappers/pandas.py +++ b/python/cudf/cudf/pandas/_wrappers/pandas.py @@ -386,17 +386,35 @@ def Index__new__(cls, *args, **kwargs): }, ) -NumpyExtensionArray = make_final_proxy_type( - "NumpyExtensionArray", - _Unusable, - pd.arrays.NumpyExtensionArray, - fast_to_slow=_Unusable(), - slow_to_fast=_Unusable(), - additional_attributes={ - "_ndarray": _FastSlowAttribute("_ndarray"), - "_dtype": _FastSlowAttribute("_dtype"), - }, -) +try: + from pandas.arrays import NumpyExtensionArray as pd_NumpyExtensionArray + + NumpyExtensionArray = make_final_proxy_type( + "NumpyExtensionArray", + _Unusable, + pd_NumpyExtensionArray, + fast_to_slow=_Unusable(), + slow_to_fast=_Unusable(), + additional_attributes={ + "_ndarray": _FastSlowAttribute("_ndarray"), + "_dtype": _FastSlowAttribute("_dtype"), + }, + ) + +except ImportError: + from pandas.arrays import PandasArray as pd_PandasArray + + PandasArray = make_final_proxy_type( + "PandasArray", + _Unusable, + pd_PandasArray, + fast_to_slow=_Unusable(), + slow_to_fast=_Unusable(), + additional_attributes={ + "_ndarray": _FastSlowAttribute("_ndarray"), + "_dtype": _FastSlowAttribute("_dtype"), + }, + ) TimedeltaArray = make_final_proxy_type( "TimedeltaArray", diff --git a/python/cudf/cudf_pandas_tests/test_cudf_pandas.py b/python/cudf/cudf_pandas_tests/test_cudf_pandas.py index e3d4f878ad5..75bceea3034 100644 --- a/python/cudf/cudf_pandas_tests/test_cudf_pandas.py +++ b/python/cudf/cudf_pandas_tests/test_cudf_pandas.py @@ -1241,8 +1241,12 @@ def test_pickle_groupby(dataframe): def test_numpy_extension_array(): np_array = np.array([0, 1, 2, 3]) - xarray = xpd.arrays.NumpyExtensionArray(np_array) - array = pd.arrays.NumpyExtensionArray(np_array) + try: + xarray = xpd.arrays.NumpyExtensionArray(np_array) + array = pd.arrays.NumpyExtensionArray(np_array) + except AttributeError: + xarray = xpd.arrays.PandasArray(np_array) + array = pd.arrays.PandasArray(np_array) tm.assert_equal(xarray, array)