Skip to content

Commit

Permalink
Simplify logic for broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Jun 10, 2024
1 parent 7cbe60a commit 5709719
Showing 1 changed file with 9 additions and 15 deletions.
24 changes: 9 additions & 15 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from cudf_polars.utils import sorting

if TYPE_CHECKING:
from collections.abc import MutableMapping
from collections.abc import MutableMapping, Set
from typing import Literal

from cudf_polars.typing import Schema
Expand Down Expand Up @@ -96,31 +96,25 @@ def broadcast(
``target_length`` is provided and not all columns are length-1
(i.e. ``n != 1``), then ``target_length`` must be equal to ``n``.
"""
lengths = {column.obj.size() for column in columns}
if len(lengths - {1}) > 1:
raise RuntimeError("Mismatching column lengths")
lengths: Set[int] = {column.obj.size() for column in columns}
if lengths == {1}:
if target_length is None:
return list(columns)
nrows = target_length
elif len(lengths) == 1:
if target_length is not None and target_length not in lengths:
raise RuntimeError(
"Cannot broadcast columns of length "
f"{lengths.pop()} to {target_length=}"
)
return list(columns)
else:
(nrows,) = lengths - {1}
if target_length is not None and target_length != nrows:
try:
(nrows,) = lengths - {1}
except ValueError as e:
raise RuntimeError("Mismatching column lengths") from e
if target_length is not None and nrows != target_length:
raise RuntimeError(
f"Cannot broadcast columns of length {nrows} to {target_length=}"
f"Cannot broadcast columns of length {nrows=} to {target_length=}"
)
return [
column
if column.obj.size() != 1
else NamedColumn(
plc.Column.from_scalar(plc.copying.get_element(column.obj, 0), nrows),
plc.Column.from_scalar(column.obj_scalar, nrows),
column.name,
is_sorted=plc.types.Sorted.YES,
order=plc.types.Order.ASCENDING,
Expand Down

0 comments on commit 5709719

Please sign in to comment.