From e9ebdea49d24f645a6ca5ff6d79e0525a114f5fc Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 17 Jun 2024 12:29:54 +0100 Subject: [PATCH] Delete unused code from stringfunction evaluator (#16032) When introducing the handling of regex contains, we replicated the handlers for some other supported string functions. This means we can delete some code. Additionally, migrate the contains tests to live with the other string function tests, and add coverage of exceptional cases. Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - https://github.com/brandon-b-miller URL: https://github.com/rapidsai/cudf/pull/16032 --- python/cudf_polars/cudf_polars/dsl/expr.py | 36 ++----- python/cudf_polars/tests/conftest.py | 10 ++ .../cudf_polars/tests/expressions/test_agg.py | 5 - .../tests/expressions/test_distinct.py | 9 +- .../tests/expressions/test_numeric_binops.py | 5 - .../tests/expressions/test_stringfunction.py | 97 ++++++++++++++++--- python/cudf_polars/tests/test_string.py | 61 ------------ 7 files changed, 102 insertions(+), 121 deletions(-) create mode 100644 python/cudf_polars/tests/conftest.py delete mode 100644 python/cudf_polars/tests/test_string.py diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index 03c1db68dbd..0605bba6642 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -688,13 +688,12 @@ def do_evaluate( else pat.obj ) return Column(plc.strings.find.contains(column.obj, pattern)) - else: - assert isinstance(arg, Literal) - prog = plc.strings.regex_program.RegexProgram.create( - arg.value.as_py(), - flags=plc.strings.regex_flags.RegexFlags.DEFAULT, - ) - return Column(plc.strings.contains.contains_re(column.obj, prog)) + assert isinstance(arg, Literal) + prog = plc.strings.regex_program.RegexProgram.create( + arg.value.as_py(), + flags=plc.strings.regex_flags.RegexFlags.DEFAULT, + ) + return Column(plc.strings.contains.contains_re(column.obj, prog)) columns = [ child.evaluate(df, context=context, mapping=mapping) for child in self.children @@ -725,26 +724,9 @@ def do_evaluate( else prefix.obj, ) ) - else: - columns = [ - child.evaluate(df, context=context, mapping=mapping) - for child in self.children - ] - if self.name == pl_expr.StringFunction.Lowercase: - (column,) = columns - return Column(plc.strings.case.to_lower(column.obj)) - elif self.name == pl_expr.StringFunction.Uppercase: - (column,) = columns - return Column(plc.strings.case.to_upper(column.obj)) - elif self.name == pl_expr.StringFunction.EndsWith: - column, suffix = columns - return Column(plc.strings.find.ends_with(column.obj, suffix.obj)) - elif self.name == pl_expr.StringFunction.StartsWith: - column, suffix = columns - return Column(plc.strings.find.starts_with(column.obj, suffix.obj)) - raise NotImplementedError( - f"StringFunction {self.name}" - ) # pragma: no cover; handled by init raising + raise NotImplementedError( + f"StringFunction {self.name}" + ) # pragma: no cover; handled by init raising class Sort(Expr): diff --git a/python/cudf_polars/tests/conftest.py b/python/cudf_polars/tests/conftest.py new file mode 100644 index 00000000000..9bbce6bc080 --- /dev/null +++ b/python/cudf_polars/tests/conftest.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import pytest + + +@pytest.fixture(params=[False, True], ids=["no_nulls", "nulls"], scope="session") +def with_nulls(request): + return request.param diff --git a/python/cudf_polars/tests/expressions/test_agg.py b/python/cudf_polars/tests/expressions/test_agg.py index 79018c80bf3..b044bbb2885 100644 --- a/python/cudf_polars/tests/expressions/test_agg.py +++ b/python/cudf_polars/tests/expressions/test_agg.py @@ -20,11 +20,6 @@ def dtype(request): return request.param -@pytest.fixture(params=[False, True], ids=["no-nulls", "with-nulls"]) -def with_nulls(request): - return request.param - - @pytest.fixture( params=[ False, diff --git a/python/cudf_polars/tests/expressions/test_distinct.py b/python/cudf_polars/tests/expressions/test_distinct.py index 22865a7ce22..143dd7e9f0f 100644 --- a/python/cudf_polars/tests/expressions/test_distinct.py +++ b/python/cudf_polars/tests/expressions/test_distinct.py @@ -9,11 +9,6 @@ from cudf_polars.testing.asserts import assert_gpu_result_equal -@pytest.fixture(params=[False, True], ids=["no-nulls", "nulls"]) -def nullable(request): - return request.param - - @pytest.fixture( params=["is_first_distinct", "is_last_distinct", "is_unique", "is_duplicated"] ) @@ -22,9 +17,9 @@ def op(request): @pytest.fixture -def df(nullable): +def df(with_nulls): values: list[int | None] = [1, 2, 3, 1, 1, 7, 3, 2, 7, 8, 1] - if nullable: + if with_nulls: values[1] = None values[4] = None return pl.LazyFrame({"a": values}) diff --git a/python/cudf_polars/tests/expressions/test_numeric_binops.py b/python/cudf_polars/tests/expressions/test_numeric_binops.py index 548aebf0875..7eefc59d927 100644 --- a/python/cudf_polars/tests/expressions/test_numeric_binops.py +++ b/python/cudf_polars/tests/expressions/test_numeric_binops.py @@ -29,11 +29,6 @@ def rtype(request): return request.param -@pytest.fixture(params=[False, True], ids=["no_nulls", "nulls"]) -def with_nulls(request): - return request.param - - @pytest.fixture( params=[ pl.Expr.eq, diff --git a/python/cudf_polars/tests/expressions/test_stringfunction.py b/python/cudf_polars/tests/expressions/test_stringfunction.py index 198f35d376b..3c498fe7286 100644 --- a/python/cudf_polars/tests/expressions/test_stringfunction.py +++ b/python/cudf_polars/tests/expressions/test_stringfunction.py @@ -2,22 +2,39 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from functools import partial + import pytest import polars as pl -from cudf_polars import translate_ir +from cudf_polars import execute_with_cudf, translate_ir from cudf_polars.testing.asserts import assert_gpu_result_equal -def test_supported_stringfunction_expression(): - ldf = pl.LazyFrame( - { - "a": ["a", "b", "cdefg", "h", "Wıth ünιcοde"], # noqa: RUF001 - "b": [0, 3, 1, -1, None], - } - ) +@pytest.fixture +def ldf(with_nulls): + a = [ + "AbC", + "de", + "FGHI", + "j", + "kLm", + "nOPq", + "", + "RsT", + "sada", + "uVw", + "h", + "Wıth ünιcοde", # noqa: RUF001 + ] + if with_nulls: + a[4] = None + a[-3] = None + return pl.LazyFrame({"a": a, "b": range(len(a))}) + +def test_supported_stringfunction_expression(ldf): query = ldf.select( pl.col("a").str.starts_with("Z"), pl.col("a").str.ends_with("h").alias("endswith_h"), @@ -27,15 +44,63 @@ def test_supported_stringfunction_expression(): assert_gpu_result_equal(query) -def test_unsupported_stringfunction(): - ldf = pl.LazyFrame( - { - "a": ["a", "b", "cdefg", "h", "Wıth ünιcοde"], # noqa: RUF001 - "b": [0, 3, 1, -1, None], - } - ) - +def test_unsupported_stringfunction(ldf): q = ldf.select(pl.col("a").str.count_matches("e", literal=True)) with pytest.raises(NotImplementedError): _ = translate_ir(q._ldf.visit()) + + +def test_contains_re_non_strict_raises(ldf): + q = ldf.select(pl.col("a").str.contains(".", strict=False)) + + with pytest.raises(NotImplementedError): + _ = translate_ir(q._ldf.visit()) + + +def test_contains_re_non_literal_raises(ldf): + q = ldf.select(pl.col("a").str.contains(pl.col("b"), literal=False)) + + with pytest.raises(NotImplementedError): + _ = translate_ir(q._ldf.visit()) + + +@pytest.mark.parametrize( + "substr", + [ + "A", + "de", + ".*", + "^a", + "^A", + "[^a-z]", + "[a-z]{3,}", + "^[A-Z]{2,}", + "j|u", + ], +) +def test_contains_regex(ldf, substr): + query = ldf.select(pl.col("a").str.contains(substr)) + assert_gpu_result_equal(query) + + +@pytest.mark.parametrize( + "literal", ["A", "de", "FGHI", "j", "kLm", "nOPq", "RsT", "uVw"] +) +def test_contains_literal(ldf, literal): + query = ldf.select(pl.col("a").str.contains(pl.lit(literal), literal=True)) + assert_gpu_result_equal(query) + + +def test_contains_column(ldf): + query = ldf.select(pl.col("a").str.contains(pl.col("a"), literal=True)) + assert_gpu_result_equal(query) + + +def test_contains_invalid(ldf): + query = ldf.select(pl.col("a").str.contains("[")) + + with pytest.raises(pl.exceptions.ComputeError): + query.collect() + with pytest.raises(pl.exceptions.ComputeError): + query.collect(post_opt_callback=partial(execute_with_cudf, raise_on_fail=True)) diff --git a/python/cudf_polars/tests/test_string.py b/python/cudf_polars/tests/test_string.py deleted file mode 100644 index f1a080d040f..00000000000 --- a/python/cudf_polars/tests/test_string.py +++ /dev/null @@ -1,61 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations - -from functools import partial - -import pytest - -import polars as pl - -from cudf_polars.callback import execute_with_cudf -from cudf_polars.testing.asserts import assert_gpu_result_equal - - -@pytest.fixture -def ldf(): - return pl.DataFrame( - {"a": ["AbC", "de", "FGHI", "j", "kLm", "nOPq", None, "RsT", None, "uVw"]} - ).lazy() - - -@pytest.mark.parametrize( - "substr", - [ - "A", - "de", - ".*", - "^a", - "^A", - "[^a-z]", - "[a-z]{3,}", - "^[A-Z]{2,}", - "j|u", - ], -) -def test_contains_regex(ldf, substr): - query = ldf.select(pl.col("a").str.contains(substr)) - assert_gpu_result_equal(query) - - -@pytest.mark.parametrize( - "literal", ["A", "de", "FGHI", "j", "kLm", "nOPq", "RsT", "uVw"] -) -def test_contains_literal(ldf, literal): - query = ldf.select(pl.col("a").str.contains(pl.lit(literal), literal=True)) - assert_gpu_result_equal(query) - - -def test_contains_column(ldf): - query = ldf.select(pl.col("a").str.contains(pl.col("a"), literal=True)) - assert_gpu_result_equal(query) - - -@pytest.mark.parametrize("pat", ["["]) -def test_contains_invalid(ldf, pat): - query = ldf.select(pl.col("a").str.contains(pat)) - - with pytest.raises(pl.exceptions.ComputeError): - query.collect() - with pytest.raises(pl.exceptions.ComputeError): - query.collect(post_opt_callback=partial(execute_with_cudf, raise_on_fail=True))