Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable binary operations between scalars and columns of differing decimal types #13034

Merged
14 changes: 8 additions & 6 deletions python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,22 +144,20 @@ def normalize_binop_value(self, other):
elif not isinstance(self.dtype, other.dtype.__class__):
# This branch occurs if we have a DecimalBaseColumn of a
# different size (e.g. 64 instead of 32).
if (
self.dtype.precision == other.dtype.precision
and self.dtype.scale == other.dtype.scale
):
if _same_precision_and_scale(self.dtype, other.dtype):
other = other.astype(self.dtype)

return other
if isinstance(other, cudf.Scalar) and isinstance(
# TODO: Should it be possible to cast scalars of other numerical
# types to decimal?
other.dtype,
cudf.core.dtypes.DecimalDtype,
):
if _same_precision_and_scale(self.dtype, other.dtype):
other = other.astype(self.dtype)
return other
elif is_scalar(other) and isinstance(other, (int, Decimal)):
return cudf.Scalar(Decimal(other))
return cudf.Scalar(Decimal(other), dtype=self.dtype)
return NotImplemented

def _decimal_quantile(
Expand Down Expand Up @@ -404,3 +402,7 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op):
)
precision = min(cudf.Decimal128Dtype.MAX_PRECISION, max_precision)
return cudf.Decimal128Dtype(precision=precision, scale=scale)


def _same_precision_and_scale(lhs: DecimalDtype, rhs: DecimalDtype) -> bool:
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
return lhs.precision == rhs.precision and lhs.scale == rhs.scale
4 changes: 0 additions & 4 deletions python/cudf/cudf/tests/test_binops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3012,10 +3012,6 @@ def decimal_series(input, dtype):
],
)
@pytest.mark.parametrize("reflected", [True, False])
@pytest_xfail(
reason="binop operations not supported for different bit-width "
"decimal types"
)
def test_binops_decimal_scalar_compare(args, reflected):
"""
Tested compare operations:
Expand Down