Skip to content

Commit

Permalink
Tests of broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Jun 10, 2024
1 parent 5709719 commit 1f22fdf
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/cudf_polars/tests/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

__all__: list[str] = []
74 changes: 74 additions & 0 deletions python/cudf_polars/tests/utils/test_broadcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

import cudf._lib.pylibcudf as plc

from cudf_polars.containers import NamedColumn
from cudf_polars.dsl.ir import broadcast


@pytest.mark.parametrize("target", [4, None])
def test_broadcast_all_scalar(target):
columns = [
NamedColumn(
plc.column_factories.make_numeric_column(
plc.DataType(plc.TypeId.INT8), 1, plc.MaskState.ALL_VALID
),
f"col{i}",
)
for i in range(3)
]
result = broadcast(*columns, target_length=target)
expected = 1 if target is None else target

assert all(column.obj.size() == expected for column in result)


def test_invalid_target_length():
columns = [
NamedColumn(
plc.column_factories.make_numeric_column(
plc.DataType(plc.TypeId.INT8), 4, plc.MaskState.ALL_VALID
),
f"col{i}",
)
for i in range(3)
]
with pytest.raises(RuntimeError):
_ = broadcast(*columns, target_length=8)


def test_broadcast_mismatching_column_lengths():
columns = [
NamedColumn(
plc.column_factories.make_numeric_column(
plc.DataType(plc.TypeId.INT8), i + 1, plc.MaskState.ALL_VALID
),
f"col{i}",
)
for i in range(3)
]
with pytest.raises(RuntimeError):
_ = broadcast(*columns)


@pytest.mark.parametrize("nrows", [0, 5])
def test_broadcast_with_scalars(nrows):
columns = [
NamedColumn(
plc.column_factories.make_numeric_column(
plc.DataType(plc.TypeId.INT8),
nrows if i == 0 else 1,
plc.MaskState.ALL_VALID,
),
f"col{i}",
)
for i in range(3)
]

result = broadcast(*columns)
assert all(column.obj.size() == nrows for column in result)

0 comments on commit 1f22fdf

Please sign in to comment.