From 8e81de7877bf2ba206b73e80c4e54a9504f86e5f Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 12 Jun 2024 13:32:41 +0000 Subject: [PATCH] Add full coverage of utility functions The datetime conversion tests just test that we can round-trip correctly for now. --- .../cudf_polars/cudf_polars/utils/dtypes.py | 4 +-- .../cudf_polars/cudf_polars/utils/sorting.py | 4 +-- python/cudf_polars/pyproject.toml | 7 ++++ .../tests/expressions/test_datetime_basic.py | 34 +++++++++++++++++++ python/cudf_polars/tests/utils/test_dtypes.py | 31 +++++++++++++++++ .../cudf_polars/tests/utils/test_sorting.py | 21 ++++++++++++ 6 files changed, 97 insertions(+), 4 deletions(-) create mode 100644 python/cudf_polars/tests/expressions/test_datetime_basic.py create mode 100644 python/cudf_polars/tests/utils/test_dtypes.py create mode 100644 python/cudf_polars/tests/utils/test_sorting.py diff --git a/python/cudf_polars/cudf_polars/utils/dtypes.py b/python/cudf_polars/cudf_polars/utils/dtypes.py index 7b0049daf11..3d4a643e1fc 100644 --- a/python/cudf_polars/cudf_polars/utils/dtypes.py +++ b/python/cudf_polars/cudf_polars/utils/dtypes.py @@ -70,7 +70,7 @@ def from_polars(dtype: pl.DataType) -> plc.DataType: return plc.DataType(plc.TypeId.TIMESTAMP_MICROSECONDS) elif dtype.time_unit == "ns": return plc.DataType(plc.TypeId.TIMESTAMP_NANOSECONDS) - assert dtype.time_unit is not None + assert dtype.time_unit is not None # pragma: no cover assert_never(dtype.time_unit) elif isinstance(dtype, pl.Duration): if dtype.time_unit == "ms": @@ -79,7 +79,7 @@ def from_polars(dtype: pl.DataType) -> plc.DataType: return plc.DataType(plc.TypeId.DURATION_MICROSECONDS) elif dtype.time_unit == "ns": return plc.DataType(plc.TypeId.DURATION_NANOSECONDS) - assert dtype.time_unit is not None + assert dtype.time_unit is not None # pragma: no cover assert_never(dtype.time_unit) elif isinstance(dtype, pl.String): return plc.DataType(plc.TypeId.STRING) diff --git a/python/cudf_polars/cudf_polars/utils/sorting.py b/python/cudf_polars/cudf_polars/utils/sorting.py index 24fd449dd88..57f94c4ec4c 100644 --- a/python/cudf_polars/cudf_polars/utils/sorting.py +++ b/python/cudf_polars/cudf_polars/utils/sorting.py @@ -43,8 +43,8 @@ def sort_order( for d in descending ] null_precedence = [] - # TODO: use strict=True when we drop py39 - assert len(descending) == len(nulls_last) + if len(descending) != len(nulls_last) or len(descending) != num_keys: + raise ValueError("Mismatching length of arguments in sort_order") for asc, null_last in zip(column_order, nulls_last): if (asc == plc.types.Order.ASCENDING) ^ (not null_last): null_precedence.append(plc.types.NullOrder.AFTER) diff --git a/python/cudf_polars/pyproject.toml b/python/cudf_polars/pyproject.toml index 11178a3be74..d88f5b6a1e3 100644 --- a/python/cudf_polars/pyproject.toml +++ b/python/cudf_polars/pyproject.toml @@ -52,6 +52,13 @@ version = {file = "cudf_polars/VERSION"} [tool.pytest.ini_options] xfail_strict = true +[tool.coverage.report] +exclude_also = [ + "if TYPE_CHECKING:", + "class .*\\bProtocol\\):", + "assert_never\\(" +] + [tool.ruff] line-length = 88 indent-width = 4 diff --git a/python/cudf_polars/tests/expressions/test_datetime_basic.py b/python/cudf_polars/tests/expressions/test_datetime_basic.py new file mode 100644 index 00000000000..6ba2a1dce1e --- /dev/null +++ b/python/cudf_polars/tests/expressions/test_datetime_basic.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars.testing.asserts import assert_gpu_result_equal + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Date(), + pl.Datetime("ms"), + pl.Datetime("us"), + pl.Datetime("ns"), + pl.Duration("ms"), + pl.Duration("us"), + pl.Duration("ns"), + ], + ids=repr, +) +def test_datetime_dataframe_scan(dtype): + ldf = pl.DataFrame( + { + "a": pl.Series([1, 2, 3, 4, 5, 6, 7], dtype=dtype), + "b": pl.Series([3, 4, 5, 6, 7, 8, 9], dtype=pl.UInt16), + } + ).lazy() + + query = ldf.select(pl.col("b"), pl.col("a")) + assert_gpu_result_equal(query) diff --git a/python/cudf_polars/tests/utils/test_dtypes.py b/python/cudf_polars/tests/utils/test_dtypes.py new file mode 100644 index 00000000000..535fdd846a0 --- /dev/null +++ b/python/cudf_polars/tests/utils/test_dtypes.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars.utils.dtypes import from_polars + + +@pytest.mark.parametrize( + "pltype", + [ + pl.Time(), + pl.Struct({"a": pl.Int8, "b": pl.Float32}), + pl.Datetime("ms", time_zone="US/Pacific"), + pl.Array(pl.Int8, 2), + pl.Binary(), + pl.Categorical(), + pl.Enum(["a", "b"]), + pl.Field("a", pl.Int8), + pl.Object(), + pl.Unknown(), + ], + ids=repr, +) +def test_unhandled_dtype_conversion_raises(pltype): + with pytest.raises(NotImplementedError): + _ = from_polars(pltype) diff --git a/python/cudf_polars/tests/utils/test_sorting.py b/python/cudf_polars/tests/utils/test_sorting.py new file mode 100644 index 00000000000..4e98a3a7ce7 --- /dev/null +++ b/python/cudf_polars/tests/utils/test_sorting.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +from cudf_polars.utils.sorting import sort_order + + +@pytest.mark.parametrize( + "descending,nulls_last,num_keys", + [ + ([True], [False, True], 3), + ([True, True], [False, True, False], 3), + ([False, True], [True], 3), + ], +) +def test_sort_order_raises_mismatch(descending, nulls_last, num_keys): + with pytest.raises(ValueError): + _ = sort_order(descending, nulls_last=nulls_last, num_keys=num_keys)