Skip to content

Commit

Permalink
Remove mapfunction nodes that don't exist/aren't supported
Browse files Browse the repository at this point in the history
We can't correctly implemented merge_sorted to match polars because
libcudf's implementation is not stable wrt input order. drop_nulls is
no longer implemented as a MapFunction, but instead a boolean filter.

Finally, add coverage of the mapfunctions we do handle.
  • Loading branch information
wence- committed Jun 12, 2024
1 parent f7ba6ab commit 02c233e
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 41 deletions.
55 changes: 14 additions & 41 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,18 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
pdf = pl.DataFrame._from_pydf(self.df)
if self.projection is not None:
pdf = pdf.select(self.projection)
# TODO: goes away when libcudf supports large strings
table = pdf.to_arrow()
schema = table.schema
for i, field in enumerate(schema):
# TODO: Nested types
if field.type == pa.large_string():
# TODO: Nested types
# TODO: goes away when libcudf supports large strings
schema = schema.set(i, pa.field(field.name, pa.string()))
elif isinstance(field.type, pa.LargeListType):
# TODO: goes away when libcudf supports large lists
schema = schema.set(
i, pa.field(field.name, pa.list_(field.type.field(0)))
)
table = table.cast(schema)
df = DataFrame.from_table(
plc.interop.from_arrow(table), list(self.schema.keys())
Expand Down Expand Up @@ -846,9 +851,10 @@ class MapFunction(IR):

_NAMES: ClassVar[frozenset[str]] = frozenset(
[
"drop_nulls",
"rechunk",
"merge_sorted",
# libcudf merge is not stable wrt order of inputs, since
# it uses a priority queue to manage the tables it produces.
# "merge_sorted",
"rename",
"explode",
]
Expand All @@ -865,46 +871,13 @@ def __post_init__(self) -> None:
# polars requires that all to-explode columns have the
# same sub-shapes
raise NotImplementedError("Explode with more than one column")
elif self.name == "merge_sorted":
assert isinstance(self.df, Union)
(key_column,) = self.options
if key_column not in self.df.dfs[0].schema:
raise ValueError(f"Key column {key_column} not found")

def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
if self.name == "merge_sorted":
# merge_sorted operates on Union inputs
# but if we evaluate the Union then we can't unpick the
# pieces, so we dive inside and evaluate the pieces by hand
assert isinstance(self.df, Union)
first, *rest = (c.evaluate(cache=cache) for c in self.df.dfs)
(key_column,) = self.options
if not all(first.column_names == r.column_names for r in rest):
raise ValueError("DataFrame shapes/column names don't match")
# Already validated that key_column is in column names
index = first.column_names.index(key_column)
return DataFrame.from_table(
plc.merge.merge_sorted(
[first.table, *(df.table for df in rest)],
[index],
[plc.types.Order.ASCENDING],
[plc.types.NullOrder.BEFORE],
),
first.column_names,
).sorted_like(first, subset={key_column})
elif self.name == "rechunk":
if self.name == "rechunk":
# No-op in our data model
return self.df.evaluate(cache=cache)
elif self.name == "drop_nulls":
df = self.df.evaluate(cache=cache)
(subset,) = self.options
subset = set(subset)
indices = [i for i, name in enumerate(df.column_names) if name in subset]
return DataFrame.from_table(
plc.stream_compaction.drop_nulls(df.table, indices, len(indices)),
df.column_names,
).sorted_like(df)
# Don't think this appears in a plan tree from python
return self.df.evaluate(cache=cache) # pragma: no cover
elif self.name == "rename":
df = self.df.evaluate(cache=cache)
# final tag is "swapping" which is useful for the
Expand All @@ -920,7 +893,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
plc.lists.explode_outer(df.table, index), df.column_names
).sorted_like(df, subset=subset)
else:
raise AssertionError("Should never be reached")
raise AssertionError("Should never be reached") # pragma: no cover


@dataclasses.dataclass(slots=True)
Expand Down
43 changes: 43 additions & 0 deletions python/cudf_polars/tests/test_mapfunction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest

import polars as pl

from cudf_polars import translate_ir
from cudf_polars.testing.asserts import assert_gpu_result_equal


def test_merge_sorted_raises():
df1 = pl.LazyFrame({"a": [1, 6, 9], "b": [1, -10, 4]})
df2 = pl.LazyFrame({"a": [-1, 5, 11, 20], "b": [2, 7, -4, None]})
df3 = pl.LazyFrame({"a": [-10, 20, 21], "b": [1, 2, 3]})

q = df1.merge_sorted(df2, key="a").merge_sorted(df3, key="a")

with pytest.raises(NotImplementedError):
_ = translate_ir(q._ldf.visit())


def test_explode_multiple_raises():
df = pl.LazyFrame({"a": [[1, 2], [3, 4]], "b": [[5, 6], [7, 8]]})
q = df.explode("a", "b")

with pytest.raises(NotImplementedError):
_ = translate_ir(q._ldf.visit())


@pytest.mark.parametrize("column", ["a", "b"])
def test_explode_single(column):
df = pl.LazyFrame(
{
"a": [[1, 2], [3, 4], None],
"b": [[5, 6], [7, 8], [9, 10]],
"c": [None, 11, 12],
}
)
q = df.explode(column)

assert_gpu_result_equal(q)

0 comments on commit 02c233e

Please sign in to comment.