Skip to content

Commit

Permalink
Add full coverage of utility functions (#15995)
Browse files Browse the repository at this point in the history
The datetime conversion tests just test that we can round-trip correctly for now.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Thomas Li (https://github.com/lithomas1)
  - Kyle Edwards (https://github.com/KyleFromNVIDIA)

URL: #15995
  • Loading branch information
wence- authored Jun 24, 2024
1 parent 525ca7e commit 4d4cdce
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 4 deletions.
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def from_polars(dtype: pl.DataType) -> plc.DataType:
return plc.DataType(plc.TypeId.TIMESTAMP_MICROSECONDS)
elif dtype.time_unit == "ns":
return plc.DataType(plc.TypeId.TIMESTAMP_NANOSECONDS)
assert dtype.time_unit is not None
assert dtype.time_unit is not None # pragma: no cover
assert_never(dtype.time_unit)
elif isinstance(dtype, pl.Duration):
if dtype.time_unit == "ms":
Expand All @@ -79,7 +79,7 @@ def from_polars(dtype: pl.DataType) -> plc.DataType:
return plc.DataType(plc.TypeId.DURATION_MICROSECONDS)
elif dtype.time_unit == "ns":
return plc.DataType(plc.TypeId.DURATION_NANOSECONDS)
assert dtype.time_unit is not None
assert dtype.time_unit is not None # pragma: no cover
assert_never(dtype.time_unit)
elif isinstance(dtype, pl.String):
return plc.DataType(plc.TypeId.STRING)
Expand Down
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/utils/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def sort_order(
for d in descending
]
null_precedence = []
# TODO: use strict=True when we drop py39
assert len(descending) == len(nulls_last)
if len(descending) != len(nulls_last) or len(descending) != num_keys:
raise ValueError("Mismatching length of arguments in sort_order")
for asc, null_last in zip(column_order, nulls_last):
if (asc == plc.types.Order.ASCENDING) ^ (not null_last):
null_precedence.append(plc.types.NullOrder.AFTER)
Expand Down
7 changes: 7 additions & 0 deletions python/cudf_polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ version = {file = "cudf_polars/VERSION"}
[tool.pytest.ini_options]
xfail_strict = true

[tool.coverage.report]
exclude_also = [
"if TYPE_CHECKING:",
"class .*\\bProtocol\\):",
"assert_never\\("
]

[tool.ruff]
line-length = 88
indent-width = 4
Expand Down
34 changes: 34 additions & 0 deletions python/cudf_polars/tests/expressions/test_datetime_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.mark.parametrize(
"dtype",
[
pl.Date(),
pl.Datetime("ms"),
pl.Datetime("us"),
pl.Datetime("ns"),
pl.Duration("ms"),
pl.Duration("us"),
pl.Duration("ns"),
],
ids=repr,
)
def test_datetime_dataframe_scan(dtype):
ldf = pl.DataFrame(
{
"a": pl.Series([1, 2, 3, 4, 5, 6, 7], dtype=dtype),
"b": pl.Series([3, 4, 5, 6, 7, 8, 9], dtype=pl.UInt16),
}
).lazy()

query = ldf.select(pl.col("b"), pl.col("a"))
assert_gpu_result_equal(query)
31 changes: 31 additions & 0 deletions python/cudf_polars/tests/utils/test_dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.utils.dtypes import from_polars


@pytest.mark.parametrize(
"pltype",
[
pl.Time(),
pl.Struct({"a": pl.Int8, "b": pl.Float32}),
pl.Datetime("ms", time_zone="US/Pacific"),
pl.Array(pl.Int8, 2),
pl.Binary(),
pl.Categorical(),
pl.Enum(["a", "b"]),
pl.Field("a", pl.Int8),
pl.Object(),
pl.Unknown(),
],
ids=repr,
)
def test_unhandled_dtype_conversion_raises(pltype):
with pytest.raises(NotImplementedError):
_ = from_polars(pltype)
21 changes: 21 additions & 0 deletions python/cudf_polars/tests/utils/test_sorting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

from cudf_polars.utils.sorting import sort_order


@pytest.mark.parametrize(
"descending,nulls_last,num_keys",
[
([True], [False, True], 3),
([True, True], [False, True, False], 3),
([False, True], [True], 3),
],
)
def test_sort_order_raises_mismatch(descending, nulls_last, num_keys):
with pytest.raises(ValueError):
_ = sort_order(descending, nulls_last=nulls_last, num_keys=num_keys)

0 comments on commit 4d4cdce

Please sign in to comment.