From 124ea6971d0795ae6c72e712d2f9e2439cc1a1b4 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 6 Jun 2024 09:38:54 +0000 Subject: [PATCH 01/13] Fix typo --- python/cudf_polars/cudf_polars/utils/sorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cudf_polars/cudf_polars/utils/sorting.py b/python/cudf_polars/cudf_polars/utils/sorting.py index d35459db20d..24fd449dd88 100644 --- a/python/cudf_polars/cudf_polars/utils/sorting.py +++ b/python/cudf_polars/cudf_polars/utils/sorting.py @@ -30,7 +30,7 @@ def sort_order( Returns ------- - tuple of column_order and null_precendence + tuple of column_order and null_precedence suitable for passing to sort routines """ # Mimicking polars broadcast handling of descending From a838c78dec0976578589b9848381ce9a8e273c1d Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 6 Jun 2024 14:13:51 +0000 Subject: [PATCH 02/13] No more name rewording --- python/cudf_polars/cudf_polars/dsl/expr.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index 6d9435ce373..f9bd22f8b53 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -875,9 +875,6 @@ def __init__( self, dtype: plc.DataType, name: str, options: Any, value: Expr ) -> None: super().__init__(dtype) - # TODO: fix polars name - if name == "nunique": - name = "n_unique" self.name = name self.options = options self.children = (value,) From 32a78eaafba784076b6136b8a90d4ea76d86a39e Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 6 Jun 2024 14:37:13 +0000 Subject: [PATCH 03/13] Use new should_broadcast option for Select and HStack --- python/cudf_polars/cudf_polars/dsl/ir.py | 21 +++++++++++++++---- .../cudf_polars/cudf_polars/dsl/translate.py | 4 ++-- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 665bbe5be41..4a954ea89d5 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -279,12 +279,16 @@ class Select(IR): """Input dataframe.""" expr: list[expr.NamedExpr] """List of expressions to evaluate to form the new dataframe.""" + should_broadcast: bool + """Should columns be broadcast?""" def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) # Handle any broadcasting - columns = broadcast(*(e.evaluate(df) for e in self.expr)) + columns = [e.evaluate(df) for e in self.expr] + if self.should_broadcast: + columns = broadcast(*columns) return DataFrame(columns) @@ -587,15 +591,24 @@ class HStack(IR): """Input dataframe.""" columns: list[expr.NamedExpr] """List of expressions to produce new columns.""" + should_broadcast: bool + """Should columns be broadcast?""" def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) columns = [c.evaluate(df) for c in self.columns] - # TODO: a bit of a hack, should inherit the should_broadcast - # property of polars' ProjectionOptions on the hstack node. - if not any(e.name.startswith("__POLARS_CSER_0x") for e in self.columns): + if self.should_broadcast: columns = broadcast(*columns, target_length=df.num_rows) + else: + # Polars ensures this is true, but let's make sure nothing + # went wrong. In this case, the parent node is a + # guaranteed to be a Select which will take care of making + # sure that everything is the same length. The result + # table that might have mismatching column lengths will + # never be turned into a pylibcudf Table with all columns + # by the Select, which is why this is safe. + assert all(e.name.startswith("__POLARS_CSER_0x") for e in self.columns) return df.with_columns(columns) diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 38107023365..adde3b1a9dc 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -122,7 +122,7 @@ def _( with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) exprs = [translate_named_expr(visitor, n=e) for e in node.expr] - return ir.Select(schema, inp, exprs) + return ir.Select(schema, inp, exprs, node.should_broadcast) @_translate_ir.register @@ -166,7 +166,7 @@ def _( with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) exprs = [translate_named_expr(visitor, n=e) for e in node.exprs] - return ir.HStack(schema, inp, exprs) + return ir.HStack(schema, inp, exprs, node.should_broadcast) @_translate_ir.register From 9a87eff4b5dd3d41f687ff7ec3b07c417304cbb0 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 7 Jun 2024 11:37:26 +0000 Subject: [PATCH 04/13] Expose __version__ in top-level __init__.py --- python/cudf_polars/cudf_polars/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/cudf_polars/cudf_polars/__init__.py b/python/cudf_polars/cudf_polars/__init__.py index b19a282129a..41d06f8631b 100644 --- a/python/cudf_polars/cudf_polars/__init__.py +++ b/python/cudf_polars/cudf_polars/__init__.py @@ -10,7 +10,13 @@ from __future__ import annotations +from cudf_polars._version import __git_commit__, __version__ from cudf_polars.callback import execute_with_cudf from cudf_polars.dsl.translate import translate_ir -__all__: list[str] = ["execute_with_cudf", "translate_ir"] +__all__: list[str] = [ + "execute_with_cudf", + "translate_ir", + "__git_commit__", + "__version__", +] From c23a0ace13f0cb3d90067daefbc32d67e8e2c1ca Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 7 Jun 2024 11:35:42 +0000 Subject: [PATCH 05/13] No need for setuptools.packages.find config --- python/cudf_polars/pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/cudf_polars/pyproject.toml b/python/cudf_polars/pyproject.toml index 2faf8c3193f..11178a3be74 100644 --- a/python/cudf_polars/pyproject.toml +++ b/python/cudf_polars/pyproject.toml @@ -49,9 +49,6 @@ license-files = ["LICENSE"] [tool.setuptools.dynamic] version = {file = "cudf_polars/VERSION"} -[tool.setuptools.packages.find] -exclude = ["*tests*"] - [tool.pytest.ini_options] xfail_strict = true From 783094bf03d107cbbf18b75d75acedcca8e04059 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 7 Jun 2024 10:08:19 +0000 Subject: [PATCH 06/13] Better docstring in NamedExpr.evaluate --- python/cudf_polars/cudf_polars/dsl/expr.py | 24 ++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index f9bd22f8b53..ac20e36b151 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -222,7 +222,7 @@ def evaluate( Notes ----- - Individual subclasses should implement :meth:`do_allocate`, + Individual subclasses should implement :meth:`do_evaluate`, this method provides logic to handle lookups in the substitution mapping. @@ -319,7 +319,27 @@ def evaluate( context: ExecutionContext = ExecutionContext.FRAME, mapping: Mapping[Expr, Column] | None = None, ) -> NamedColumn: - """Evaluate this expression given a dataframe for context.""" + """ + Evaluate this expression given a dataframe for context. + + Parameters + ---------- + df + DataFrame providing context + context + Execution context + mapping + Substitution mapping + + Returns + ------- + NamedColumn attaching a name to an evaluated Column + + See Also + -------- + :meth:`Expr.evaluate` for details, this function just adds the + name to a column produced from an expression. + """ obj = self.value.evaluate(df, context=context, mapping=mapping) if isinstance(obj, Scalar): return NamedColumn( From 71b5acc093088e513d3509bb45e294f1a7fceff5 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 5 Jun 2024 11:05:48 +0000 Subject: [PATCH 07/13] Cull usage of Scalar containers Now we always return columns and, where usage of a scalar might be correct (for example broadcasting in binops), we check if the column is "actually" a scalar and extract it. This is slightly annoying because we have to introspect things in various places. But without changing libcudf to treat length-1 columns as always broadcastable like scalars this is, I think, the best we can do. --- .../cudf_polars/containers/__init__.py | 3 +- .../cudf_polars/containers/column.py | 25 ++++++ .../cudf_polars/containers/dataframe.py | 6 +- .../cudf_polars/containers/scalar.py | 23 ----- python/cudf_polars/cudf_polars/dsl/expr.py | 87 ++++++++++--------- python/cudf_polars/cudf_polars/dsl/ir.py | 46 +++++++++- 6 files changed, 117 insertions(+), 73 deletions(-) delete mode 100644 python/cudf_polars/cudf_polars/containers/scalar.py diff --git a/python/cudf_polars/cudf_polars/containers/__init__.py b/python/cudf_polars/cudf_polars/containers/__init__.py index ee69e748eb5..06bb08953f1 100644 --- a/python/cudf_polars/cudf_polars/containers/__init__.py +++ b/python/cudf_polars/cudf_polars/containers/__init__.py @@ -5,8 +5,7 @@ from __future__ import annotations -__all__: list[str] = ["DataFrame", "Column", "NamedColumn", "Scalar"] +__all__: list[str] = ["DataFrame", "Column", "NamedColumn"] from cudf_polars.containers.column import Column, NamedColumn from cudf_polars.containers.dataframe import DataFrame -from cudf_polars.containers.scalar import Scalar diff --git a/python/cudf_polars/cudf_polars/containers/column.py b/python/cudf_polars/cudf_polars/containers/column.py index 575d15d3ece..10b0ed78b07 100644 --- a/python/cudf_polars/cudf_polars/containers/column.py +++ b/python/cudf_polars/cudf_polars/containers/column.py @@ -23,6 +23,7 @@ class Column: is_sorted: plc.types.Sorted order: plc.types.Order null_order: plc.types.NullOrder + is_scalar: bool def __init__( self, @@ -33,10 +34,32 @@ def __init__( null_order: plc.types.NullOrder = plc.types.NullOrder.BEFORE, ): self.obj = column + self.is_scalar = self.obj.size() == 1 + if self.is_scalar: + is_sorted = plc.types.Sorted.YES self.is_sorted = is_sorted self.order = order self.null_order = null_order + @property + def obj_scalar(self) -> plc.Scalar: + """ + View the column object as a pylibcudf Scalar. + + Returns + ------- + pylibcudf Scalar object. + + Raises + ------ + RuntimeError if the column is not length-1. + """ + if not self.is_scalar: + raise RuntimeError( + f"Cannot convert a column of length {self.obj.size()} to scalar" + ) + return plc.copying.get_element(self.obj, 0) + def sorted_like(self, like: Column, /) -> Self: """ Copy sortedness properties from a column onto self. @@ -81,6 +104,8 @@ def set_sorted( ------- Self with metadata set. """ + if self.obj.size() == 1: + is_sorted = plc.types.Sorted.YES self.is_sorted = is_sorted self.order = order self.null_order = null_order diff --git a/python/cudf_polars/cudf_polars/containers/dataframe.py b/python/cudf_polars/cudf_polars/containers/dataframe.py index ac7e748095e..7039fcaf077 100644 --- a/python/cudf_polars/cudf_polars/containers/dataframe.py +++ b/python/cudf_polars/cudf_polars/containers/dataframe.py @@ -32,7 +32,7 @@ class DataFrame: """A representation of a dataframe.""" columns: list[NamedColumn] - table: plc.Table | None + table: plc.Table def __init__(self, columns: Sequence[NamedColumn]) -> None: self.columns = list(columns) @@ -41,7 +41,7 @@ def __init__(self, columns: Sequence[NamedColumn]) -> None: def copy(self) -> Self: """Return a shallow copy of self.""" - return type(self)(self.columns) + return type(self)([c.copy() for c in self.columns]) def to_polars(self) -> pl.DataFrame: """Convert to a polars DataFrame.""" @@ -70,8 +70,6 @@ def num_columns(self) -> int: @cached_property def num_rows(self) -> int: """Number of rows.""" - if self.table is None: - raise ValueError("Number of rows of frame with scalars makes no sense") return self.table.num_rows() @classmethod diff --git a/python/cudf_polars/cudf_polars/containers/scalar.py b/python/cudf_polars/cudf_polars/containers/scalar.py deleted file mode 100644 index fc97d0fd9c2..00000000000 --- a/python/cudf_polars/cudf_polars/containers/scalar.py +++ /dev/null @@ -1,23 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-License-Identifier: Apache-2.0 - -"""A scalar, with some properties.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - import cudf._lib.pylibcudf as plc - -__all__: list[str] = ["Scalar"] - - -class Scalar: - """A scalar, and a name.""" - - __slots__ = ("obj", "name") - obj: plc.Scalar - - def __init__(self, scalar: plc.Scalar): - self.obj = scalar diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index ac20e36b151..a81cdcbf0c3 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -5,7 +5,7 @@ """ DSL nodes for the polars expression language. -An expression node is a function, `DataFrame -> Column` or `DataFrame -> Scalar`. +An expression node is a function, `DataFrame -> Column`. The evaluation context is provided by a LogicalPlan node, and can affect the evaluation rule as well as providing the dataframe input. @@ -26,7 +26,7 @@ import cudf._lib.pylibcudf as plc -from cudf_polars.containers import Column, NamedColumn, Scalar +from cudf_polars.containers import Column, NamedColumn from cudf_polars.utils import sorting if TYPE_CHECKING: @@ -165,7 +165,7 @@ def do_evaluate( *, context: ExecutionContext = ExecutionContext.FRAME, mapping: Mapping[Expr, Column] | None = None, - ) -> Column: # TODO: return type is a lie for Literal + ) -> Column: """ Evaluate this expression given a dataframe for context. @@ -187,8 +187,7 @@ def do_evaluate( Returns ------- - Column representing the evaluation of the expression (or maybe - a scalar). + Column representing the evaluation of the expression. Raises ------ @@ -205,7 +204,7 @@ def evaluate( *, context: ExecutionContext = ExecutionContext.FRAME, mapping: Mapping[Expr, Column] | None = None, - ) -> Column: # TODO: return type is a lie for Literal + ) -> Column: """ Evaluate this expression given a dataframe for context. @@ -226,19 +225,9 @@ def evaluate( this method provides logic to handle lookups in the substitution mapping. - The typed return value of :class:`Column` is not true when - evaluating :class:`Literal` nodes (which instead produce - :class:`Scalar` objects). However, these duck-type to having a - pylibcudf container object inside them, and usually they end - up appearing in binary expressions which pylibcudf handles - appropriately since there are overloads for (column, scalar) - pairs. We don't have to handle (scalar, scalar) in binops - since the polars optimizer has a constant-folding pass. - Returns ------- - Column representing the evaluation of the expression (or maybe - a scalar). + Column representing the evaluation of the expression. Raises ------ @@ -341,22 +330,13 @@ def evaluate( name to a column produced from an expression. """ obj = self.value.evaluate(df, context=context, mapping=mapping) - if isinstance(obj, Scalar): - return NamedColumn( - plc.Column.from_scalar(obj.obj, 1), - self.name, - is_sorted=plc.types.Sorted.YES, - order=plc.types.Order.ASCENDING, - null_order=plc.types.NullOrder.BEFORE, - ) - else: - return NamedColumn( - obj.obj, - self.name, - is_sorted=obj.is_sorted, - order=obj.order, - null_order=obj.null_order, - ) + return NamedColumn( + obj.obj, + self.name, + is_sorted=obj.is_sorted, + order=obj.order, + null_order=obj.null_order, + ) def collect_agg(self, *, depth: int) -> AggInfo: """Collect information about aggregations in groupbys.""" @@ -383,7 +363,7 @@ def do_evaluate( ) -> Column: """Evaluate this expression given a dataframe for context.""" # datatype of pyarrow scalar is correct by construction. - return Scalar(plc.interop.from_arrow(self.value)) # type: ignore + return Column(plc.Column.from_scalar(plc.interop.from_arrow(self.value), 1)) class Col(Expr): @@ -422,8 +402,14 @@ def do_evaluate( mapping: Mapping[Expr, Column] | None = None, ) -> Column: """Evaluate this expression given a dataframe for context.""" - # TODO: type is wrong, and dtype - return df.num_rows # type: ignore + return Column( + plc.Column.from_scalar( + plc.interop.from_arrow( + pa.scalar(df.num_rows, type=plc.interop.to_arrow(self.dtype)) + ), + 1, + ) + ) def collect_agg(self, *, depth: int) -> AggInfo: """Collect information about aggregations in groupbys.""" @@ -684,10 +670,24 @@ def do_evaluate( return Column(plc.strings.case.to_upper(column.obj)) elif self.name == pl_expr.StringFunction.EndsWith: column, suffix = columns - return Column(plc.strings.find.ends_with(column.obj, suffix.obj)) + return Column( + plc.strings.find.ends_with( + column.obj, + suffix.obj_scalar + if column.obj.size() != suffix.obj.size() and suffix.is_scalar + else suffix.obj, + ) + ) elif self.name == pl_expr.StringFunction.StartsWith: - column, suffix = columns - return Column(plc.strings.find.starts_with(column.obj, suffix.obj)) + column, prefix = columns + return Column( + plc.strings.find.starts_with( + column.obj, + prefix.obj_scalar + if column.obj.size() != prefix.obj.size() and prefix.is_scalar + else prefix.obj, + ) + ) else: raise NotImplementedError(f"StringFunction {self.name}") @@ -1109,8 +1109,15 @@ def do_evaluate( child.evaluate(df, context=context, mapping=mapping) for child in self.children ) + lop = left.obj + rop = right.obj + if left.obj.size() != right.obj.size(): + if left.is_scalar: + lop = left.obj_scalar + elif right.is_scalar: + rop = right.obj_scalar return Column( - plc.binaryop.binary_operation(left.obj, right.obj, self.op, self.dtype), + plc.binaryop.binary_operation(lop, rop, self.op, self.dtype), ) def collect_agg(self, *, depth: int) -> AggInfo: diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 4a954ea89d5..ec23f3be2e5 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -63,6 +63,39 @@ def broadcast( *columns: NamedColumn, target_length: int | None = None ) -> list[NamedColumn]: + """ + Broadcast a sequence of columns to a common length. + + Parameters + ---------- + columns + Columns to broadcast. + target_length + Optional length to broadcast to. If not provided, uses the + non-unit length of existing columns. + + Returns + ------- + List of broadcasted columns all of the same length. + + Raises + ------ + RuntimeError + If broadcasting is not possible. + + Notes + ----- + In evaluation of a set of expressions, polars type-puns length-1 + columns with scalars. When we insert these into a DataFrame + object, we need to ensure they are of equal length. This function + takes some columns, some of which may be length-1 and ensures that + all length-1 columns are broadcast to the length of the others. + + Broadcasting is only possible if the set of lengths of the input + columns is a subset of ``{1, n}`` for some (fixed) ``n``. If + ``target_length`` is provided and not all columns are length-1 + (i.e. ``n != 1``), then ``target_length`` must be equal to ``n``. + """ lengths = {column.obj.size() for column in columns} if len(lengths - {1}) > 1: raise RuntimeError("Mismatching column lengths") @@ -71,13 +104,18 @@ def broadcast( return list(columns) nrows = target_length elif len(lengths) == 1: - if target_length is not None: - assert target_length in lengths + if target_length is not None and target_length not in lengths: + raise RuntimeError( + "Cannot broadcast columns of length " + f"{lengths.pop()} to {target_length=}" + ) return list(columns) else: (nrows,) = lengths - {1} - if target_length is not None: - assert target_length == nrows + if target_length is not None and target_length != nrows: + raise RuntimeError( + f"Cannot broadcast columns of length {nrows} to {target_length=}" + ) return [ column if column.obj.size() != 1 From 7cbe60ada7cd67e2b8b76737a0ecfa01ad4ff903 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 7 Jun 2024 16:04:31 +0000 Subject: [PATCH 08/13] Note that columns are immutable in our data model --- python/cudf_polars/cudf_polars/containers/column.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/cudf_polars/cudf_polars/containers/column.py b/python/cudf_polars/cudf_polars/containers/column.py index 10b0ed78b07..c5aed4e429e 100644 --- a/python/cudf_polars/cudf_polars/containers/column.py +++ b/python/cudf_polars/cudf_polars/containers/column.py @@ -17,7 +17,7 @@ class Column: - """A column with sortedness metadata.""" + """An immutable column with sortedness metadata.""" obj: plc.Column is_sorted: plc.types.Sorted @@ -41,10 +41,10 @@ def __init__( self.order = order self.null_order = null_order - @property + @functools.cached_property def obj_scalar(self) -> plc.Scalar: """ - View the column object as a pylibcudf Scalar. + A copy of the column object as a pylibcudf Scalar. Returns ------- From 57097194da64111c17f067b4a2fa4197749b7ded Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 10 Jun 2024 10:36:59 +0000 Subject: [PATCH 09/13] Simplify logic for broadcasting --- python/cudf_polars/cudf_polars/dsl/ir.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index ec23f3be2e5..edd6124e570 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -32,7 +32,7 @@ from cudf_polars.utils import sorting if TYPE_CHECKING: - from collections.abc import MutableMapping + from collections.abc import MutableMapping, Set from typing import Literal from cudf_polars.typing import Schema @@ -96,31 +96,25 @@ def broadcast( ``target_length`` is provided and not all columns are length-1 (i.e. ``n != 1``), then ``target_length`` must be equal to ``n``. """ - lengths = {column.obj.size() for column in columns} - if len(lengths - {1}) > 1: - raise RuntimeError("Mismatching column lengths") + lengths: Set[int] = {column.obj.size() for column in columns} if lengths == {1}: if target_length is None: return list(columns) nrows = target_length - elif len(lengths) == 1: - if target_length is not None and target_length not in lengths: - raise RuntimeError( - "Cannot broadcast columns of length " - f"{lengths.pop()} to {target_length=}" - ) - return list(columns) else: - (nrows,) = lengths - {1} - if target_length is not None and target_length != nrows: + try: + (nrows,) = lengths - {1} + except ValueError as e: + raise RuntimeError("Mismatching column lengths") from e + if target_length is not None and nrows != target_length: raise RuntimeError( - f"Cannot broadcast columns of length {nrows} to {target_length=}" + f"Cannot broadcast columns of length {nrows=} to {target_length=}" ) return [ column if column.obj.size() != 1 else NamedColumn( - plc.Column.from_scalar(plc.copying.get_element(column.obj, 0), nrows), + plc.Column.from_scalar(column.obj_scalar, nrows), column.name, is_sorted=plc.types.Sorted.YES, order=plc.types.Order.ASCENDING, From 1f22fdfbb96b452b25d8c16a502997d39a8be5eb Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 10 Jun 2024 10:57:15 +0000 Subject: [PATCH 10/13] Tests of broadcasting --- python/cudf_polars/tests/utils/__init__.py | 6 ++ .../cudf_polars/tests/utils/test_broadcast.py | 74 +++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 python/cudf_polars/tests/utils/__init__.py create mode 100644 python/cudf_polars/tests/utils/test_broadcast.py diff --git a/python/cudf_polars/tests/utils/__init__.py b/python/cudf_polars/tests/utils/__init__.py new file mode 100644 index 00000000000..4611d642f14 --- /dev/null +++ b/python/cudf_polars/tests/utils/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/python/cudf_polars/tests/utils/test_broadcast.py b/python/cudf_polars/tests/utils/test_broadcast.py new file mode 100644 index 00000000000..69ad1e519e2 --- /dev/null +++ b/python/cudf_polars/tests/utils/test_broadcast.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +import cudf._lib.pylibcudf as plc + +from cudf_polars.containers import NamedColumn +from cudf_polars.dsl.ir import broadcast + + +@pytest.mark.parametrize("target", [4, None]) +def test_broadcast_all_scalar(target): + columns = [ + NamedColumn( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 1, plc.MaskState.ALL_VALID + ), + f"col{i}", + ) + for i in range(3) + ] + result = broadcast(*columns, target_length=target) + expected = 1 if target is None else target + + assert all(column.obj.size() == expected for column in result) + + +def test_invalid_target_length(): + columns = [ + NamedColumn( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 4, plc.MaskState.ALL_VALID + ), + f"col{i}", + ) + for i in range(3) + ] + with pytest.raises(RuntimeError): + _ = broadcast(*columns, target_length=8) + + +def test_broadcast_mismatching_column_lengths(): + columns = [ + NamedColumn( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), i + 1, plc.MaskState.ALL_VALID + ), + f"col{i}", + ) + for i in range(3) + ] + with pytest.raises(RuntimeError): + _ = broadcast(*columns) + + +@pytest.mark.parametrize("nrows", [0, 5]) +def test_broadcast_with_scalars(nrows): + columns = [ + NamedColumn( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), + nrows if i == 0 else 1, + plc.MaskState.ALL_VALID, + ), + f"col{i}", + ) + for i in range(3) + ] + + result = broadcast(*columns) + assert all(column.obj.size() == nrows for column in result) From 59c6163167ef30afe9349a1934e888d747aba328 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 10 Jun 2024 11:00:11 +0000 Subject: [PATCH 11/13] Length-0 columns are also always sorted --- python/cudf_polars/cudf_polars/containers/column.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cudf_polars/cudf_polars/containers/column.py b/python/cudf_polars/cudf_polars/containers/column.py index c5aed4e429e..776d71bb8c9 100644 --- a/python/cudf_polars/cudf_polars/containers/column.py +++ b/python/cudf_polars/cudf_polars/containers/column.py @@ -35,7 +35,7 @@ def __init__( ): self.obj = column self.is_scalar = self.obj.size() == 1 - if self.is_scalar: + if self.obj.size() <= 1: is_sorted = plc.types.Sorted.YES self.is_sorted = is_sorted self.order = order @@ -104,7 +104,7 @@ def set_sorted( ------- Self with metadata set. """ - if self.obj.size() == 1: + if self.obj.size() <= 1: is_sorted = plc.types.Sorted.YES self.is_sorted = is_sorted self.order = order From 5a7f92c585fdbd15b21da3f97cb2b9cdde370c5e Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 11 Jun 2024 09:13:04 +0000 Subject: [PATCH 12/13] Raise ValueError not RuntimeError --- python/cudf_polars/cudf_polars/containers/column.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/cudf_polars/cudf_polars/containers/column.py b/python/cudf_polars/cudf_polars/containers/column.py index 776d71bb8c9..156dd395d64 100644 --- a/python/cudf_polars/cudf_polars/containers/column.py +++ b/python/cudf_polars/cudf_polars/containers/column.py @@ -52,10 +52,11 @@ def obj_scalar(self) -> plc.Scalar: Raises ------ - RuntimeError if the column is not length-1. + ValueError + If the column is not length-1. """ if not self.is_scalar: - raise RuntimeError( + raise ValueError( f"Cannot convert a column of length {self.obj.size()} to scalar" ) return plc.copying.get_element(self.obj, 0) From 7fc75e9b9b30e6ae35512b137a35e33a62926cfe Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 11 Jun 2024 09:24:57 +0000 Subject: [PATCH 13/13] Use set.difference --- python/cudf_polars/cudf_polars/dsl/ir.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index edd6124e570..0a6deb5698c 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -32,7 +32,7 @@ from cudf_polars.utils import sorting if TYPE_CHECKING: - from collections.abc import MutableMapping, Set + from collections.abc import MutableMapping from typing import Literal from cudf_polars.typing import Schema @@ -96,14 +96,14 @@ def broadcast( ``target_length`` is provided and not all columns are length-1 (i.e. ``n != 1``), then ``target_length`` must be equal to ``n``. """ - lengths: Set[int] = {column.obj.size() for column in columns} + lengths: set[int] = {column.obj.size() for column in columns} if lengths == {1}: if target_length is None: return list(columns) nrows = target_length else: try: - (nrows,) = lengths - {1} + (nrows,) = lengths.difference([1]) except ValueError as e: raise RuntimeError("Mismatching column lengths") from e if target_length is not None and nrows != target_length: