diff --git a/dask_sql/context.py b/dask_sql/context.py index 98cc46e21..6dc1850f1 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -748,7 +748,7 @@ def _prepare_schemas(self): logger.debug("No custom functions defined.") for function_description in schema.function_lists: name = function_description.name - sql_return_type = python_to_sql_type(function_description.return_type) + sql_return_type = function_description.return_type if function_description.aggregation: logger.debug(f"Adding function '{name}' to schema as aggregation.") dask_function = DaskAggregateFunction(name, sql_return_type) @@ -771,10 +771,7 @@ def _prepare_schemas(self): @staticmethod def _add_parameters_from_description(function_description, dask_function): for parameter in function_description.parameters: - param_name, param_type = parameter - sql_param_type = python_to_sql_type(param_type) - - dask_function.addParameter(param_name, sql_param_type, False) + dask_function.addParameter(*parameter, False) return dask_function @@ -898,9 +895,16 @@ def _register_callable( row_udf: bool = False, ): """Helper function to do the function or aggregation registration""" + schema_name = schema_name or self.schema_name schema = self.schema[schema_name] + # validate and cache UDF metadata + sql_parameters = [ + (name, python_to_sql_type(param_type)) for name, param_type in parameters + ] + sql_return_type = python_to_sql_type(return_type) + if not aggregation: f = UDF(f, row_udf, parameters, return_type) lower_name = name.lower() @@ -920,9 +924,13 @@ def _register_callable( ) schema.function_lists.append( - FunctionDescription(name.upper(), parameters, return_type, aggregation) + FunctionDescription( + name.upper(), sql_parameters, sql_return_type, aggregation + ) ) schema.function_lists.append( - FunctionDescription(name.lower(), parameters, return_type, aggregation) + FunctionDescription( + name.lower(), sql_parameters, sql_return_type, aggregation + ) ) schema.functions[lower_name] = f diff --git a/dask_sql/datacontainer.py b/dask_sql/datacontainer.py index db77c9dfc..f81952e68 100644 --- a/dask_sql/datacontainer.py +++ b/dask_sql/datacontainer.py @@ -198,11 +198,6 @@ def __init__(self, func, row_udf: bool, params, return_type=None): self.names = [param[0] for param in params] - if return_type is None: - # These UDFs go through apply and without providing - # a return type, dask will attempt to guess it, and - # dask might be wrong. - raise ValueError("Return type must be provided") self.meta = (None, return_type) def __call__(self, *args, **kwargs): @@ -218,7 +213,6 @@ def __call__(self, *args, **kwargs): df = column_args[0].to_frame(self.names[0]) for name, col in zip(self.names[1:], column_args[1:]): df[name] = col - result = df.apply( self.func, axis=1, args=tuple(scalar_args), meta=self.meta ).astype(self.meta[1]) diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index e59025918..623c38a37 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -88,6 +88,11 @@ def python_to_sql_type(python_type): """Mapping between python and SQL types.""" + if python_type in (int, float): + python_type = np.dtype(python_type) + elif python_type is str: + python_type = np.dtype("object") + if isinstance(python_type, np.dtype): python_type = python_type.type diff --git a/tests/integration/test_function.py b/tests/integration/test_function.py index 92fc58b14..d8ba40c0f 100644 --- a/tests/integration/test_function.py +++ b/tests/integration/test_function.py @@ -52,17 +52,12 @@ def f(row): @pytest.mark.parametrize( "retty", - [None, np.float64, np.float32, np.int64, np.int32, np.int16, np.int8, np.bool_], + [np.float64, np.float32, np.int64, np.int32, np.int16, np.int8, np.bool_], ) def test_custom_function_row_return_types(c, df, retty): def f(row): return row["x"] ** 2 - if retty is None: - with pytest.raises(ValueError): - c.register_function(f, "f", [("x", np.float64)], retty, row_udf=True) - return - c.register_function(f, "f", [("x", np.float64)], retty, row_udf=True) return_df = c.sql("SELECT F(a) AS a FROM df") @@ -199,3 +194,17 @@ def f(x): c.register_aggregation(fagg, "fagg", [("x", np.float64)], np.float64) c.register_aggregation(fagg, "fagg", [("x", np.float64)], np.float64, replace=True) + + +@pytest.mark.parametrize("dtype", [np.timedelta64, None, "a string"]) +def test_unsupported_dtype(c, dtype): + def f(x): + return x**2 + + # test that an invalid return type raises + with pytest.raises(NotImplementedError): + c.register_function(f, "f", [("x", np.int64)], dtype) + + # test that an invalid param type raises + with pytest.raises(NotImplementedError): + c.register_function(f, "f", [("x", dtype)], np.int64) diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index b8cfa6504..fca5c7454 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -6,6 +6,7 @@ from dask_sql import Context from dask_sql.datacontainer import Statistics +from dask_sql.mappings import python_to_sql_type from tests.utils import assert_eq try: @@ -198,6 +199,11 @@ def g(gpu=gpu): g(gpu=gpu) +int_sql_type = python_to_sql_type(int) +float_sql_type = python_to_sql_type(float) +str_sql_type = python_to_sql_type(str) + + def test_function_adding(): c = Context() @@ -211,12 +217,12 @@ def test_function_adding(): assert c.schema[c.schema_name].functions["f"].func == f assert len(c.schema[c.schema_name].function_lists) == 2 assert c.schema[c.schema_name].function_lists[0].name == "F" - assert c.schema[c.schema_name].function_lists[0].parameters == [("x", int)] - assert c.schema[c.schema_name].function_lists[0].return_type == float + assert c.schema[c.schema_name].function_lists[0].parameters == [("x", int_sql_type)] + assert c.schema[c.schema_name].function_lists[0].return_type == float_sql_type assert not c.schema[c.schema_name].function_lists[0].aggregation assert c.schema[c.schema_name].function_lists[1].name == "f" - assert c.schema[c.schema_name].function_lists[1].parameters == [("x", int)] - assert c.schema[c.schema_name].function_lists[1].return_type == float + assert c.schema[c.schema_name].function_lists[1].parameters == [("x", int_sql_type)] + assert c.schema[c.schema_name].function_lists[1].return_type == float_sql_type assert not c.schema[c.schema_name].function_lists[1].aggregation # Without replacement @@ -226,12 +232,16 @@ def test_function_adding(): assert c.schema[c.schema_name].functions["f"].func == f assert len(c.schema[c.schema_name].function_lists) == 4 assert c.schema[c.schema_name].function_lists[2].name == "F" - assert c.schema[c.schema_name].function_lists[2].parameters == [("x", float)] - assert c.schema[c.schema_name].function_lists[2].return_type == int + assert c.schema[c.schema_name].function_lists[2].parameters == [ + ("x", float_sql_type) + ] + assert c.schema[c.schema_name].function_lists[2].return_type == int_sql_type assert not c.schema[c.schema_name].function_lists[2].aggregation assert c.schema[c.schema_name].function_lists[3].name == "f" - assert c.schema[c.schema_name].function_lists[3].parameters == [("x", float)] - assert c.schema[c.schema_name].function_lists[3].return_type == int + assert c.schema[c.schema_name].function_lists[3].parameters == [ + ("x", float_sql_type) + ] + assert c.schema[c.schema_name].function_lists[3].return_type == int_sql_type assert not c.schema[c.schema_name].function_lists[3].aggregation # With replacement @@ -242,12 +252,12 @@ def test_function_adding(): assert c.schema[c.schema_name].functions["f"].func == f assert len(c.schema[c.schema_name].function_lists) == 2 assert c.schema[c.schema_name].function_lists[0].name == "F" - assert c.schema[c.schema_name].function_lists[0].parameters == [("x", str)] - assert c.schema[c.schema_name].function_lists[0].return_type == str + assert c.schema[c.schema_name].function_lists[0].parameters == [("x", str_sql_type)] + assert c.schema[c.schema_name].function_lists[0].return_type == str_sql_type assert not c.schema[c.schema_name].function_lists[0].aggregation assert c.schema[c.schema_name].function_lists[1].name == "f" - assert c.schema[c.schema_name].function_lists[1].parameters == [("x", str)] - assert c.schema[c.schema_name].function_lists[1].return_type == str + assert c.schema[c.schema_name].function_lists[1].parameters == [("x", str_sql_type)] + assert c.schema[c.schema_name].function_lists[1].return_type == str_sql_type assert not c.schema[c.schema_name].function_lists[1].aggregation @@ -264,12 +274,12 @@ def test_aggregation_adding(): assert c.schema[c.schema_name].functions["f"] == f assert len(c.schema[c.schema_name].function_lists) == 2 assert c.schema[c.schema_name].function_lists[0].name == "F" - assert c.schema[c.schema_name].function_lists[0].parameters == [("x", int)] - assert c.schema[c.schema_name].function_lists[0].return_type == float + assert c.schema[c.schema_name].function_lists[0].parameters == [("x", int_sql_type)] + assert c.schema[c.schema_name].function_lists[0].return_type == float_sql_type assert c.schema[c.schema_name].function_lists[0].aggregation assert c.schema[c.schema_name].function_lists[1].name == "f" - assert c.schema[c.schema_name].function_lists[1].parameters == [("x", int)] - assert c.schema[c.schema_name].function_lists[1].return_type == float + assert c.schema[c.schema_name].function_lists[1].parameters == [("x", int_sql_type)] + assert c.schema[c.schema_name].function_lists[1].return_type == float_sql_type assert c.schema[c.schema_name].function_lists[1].aggregation # Without replacement @@ -279,12 +289,16 @@ def test_aggregation_adding(): assert c.schema[c.schema_name].functions["f"] == f assert len(c.schema[c.schema_name].function_lists) == 4 assert c.schema[c.schema_name].function_lists[2].name == "F" - assert c.schema[c.schema_name].function_lists[2].parameters == [("x", float)] - assert c.schema[c.schema_name].function_lists[2].return_type == int + assert c.schema[c.schema_name].function_lists[2].parameters == [ + ("x", float_sql_type) + ] + assert c.schema[c.schema_name].function_lists[2].return_type == int_sql_type assert c.schema[c.schema_name].function_lists[2].aggregation assert c.schema[c.schema_name].function_lists[3].name == "f" - assert c.schema[c.schema_name].function_lists[3].parameters == [("x", float)] - assert c.schema[c.schema_name].function_lists[3].return_type == int + assert c.schema[c.schema_name].function_lists[3].parameters == [ + ("x", float_sql_type) + ] + assert c.schema[c.schema_name].function_lists[3].return_type == int_sql_type assert c.schema[c.schema_name].function_lists[3].aggregation # With replacement @@ -295,12 +309,12 @@ def test_aggregation_adding(): assert c.schema[c.schema_name].functions["f"] == f assert len(c.schema[c.schema_name].function_lists) == 2 assert c.schema[c.schema_name].function_lists[0].name == "F" - assert c.schema[c.schema_name].function_lists[0].parameters == [("x", str)] - assert c.schema[c.schema_name].function_lists[0].return_type == str + assert c.schema[c.schema_name].function_lists[0].parameters == [("x", str_sql_type)] + assert c.schema[c.schema_name].function_lists[0].return_type == str_sql_type assert c.schema[c.schema_name].function_lists[0].aggregation assert c.schema[c.schema_name].function_lists[1].name == "f" - assert c.schema[c.schema_name].function_lists[1].parameters == [("x", str)] - assert c.schema[c.schema_name].function_lists[1].return_type == str + assert c.schema[c.schema_name].function_lists[1].parameters == [("x", str_sql_type)] + assert c.schema[c.schema_name].function_lists[1].return_type == str_sql_type assert c.schema[c.schema_name].function_lists[1].aggregation