Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove mapfunction nodes that don't exist/aren't supported #15991

Merged
merged 4 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 15 additions & 41 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,18 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
pdf = pl.DataFrame._from_pydf(self.df)
if self.projection is not None:
pdf = pdf.select(self.projection)
# TODO: goes away when libcudf supports large strings
table = pdf.to_arrow()
schema = table.schema
for i, field in enumerate(schema):
# TODO: Nested types
if field.type == pa.large_string():
# TODO: Nested types
# TODO: goes away when libcudf supports large strings
schema = schema.set(i, pa.field(field.name, pa.string()))
elif isinstance(field.type, pa.LargeListType):
# TODO: goes away when libcudf supports large lists
schema = schema.set(
i, pa.field(field.name, pa.list_(field.type.field(0)))
)
table = table.cast(schema)
df = DataFrame.from_table(
plc.interop.from_arrow(table), list(self.schema.keys())
Expand Down Expand Up @@ -850,9 +855,11 @@ class MapFunction(IR):

_NAMES: ClassVar[frozenset[str]] = frozenset(
[
"drop_nulls",
"rechunk",
"merge_sorted",
# libcudf merge is not stable wrt order of inputs, since
# it uses a priority queue to manage the tables it produces.
# See: https://github.com/rapidsai/cudf/issues/16010
# "merge_sorted",
"rename",
"explode",
]
Expand All @@ -869,46 +876,13 @@ def __post_init__(self) -> None:
# polars requires that all to-explode columns have the
# same sub-shapes
raise NotImplementedError("Explode with more than one column")
elif self.name == "merge_sorted":
assert isinstance(self.df, Union)
(key_column,) = self.options
if key_column not in self.df.dfs[0].schema:
raise ValueError(f"Key column {key_column} not found")

def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
if self.name == "merge_sorted":
# merge_sorted operates on Union inputs
# but if we evaluate the Union then we can't unpick the
# pieces, so we dive inside and evaluate the pieces by hand
assert isinstance(self.df, Union)
first, *rest = (c.evaluate(cache=cache) for c in self.df.dfs)
(key_column,) = self.options
if not all(first.column_names == r.column_names for r in rest):
raise ValueError("DataFrame shapes/column names don't match")
# Already validated that key_column is in column names
index = first.column_names.index(key_column)
return DataFrame.from_table(
plc.merge.merge_sorted(
[first.table, *(df.table for df in rest)],
[index],
[plc.types.Order.ASCENDING],
[plc.types.NullOrder.BEFORE],
),
first.column_names,
).sorted_like(first, subset={key_column})
elif self.name == "rechunk":
if self.name == "rechunk":
# No-op in our data model
return self.df.evaluate(cache=cache)
elif self.name == "drop_nulls":
df = self.df.evaluate(cache=cache)
(subset,) = self.options
subset = set(subset)
indices = [i for i, name in enumerate(df.column_names) if name in subset]
return DataFrame.from_table(
plc.stream_compaction.drop_nulls(df.table, indices, len(indices)),
df.column_names,
).sorted_like(df)
# Don't think this appears in a plan tree from python
return self.df.evaluate(cache=cache) # pragma: no cover
elif self.name == "rename":
df = self.df.evaluate(cache=cache)
# final tag is "swapping" which is useful for the
Expand All @@ -924,7 +898,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
plc.lists.explode_outer(df.table, index), df.column_names
).sorted_like(df, subset=subset)
else:
raise AssertionError("Should never be reached")
raise AssertionError("Should never be reached") # pragma: no cover


@dataclasses.dataclass(slots=True)
Expand Down
43 changes: 43 additions & 0 deletions python/cudf_polars/tests/test_mapfunction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest

import polars as pl

from cudf_polars import translate_ir
from cudf_polars.testing.asserts import assert_gpu_result_equal


def test_merge_sorted_raises():
df1 = pl.LazyFrame({"a": [1, 6, 9], "b": [1, -10, 4]})
df2 = pl.LazyFrame({"a": [-1, 5, 11, 20], "b": [2, 7, -4, None]})
df3 = pl.LazyFrame({"a": [-10, 20, 21], "b": [1, 2, 3]})

q = df1.merge_sorted(df2, key="a").merge_sorted(df3, key="a")

with pytest.raises(NotImplementedError):
_ = translate_ir(q._ldf.visit())


def test_explode_multiple_raises():
df = pl.LazyFrame({"a": [[1, 2], [3, 4]], "b": [[5, 6], [7, 8]]})
q = df.explode("a", "b")

with pytest.raises(NotImplementedError):
_ = translate_ir(q._ldf.visit())


@pytest.mark.parametrize("column", ["a", "b"])
def test_explode_single(column):
df = pl.LazyFrame(
{
"a": [[1, 2], [3, 4], None],
"b": [[5, 6], [7, 8], [9, 10]],
"c": [None, 11, 12],
}
)
q = df.explode(column)

assert_gpu_result_equal(q)
Loading