From 248e5d98a693e199fd2dc7f61b6824910e9adc69 Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Tue, 10 May 2022 12:44:00 -0700 Subject: [PATCH] add tests for null scalar binaryops --- python/cudf/cudf/tests/test_binops.py | 39 +++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/tests/test_binops.py b/python/cudf/cudf/tests/test_binops.py index 742a3d7cd06..0d1bac6aead 100644 --- a/python/cudf/cudf/tests/test_binops.py +++ b/python/cudf/cudf/tests/test_binops.py @@ -902,7 +902,7 @@ def dtype_scalar(val, dtype): return dtype.type(val) -def make_valid_scalar_add_data(): +def make_scalar_add_data(): valid = set() # to any int, we may add any kind of @@ -968,7 +968,7 @@ def make_invalid_scalar_add_data(): return sorted(list(invalid)) -@pytest.mark.parametrize("dtype_l,dtype_r", make_valid_scalar_add_data()) +@pytest.mark.parametrize("dtype_l,dtype_r", make_scalar_add_data()) def test_scalar_add(dtype_l, dtype_r): test_value = 1 @@ -1481,6 +1481,41 @@ def test_scalar_power_invalid(dtype_l, dtype_r): lval_gpu**rval_gpu +def make_scalar_null_binops_data(): + return ( + [(operator.add, *dtypes) for dtypes in make_scalar_add_data()] + + [(operator.sub, *dtypes) for dtypes in make_scalar_difference_data()] + + [(operator.mul, *dtypes) for dtypes in make_scalar_product_data()] + + [(operator.add, *dtypes) for dtypes in make_scalar_add_data()] + + [ + (operator.floordiv, *dtypes) + for dtypes in make_scalar_floordiv_data() + ] + + [ + (operator.truediv, *dtypes) + for dtypes in make_scalar_truediv_data() + ] + + [(operator.mod, *dtypes) for dtypes in make_scalar_remainder_data()] + + [(operator.pow, *dtypes) for dtypes in make_scalar_power_data()] + ) + + +@pytest.mark.parametrize("op,dtype_l,dtype_r", make_scalar_null_binops_data()) +def test_scalar_null_binops(op, dtype_l, dtype_r): + lhs = cudf.Scalar(cudf.NA, dtype=dtype_l) + rhs = cudf.Scalar(cudf.NA, dtype=dtype_r) + + result = op(lhs, rhs) + assert result.value is cudf.NA + + # make sure dtype is the same as had there been a valid scalar + valid_lhs = cudf.Scalar(0, dtype=dtype_l) + valid_rhs = cudf.Scalar(0, dtype=dtype_r) + + valid_result = op(valid_lhs, valid_rhs) + assert result.dtype == valid_result.dtype + + @pytest.mark.parametrize( "date_col", [