Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix use of row UDFs at intermediate query stages #409

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,8 +964,7 @@ def _register_callable(
schema = self.schema[schema_name]

if not aggregation:
f = UDF(f, row_udf, return_type)

f = UDF(f, row_udf, parameters, return_type)
lower_name = name.lower()
if lower_name in schema.functions:
if replace:
Expand Down
10 changes: 7 additions & 3 deletions dask_sql/datacontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def assign(self) -> dd.DataFrame:


class UDF:
def __init__(self, func, row_udf: bool, return_type=None):
def __init__(self, func, row_udf: bool, params, return_type=None):
"""
Helper class that handles different types of UDFs and manages
how they should be mapped to dask operations. Two versions of
Expand All @@ -196,6 +196,8 @@ def __init__(self, func, row_udf: bool, return_type=None):
self.row_udf = row_udf
self.func = func

self.names = [param[0] for param in params]

if return_type is None:
# These UDFs go through apply and without providing
# a return type, dask will attempt to guess it, and
Expand All @@ -212,9 +214,11 @@ def __call__(self, *args, **kwargs):
column_args.append(operand)
else:
scalar_args.append(operand)

df = column_args[0].to_frame()
charlesbluca marked this conversation as resolved.
Show resolved Hide resolved
for col in column_args[1:]:
df[col.name] = col
for name, col in zip(self.names, column_args):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we pass the first parameter column name to to_frame, we lose a layer off the resulting HLG and don't have to deal with a superfluous column:

            df = column_args[0].to_frame(self.names[0])
            for name, col in zip(self.names[1:], column_args[1:]):

charlesbluca marked this conversation as resolved.
Show resolved Hide resolved
df[name] = col

result = df.apply(
self.func, axis=1, args=tuple(scalar_args), meta=self.meta
).astype(self.meta[1])
Expand Down
15 changes: 15 additions & 0 deletions tests/integration/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ def df_simple():
return pd.DataFrame({"a": [1, 2, 3], "b": [1.1, 2.2, 3.3]})


@pytest.fixture()
def df_wide():
return pd.DataFrame(
{
"a": [0, 1, 2],
"b": [3, 4, 5],
"c": [6, 7, 8],
"d": [9, 10, 11],
"e": [12, 13, 14],
}
)


@pytest.fixture()
def df():
np.random.seed(42)
Expand Down Expand Up @@ -126,6 +139,7 @@ def gpu_datetime_table(datetime_table):
@pytest.fixture()
def c(
df_simple,
df_wide,
df,
user_table_1,
user_table_2,
Expand All @@ -142,6 +156,7 @@ def c(
):
dfs = {
"df_simple": df_simple,
"df_wide": df_wide,
"df": df,
"user_table_1": user_table_1,
"user_table_2": user_table_2,
Expand Down
24 changes: 23 additions & 1 deletion tests/integration/test_function.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import itertools
import operator

import dask.dataframe as dd
import numpy as np
import pytest
from pandas.testing import assert_frame_equal
from pandas.testing import assert_frame_equal, assert_series_equal


def test_custom_function(c, df):
Expand Down Expand Up @@ -40,6 +41,27 @@ def f(row):
assert_frame_equal(return_df.reset_index(drop=True), df[["a"]] ** 2)


@pytest.mark.parametrize("colnames", list(itertools.combinations(["a", "b", "c"], 2)))
def test_custom_function_any_colnames(colnames, df_wide, c):
# a third column is needed

def f(row):
return row["x"] + row["y"]

colname_x, colname_y = colnames
c.register_function(
f, "f", [("x", np.int64), ("y", np.int64)], np.int64, row_udf=True
)

return_df = c.sql(f"SELECT F({colname_x},{colname_y}) FROM df_wide")

return_df = return_df.compute()
expect = df_wide[colname_x] + df_wide[colname_y]
got = return_df[return_df.columns[0]]

assert_series_equal(expect, got, check_names=False)
charlesbluca marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
"retty",
[None, np.float64, np.float32, np.int64, np.int32, np.int16, np.int8, np.bool_],
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_tables(c):
"Table": [
"df",
"df_simple",
"df_wide",
"user_table_1",
"user_table_2",
"long_table",
Expand All @@ -47,6 +48,7 @@ def test_tables(c):
else [
"df",
"df_simple",
"df_wide",
"user_table_1",
"user_table_2",
"long_table",
Expand Down