Skip to content

Commit

Permalink
Fix bug in HConcat
Browse files Browse the repository at this point in the history
Shorter tables must be extended with nulls before concatenation.
  • Loading branch information
wence- committed Jul 16, 2024
1 parent 2e39750 commit a8e3c1a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
28 changes: 27 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import cudf_polars.dsl.expr as expr
from cudf_polars.containers import DataFrame, NamedColumn
from cudf_polars.utils import sorting
from cudf_polars.utils import dtypes, sorting

if TYPE_CHECKING:
from collections.abc import MutableMapping
Expand Down Expand Up @@ -1044,6 +1044,32 @@ class HConcat(IR):
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
dfs = [df.evaluate(cache=cache) for df in self.dfs]
max_rows = max(df.num_rows for df in dfs)
# Horizontal concatenation extends shorter tables with nulls
dfs = [
df
if df.num_rows == max_rows
else DataFrame.from_table(
plc.concatenate.concatenate(
[
df.table,
plc.Table(
[
plc.Column.from_scalar(
plc.interop.from_arrow(
pa.scalar(None, type=dtypes.arrow_type(c.obj))
),
max_rows - df.num_rows,
)
for c in df.columns
]
),
]
),
df.column_names,
)
for df in dfs
]
return DataFrame(
list(itertools.chain.from_iterable(df.columns for df in dfs)),
)
11 changes: 8 additions & 3 deletions python/cudf_polars/cudf_polars/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,19 @@ def arrow_type(column: plc.Column) -> pa.DataType:
For unsupported conversions.
"""
dtype = column.type()
if dtype.id() == plc.TypeId.LIST:
tid = dtype.id()
if tid == plc.TypeId.LIST:
_, inner = column.children()
return plc.interop.to_arrow(dtype, value_type=arrow_type(inner))
elif plc.traits.is_fixed_width(dtype) and not plc.traits.is_fixed_point(dtype):
elif (
(plc.traits.is_fixed_width(dtype) and not plc.traits.is_fixed_point(dtype))
or tid == plc.TypeId.STRING
or tid == plc.TypeId.EMPTY
):
return plc.interop.to_arrow(dtype)
else:
raise NotImplementedError(
"No conversion for struct columns"
f"No conversion for columns with type {tid}"
) # pragma: no cover; unreachable since we raise earlier


Expand Down

0 comments on commit a8e3c1a

Please sign in to comment.