From 5a04207cd657cb7782ab92dff620309e06d73199 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 25 Nov 2024 01:20:31 -0800 Subject: [PATCH] Add basic serialization tests --- .../tests/dsl/test_serialization.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 python/cudf_polars/tests/dsl/test_serialization.py diff --git a/python/cudf_polars/tests/dsl/test_serialization.py b/python/cudf_polars/tests/dsl/test_serialization.py new file mode 100644 index 00000000000..f2b9480880f --- /dev/null +++ b/python/cudf_polars/tests/dsl/test_serialization.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pickle + +import pytest + +from polars.polars import _expr_nodes as pl_expr + +from cudf_polars.dsl.expressions.boolean import BooleanFunction +from cudf_polars.dsl.expressions.datetime import TemporalFunction +from cudf_polars.dsl.expressions.string import StringFunction + + +@pytest.mark.parametrize( + "function", [BooleanFunction, TemporalFunction, StringFunction] +) +def test_function_name_serialization_all_values(function): + # Test serialization and deserialization for all values of function.Name + for name in function.Name: + serialized_name = pickle.dumps(name) + deserialized_name = pickle.loads(serialized_name) + assert deserialized_name is name + + +@pytest.mark.parametrize( + "function", [BooleanFunction, TemporalFunction, StringFunction] +) +def test_function_name_invalid(function): + # Test invalid attribute name + with pytest.raises( + AttributeError, match="type object 'Name' has no attribute 'InvalidAttribute'" + ): + assert function.Name.InvalidAttribute is function.Name.InvalidAttribute + + +@pytest.mark.parametrize( + "function", [BooleanFunction, TemporalFunction, StringFunction] +) +def test_from_polars_all_names(function): + # Test that all valid names of polars expressions are correctly converted + for name in function.Name: + polars_function = getattr(pl_expr, function.__name__) + polars_function_attr = getattr(polars_function, name.name) + cudf_function = function.Name.from_polars(polars_function_attr) + assert cudf_function == name + + +@pytest.mark.parametrize( + "function", [BooleanFunction, TemporalFunction, StringFunction] +) +def test_from_polars_invalid_attribute(function): + # Test converting from invalid attribute name + with pytest.raises(ValueError, match=f"{function.__name__} required"): + function.Name.from_polars("InvalidAttribute") + + +@pytest.mark.parametrize( + "function", [BooleanFunction, TemporalFunction, StringFunction] +) +def test_from_polars_invalid_polars_attribute(function): + # Test converting from polars function with invalid attribute name + with pytest.raises( + AttributeError, match="type object 'Name' has no attribute 'InvalidAttribute'" + ): + function.Name.from_polars(f"{function.__name__}.InvalidAttribute")