Skip to content

Commit

Permalink
Raise in query if dtype is not supported (#9921)
Browse files Browse the repository at this point in the history
Closes #9894

Authors:
  - https://github.com/brandon-b-miller

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #9921
  • Loading branch information
brandon-b-miller authored Jan 11, 2022
1 parent 25a7485 commit 3216342
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
23 changes: 23 additions & 0 deletions python/cudf/cudf/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,26 @@ def test_query_with_index_keyword(query, a_val, b_val, c_val):
expect = pdf.query(query)

assert_eq(out, expect)


@pytest.mark.parametrize(
"data, query",
[
# Only need to test the dtypes that pandas
# supports but that we do not
(["a", "b", "c"], "data == 'a'"),
],
)
def test_query_unsupported_dtypes(data, query):
gdf = cudf.DataFrame({"data": data})

# make sure the query works in pandas
pdf = gdf.to_pandas()
pdf_result = pdf.query(query)

expect = pd.DataFrame({"data": ["a"]})
assert_eq(expect, pdf_result)

# but fails in cuDF
with pytest.raises(TypeError):
gdf.query(query)
32 changes: 25 additions & 7 deletions python/cudf/cudf/utils/queryutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,20 @@
import cudf
from cudf.core.column import column_empty
from cudf.utils import applyutils
from cudf.utils.dtypes import (
BOOL_TYPES,
DATETIME_TYPES,
NUMERIC_TYPES,
TIMEDELTA_TYPES,
)

ENVREF_PREFIX = "__CUDF_ENVREF__"

SUPPORTED_QUERY_TYPES = {
np.dtype(dt)
for dt in NUMERIC_TYPES | DATETIME_TYPES | TIMEDELTA_TYPES | BOOL_TYPES
}


class QuerySyntaxError(ValueError):
pass
Expand Down Expand Up @@ -197,6 +208,20 @@ def query_execute(df, expr, callenv):

# compile
compiled = query_compile(expr)
columns = compiled["colnames"]

# prepare col args
colarrays = [cudf.core.dataframe.extract_col(df, col) for col in columns]

# wait to check the types until we know which cols are used
if any(col.dtype not in SUPPORTED_QUERY_TYPES for col in colarrays):
raise TypeError(
"query only supports numeric, datetime, timedelta, "
"or bool dtypes."
)

colarrays = [col.data_array_view for col in colarrays]

kernel = compiled["kernel"]
# process env args
envargs = []
Expand All @@ -214,13 +239,6 @@ def query_execute(df, expr, callenv):
raise NameError(msg.format(name))
else:
envargs.append(val)
columns = compiled["colnames"]
# prepare col args

colarrays = [
cudf.core.dataframe.extract_col(df, col).data_array_view
for col in columns
]

# allocate output buffer
nrows = len(df)
Expand Down

0 comments on commit 3216342

Please sign in to comment.