Skip to content

Commit

Permalink
Fix issue in horizontal concat implementation in cudf-polars (#16271)
Browse files Browse the repository at this point in the history
Shorter tables must be extended to the same length as the longest table.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #16271
  • Loading branch information
wence- authored Jul 22, 2024
1 parent e6537de commit 852b151
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 47 deletions.
22 changes: 22 additions & 0 deletions python/cudf/cudf/_lib/pylibcudf/column.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,28 @@ cdef class Column:
c_result = move(make_column_from_scalar(dereference(c_scalar), size))
return Column.from_libcudf(move(c_result))

@staticmethod
def all_null_like(Column like, size_type size):
"""Create an all null column from a template.
Parameters
----------
like : Column
Column whose type we should mimic
size : int
Number of rows in the resulting column.
Returns
-------
Column
An all-null column of `size` rows and type matching `like`.
"""
cdef Scalar slr = Scalar.empty_like(like)
cdef unique_ptr[column] c_result
with nogil:
c_result = move(make_column_from_scalar(dereference(slr.get()), size))
return Column.from_libcudf(move(c_result))

@staticmethod
def from_cuda_array_interface_obj(object obj):
"""Create a Column from an object with a CUDA array interface.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from libcpp.memory cimport unique_ptr
from libcpp.string cimport string

from cudf._lib.pylibcudf.libcudf.column.column_view cimport column_view
from cudf._lib.pylibcudf.libcudf.scalar.scalar cimport scalar


cdef extern from "cudf/scalar/scalar_factories.hpp" namespace "cudf" nogil:
cdef unique_ptr[scalar] make_string_scalar(const string & _string) except +
cdef unique_ptr[scalar] make_fixed_width_scalar[T](T value) except +

cdef unique_ptr[scalar] make_empty_scalar_like(const column_view &) except +
4 changes: 4 additions & 0 deletions python/cudf/cudf/_lib/pylibcudf/scalar.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ from rmm._lib.memory_resource cimport DeviceMemoryResource

from cudf._lib.pylibcudf.libcudf.scalar.scalar cimport scalar

from .column cimport Column
from .types cimport DataType


Expand All @@ -24,5 +25,8 @@ cdef class Scalar:
cpdef DataType type(self)
cpdef bool is_valid(self)

@staticmethod
cdef Scalar empty_like(Column column)

@staticmethod
cdef Scalar from_libcudf(unique_ptr[scalar] libcudf_scalar, dtype=*)
20 changes: 20 additions & 0 deletions python/cudf/cudf/_lib/pylibcudf/scalar.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

from cython cimport no_gc_clear
from libcpp.memory cimport unique_ptr
from libcpp.utility cimport move

from rmm._lib.memory_resource cimport get_current_device_resource

from cudf._lib.pylibcudf.libcudf.scalar.scalar cimport scalar
from cudf._lib.pylibcudf.libcudf.scalar.scalar_factories cimport (
make_empty_scalar_like,
)

from .column cimport Column
from .types cimport DataType


Expand Down Expand Up @@ -46,6 +51,21 @@ cdef class Scalar:
"""True if the scalar is valid, false if not"""
return self.get().is_valid()

@staticmethod
cdef Scalar empty_like(Column column):
"""Construct a null scalar with the same type as column.
Parameters
----------
column
Column to take type from
Returns
-------
New empty (null) scalar of the given type.
"""
return Scalar.from_libcudf(move(make_empty_scalar_like(column.view())))

@staticmethod
cdef Scalar from_libcudf(unique_ptr[scalar] libcudf_scalar, dtype=None):
"""Construct a Scalar object from a libcudf scalar.
Expand Down
39 changes: 39 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,9 +1101,48 @@ class HConcat(IR):
dfs: list[IR]
"""List of inputs."""

@staticmethod
def _extend_with_nulls(table: plc.Table, *, nrows: int) -> plc.Table:
"""
Extend a table with nulls.
Parameters
----------
table
Table to extend
nrows
Number of additional rows
Returns
-------
New pylibcudf table.
"""
return plc.concatenate.concatenate(
[
table,
plc.Table(
[
plc.Column.all_null_like(column, nrows)
for column in table.columns()
]
),
]
)

def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
dfs = [df.evaluate(cache=cache) for df in self.dfs]
max_rows = max(df.num_rows for df in dfs)
# Horizontal concatenation extends shorter tables with nulls
dfs = [
df
if df.num_rows == max_rows
else DataFrame.from_table(
self._extend_with_nulls(df.table, nrows=max_rows - df.num_rows),
df.column_names,
)
for df in dfs
]
return DataFrame(
list(itertools.chain.from_iterable(df.columns for df in dfs)),
)
3 changes: 2 additions & 1 deletion python/cudf_polars/cudf_polars/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def from_polars(dtype: pl.DataType) -> plc.DataType:
# TODO: Hopefully
return plc.DataType(plc.TypeId.EMPTY)
elif isinstance(dtype, pl.List):
# TODO: This doesn't consider the value type.
# Recurse to catch unsupported inner types
_ = from_polars(dtype.inner)
return plc.DataType(plc.TypeId.LIST)
else:
raise NotImplementedError(f"{dtype=} conversion not supported")
9 changes: 9 additions & 0 deletions python/cudf_polars/tests/test_hconcat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,12 @@ def test_hconcat():
ldf2 = ldf.select((pl.col("a") + pl.col("b")).alias("c"))
query = pl.concat([ldf, ldf2], how="horizontal")
assert_gpu_result_equal(query)


def test_hconcat_different_heights():
left = pl.LazyFrame({"a": [1, 2, 3, 4]})

right = pl.LazyFrame({"b": [[1], [2]], "c": ["a", "bcde"]})

q = pl.concat([left, right], how="horizontal")
assert_gpu_result_equal(q)
93 changes: 47 additions & 46 deletions python/cudf_polars/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,65 +12,68 @@
)


