From 3216342f01d198cfbe2ef9e2ac861674414dc493 Mon Sep 17 00:00:00 2001 From: brandon-b-miller <53796099+brandon-b-miller@users.noreply.github.com> Date: Tue, 11 Jan 2022 17:04:59 -0600 Subject: [PATCH] Raise in `query` if dtype is not supported (#9921) Closes https://github.com/rapidsai/cudf/issues/9894 Authors: - https://github.com/brandon-b-miller Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/9921 --- python/cudf/cudf/tests/test_query.py | 23 ++++++++++++++++++++ python/cudf/cudf/utils/queryutils.py | 32 ++++++++++++++++++++++------ 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/python/cudf/cudf/tests/test_query.py b/python/cudf/cudf/tests/test_query.py index 9a02d5145bb..3de38b2cf6f 100644 --- a/python/cudf/cudf/tests/test_query.py +++ b/python/cudf/cudf/tests/test_query.py @@ -209,3 +209,26 @@ def test_query_with_index_keyword(query, a_val, b_val, c_val): expect = pdf.query(query) assert_eq(out, expect) + + +@pytest.mark.parametrize( + "data, query", + [ + # Only need to test the dtypes that pandas + # supports but that we do not + (["a", "b", "c"], "data == 'a'"), + ], +) +def test_query_unsupported_dtypes(data, query): + gdf = cudf.DataFrame({"data": data}) + + # make sure the query works in pandas + pdf = gdf.to_pandas() + pdf_result = pdf.query(query) + + expect = pd.DataFrame({"data": ["a"]}) + assert_eq(expect, pdf_result) + + # but fails in cuDF + with pytest.raises(TypeError): + gdf.query(query) diff --git a/python/cudf/cudf/utils/queryutils.py b/python/cudf/cudf/utils/queryutils.py index 217466a5a1b..d9153c2b1d2 100644 --- a/python/cudf/cudf/utils/queryutils.py +++ b/python/cudf/cudf/utils/queryutils.py @@ -10,9 +10,20 @@ import cudf from cudf.core.column import column_empty from cudf.utils import applyutils +from cudf.utils.dtypes import ( + BOOL_TYPES, + DATETIME_TYPES, + NUMERIC_TYPES, + TIMEDELTA_TYPES, +) ENVREF_PREFIX = "__CUDF_ENVREF__" +SUPPORTED_QUERY_TYPES = { + np.dtype(dt) + for dt in NUMERIC_TYPES | DATETIME_TYPES | TIMEDELTA_TYPES | BOOL_TYPES +} + class QuerySyntaxError(ValueError): pass @@ -197,6 +208,20 @@ def query_execute(df, expr, callenv): # compile compiled = query_compile(expr) + columns = compiled["colnames"] + + # prepare col args + colarrays = [cudf.core.dataframe.extract_col(df, col) for col in columns] + + # wait to check the types until we know which cols are used + if any(col.dtype not in SUPPORTED_QUERY_TYPES for col in colarrays): + raise TypeError( + "query only supports numeric, datetime, timedelta, " + "or bool dtypes." + ) + + colarrays = [col.data_array_view for col in colarrays] + kernel = compiled["kernel"] # process env args envargs = [] @@ -214,13 +239,6 @@ def query_execute(df, expr, callenv): raise NameError(msg.format(name)) else: envargs.append(val) - columns = compiled["colnames"] - # prepare col args - - colarrays = [ - cudf.core.dataframe.extract_col(df, col).data_array_view - for col in columns - ] # allocate output buffer nrows = len(df)