Skip to content

Commit

Permalink
Add basic serialization tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Nov 25, 2024
1 parent 9b54437 commit 5a04207
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions python/cudf_polars/tests/dsl/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pickle

import pytest

from polars.polars import _expr_nodes as pl_expr

from cudf_polars.dsl.expressions.boolean import BooleanFunction
from cudf_polars.dsl.expressions.datetime import TemporalFunction
from cudf_polars.dsl.expressions.string import StringFunction


@pytest.mark.parametrize(
"function", [BooleanFunction, TemporalFunction, StringFunction]
)
def test_function_name_serialization_all_values(function):
# Test serialization and deserialization for all values of function.Name
for name in function.Name:
serialized_name = pickle.dumps(name)
deserialized_name = pickle.loads(serialized_name)
assert deserialized_name is name


@pytest.mark.parametrize(
"function", [BooleanFunction, TemporalFunction, StringFunction]
)
def test_function_name_invalid(function):
# Test invalid attribute name
with pytest.raises(
AttributeError, match="type object 'Name' has no attribute 'InvalidAttribute'"
):
assert function.Name.InvalidAttribute is function.Name.InvalidAttribute


@pytest.mark.parametrize(
"function", [BooleanFunction, TemporalFunction, StringFunction]
)
def test_from_polars_all_names(function):
# Test that all valid names of polars expressions are correctly converted
for name in function.Name:
polars_function = getattr(pl_expr, function.__name__)
polars_function_attr = getattr(polars_function, name.name)
cudf_function = function.Name.from_polars(polars_function_attr)
assert cudf_function == name


@pytest.mark.parametrize(
"function", [BooleanFunction, TemporalFunction, StringFunction]
)
def test_from_polars_invalid_attribute(function):
# Test converting from invalid attribute name
with pytest.raises(ValueError, match=f"{function.__name__} required"):
function.Name.from_polars("InvalidAttribute")


@pytest.mark.parametrize(
"function", [BooleanFunction, TemporalFunction, StringFunction]
)
def test_from_polars_invalid_polars_attribute(function):
# Test converting from polars function with invalid attribute name
with pytest.raises(
AttributeError, match="type object 'Name' has no attribute 'InvalidAttribute'"
):
function.Name.from_polars(f"{function.__name__}.InvalidAttribute")

0 comments on commit 5a04207

Please sign in to comment.