@pytest.mark.parametrize(
"how",
[
"inner",
"left",
"semi",
"anti",
"full",
],
)
@pytest.mark.parametrize("coalesce", [False, True])
@pytest.mark.parametrize(
"join_nulls", [False, True], ids=["nulls_not_equal", "nulls_equal"]
)
@pytest.mark.parametrize(
"join_expr",
[
pl.col("a"),
pl.col("a") * 2,
[pl.col("a"), pl.col("c") + 1],
["c", "a"],
],
)
def test_join(how, coalesce, join_nulls, join_expr):
left = pl.DataFrame(
@pytest.fixture(params=[False, True], ids=["nulls_not_equal", "nulls_equal"])
def join_nulls(request):
return request.param


@pytest.fixture(params=["inner", "left", "semi", "anti", "full"])
def how(request):
return request.param


@pytest.fixture
def left():
return pl.LazyFrame(
{
"a": [1, 2, 3, 1, None],
"b": [1, 2, 3, 4, 5],
"c": [2, 3, 4, 5, 6],
}
).lazy()
right = pl.DataFrame(
)


@pytest.fixture
def right():
return pl.LazyFrame(
{
"a": [1, 4, 3, 7, None, None],
"c": [2, 3, 4, 5, 6, 7],
}
).lazy()
)


@pytest.mark.parametrize(
"join_expr",
[
pl.col("a"),
pl.col("a") * 2,
[pl.col("a"), pl.col("c") + 1],
["c", "a"],
],
)
def test_non_coalesce_join(left, right, how, join_nulls, join_expr):
query = left.join(
right, on=join_expr, how=how, join_nulls=join_nulls, coalesce=coalesce
right, on=join_expr, how=how, join_nulls=join_nulls, coalesce=False
)
assert_gpu_result_equal(query, check_row_order=how == "left")


def test_cross_join():
left = pl.DataFrame(
{
"a": [1, 2, 3, 1, None],
"b": [1, 2, 3, 4, 5],
"c": [2, 3, 4, 5, 6],
}
).lazy()
right = pl.DataFrame(
{
"a": [1, 4, 3, 7, None, None],
"c": [2, 3, 4, 5, 6, 7],
}
).lazy()
@pytest.mark.parametrize(
"join_expr",
[
pl.col("a"),
["c", "a"],
],
)
def test_coalesce_join(left, right, how, join_nulls, join_expr):
query = left.join(
right, on=join_expr, how=how, join_nulls=join_nulls, coalesce=True
)
assert_gpu_result_equal(query, check_row_order=False)


def test_cross_join(left, right):
q = left.join(right, how="cross")

assert_gpu_result_equal(q)
Expand All @@ -79,9 +82,7 @@ def test_cross_join():
@pytest.mark.parametrize(
"left_on,right_on", [(pl.col("a"), pl.lit(2)), (pl.lit(2), pl.col("a"))]
)
def test_join_literal_key_unsupported(left_on, right_on):
left = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
right = pl.LazyFrame({"a": [1, 2, 3], "b": [5, 6, 7]})
def test_join_literal_key_unsupported(left, right, left_on, right_on):
q = left.join(right, left_on=left_on, right_on=right_on, how="inner")

assert_ir_translation_raises(q, NotImplementedError)
1 change: 1 addition & 0 deletions python/cudf_polars/tests/utils/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
pl.Time(),
pl.Struct({"a": pl.Int8, "b": pl.Float32}),
pl.Datetime("ms", time_zone="US/Pacific"),
pl.List(pl.Datetime("ms", time_zone="US/Pacific")),
pl.Array(pl.Int8, 2),
pl.Binary(),
pl.Categorical(),
Expand Down

0 comments on commit 852b151

Please sign in to comment.