Skip to content

Commit

Permalink
refactor: change compiler to required keyword arg
Browse files Browse the repository at this point in the history
  • Loading branch information
gforsyth committed Apr 24, 2023
1 parent 99649dc commit 15737f3
Showing 1 changed file with 31 additions and 52 deletions.
83 changes: 31 additions & 52 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _schema(schema: sch.Schema) -> stt.NamedStruct:
def _expr(
expr: ir.Expr,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
return translate(expr.op(), compiler=compiler, **kwargs)
Expand All @@ -252,7 +252,7 @@ def _expr(
def _literal(
op: ops.Literal,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
dtype = op.output_dtype
Expand Down Expand Up @@ -509,7 +509,7 @@ def _translate_window_bounds(
def alias_op(
op: ops.Alias,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
# For an alias, dispatch on the underlying argument
Expand All @@ -520,11 +520,9 @@ def alias_op(
def value_op(
op: ops.ValueOp,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
if compiler is None:
raise ValueError
# Check if scalar function is valid for input dtype(s) and insert casts as needed to
# make sure inputs are correct.
op = _check_and_upcast(op)
Expand All @@ -548,11 +546,9 @@ def value_op(
def window_op(
op: ops.WindowOp,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
if compiler is None:
raise ValueError
lower_bound, upper_bound = _translate_window_bounds(
op.window.preceding, op.window.following
)
Expand Down Expand Up @@ -584,11 +580,9 @@ def window_op(
def _reduction(
op: ops.Reduction,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.AggregateFunction:
if compiler is None:
raise ValueError
return stalg.AggregateFunction(
function_reference=compiler.function_id(op),
arguments=[
Expand All @@ -605,11 +599,9 @@ def _reduction(
def _count(
op: ops.Count,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.AggregateFunction:
if compiler is None:
raise ValueError
translated_args = []
# TODO: remove this expr
arg = op.arg.op().to_expr()
Expand All @@ -631,11 +623,9 @@ def _count(
def _variance_base(
op: ops.StandardDev | ops.Variance,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.AggregateFunction:
if compiler is None:
raise ValueError
translated_arg = stalg.FunctionArgument(
value=translate(op.arg.op(), compiler=compiler, **kwargs)
)
Expand All @@ -656,7 +646,7 @@ def _variance_base(
def sort_key(
op: ops.SortKey,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.SortField:
ordering = "ASC" if op.ascending else "DESC"
Expand All @@ -673,7 +663,7 @@ def sort_key(
def table_column(
op: ops.TableColumn,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
child_rel_field_offsets: MutableMapping[ops.TableNode, int] | None = None,
**kwargs: Any,
) -> stalg.Expression:
Expand All @@ -697,7 +687,7 @@ def table_column(
def struct_field(
op: ops.StructField,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
child = translate(op.arg, compiler=compiler, **kwargs)
Expand Down Expand Up @@ -728,7 +718,7 @@ def struct_field(
def unbound_table(
op: ops.UnboundTable,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Rel:
return stalg.Rel(
Expand Down Expand Up @@ -778,7 +768,7 @@ def _get_child_relation_field_offsets(table: ir.TableExpr) -> dict[ops.TableNode
def selection(
op: ops.Selection,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
child_rel_field_offsets: Mapping[ops.TableNode, int] | None = None,
**kwargs: Any,
) -> stalg.Rel:
Expand Down Expand Up @@ -968,7 +958,7 @@ def _translate_anti_join(_: ops.LeftAntiJoin) -> stalg.JoinRel.JoinType.V:
def join(
op: ops.Join,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Rel:
child_rel_field_offsets = kwargs.pop("child_rel_field_offsets", None)
Expand All @@ -995,7 +985,7 @@ def join(
def limit(
op: ops.Limit,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Rel:
return stalg.Rel(
Expand Down Expand Up @@ -1035,7 +1025,7 @@ def set_op_type_difference(op: ops.Difference) -> stalg.SetRel.SetOp.V:
def set_op(
op: ops.SetOp,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Rel:
return stalg.Rel(
Expand All @@ -1053,7 +1043,7 @@ def set_op(
def aggregation(
op: ops.Aggregation,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Rel:
if op.having:
Expand Down Expand Up @@ -1101,7 +1091,7 @@ def aggregation(
def _simple_searched_case(
op: ops.SimpleCase | ops.SearchedCase,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
# the field names for an `if_then` are `if` and `else` which means we need
Expand All @@ -1124,7 +1114,7 @@ def _simple_searched_case(
def _where(
op: ops.Where,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
# the field names for an `if_then` are `if` and `else` which means we need
Expand All @@ -1144,7 +1134,7 @@ def _where(
def _contains(
op: ops.Contains,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
options = (
Expand All @@ -1166,7 +1156,7 @@ def _contains(
def _cast(
op: ops.Cast,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
return stalg.Expression(
Expand All @@ -1182,7 +1172,7 @@ def _cast(
def _extractdatefield(
op: ops.ExtractDateField,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
# e.g. "ExtractYear" -> "YEAR"
Expand All @@ -1192,9 +1182,6 @@ def _extractdatefield(
for arg in op.args
if isinstance(arg, (ir.Expr, ops.Value))
)
if compiler is None:
raise ValueError

scalar_func = stalg.Expression.ScalarFunction(
function_reference=compiler.function_id(op),
output_type=translate(op.output_dtype),
Expand All @@ -1208,12 +1195,9 @@ def _extractdatefield(
def _log(
op: ops.Log,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
if compiler is None:
raise ValueError

arg = stalg.FunctionArgument(value=translate(op.arg, compiler=compiler, **kwargs))
base = stalg.FunctionArgument(
value=translate(
Expand All @@ -1235,7 +1219,7 @@ def _log(
def _floordivide(
op: ops.FloorDivide,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
left, right = op.left, op.right
Expand All @@ -1246,7 +1230,7 @@ def _floordivide(
def _clip(
op: ops.Clip,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
arg, lower, upper = op.arg, op.lower, op.upper
Expand Down Expand Up @@ -1276,7 +1260,7 @@ def _clip(
def _table_array_view(
op: ops.TableArrayView,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
return translate(op.table, compiler=compiler, **kwargs)
Expand All @@ -1286,7 +1270,7 @@ def _table_array_view(
def _self_reference(
op: ops.SelfReference,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
return translate(op.table, compiler=compiler, **kwargs)
Expand All @@ -1296,7 +1280,7 @@ def _self_reference(
def _exists_subquery(
op: ops.ExistsSubquery,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
predicates = [pred.op().to_expr() for pred in op.predicates]
Expand Down Expand Up @@ -1325,7 +1309,7 @@ def _exists_subquery(
def _not_exists_subquery(
op: ops.NotExistsSubquery,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
assert compiler is not None
Expand Down Expand Up @@ -1366,11 +1350,9 @@ def _not_exists_subquery(
def _floor_ceil_cast(
op: ops.Floor,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
if compiler is None:
raise ValueError
output_type = translate(op.output_dtype)
input = stalg.Expression(
scalar_function=stalg.Expression.ScalarFunction(
Expand Down Expand Up @@ -1398,12 +1380,9 @@ def _floor_ceil_cast(
def _elementwise_udf(
op: ops.ElementWiseVectorizedUDF,
*,
compiler: SubstraitCompiler | None = None,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
if compiler is None:
raise ValueError

if compiler.udf_uri is None:
raise ValueError(
"""
Expand Down

0 comments on commit 15737f3

Please sign in to comment.