Skip to content

Commit

Permalink
cudf-polars string slicing (#16082)
Browse files Browse the repository at this point in the history
This PR plumbs the libcudf/pylibcudf `slice_strings` function through to cudf-polars. Depends on #15988

Authors:
  - https://github.com/brandon-b-miller
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: #16082
  • Loading branch information
brandon-b-miller authored Jul 3, 2024
1 parent 25febbc commit 3aedeea
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
36 changes: 36 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ def _validate_input(self):
pl_expr.StringFunction.EndsWith,
pl_expr.StringFunction.StartsWith,
pl_expr.StringFunction.Contains,
pl_expr.StringFunction.Slice,
):
raise NotImplementedError(f"String function {self.name}")
if self.name == pl_expr.StringFunction.Contains:
Expand All @@ -716,6 +717,11 @@ def _validate_input(self):
raise NotImplementedError(
"Regex contains only supports a scalar pattern"
)
elif self.name == pl_expr.StringFunction.Slice:
if not all(isinstance(child, Literal) for child in self.children[1:]):
raise NotImplementedError(
"Slice only supports literal start and stop values"
)

def do_evaluate(
self,
Expand Down Expand Up @@ -744,6 +750,36 @@ def do_evaluate(
flags=plc.strings.regex_flags.RegexFlags.DEFAULT,
)
return Column(plc.strings.contains.contains_re(column.obj, prog))
elif self.name == pl_expr.StringFunction.Slice:
child, expr_offset, expr_length = self.children
assert isinstance(expr_offset, Literal)
assert isinstance(expr_length, Literal)

column = child.evaluate(df, context=context, mapping=mapping)
# libcudf slices via [start,stop).
# polars slices with offset + length where start == offset
# stop = start + length. Negative values for start look backward
# from the last element of the string. If the end index would be
# below zero, an empty string is returned.
# Do this maths on the host
start = expr_offset.value.as_py()
length = expr_length.value.as_py()

if length == 0:
stop = start
else:
# No length indicates a scan to the end
# The libcudf equivalent is a null stop
stop = start + length if length else None
if length and start < 0 and length >= -start:
stop = None
return Column(
plc.strings.slice.slice_strings(
column.obj,
plc.interop.from_arrow(pa.scalar(start, type=pa.int32())),
plc.interop.from_arrow(pa.scalar(stop, type=pa.int32())),
)
)
columns = [
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
Expand Down
46 changes: 46 additions & 0 deletions python/cudf_polars/tests/expressions/test_stringfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,30 @@ def ldf(with_nulls):
return pl.LazyFrame({"a": a, "b": range(len(a))})


slice_cases = [
(1, 3),
(0, 3),
(0, 0),
(-3, 1),
(-100, 5),
(1, 1),
(100, 100),
(-3, 4),
(-3, 3),
]


@pytest.fixture(params=slice_cases)
def slice_column_data(ldf, request):
start, length = request.param
if length:
return ldf.with_columns(
pl.lit(start).alias("start"), pl.lit(length).alias("length")
)
else:
return ldf.with_columns(pl.lit(start).alias("start"))


def test_supported_stringfunction_expression(ldf):
query = ldf.select(
pl.col("a").str.starts_with("Z"),
Expand Down Expand Up @@ -104,3 +128,25 @@ def test_contains_invalid(ldf):
query.collect()
with pytest.raises(pl.exceptions.ComputeError):
query.collect(post_opt_callback=partial(execute_with_cudf, raise_on_fail=True))


@pytest.mark.parametrize("offset", [1, -1, 0, 100, -100])
def test_slice_scalars_offset(ldf, offset):
query = ldf.select(pl.col("a").str.slice(offset))
assert_gpu_result_equal(query)


@pytest.mark.parametrize("offset,length", slice_cases)
def test_slice_scalars_length_and_offset(ldf, offset, length):
query = ldf.select(pl.col("a").str.slice(offset, length))
assert_gpu_result_equal(query)


def test_slice_column(slice_column_data):
if "length" in slice_column_data.collect_schema():
query = slice_column_data.select(
pl.col("a").str.slice(pl.col("start"), pl.col("length"))
)
else:
query = slice_column_data.select(pl.col("a").str.slice(pl.col("start")))
assert_ir_translation_raises(query, NotImplementedError)

0 comments on commit 3aedeea

Please sign in to comment.