Skip to content

Commit

Permalink
Handle constructing Scalar from Scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
shwina committed Mar 18, 2021
1 parent 423e599 commit 97c819d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
9 changes: 8 additions & 1 deletion python/cudf/cudf/core/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,14 @@ def __init__(self, value, dtype=None):
self._host_value = None
self._host_dtype = None
self._device_value = None
if isinstance(value, DeviceScalar):

if isinstance(value, Scalar):
if value._is_host_value_current:
self._host_value = value._host_value
self._host_dtype = value._host_dtype
else:
self._device_value = value._device_value
elif isinstance(value, DeviceScalar):
self._device_value = value
else:
self._host_value, self._host_dtype = self._preprocess_host_value(
Expand Down
11 changes: 11 additions & 0 deletions python/cudf/cudf/tests/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,14 @@ def test_device_scalar_direct_construction(value):
assert s.dtype == "object"
else:
assert s.dtype == dtype


@pytest.mark.parametrize("value", SCALAR_VALUES)
def test_construct_from_scalar(value):
value = cudf.utils.utils.to_cudf_compatible_scalar(value)
x = cudf.Scalar(1, value.dtype)
y = cudf.Scalar(x)
assert x.value == y.value

# check that this works:
y.device_value

0 comments on commit 97c819d

Please sign in to comment.