Skip to content

Commit

Permalink
Support both polars 1.0 and 1.1
Browse files Browse the repository at this point in the history
Convert internal PySeries to public Series in translation.
  • Loading branch information
wence- committed Jul 9, 2024
1 parent f6b355d commit dfd644a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence

import polars.polars as plrs
import polars as pl
import polars.type_aliases as pl_types

from cudf_polars.containers import DataFrame
Expand Down Expand Up @@ -377,7 +377,7 @@ class LiteralColumn(Expr):
value: pa.Array[Any, Any]
children: tuple[()]

def __init__(self, dtype: plc.DataType, value: plrs.PySeries) -> None:
def __init__(self, dtype: plc.DataType, value: pl.Series) -> None:
super().__init__(dtype)
data = value.to_arrow()
self.value = data.cast(dtypes.downcast_arrow_lists(data.type))
Expand Down
3 changes: 2 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pyarrow as pa
from typing_extensions import assert_never

import polars as pl
import polars.polars as plrs
from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir

Expand Down Expand Up @@ -402,7 +403,7 @@ def _(node: pl_expr.Window, visitor: NodeTraverser, dtype: plc.DataType) -> expr
@_translate_expr.register
def _(node: pl_expr.Literal, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
if isinstance(node.value, plrs.PySeries):
return expr.LiteralColumn(dtype, node.value)
return expr.LiteralColumn(dtype, pl.Series._from_pyseries(node.value))
value = pa.scalar(node.value, type=plc.interop.to_arrow(dtype))
return expr.Literal(dtype, value)

Expand Down

0 comments on commit dfd644a

Please sign in to comment.