Skip to content

Commit

Permalink
Always initialise Literal with pyarrow scalar value
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Jun 4, 2024
1 parent eb46016 commit 8982e31
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 31 deletions.
46 changes: 33 additions & 13 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,32 +484,48 @@ def do_evaluate(
return self._distinct(
column,
keep=plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST,
source_value=plc.interop.from_arrow(pa.scalar(True)), # noqa: FBT003
target_value=plc.interop.from_arrow(pa.scalar(False)), # noqa: FBT003
source_value=plc.interop.from_arrow(
pa.scalar(True, type=plc.interop.to_arrow(self.dtype)) # noqa: FBT003
),
target_value=plc.interop.from_arrow(
pa.scalar(False, type=plc.interop.to_arrow(self.dtype)) # noqa: FBT003
),
)
elif self.name == pl_expr.BooleanFunction.IsLastDistinct:
(column,) = columns
return self._distinct(
column,
keep=plc.stream_compaction.DuplicateKeepOption.KEEP_LAST,
source_value=plc.interop.from_arrow(pa.scalar(True)), # noqa: FBT003
target_value=plc.interop.from_arrow(pa.scalar(False)), # noqa: FBT003
source_value=plc.interop.from_arrow(
pa.scalar(True, type=plc.interop.to_arrow(self.dtype)) # noqa: FBT003
),
target_value=plc.interop.from_arrow(
pa.scalar(False, type=plc.interop.to_arrow(self.dtype)) # noqa: FBT003
),
)
elif self.name == pl_expr.BooleanFunction.IsUnique:
(column,) = columns
return self._distinct(
column,
keep=plc.stream_compaction.DuplicateKeepOption.KEEP_NONE,
source_value=plc.interop.from_arrow(pa.scalar(True)), # noqa: FBT003
target_value=plc.interop.from_arrow(pa.scalar(False)), # noqa: FBT003
source_value=plc.interop.from_arrow(
pa.scalar(True, type=plc.interop.to_arrow(self.dtype)) # noqa: FBT003
),
target_value=plc.interop.from_arrow(
pa.scalar(False, type=plc.interop.to_arrow(self.dtype)) # noqa: FBT003
),
)
elif self.name == pl_expr.BooleanFunction.IsDuplicated:
(column,) = columns
return self._distinct(
column,
keep=plc.stream_compaction.DuplicateKeepOption.KEEP_NONE,
source_value=plc.interop.from_arrow(pa.scalar(False)), # noqa: FBT003
target_value=plc.interop.from_arrow(pa.scalar(True)), # noqa: FBT003
source_value=plc.interop.from_arrow(
pa.scalar(False, type=plc.interop.to_arrow(self.dtype)) # noqa: FBT003
),
target_value=plc.interop.from_arrow(
pa.scalar(True, type=plc.interop.to_arrow(self.dtype)) # noqa: FBT003
),
)
elif self.name == pl_expr.BooleanFunction.AllHorizontal:
name = columns[0].name
Expand Down Expand Up @@ -717,7 +733,9 @@ def do_evaluate(
bounds_policy = plc.copying.OutOfBoundsPolicy.NULLIFY
obj = plc.replace.replace_nulls(
indices.obj,
plc.interop.from_arrow(pa.scalar(n), data_type=indices.obj.data_type()),
plc.interop.from_arrow(
pa.scalar(n, type=plc.interop.to_arrow(indices.obj.data_type()))
),
)
else:
bounds_policy = plc.copying.OutOfBoundsPolicy.DONT_CHECK
Expand Down Expand Up @@ -893,11 +911,13 @@ def _reduce(
)

def _count(self, column: Column) -> Column:
# TODO: dtype handling
return Column(
plc.Column.from_scalar(
plc.interop.from_arrow(
pa.scalar(column.obj.size() - column.obj.null_count()),
pa.scalar(
column.obj.size() - column.obj.null_count(),
type=plc.interop.to_arrow(self.dtype),
),
),
1,
),
Expand All @@ -909,7 +929,7 @@ def _min(self, column: Column, *, propagate_nans: bool) -> Column:
return Column(
plc.Column.from_scalar(
plc.interop.from_arrow(
pa.scalar(float("nan")), data_type=self.dtype
pa.scalar(float("nan"), type=plc.interop.to_arrow(self.dtype))
),
1,
),
Expand All @@ -924,7 +944,7 @@ def _max(self, column: Column, *, propagate_nans: bool) -> Column:
return Column(
plc.Column.from_scalar(
plc.interop.from_arrow(
pa.scalar(float("nan")), data_type=self.dtype
pa.scalar(float("nan"), type=plc.interop.to_arrow(self.dtype))
),
1,
),
Expand Down
10 changes: 7 additions & 3 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,13 @@ def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame:
assert_never(self.typ)
if row_index is not None:
name, offset = row_index
# TODO: dtype
step = plc.interop.from_arrow(pa.scalar(1))
init = plc.interop.from_arrow(pa.scalar(offset))
dtype = self.schema[name]
step = plc.interop.from_arrow(
pa.scalar(1, type=plc.interop.to_arrow(dtype))
)
init = plc.interop.from_arrow(
pa.scalar(offset, type=plc.interop.to_arrow(dtype))
)
index = Column(
plc.filling.sequence(df.num_rows, init, step), name
).set_sorted(
Expand Down
9 changes: 6 additions & 3 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from functools import singledispatch
from typing import Any

import pyarrow as pa

from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir

import cudf._lib.pylibcudf as plc # noqa: TCH002, singledispatch register needs this name defined.
import cudf._lib.pylibcudf as plc

from cudf_polars.dsl import expr, ir
from cudf_polars.utils import dtypes
Expand Down Expand Up @@ -295,7 +297,8 @@ def _(node: pl_expr.Window, visitor: Any, dtype: plc.DataType) -> expr.Expr:

@_translate_expr.register
def _(node: pl_expr.Literal, visitor: Any, dtype: plc.DataType) -> expr.Expr:
return expr.Literal(dtype, node.value)
value = pa.scalar(node.value, type=plc.interop.to_arrow(dtype))
return expr.Literal(dtype, value)


@_translate_expr.register
Expand Down Expand Up @@ -337,7 +340,7 @@ def _(node: pl_expr.Cast, visitor: Any, dtype: plc.DataType) -> expr.Expr:
inner = translate_expr(visitor, n=node.expr)
# Push casts into literals so we can handle Cast(Literal(Null))
if isinstance(inner, expr.Literal):
return expr.Literal(dtype, inner.value)
return expr.Literal(dtype, inner.value.cast(plc.interop.to_arrow(dtype)))
else:
return expr.Cast(dtype, inner)

Expand Down
3 changes: 3 additions & 0 deletions python/cudf_polars/cudf_polars/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

import cudf._lib.pylibcudf as plc

__all__ = ["from_polars"]


@cache
def from_polars(dtype: pl.DataType) -> plc.DataType:
Expand Down Expand Up @@ -84,6 +86,7 @@ def from_polars(dtype: pl.DataType) -> plc.DataType:
# TODO: Hopefully
return plc.DataType(plc.TypeId.EMPTY)
elif isinstance(dtype, pl.List):
# TODO: This doesn't consider the value type.
return plc.DataType(plc.TypeId.LIST)
else:
raise NotImplementedError(f"{dtype=} conversion not supported")
2 changes: 1 addition & 1 deletion python/cudf_polars/tests/expressions/test_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_agg(df, agg):
q = df.select(expr)

# https://github.com/rapidsai/cudf/issues/15852
check_dtype = agg not in {"count", "n_unique", "median"}
check_dtype = agg not in {"n_unique", "median"}
if not check_dtype and q.schema["a"] != pl.Float64:
with pytest.raises(AssertionError):
assert_gpu_result_equal(q)
Expand Down
12 changes: 1 addition & 11 deletions python/cudf_polars/tests/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,7 @@


@pytest.fixture(
params=[
(None, None),
pytest.param(
("row-index", 0),
marks=pytest.mark.xfail(reason="Incorrect dtype for row index"),
),
pytest.param(
("index", 10),
marks=pytest.mark.xfail(reason="Incorrect dtype for row index"),
),
],
params=[(None, None), ("row-index", 0), ("index", 10)],
ids=["no-row-index", "zero-offset-row-index", "offset-row-index"],
)
def row_index(request):
Expand Down

0 comments on commit 8982e31

Please sign in to comment.