Skip to content

Commit

Permalink
Merge branch 'branch-24.08' into 16046
Browse files Browse the repository at this point in the history
  • Loading branch information
galipremsagar authored Jun 17, 2024
2 parents 7ef3187 + e9ebdea commit 16b0062
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 121 deletions.
36 changes: 9 additions & 27 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,13 +688,12 @@ def do_evaluate(
else pat.obj
)
return Column(plc.strings.find.contains(column.obj, pattern))
else:
assert isinstance(arg, Literal)
prog = plc.strings.regex_program.RegexProgram.create(
arg.value.as_py(),
flags=plc.strings.regex_flags.RegexFlags.DEFAULT,
)
return Column(plc.strings.contains.contains_re(column.obj, prog))
assert isinstance(arg, Literal)
prog = plc.strings.regex_program.RegexProgram.create(
arg.value.as_py(),
flags=plc.strings.regex_flags.RegexFlags.DEFAULT,
)
return Column(plc.strings.contains.contains_re(column.obj, prog))
columns = [
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
Expand Down Expand Up @@ -725,26 +724,9 @@ def do_evaluate(
else prefix.obj,
)
)
else:
columns = [
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
]
if self.name == pl_expr.StringFunction.Lowercase:
(column,) = columns
return Column(plc.strings.case.to_lower(column.obj))
elif self.name == pl_expr.StringFunction.Uppercase:
(column,) = columns
return Column(plc.strings.case.to_upper(column.obj))
elif self.name == pl_expr.StringFunction.EndsWith:
column, suffix = columns
return Column(plc.strings.find.ends_with(column.obj, suffix.obj))
elif self.name == pl_expr.StringFunction.StartsWith:
column, suffix = columns
return Column(plc.strings.find.starts_with(column.obj, suffix.obj))
raise NotImplementedError(
f"StringFunction {self.name}"
) # pragma: no cover; handled by init raising
raise NotImplementedError(
f"StringFunction {self.name}"
) # pragma: no cover; handled by init raising


class Sort(Expr):
Expand Down
10 changes: 10 additions & 0 deletions python/cudf_polars/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest


@pytest.fixture(params=[False, True], ids=["no_nulls", "nulls"], scope="session")
def with_nulls(request):
return request.param
5 changes: 0 additions & 5 deletions python/cudf_polars/tests/expressions/test_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ def dtype(request):
return request.param


@pytest.fixture(params=[False, True], ids=["no-nulls", "with-nulls"])
def with_nulls(request):
return request.param


@pytest.fixture(
params=[
False,
Expand Down
9 changes: 2 additions & 7 deletions python/cudf_polars/tests/expressions/test_distinct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@
from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.fixture(params=[False, True], ids=["no-nulls", "nulls"])
def nullable(request):
return request.param


@pytest.fixture(
params=["is_first_distinct", "is_last_distinct", "is_unique", "is_duplicated"]
)
Expand All @@ -22,9 +17,9 @@ def op(request):


@pytest.fixture
def df(nullable):
def df(with_nulls):
values: list[int | None] = [1, 2, 3, 1, 1, 7, 3, 2, 7, 8, 1]
if nullable:
if with_nulls:
values[1] = None
values[4] = None
return pl.LazyFrame({"a": values})
Expand Down
5 changes: 0 additions & 5 deletions python/cudf_polars/tests/expressions/test_numeric_binops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ def rtype(request):
return request.param


@pytest.fixture(params=[False, True], ids=["no_nulls", "nulls"])
def with_nulls(request):
return request.param


@pytest.fixture(
params=[
pl.Expr.eq,
Expand Down
97 changes: 81 additions & 16 deletions python/cudf_polars/tests/expressions/test_stringfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,39 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from functools import partial

import pytest

import polars as pl

from cudf_polars import translate_ir
from cudf_polars import execute_with_cudf, translate_ir
from cudf_polars.testing.asserts import assert_gpu_result_equal


def test_supported_stringfunction_expression():
ldf = pl.LazyFrame(
{
"a": ["a", "b", "cdefg", "h", "Wıth ünιcοde"], # noqa: RUF001
"b": [0, 3, 1, -1, None],
}
)
@pytest.fixture
def ldf(with_nulls):
a = [
"AbC",
"de",
"FGHI",
"j",
"kLm",
"nOPq",
"",
"RsT",
"sada",
"uVw",
"h",
"Wıth ünιcοde", # noqa: RUF001
]
if with_nulls:
a[4] = None
a[-3] = None
return pl.LazyFrame({"a": a, "b": range(len(a))})


def test_supported_stringfunction_expression(ldf):
query = ldf.select(
pl.col("a").str.starts_with("Z"),
pl.col("a").str.ends_with("h").alias("endswith_h"),
Expand All @@ -27,15 +44,63 @@ def test_supported_stringfunction_expression():
assert_gpu_result_equal(query)


def test_unsupported_stringfunction():
ldf = pl.LazyFrame(
{
"a": ["a", "b", "cdefg", "h", "Wıth ünιcοde"], # noqa: RUF001
"b": [0, 3, 1, -1, None],
}
)

def test_unsupported_stringfunction(ldf):
q = ldf.select(pl.col("a").str.count_matches("e", literal=True))

with pytest.raises(NotImplementedError):
_ = translate_ir(q._ldf.visit())


def test_contains_re_non_strict_raises(ldf):
q = ldf.select(pl.col("a").str.contains(".", strict=False))

with pytest.raises(NotImplementedError):
_ = translate_ir(q._ldf.visit())


def test_contains_re_non_literal_raises(ldf):
q = ldf.select(pl.col("a").str.contains(pl.col("b"), literal=False))

with pytest.raises(NotImplementedError):
_ = translate_ir(q._ldf.visit())


@pytest.mark.parametrize(
"substr",
[
"A",
"de",
".*",
"^a",
"^A",
"[^a-z]",
"[a-z]{3,}",
"^[A-Z]{2,}",
"j|u",
],
)
def test_contains_regex(ldf, substr):
query = ldf.select(pl.col("a").str.contains(substr))
assert_gpu_result_equal(query)


@pytest.mark.parametrize(
"literal", ["A", "de", "FGHI", "j", "kLm", "nOPq", "RsT", "uVw"]
)
def test_contains_literal(ldf, literal):
query = ldf.select(pl.col("a").str.contains(pl.lit(literal), literal=True))
assert_gpu_result_equal(query)


def test_contains_column(ldf):
query = ldf.select(pl.col("a").str.contains(pl.col("a"), literal=True))
assert_gpu_result_equal(query)


def test_contains_invalid(ldf):
query = ldf.select(pl.col("a").str.contains("["))

with pytest.raises(pl.exceptions.ComputeError):
query.collect()
with pytest.raises(pl.exceptions.ComputeError):
query.collect(post_opt_callback=partial(execute_with_cudf, raise_on_fail=True))
61 changes: 0 additions & 61 deletions python/cudf_polars/tests/test_string.py

This file was deleted.

0 comments on commit 16b0062

Please sign in to comment.