From 246d017669cbeca3570106b4bb52a92f931ea2c1 Mon Sep 17 00:00:00 2001 From: brandon-b-miller <53796099+brandon-b-miller@users.noreply.github.com> Date: Thu, 13 Jun 2024 09:33:43 -0500 Subject: [PATCH] Plumb pylibcudf strings `contains_re` through cudf_polars (#15918) This PR adds cudf-polars code for evaluating the `StringFunction.Contains` expression node. Depends on https://github.com/rapidsai/cudf/pull/15880/ Authors: - https://github.com/brandon-b-miller - Lawrence Mitchell (https://github.com/wence-) Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: https://github.com/rapidsai/cudf/pull/15918 --- python/cudf_polars/cudf_polars/dsl/expr.py | 51 ++++++++++++++++++ python/cudf_polars/tests/test_string.py | 61 ++++++++++++++++++++++ 2 files changed, 112 insertions(+) create mode 100644 python/cudf_polars/tests/test_string.py diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index 298ef5ab070..03c1db68dbd 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -644,13 +644,28 @@ def __init__( self.options = options self.name = name self.children = children + self._validate_input() + + def _validate_input(self): if self.name not in ( pl_expr.StringFunction.Lowercase, pl_expr.StringFunction.Uppercase, pl_expr.StringFunction.EndsWith, pl_expr.StringFunction.StartsWith, + pl_expr.StringFunction.Contains, ): raise NotImplementedError(f"String function {self.name}") + if self.name == pl_expr.StringFunction.Contains: + literal, strict = self.options + if not literal: + if not strict: + raise NotImplementedError( + "f{strict=} is not supported for regex contains" + ) + if not isinstance(self.children[1], Literal): + raise NotImplementedError( + "Regex contains only supports a scalar pattern" + ) def do_evaluate( self, @@ -660,6 +675,26 @@ def do_evaluate( mapping: Mapping[Expr, Column] | None = None, ) -> Column: """Evaluate this expression given a dataframe for context.""" + if self.name == pl_expr.StringFunction.Contains: + child, arg = self.children + column = child.evaluate(df, context=context, mapping=mapping) + + literal, _ = self.options + if literal: + pat = arg.evaluate(df, context=context, mapping=mapping) + pattern = ( + pat.obj_scalar + if pat.is_scalar and pat.obj.size() != column.obj.size() + else pat.obj + ) + return Column(plc.strings.find.contains(column.obj, pattern)) + else: + assert isinstance(arg, Literal) + prog = plc.strings.regex_program.RegexProgram.create( + arg.value.as_py(), + flags=plc.strings.regex_flags.RegexFlags.DEFAULT, + ) + return Column(plc.strings.contains.contains_re(column.obj, prog)) columns = [ child.evaluate(df, context=context, mapping=mapping) for child in self.children @@ -691,6 +726,22 @@ def do_evaluate( ) ) else: + columns = [ + child.evaluate(df, context=context, mapping=mapping) + for child in self.children + ] + if self.name == pl_expr.StringFunction.Lowercase: + (column,) = columns + return Column(plc.strings.case.to_lower(column.obj)) + elif self.name == pl_expr.StringFunction.Uppercase: + (column,) = columns + return Column(plc.strings.case.to_upper(column.obj)) + elif self.name == pl_expr.StringFunction.EndsWith: + column, suffix = columns + return Column(plc.strings.find.ends_with(column.obj, suffix.obj)) + elif self.name == pl_expr.StringFunction.StartsWith: + column, suffix = columns + return Column(plc.strings.find.starts_with(column.obj, suffix.obj)) raise NotImplementedError( f"StringFunction {self.name}" ) # pragma: no cover; handled by init raising diff --git a/python/cudf_polars/tests/test_string.py b/python/cudf_polars/tests/test_string.py new file mode 100644 index 00000000000..f1a080d040f --- /dev/null +++ b/python/cudf_polars/tests/test_string.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from functools import partial + +import pytest + +import polars as pl + +from cudf_polars.callback import execute_with_cudf +from cudf_polars.testing.asserts import assert_gpu_result_equal + + +@pytest.fixture +def ldf(): + return pl.DataFrame( + {"a": ["AbC", "de", "FGHI", "j", "kLm", "nOPq", None, "RsT", None, "uVw"]} + ).lazy() + + +@pytest.mark.parametrize( + "substr", + [ + "A", + "de", + ".*", + "^a", + "^A", + "[^a-z]", + "[a-z]{3,}", + "^[A-Z]{2,}", + "j|u", + ], +) +def test_contains_regex(ldf, substr): + query = ldf.select(pl.col("a").str.contains(substr)) + assert_gpu_result_equal(query) + + +@pytest.mark.parametrize( + "literal", ["A", "de", "FGHI", "j", "kLm", "nOPq", "RsT", "uVw"] +) +def test_contains_literal(ldf, literal): + query = ldf.select(pl.col("a").str.contains(pl.lit(literal), literal=True)) + assert_gpu_result_equal(query) + + +def test_contains_column(ldf): + query = ldf.select(pl.col("a").str.contains(pl.col("a"), literal=True)) + assert_gpu_result_equal(query) + + +@pytest.mark.parametrize("pat", ["["]) +def test_contains_invalid(ldf, pat): + query = ldf.select(pl.col("a").str.contains(pat)) + + with pytest.raises(pl.exceptions.ComputeError): + query.collect() + with pytest.raises(pl.exceptions.ComputeError): + query.collect(post_opt_callback=partial(execute_with_cudf, raise_on_fail=True))