Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plumb pylibcudf strings contains_re through cudf_polars #15918

Merged
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
05d1eb4
initial
brandon-b-miller May 28, 2024
004ae1d
Merge branch 'branch-24.08' into pylibcudf-strings-contains
brandon-b-miller May 28, 2024
9a3f19d
merge/resolve
brandon-b-miller May 29, 2024
87341d4
one test
brandon-b-miller May 29, 2024
6d50191
tests, fixes
brandon-b-miller May 29, 2024
e35cd9a
declaration
brandon-b-miller May 29, 2024
69ad703
Merge branch 'branch-24.08' into pylibcudf-strings-contains
brandon-b-miller May 30, 2024
83178c9
docs, style
brandon-b-miller May 31, 2024
758755c
type create more strongly
brandon-b-miller May 31, 2024
98aeefa
add more tests
brandon-b-miller May 31, 2024
936e412
style
brandon-b-miller May 31, 2024
b15588a
regex program tests
brandon-b-miller May 31, 2024
b5a68c5
Merge branch 'branch-24.08' into pylibcudf-strings-contains
brandon-b-miller Jun 3, 2024
4b6a393
polars contains_re plumbing
brandon-b-miller Jun 4, 2024
6c125cb
refactor expr
brandon-b-miller Jun 5, 2024
9fb3a2b
add tests for invalid regex
brandon-b-miller Jun 5, 2024
0463688
merge latest/resolve conflicts
brandon-b-miller Jun 6, 2024
7543726
cleanup
brandon-b-miller Jun 6, 2024
42b158f
Address reviews
brandon-b-miller Jun 6, 2024
39b57ca
merge latest/resolve conflicts
brandon-b-miller Jun 10, 2024
e3fb170
refactor logic
brandon-b-miller Jun 10, 2024
da08309
merge latest/resolve
brandon-b-miller Jun 12, 2024
e45fbed
add literal column tests, support it, refactor logic
brandon-b-miller Jun 12, 2024
22e1031
add tests, refactor
brandon-b-miller Jun 12, 2024
4b643a7
pacify mypy
brandon-b-miller Jun 12, 2024
5533e5b
Make type-narrowing a no-op if run with `-O`
wence- Jun 13, 2024
ee42757
Merge branch 'branch-24.08' into cudf-polars-str-contains
wence- Jun 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,13 +642,28 @@ def __init__(
self.options = options
self.name = name
self.children = children
self._validate_input()

def _validate_input(self):
if self.name not in (
pl_expr.StringFunction.Lowercase,
pl_expr.StringFunction.Uppercase,
pl_expr.StringFunction.EndsWith,
pl_expr.StringFunction.StartsWith,
pl_expr.StringFunction.Contains,
):
raise NotImplementedError(f"String function {self.name}")
if self.name == pl_expr.StringFunction.Contains:
literal, strict = self.options
if not literal:
if not strict:
raise NotImplementedError(
"f{strict=} is not supported for regex contains"
)
if not isinstance(self.children[1], Literal):
raise NotImplementedError(
"Regex contains only supports a scalar pattern"
)

def do_evaluate(
self,
Expand All @@ -658,6 +673,28 @@ def do_evaluate(
mapping: Mapping[Expr, Column] | None = None,
) -> Column:
"""Evaluate this expression given a dataframe for context."""
if self.name == pl_expr.StringFunction.Contains:
child, arg = self.children
column = child.evaluate(df, context=context, mapping=mapping)

literal, _ = self.options
if literal:
pat = arg.evaluate(df, context=context, mapping=mapping)
pattern = (
pat.obj_scalar
if pat.is_scalar and pat.obj.size() != column.obj.size()
else pat.obj
)
return Column(plc.strings.find.contains(column.obj, pattern))
else:
if isinstance(arg, Literal):
prog = plc.strings.regex_program.RegexProgram.create(
arg.value.as_py(),
flags=plc.strings.regex_flags.RegexFlags.DEFAULT,
)
return Column(plc.strings.contains.contains_re(column.obj, prog))
else:
raise NotImplementedError
wence- marked this conversation as resolved.
Show resolved Hide resolved
columns = [
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
Expand Down Expand Up @@ -689,7 +726,24 @@ def do_evaluate(
)
)
else:
raise NotImplementedError(f"StringFunction {self.name}")
columns = [
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
]
if self.name == pl_expr.StringFunction.Lowercase:
(column,) = columns
return Column(plc.strings.case.to_lower(column.obj))
elif self.name == pl_expr.StringFunction.Uppercase:
(column,) = columns
return Column(plc.strings.case.to_upper(column.obj))
elif self.name == pl_expr.StringFunction.EndsWith:
column, suffix = columns
return Column(plc.strings.find.ends_with(column.obj, suffix.obj))
elif self.name == pl_expr.StringFunction.StartsWith:
column, suffix = columns
return Column(plc.strings.find.starts_with(column.obj, suffix.obj))
else:
raise NotImplementedError(f"StringFunction {self.name}")


class Sort(Expr):
Expand Down
61 changes: 61 additions & 0 deletions python/cudf_polars/tests/test_string.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from functools import partial

import pytest

import polars as pl

from cudf_polars.callback import execute_with_cudf
from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.fixture
def ldf():
return pl.DataFrame(
{"a": ["AbC", "de", "FGHI", "j", "kLm", "nOPq", None, "RsT", None, "uVw"]}
).lazy()


@pytest.mark.parametrize(
"substr",
[
"A",
"de",
".*",
"^a",
"^A",
"[^a-z]",
"[a-z]{3,}",
"^[A-Z]{2,}",
"j|u",
],
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
)
def test_contains_regex(ldf, substr):
query = ldf.select(pl.col("a").str.contains(substr))
assert_gpu_result_equal(query)


@pytest.mark.parametrize(
"literal", ["A", "de", "FGHI", "j", "kLm", "nOPq", "RsT", "uVw"]
)
def test_contains_literal(ldf, literal):
query = ldf.select(pl.col("a").str.contains(pl.lit(literal), literal=True))
assert_gpu_result_equal(query)


def test_contains_column(ldf):
query = ldf.select(pl.col("a").str.contains(pl.col("a"), literal=True))
assert_gpu_result_equal(query)
wence- marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("pat", ["["])
def test_contains_invalid(ldf, pat):
query = ldf.select(pl.col("a").str.contains(pat))

with pytest.raises(pl.exceptions.ComputeError):
query.collect()
with pytest.raises(pl.exceptions.ComputeError):
query.collect(post_opt_callback=partial(execute_with_cudf, raise_on_fail=True))
Loading