diff --git a/dask_sql/context.py b/dask_sql/context.py index adce9eaf1..c3a66de6a 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -418,6 +418,7 @@ def sql( sql: str, return_futures: bool = True, dataframes: Dict[str, Union[dd.DataFrame, pd.DataFrame]] = None, + gpu: bool = False, ) -> Union[dd.DataFrame, pd.DataFrame]: """ Query the registered tables with the given SQL. @@ -443,6 +444,8 @@ def sql( Defaults to returning the dask dataframe. dataframes (:obj:`Dict[str, dask.dataframe.DataFrame]`): additional Dask or pandas dataframes to register before executing this query + gpu (:obj:`bool`): Whether or not to load the additional Dask or pandas dataframes (if any) on GPU; + requires cuDF / dask-cuDF if enabled. Defaults to False. Returns: :obj:`dask.dataframe.DataFrame`: the created data frame of this query. @@ -450,7 +453,7 @@ def sql( """ if dataframes is not None: for df_name, df in dataframes.items(): - self.create_table(df_name, df) + self.create_table(df_name, df, gpu=gpu) rel, select_names, _ = self._get_ral(sql) @@ -477,7 +480,10 @@ def sql( return df def explain( - self, sql: str, dataframes: Dict[str, Union[dd.DataFrame, pd.DataFrame]] = None + self, + sql: str, + dataframes: Dict[str, Union[dd.DataFrame, pd.DataFrame]] = None, + gpu: bool = False, ) -> str: """ Return the stringified relational algebra that this query will produce @@ -492,6 +498,8 @@ def explain( sql (:obj:`str`): The query string to use dataframes (:obj:`Dict[str, dask.dataframe.DataFrame]`): additional Dask or pandas dataframes to register before executing this query + gpu (:obj:`bool`): Whether or not to load the additional Dask or pandas dataframes (if any) on GPU; + requires cuDF / dask-cuDF if enabled. Defaults to False. Returns: :obj:`str`: a description of the created relational algebra. @@ -499,7 +507,7 @@ def explain( """ if dataframes is not None: for df_name, df in dataframes.items(): - self.create_table(df_name, df) + self.create_table(df_name, df, gpu=gpu) _, _, rel_string = self._get_ral(sql) return rel_string diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index a9b0d3fe6..b84f9fa11 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -78,11 +78,8 @@ def test_explain(gpu): data_frame = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}), npartitions=1) - if gpu: - data_frame = dask_cudf.from_dask_dataframe(data_frame) - sql_string = c.explain( - "SELECT * FROM other_df", dataframes={"other_df": data_frame} + "SELECT * FROM other_df", dataframes={"other_df": data_frame}, gpu=gpu ) assert sql_string.startswith( @@ -107,9 +104,9 @@ def test_sql(gpu): assert isinstance(result, pd.DataFrame if not gpu else cudf.DataFrame) dd.assert_eq(result, data_frame) - if gpu: - data_frame = dask_cudf.from_dask_dataframe(data_frame) - result = c.sql("SELECT * FROM other_df", dataframes={"other_df": data_frame}) + result = c.sql( + "SELECT * FROM other_df", dataframes={"other_df": data_frame}, gpu=gpu + ) assert isinstance(result, dd.DataFrame if not gpu else dask_cudf.DataFrame) dd.assert_eq(result, data_frame)