-
Notifications
You must be signed in to change notification settings - Fork 915
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
68 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from __future__ import annotations | ||
|
||
import pickle | ||
|
||
import pytest | ||
|
||
from polars.polars import _expr_nodes as pl_expr | ||
|
||
from cudf_polars.dsl.expressions.boolean import BooleanFunction | ||
from cudf_polars.dsl.expressions.datetime import TemporalFunction | ||
from cudf_polars.dsl.expressions.string import StringFunction | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"function", [BooleanFunction, TemporalFunction, StringFunction] | ||
) | ||
def test_function_name_serialization_all_values(function): | ||
# Test serialization and deserialization for all values of function.Name | ||
for name in function.Name: | ||
serialized_name = pickle.dumps(name) | ||
deserialized_name = pickle.loads(serialized_name) | ||
assert deserialized_name is name | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"function", [BooleanFunction, TemporalFunction, StringFunction] | ||
) | ||
def test_function_name_invalid(function): | ||
# Test invalid attribute name | ||
with pytest.raises( | ||
AttributeError, match="type object 'Name' has no attribute 'InvalidAttribute'" | ||
): | ||
assert function.Name.InvalidAttribute is function.Name.InvalidAttribute | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"function", [BooleanFunction, TemporalFunction, StringFunction] | ||
) | ||
def test_from_polars_all_names(function): | ||
# Test that all valid names of polars expressions are correctly converted | ||
for name in function.Name: | ||
polars_function = getattr(pl_expr, function.__name__) | ||
polars_function_attr = getattr(polars_function, name.name) | ||
cudf_function = function.Name.from_polars(polars_function_attr) | ||
assert cudf_function == name | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"function", [BooleanFunction, TemporalFunction, StringFunction] | ||
) | ||
def test_from_polars_invalid_attribute(function): | ||
# Test converting from invalid attribute name | ||
with pytest.raises(ValueError, match=f"{function.__name__} required"): | ||
function.Name.from_polars("InvalidAttribute") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"function", [BooleanFunction, TemporalFunction, StringFunction] | ||
) | ||
def test_from_polars_invalid_polars_attribute(function): | ||
# Test converting from polars function with invalid attribute name | ||
with pytest.raises( | ||
AttributeError, match="type object 'Name' has no attribute 'InvalidAttribute'" | ||
): | ||
function.Name.from_polars(f"{function.__name__}.InvalidAttribute") |