diff --git a/python/cudf/cudf/_lib/expressions.pxd b/python/cudf/cudf/_lib/expressions.pxd index 85665822174..f93f815c3ec 100644 --- a/python/cudf/cudf/_lib/expressions.pxd +++ b/python/cudf/cudf/_lib/expressions.pxd @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. from libc.stdint cimport int32_t, int64_t from libcpp.memory cimport unique_ptr @@ -9,16 +9,7 @@ from cudf._lib.cpp.expressions cimport ( literal, operation, ) -from cudf._lib.cpp.scalar.scalar cimport numeric_scalar - -ctypedef enum scalar_type_t: - INT - DOUBLE - - -ctypedef union int_or_double_scalar_ptr: - unique_ptr[numeric_scalar[int64_t]] int_ptr - unique_ptr[numeric_scalar[double]] double_ptr +from cudf._lib.cpp.scalar.scalar cimport numeric_scalar, scalar, string_scalar cdef class Expression: @@ -26,8 +17,7 @@ cdef class Expression: cdef class Literal(Expression): - cdef scalar_type_t c_scalar_type - cdef int_or_double_scalar_ptr c_scalar + cdef unique_ptr[scalar] c_scalar cdef class ColumnReference(Expression): diff --git a/python/cudf/cudf/_lib/expressions.pyx b/python/cudf/cudf/_lib/expressions.pyx index 269318240b2..c97aa9e75ee 100644 --- a/python/cudf/cudf/_lib/expressions.pyx +++ b/python/cudf/cudf/_lib/expressions.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. from enum import Enum @@ -77,27 +77,20 @@ class TableReference(Enum): # restrictive at the moment. cdef class Literal(Expression): def __cinit__(self, value): - # TODO: Would love to find a better solution than unions for literals. - cdef int intval - cdef double doubleval - if isinstance(value, int): - self.c_scalar_type = scalar_type_t.INT - intval = value - self.c_scalar.int_ptr = make_unique[numeric_scalar[int64_t]]( - intval, True - ) + self.c_scalar.reset(new numeric_scalar[int64_t](value, True)) self.c_obj = make_unique[libcudf_exp.literal]( - dereference(self.c_scalar.int_ptr) + dereference(self.c_scalar) ) elif isinstance(value, float): - self.c_scalar_type = scalar_type_t.DOUBLE - doubleval = value - self.c_scalar.double_ptr = make_unique[numeric_scalar[double]]( - doubleval, True + self.c_scalar.reset(new numeric_scalar[double](value, True)) + self.c_obj = make_unique[libcudf_exp.literal]( + dereference(self.c_scalar) ) + elif isinstance(value, str): + self.c_scalar.reset(new string_scalar(value.encode(), True)) self.c_obj = make_unique[libcudf_exp.literal]( - dereference(self.c_scalar.double_ptr) + dereference(self.c_scalar) ) diff --git a/python/cudf/cudf/core/_internals/expressions.py b/python/cudf/cudf/core/_internals/expressions.py index bc587d4e1e2..e3c58bd0c8d 100644 --- a/python/cudf/cudf/core/_internals/expressions.py +++ b/python/cudf/cudf/core/_internals/expressions.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. import ast import functools @@ -115,7 +115,7 @@ def visit_Name(self, node): self.stack.append(ColumnReference(col_id)) def visit_Constant(self, node): - if not isinstance(node, ast.Num): + if not isinstance(node, (ast.Num, ast.Str)): raise ValueError( f"Unsupported literal {repr(node.value)} of type " "{type(node.value).__name__}" diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index de324515729..8c8f0119b3f 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -7063,7 +7063,8 @@ def eval(self, expr: str, inplace: bool = False, **kwargs): Specifically, `&` must be used for bitwise operators on integers, not `and`, which is specifically for the logical and between booleans. - * Only numerical types are currently supported. + * Only numerical types currently support all operators. + * String types currently support comparison operators. * Operators generally will not cast automatically. Users are responsible for casting columns to suitable types before evaluating a function. diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 918a10ffd75..b3c8468c119 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -9820,6 +9820,9 @@ def df_eval(request): float, ), ("a_b_are_equal = (a == b)", int), + ("a > b", str), + ("a < '1'", str), + ('a == "1"', str), ], ) def test_dataframe_eval(df_eval, expr, dtype):