Skip to content

Commit

Permalink
Remove workaround for large strings in arrow schema
Browse files Browse the repository at this point in the history
Now libcudf supports interop from arrow large strings, no need for
the cast. While here, handle casting of nested lists as well.
  • Loading branch information
wence- committed Jun 27, 2024
1 parent fa8284d commit 8c58378
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 11 deletions.
15 changes: 5 additions & 10 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import cudf_polars.dsl.expr as expr
from cudf_polars.containers import DataFrame, NamedColumn
from cudf_polars.utils import sorting
from cudf_polars.utils import dtypes, sorting

if TYPE_CHECKING:
from collections.abc import MutableMapping
Expand Down Expand Up @@ -292,15 +292,10 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
table = pdf.to_arrow()
schema = table.schema
for i, field in enumerate(schema):
# TODO: Nested types
if field.type == pa.large_string():
# TODO: goes away when libcudf supports large strings
schema = schema.set(i, pa.field(field.name, pa.string()))
elif isinstance(field.type, pa.LargeListType):
# TODO: goes away when libcudf supports large lists
schema = schema.set(
i, pa.field(field.name, pa.list_(field.type.field(0)))
)
schema = schema.set(
i, pa.field(field.name, dtypes.downcast_arrow_lists(field.type))
)
# No-op if the schema is unchanged.
table = table.cast(schema)
df = DataFrame.from_table(
plc.interop.from_arrow(table), list(self.schema.keys())
Expand Down
31 changes: 30 additions & 1 deletion python/cudf_polars/cudf_polars/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,42 @@

from functools import cache

import pyarrow as pa
from typing_extensions import assert_never

import polars as pl

import cudf._lib.pylibcudf as plc

__all__ = ["from_polars"]
__all__ = ["from_polars", "downcast_arrow_lists"]


def downcast_arrow_lists(typ: pa.DataType) -> pa.DataType:
"""
Sanitize an arrow datatype from polars.
Parameters
----------
typ
Arrow type to sanitize
Returns
-------
Sanitized arrow type
Notes
-----
As well as arrow ``ListType``s, polars can produce
``LargeListType``s and ``FixedSizeListType``s, these are not
currently handled by libcudf, so we attempt to cast them all into
normal ``ListType``s on the arrow side before consuming the arrow
data.
"""
if isinstance(typ, pa.LargeListType):
return pa.list_(downcast_arrow_lists(typ.value_type))
# We don't have to worry about diving into struct types for now
# since those are always NotImplemented before we get here.
return typ


@cache
Expand Down
19 changes: 19 additions & 0 deletions python/cudf_polars/tests/test_dataframescan.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,22 @@ def test_scan_drop_nulls(subset, predicate_pushdown):
assert_gpu_result_equal(
q, collect_kwargs={"predicate_pushdown": predicate_pushdown}
)


def test_can_convert_lists():
df = pl.LazyFrame(
{
"a": pl.Series([[1, 2], [3]], dtype=pl.List(pl.Int8())),
"b": pl.Series([[1], [2]], dtype=pl.List(pl.UInt16())),
"c": pl.Series(
[
[["1", "2", "3"], ["4", "567"]],
[["8", "9"], []],
],
dtype=pl.List(pl.List(pl.String())),
),
"d": pl.Series([[[1, 2]], []], dtype=pl.List(pl.List(pl.UInt16()))),
}
)

assert_gpu_result_equal(df)

0 comments on commit 8c58378

Please sign in to comment.