Skip to content

Commit

Permalink
Handle constructing a cudf.Scalar from a cudf.Scalar (#7639)
Browse files Browse the repository at this point in the history
...also fix `DeviceScalar.__repr__` to print `"DeviceScalar"` instead of `"Scalar"`.

Authors:
  - Ashwin Srinath (@shwina)

Approvers:
  - Keith Kraus (@kkraus14)
  - @brandon-b-miller

URL: #7639
  • Loading branch information
shwina authored Mar 19, 2021
1 parent a568432 commit c13351d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
10 changes: 7 additions & 3 deletions python/cudf/cudf/_lib/scalar.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ cdef class DeviceScalar:

def __init__(self, value, dtype):
"""
cudf.Scalar: Type representing a scalar value on the device
Type representing an *immutable* scalar value on the device
Parameters
----------
Expand All @@ -63,6 +63,7 @@ cdef class DeviceScalar:
self._set_value(value, dtype)

def _set_value(self, value, dtype):
# IMPORTANT: this should only ever be called from __init__
valid = not _is_null_host_scalar(value)

if pd.api.types.is_string_dtype(dtype):
Expand Down Expand Up @@ -128,9 +129,12 @@ cdef class DeviceScalar:

def __repr__(self):
if self.value is cudf.NA:
return f"Scalar({self.value}, {self.dtype.__repr__()})"
return (
f"{self.__class__.__name__}"
f"({self.value}, {self.dtype.__repr__()})"
)
else:
return f"Scalar({self.value.__repr__()})"
return f"{self.__class__.__name__}({self.value.__repr__()})"

@staticmethod
cdef DeviceScalar from_unique_ptr(unique_ptr[scalar] ptr):
Expand Down
14 changes: 12 additions & 2 deletions python/cudf/cudf/core/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,14 @@ def __init__(self, value, dtype=None):
self._host_value = None
self._host_dtype = None
self._device_value = None
if isinstance(value, DeviceScalar):

if isinstance(value, Scalar):
if value._is_host_value_current:
self._host_value = value._host_value
self._host_dtype = value._host_dtype
else:
self._device_value = value._device_value
elif isinstance(value, DeviceScalar):
self._device_value = value
else:
self._host_value, self._host_dtype = self._preprocess_host_value(
Expand Down Expand Up @@ -248,7 +255,10 @@ def __neg__(self):
def __repr__(self):
# str() fixes a numpy bug with NaT
# https://github.com/numpy/numpy/issues/17552
return f"Scalar({str(self.value)}, dtype={self.dtype})"
return (
f"{self.__class__.__name__}"
f"({str(self.value)}, dtype={self.dtype})"
)

def _binop_result_dtype_or_error(self, other, op):
if op in {"__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"}:
Expand Down
14 changes: 14 additions & 0 deletions python/cudf/cudf/tests/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,17 @@ def test_device_scalar_direct_construction(value):
assert s.dtype == "object"
else:
assert s.dtype == dtype


@pytest.mark.parametrize("value", SCALAR_VALUES)
def test_construct_from_scalar(value):
value = cudf.utils.utils.to_cudf_compatible_scalar(value)
x = cudf.Scalar(value, value.dtype)
y = cudf.Scalar(x)
assert x.value == y.value or np.isnan(x.value) and np.isnan(y.value)

# check that this works:
y.device_value

x._is_host_value_current == y._is_host_value_current
x._is_device_value_current == y._is_device_value_current

0 comments on commit c13351d

Please sign in to comment.