Skip to content

Commit

Permalink
add tests for null scalar binaryops
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-b-miller committed May 10, 2022
1 parent 19c5bad commit 248e5d9
Showing 1 changed file with 37 additions and 2 deletions.
39 changes: 37 additions & 2 deletions python/cudf/cudf/tests/test_binops.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ def dtype_scalar(val, dtype):
return dtype.type(val)


def make_valid_scalar_add_data():
def make_scalar_add_data():
valid = set()

# to any int, we may add any kind of
Expand Down Expand Up @@ -968,7 +968,7 @@ def make_invalid_scalar_add_data():
return sorted(list(invalid))


@pytest.mark.parametrize("dtype_l,dtype_r", make_valid_scalar_add_data())
@pytest.mark.parametrize("dtype_l,dtype_r", make_scalar_add_data())
def test_scalar_add(dtype_l, dtype_r):
test_value = 1

Expand Down Expand Up @@ -1481,6 +1481,41 @@ def test_scalar_power_invalid(dtype_l, dtype_r):
lval_gpu**rval_gpu


def make_scalar_null_binops_data():
return (
[(operator.add, *dtypes) for dtypes in make_scalar_add_data()]
+ [(operator.sub, *dtypes) for dtypes in make_scalar_difference_data()]
+ [(operator.mul, *dtypes) for dtypes in make_scalar_product_data()]
+ [(operator.add, *dtypes) for dtypes in make_scalar_add_data()]
+ [
(operator.floordiv, *dtypes)
for dtypes in make_scalar_floordiv_data()
]
+ [
(operator.truediv, *dtypes)
for dtypes in make_scalar_truediv_data()
]
+ [(operator.mod, *dtypes) for dtypes in make_scalar_remainder_data()]
+ [(operator.pow, *dtypes) for dtypes in make_scalar_power_data()]
)


@pytest.mark.parametrize("op,dtype_l,dtype_r", make_scalar_null_binops_data())
def test_scalar_null_binops(op, dtype_l, dtype_r):
lhs = cudf.Scalar(cudf.NA, dtype=dtype_l)
rhs = cudf.Scalar(cudf.NA, dtype=dtype_r)

result = op(lhs, rhs)
assert result.value is cudf.NA

# make sure dtype is the same as had there been a valid scalar
valid_lhs = cudf.Scalar(0, dtype=dtype_l)
valid_rhs = cudf.Scalar(0, dtype=dtype_r)

valid_result = op(valid_lhs, valid_rhs)
assert result.dtype == valid_result.dtype


@pytest.mark.parametrize(
"date_col",
[
Expand Down

0 comments on commit 248e5d9

Please sign in to comment.