From a894ca03b18bd0304180f97882ccaaffa18028a0 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 13 Dec 2023 12:55:54 +0000 Subject: [PATCH] Add (implicit) handling for torch tensors in is_scalar (#14623) PyTorch tensors advertise that they support the number API, and hence answer "True" to the question pd.api.types.is_scalar(torch_tensor). This trips up some of our data ingest, since in as_index we check if the input is a scalar (and raise) before handing off to as_column. To handle this, if we get True back from pandas' is_scalar call, additionally check that the object has an empty shape attribute (if it exists). See also: - https://github.com/pytorch/pytorch/issues/99646 - https://github.com/pandas-dev/pandas/issues/52701 Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Ashwin Srinath (https://github.com/shwina) URL: https://github.com/rapidsai/cudf/pull/14623 --- python/cudf/cudf/api/types.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/cudf/cudf/api/types.py b/python/cudf/cudf/api/types.py index 947931d1135..72fc17f0286 100644 --- a/python/cudf/cudf/api/types.py +++ b/python/cudf/cudf/api/types.py @@ -135,7 +135,17 @@ def is_scalar(val): cudf._lib.scalar.DeviceScalar, cudf.core.tools.datetimes.DateOffset, ), - ) or pd_types.is_scalar(val) + ) or ( + pd_types.is_scalar(val) + # Pytorch tensors advertise that they support the number + # protocol, and therefore return True for PyNumber_Check even + # when they have a shape. So, if we get through this, let's + # additionally check that if they have a shape property that + # it is empty. + # See https://github.com/pytorch/pytorch/issues/99646 + # and https://github.com/pandas-dev/pandas/issues/52701 + and len(getattr(val, "shape", ())) == 0 + ) def _is_scalar_or_zero_d_array(val):