From a8e3c1a3fd232e1d8a95ffcd99571f7b49e13327 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 1 Jul 2024 17:26:21 +0000 Subject: [PATCH] Fix bug in HConcat Shorter tables must be extended with nulls before concatenation. --- python/cudf_polars/cudf_polars/dsl/ir.py | 28 ++++++++++++++++++- .../cudf_polars/cudf_polars/utils/dtypes.py | 11 ++++++-- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index cce0c4a3d94..488de7aaa02 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -30,7 +30,7 @@ import cudf_polars.dsl.expr as expr from cudf_polars.containers import DataFrame, NamedColumn -from cudf_polars.utils import sorting +from cudf_polars.utils import dtypes, sorting if TYPE_CHECKING: from collections.abc import MutableMapping @@ -1044,6 +1044,32 @@ class HConcat(IR): def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" dfs = [df.evaluate(cache=cache) for df in self.dfs] + max_rows = max(df.num_rows for df in dfs) + # Horizontal concatenation extends shorter tables with nulls + dfs = [ + df + if df.num_rows == max_rows + else DataFrame.from_table( + plc.concatenate.concatenate( + [ + df.table, + plc.Table( + [ + plc.Column.from_scalar( + plc.interop.from_arrow( + pa.scalar(None, type=dtypes.arrow_type(c.obj)) + ), + max_rows - df.num_rows, + ) + for c in df.columns + ] + ), + ] + ), + df.column_names, + ) + for df in dfs + ] return DataFrame( list(itertools.chain.from_iterable(df.columns for df in dfs)), ) diff --git a/python/cudf_polars/cudf_polars/utils/dtypes.py b/python/cudf_polars/cudf_polars/utils/dtypes.py index 9db41d641bf..81a98508d6e 100644 --- a/python/cudf_polars/cudf_polars/utils/dtypes.py +++ b/python/cudf_polars/cudf_polars/utils/dtypes.py @@ -106,14 +106,19 @@ def arrow_type(column: plc.Column) -> pa.DataType: For unsupported conversions. """ dtype = column.type() - if dtype.id() == plc.TypeId.LIST: + tid = dtype.id() + if tid == plc.TypeId.LIST: _, inner = column.children() return plc.interop.to_arrow(dtype, value_type=arrow_type(inner)) - elif plc.traits.is_fixed_width(dtype) and not plc.traits.is_fixed_point(dtype): + elif ( + (plc.traits.is_fixed_width(dtype) and not plc.traits.is_fixed_point(dtype)) + or tid == plc.TypeId.STRING + or tid == plc.TypeId.EMPTY + ): return plc.interop.to_arrow(dtype) else: raise NotImplementedError( - "No conversion for struct columns" + f"No conversion for columns with type {tid}" ) # pragma: no cover; unreachable since we raise earlier