From ceacfa43b15d81a26678763a70f115db94267e53 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath <3190405+shwina@users.noreply.github.com> Date: Fri, 5 May 2023 06:32:04 -0400 Subject: [PATCH] First check for `BaseDtype` when infering the data type of an arbitrary object (#13295) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We have an internal utility called `dtype()` that attempts to infer the data type of an arbitrary object. One of the first thing that `dtype()` does is attempt to call `np.dtype(obj)`. That can be slow for extremely large cardinality categorical data types, as it copies data to host (in particular, it attempts to call the object's `__repr__`): Before this PR: ```python dtype = cudf.CategoricalDtype(categories=range(100_000_000)) %%time x = cudf.core.dtypes.dtype(dtype) CPU times: user 3.75 s, sys: 885 ms, total: 4.64 s Wall time: 4.63 s ``` This PR ensures we attempt to do far less expensive inference first, before calling `np.dtype(...)`. After this PR: ```python %%time x = cudf.core.dtypes.dtype(dtype) CPU times: user 13 µs, sys: 1 µs, total: 14 µs Wall time: 19.1 µs ``` Authors: - Ashwin Srinath (https://github.com/shwina) Approvers: - Bradley Dice (https://github.com/bdice) - Lawrence Mitchell (https://github.com/wence-) URL: https://github.com/rapidsai/cudf/pull/13295 --- python/cudf/cudf/core/dtypes.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/cudf/cudf/core/dtypes.py b/python/cudf/cudf/core/dtypes.py index 49c931a4218..edd557aad1f 100644 --- a/python/cudf/cudf/core/dtypes.py +++ b/python/cudf/cudf/core/dtypes.py @@ -42,7 +42,11 @@ def dtype(arbitrary): ------- dtype: the cuDF-supported dtype that best matches `arbitrary` """ - # first, try interpreting arbitrary as a NumPy dtype that we support: + # first, check if `arbitrary` is one of our extension types: + if isinstance(arbitrary, cudf.core.dtypes._BaseDtype): + return arbitrary + + # next, try interpreting arbitrary as a NumPy dtype that we support: try: np_dtype = np.dtype(arbitrary) if np_dtype.kind in ("OU"): @@ -54,10 +58,6 @@ def dtype(arbitrary): raise TypeError(f"Unsupported type {np_dtype}") return np_dtype - # next, check if `arbitrary` is one of our extension types: - if isinstance(arbitrary, cudf.core.dtypes._BaseDtype): - return arbitrary - # use `pandas_dtype` to try and interpret # `arbitrary` as a Pandas extension type. # Return the corresponding NumPy/cuDF type.