Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for gpu kwarg in Context.sql and explain #368

Merged
merged 3 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def sql(
sql: str,
return_futures: bool = True,
dataframes: Dict[str, Union[dd.DataFrame, pd.DataFrame]] = None,
gpu: bool = False,
) -> Union[dd.DataFrame, pd.DataFrame]:
"""
Query the registered tables with the given SQL.
Expand All @@ -443,14 +444,16 @@ def sql(
Defaults to returning the dask dataframe.
dataframes (:obj:`Dict[str, dask.dataframe.DataFrame]`): additional Dask or pandas dataframes
to register before executing this query
gpu (:obj:`bool`): Whether or not to load the additional Dask or pandas dataframes (if any) on GPU;
requires cuDF / dask-cuDF if enabled. Defaults to False.

Returns:
:obj:`dask.dataframe.DataFrame`: the created data frame of this query.

"""
if dataframes is not None:
for df_name, df in dataframes.items():
self.create_table(df_name, df)
self.create_table(df_name, df, gpu=gpu)

rel, select_names, _ = self._get_ral(sql)

Expand All @@ -477,7 +480,10 @@ def sql(
return df

def explain(
self, sql: str, dataframes: Dict[str, Union[dd.DataFrame, pd.DataFrame]] = None
self,
sql: str,
dataframes: Dict[str, Union[dd.DataFrame, pd.DataFrame]] = None,
gpu: bool = False,
) -> str:
"""
Return the stringified relational algebra that this query will produce
Expand All @@ -492,14 +498,16 @@ def explain(
sql (:obj:`str`): The query string to use
dataframes (:obj:`Dict[str, dask.dataframe.DataFrame]`): additional Dask or pandas dataframes
to register before executing this query
gpu (:obj:`bool`): Whether or not to load the additional Dask or pandas dataframes (if any) on GPU;
requires cuDF / dask-cuDF if enabled. Defaults to False.

Returns:
:obj:`str`: a description of the created relational algebra.

"""
if dataframes is not None:
for df_name, df in dataframes.items():
self.create_table(df_name, df)
self.create_table(df_name, df, gpu=gpu)

_, _, rel_string = self._get_ral(sql)
return rel_string
Expand Down
11 changes: 4 additions & 7 deletions tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,8 @@ def test_explain(gpu):

data_frame = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}), npartitions=1)

if gpu:
data_frame = dask_cudf.from_dask_dataframe(data_frame)

sql_string = c.explain(
"SELECT * FROM other_df", dataframes={"other_df": data_frame}
"SELECT * FROM other_df", dataframes={"other_df": data_frame}, gpu=gpu
)

assert sql_string.startswith(
Expand All @@ -107,9 +104,9 @@ def test_sql(gpu):
assert isinstance(result, pd.DataFrame if not gpu else cudf.DataFrame)
dd.assert_eq(result, data_frame)

if gpu:
data_frame = dask_cudf.from_dask_dataframe(data_frame)
result = c.sql("SELECT * FROM other_df", dataframes={"other_df": data_frame})
result = c.sql(
"SELECT * FROM other_df", dataframes={"other_df": data_frame}, gpu=gpu
)
assert isinstance(result, dd.DataFrame if not gpu else dask_cudf.DataFrame)
dd.assert_eq(result, data_frame)

Expand Down