Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract polars function expression nodes to ensure they are serializable #17418

Merged
merged 14 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/expressions/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ class Name(IntEnum):
@classmethod
def from_polars(cls, obj: pl_expr.BooleanFunction) -> Self:
"""Convert from polars' `BooleanFunction`."""
function, name = str(obj).split(".", maxsplit=1)
try:
function, name = str(obj).split(".", maxsplit=1)
except ValueError:
# Failed to unpack string
function = None
if function != "BooleanFunction":
raise ValueError("BooleanFunction required")
return getattr(cls, name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ class Name(IntEnum):
@classmethod
def from_polars(cls, obj: pl_expr.TemporalFunction) -> Self:
"""Convert from polars' `TemporalFunction`."""
function, name = str(obj).split(".", maxsplit=1)
try:
function, name = str(obj).split(".", maxsplit=1)
except ValueError:
# Failed to unpack string
function = None
if function != "TemporalFunction":
raise ValueError("TemporalFunction required")
return getattr(cls, name)
Expand Down
6 changes: 5 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/expressions/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ class Name(IntEnum):
@classmethod
def from_polars(cls, obj: pl_expr.StringFunction) -> Self:
"""Convert from polars' `StringFunction`."""
function, name = str(obj).split(".", maxsplit=1)
try:
function, name = str(obj).split(".", maxsplit=1)
except ValueError:
# Failed to unpack string
function = None
if function != "StringFunction":
raise ValueError("StringFunction required")
return getattr(cls, name)
Expand Down
68 changes: 68 additions & 0 deletions python/cudf_polars/tests/dsl/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pickle

import pytest

from polars.polars import _expr_nodes as pl_expr

from cudf_polars.dsl.expressions.boolean import BooleanFunction
from cudf_polars.dsl.expressions.datetime import TemporalFunction
from cudf_polars.dsl.expressions.string import StringFunction


@pytest.mark.parametrize(
"function", [BooleanFunction, TemporalFunction, StringFunction]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's abstract this into a reusable fixture:

@pytest.fixture(params=["BooleanFunction", "StringFunction", "TemporalFunction"])
def function(request):
    return request.param

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, done in fcf820f . I've kept the module references instead of using strings and resolved them in the tests with __name__, let me know if you have a strong preference for strings instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that's fine.

def test_function_name_serialization_all_values(function):
# Test serialization and deserialization for all values of function.Name
for name in function.Name:
serialized_name = pickle.dumps(name)
deserialized_name = pickle.loads(serialized_name)
assert deserialized_name is name


@pytest.mark.parametrize(
"function", [BooleanFunction, TemporalFunction, StringFunction]
)
def test_function_name_invalid(function):
# Test invalid attribute name
with pytest.raises(
AttributeError, match="type object 'Name' has no attribute 'InvalidAttribute'"
):
assert function.Name.InvalidAttribute is function.Name.InvalidAttribute


@pytest.mark.parametrize(
"function", [BooleanFunction, TemporalFunction, StringFunction]
)
def test_from_polars_all_names(function):
# Test that all valid names of polars expressions are correctly converted
for name in function.Name:
polars_function = getattr(pl_expr, function.__name__)
polars_function_attr = getattr(polars_function, name.name)
cudf_function = function.Name.from_polars(polars_function_attr)
assert cudf_function == name
wence- marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
"function", [BooleanFunction, TemporalFunction, StringFunction]
)
def test_from_polars_invalid_attribute(function):
# Test converting from invalid attribute name
with pytest.raises(ValueError, match=f"{function.__name__} required"):
function.Name.from_polars("InvalidAttribute")


@pytest.mark.parametrize(
"function", [BooleanFunction, TemporalFunction, StringFunction]
)
def test_from_polars_invalid_polars_attribute(function):
# Test converting from polars function with invalid attribute name
with pytest.raises(
AttributeError, match="type object 'Name' has no attribute 'InvalidAttribute'"
):
function.Name.from_polars(f"{function.__name__}.InvalidAttribute")
wence- marked this conversation as resolved.
Show resolved Hide resolved
Loading