Skip to content

Commit

Permalink
Some small fixes in cudf-polars (#16191)
Browse files Browse the repository at this point in the history
These catch a few more edge cases.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Matthew Roeschke (https://github.com/mroeschke)

URL: #16191
  • Loading branch information
wence- authored Jul 4, 2024
1 parent aa4033c commit 5f57bc9
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 3 deletions.
13 changes: 11 additions & 2 deletions python/cudf_polars/cudf_polars/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ def _callback(
return ir.evaluate(cache={}).to_polars()


def execute_with_cudf(nt: NodeTraverser, *, raise_on_fail: bool = False) -> None:
def execute_with_cudf(
nt: NodeTraverser,
*,
raise_on_fail: bool = False,
exception: type[Exception] | tuple[type[Exception], ...] = Exception,
) -> None:
"""
A post optimization callback that attempts to execute the plan with cudf.
Expand All @@ -47,11 +52,15 @@ def execute_with_cudf(nt: NodeTraverser, *, raise_on_fail: bool = False) -> None
Should conversion raise an exception rather than continuing
without setting a callback.
exception
Optional exception, or tuple of exceptions, to catch during
translation. Defaults to ``Exception``.
The NodeTraverser is mutated if the libcudf executor can handle the plan.
"""
try:
with nvtx.annotate(message="ConvertIR", domain="cudf_polars"):
nt.set_udf(partial(_callback, translate_ir(nt)))
except NotImplementedError:
except exception:
if raise_on_fail:
raise
6 changes: 5 additions & 1 deletion python/cudf_polars/cudf_polars/containers/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import itertools
from functools import cached_property
from typing import TYPE_CHECKING, cast

Expand Down Expand Up @@ -160,7 +161,10 @@ def with_columns(self, columns: Sequence[NamedColumn]) -> Self:
-----
If column names overlap, newer names replace older ones.
"""
return type(self)([*self.columns, *columns])
columns = list(
{c.name: c for c in itertools.chain(self.columns, columns)}.values()
)
return type(self)(columns)

def discard_columns(self, names: Set[str]) -> Self:
"""Drop columns by name."""
Expand Down
2 changes: 2 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def broadcast(
``target_length`` is provided and not all columns are length-1
(i.e. ``n != 1``), then ``target_length`` must be equal to ``n``.
"""
if len(columns) == 0:
return []
lengths: set[int] = {column.obj.size() for column in columns}
if lengths == {1}:
if target_length is None:
Expand Down
9 changes: 9 additions & 0 deletions python/cudf_polars/tests/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,12 @@ def test_concat_vertical():
q = pl.concat([ldf, ldf2], how="vertical")

assert_gpu_result_equal(q)


def test_concat_diagonal_empty():
df1 = pl.LazyFrame()
df2 = pl.LazyFrame({"a": [1, 2]})

q = pl.concat([df1, df2], how="diagonal_relaxed")

assert_gpu_result_equal(q, collect_kwargs={"no_optimization": True})

0 comments on commit 5f57bc9

Please sign in to comment.