Skip to content

Commit

Permalink
feat(autocast): add autocast for digits argument to ops.Round
Browse files Browse the repository at this point in the history
Ibis literals default to Int8, which makes this a pain for users to
type. Since substrait only allows Int32 for the `digits` arg, we
autocast to int32 from any int type
  • Loading branch information
gforsyth committed Dec 1, 2023
1 parent 4fd5264 commit f0d4940
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
14 changes: 14 additions & 0 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,3 +1502,17 @@ def _upcast_string_op(op: string_op) -> string_op:
for newop in op.args
]
return type(op)(*casted_args)


@_upcast.register(ops.Round)
def _upcast_round_digits(op: ops.Round) -> ops.Round:
# Substrait wants Int32 for decimal place argument to round
if op.digits is None:
raise ValueError(
"Substrait requires that a rounding operation specify the number of digits to round to"
)
elif not isinstance(op.digits.dtype, dt.Int32):
return ops.Round(
op.arg, op.digits.copy(dtype=dt.Int32(nullable=op.digits.dtype.nullable))
)
return op
27 changes: 27 additions & 0 deletions ibis_substrait/tests/compiler/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,30 @@ def test_extension_arithmetic_multiple_signatures(compiler):
assert "add:fp32_fp32" in scalar_func_names
assert "subtract:i64_i64" in scalar_func_names
assert "subtract:fp32_fp32" in scalar_func_names


_TYPE_MAPPING = {
"int8": "i8",
"int16": "i16",
"int32": "i32",
"int64": "i64",
"float32": "fp32",
"float64": "fp64",
}


@pytest.mark.parametrize(
"col_dtype", ["float32", "float64", "int8", "int16", "int32", "int64"]
)
@pytest.mark.parametrize("digits_dtype", ["int8", "int16", "int32", "int64"])
def test_extension_round_upcast(compiler, col_dtype, digits_dtype):
t = ibis.table([("col", col_dtype)], name="t")

query = t.mutate(col=t.col.round(ibis.literal(8, type=digits_dtype)))
plan = compiler.compile(query)

scalar_func_names = [
extension.extension_function.name for extension in plan.extensions
]

assert f"round:{_TYPE_MAPPING[col_dtype]}_i32" in scalar_func_names

0 comments on commit f0d4940

Please sign in to comment.