From 1f22fdfbb96b452b25d8c16a502997d39a8be5eb Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 10 Jun 2024 10:57:15 +0000 Subject: [PATCH] Tests of broadcasting --- python/cudf_polars/tests/utils/__init__.py | 6 ++ .../cudf_polars/tests/utils/test_broadcast.py | 74 +++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 python/cudf_polars/tests/utils/__init__.py create mode 100644 python/cudf_polars/tests/utils/test_broadcast.py diff --git a/python/cudf_polars/tests/utils/__init__.py b/python/cudf_polars/tests/utils/__init__.py new file mode 100644 index 00000000000..4611d642f14 --- /dev/null +++ b/python/cudf_polars/tests/utils/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/python/cudf_polars/tests/utils/test_broadcast.py b/python/cudf_polars/tests/utils/test_broadcast.py new file mode 100644 index 00000000000..69ad1e519e2 --- /dev/null +++ b/python/cudf_polars/tests/utils/test_broadcast.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +import cudf._lib.pylibcudf as plc + +from cudf_polars.containers import NamedColumn +from cudf_polars.dsl.ir import broadcast + + +@pytest.mark.parametrize("target", [4, None]) +def test_broadcast_all_scalar(target): + columns = [ + NamedColumn( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 1, plc.MaskState.ALL_VALID + ), + f"col{i}", + ) + for i in range(3) + ] + result = broadcast(*columns, target_length=target) + expected = 1 if target is None else target + + assert all(column.obj.size() == expected for column in result) + + +def test_invalid_target_length(): + columns = [ + NamedColumn( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 4, plc.MaskState.ALL_VALID + ), + f"col{i}", + ) + for i in range(3) + ] + with pytest.raises(RuntimeError): + _ = broadcast(*columns, target_length=8) + + +def test_broadcast_mismatching_column_lengths(): + columns = [ + NamedColumn( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), i + 1, plc.MaskState.ALL_VALID + ), + f"col{i}", + ) + for i in range(3) + ] + with pytest.raises(RuntimeError): + _ = broadcast(*columns) + + +@pytest.mark.parametrize("nrows", [0, 5]) +def test_broadcast_with_scalars(nrows): + columns = [ + NamedColumn( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), + nrows if i == 0 else 1, + plc.MaskState.ALL_VALID, + ), + f"col{i}", + ) + for i in range(3) + ] + + result = broadcast(*columns) + assert all(column.obj.size() == nrows for column in result)