Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable intermediate proxies to be picklable #14752

Merged
merged 13 commits into from
Jan 20, 2024
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/_internals/where.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.

shwina marked this conversation as resolved.
Show resolved Hide resolved
import warnings
from typing import Tuple, Union
Expand Down
82 changes: 61 additions & 21 deletions python/cudf/cudf/pandas/fast_slow_proxy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

Expand All @@ -25,6 +25,11 @@

from .annotation import nvtx


def call_operator(fn, args, kwargs):
return fn(*args, **kwargs)


_CUDF_PANDAS_NVTX_COLORS = {
"COPY_SLOW_TO_FAST": 0xCA0020,
"COPY_FAST_TO_SLOW": 0xF4A582,
Expand Down Expand Up @@ -109,6 +114,7 @@ def make_final_proxy_type(
additional_attributes: Mapping[str, Any] | None = None,
postprocess: Callable[[_FinalProxy, Any, Any], Any] | None = None,
bases: Tuple = (),
picklable: bool = True,
) -> Type[_FinalProxy]:
shwina marked this conversation as resolved.
Show resolved Hide resolved
"""
Defines a fast-slow proxy type for a pair of "final" fast and slow
Expand Down Expand Up @@ -139,7 +145,8 @@ def make_final_proxy_type(
construct said unwrapped object. See also `_maybe_wrap_result`.
bases
Optional tuple of base classes to insert into the mro.

picklable: bool
Whether or not the proxy object should be picklable
Notes
-----
As a side-effect, this function adds `fast_type` and `slow_type`
Expand Down Expand Up @@ -189,22 +196,6 @@ def _fsproxy_state(self) -> _State:
else _State.SLOW
)

def __reduce__(self):
# Need a local import to avoid circular import issues
from .module_accelerator import disable_module_accelerator

with disable_module_accelerator():
pickled_wrapped_obj = pickle.dumps(self._fsproxy_wrapped)
return (_PickleConstructor(type(self)), (), pickled_wrapped_obj)

def __setstate__(self, state):
# Need a local import to avoid circular import issues
from .module_accelerator import disable_module_accelerator

with disable_module_accelerator():
unpickled_wrapped_obj = pickle.loads(state)
self._fsproxy_wrapped = unpickled_wrapped_obj

slow_dir = dir(slow_type)
cls_dict = {
"__init__": __init__,
Expand All @@ -215,9 +206,8 @@ def __setstate__(self, state):
"_fsproxy_slow_to_fast": _fsproxy_slow_to_fast,
"_fsproxy_fast_to_slow": _fsproxy_fast_to_slow,
"_fsproxy_state": _fsproxy_state,
"__reduce__": __reduce__,
"__setstate__": __setstate__,
}

if additional_attributes is None:
additional_attributes = {}
for method in _SPECIAL_METHODS:
Expand Down Expand Up @@ -257,6 +247,7 @@ def make_intermediate_proxy_type(
slow_type: type,
*,
module: Optional[str] = None,
picklable: bool = True,
) -> Type[_IntermediateProxy]:
shwina marked this conversation as resolved.
Show resolved Hide resolved
"""
Defines a proxy type for a pair of "intermediate" fast and slow
Expand All @@ -273,6 +264,8 @@ def make_intermediate_proxy_type(
The name of the class returned
fast_type: type
slow_type: type
picklable: bool
Whether or not the proxy object should be picklable
"""

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -716,6 +709,26 @@ def _fsproxy_wrap(cls, value, func):
proxy._fsproxy_wrapped = value
return proxy

def __reduce__(self):
"""
In conjunction with `__proxy_setstate__`, this effectively enables
proxy types to be pickled and unpickled by pickling and unpickling
the underlying wrapped types.
"""
from .module_accelerator import disable_module_accelerator

with disable_module_accelerator():
pickled_wrapped_obj = pickle.dumps(self._fsproxy_wrapped)
return (_PickleConstructor(type(self)), (), pickled_wrapped_obj)

def __setstate__(self, state):
# Need a local import to avoid circular import issues
from .module_accelerator import disable_module_accelerator
shwina marked this conversation as resolved.
Show resolved Hide resolved

with disable_module_accelerator():
unpickled_wrapped_obj = pickle.loads(state)
self._fsproxy_wrapped = unpickled_wrapped_obj


class _IntermediateProxy(_FastSlowProxy):
"""
Expand Down Expand Up @@ -772,6 +785,33 @@ def _fsproxy_fast_to_slow(self) -> Any:
args, kwargs = _slow_arg(args), _slow_arg(kwargs)
return func(*args, **kwargs)

def __reduce__(self):
"""
In conjunction with `__proxy_setstate__`, this effectively enables
proxy types to be pickled and unpickled by pickling and unpickling
the underlying wrapped types.
"""
from .module_accelerator import disable_module_accelerator

with disable_module_accelerator():
pickled_wrapped_obj = pickle.dumps(self._fsproxy_wrapped)
pickled_method_chain = pickle.dumps(self._method_chain)
return (
_PickleConstructor(type(self)),
(),
(pickled_wrapped_obj, pickled_method_chain),
)

def __setstate__(self, state):
# Need a local import to avoid circular import issues
from .module_accelerator import disable_module_accelerator

with disable_module_accelerator():
unpickled_wrapped_obj = pickle.loads(state[0])
unpickled_method_chain = pickle.loads(state[1])
self._fsproxy_wrapped = unpickled_wrapped_obj
self._method_chain = unpickled_method_chain


class _CallableProxyMixin:
"""
Expand All @@ -788,7 +828,7 @@ def __call__(self, *args, **kwargs) -> Any:
# _fast_slow_function_call) to avoid infinite recursion.
# TODO: When Python 3.11 is the minimum supported Python version
# this can use operator.call
lambda fn, args, kwargs: fn(*args, **kwargs),
call_operator,
self,
vyasr marked this conversation as resolved.
Show resolved Hide resolved
args,
kwargs,
Expand Down
5 changes: 3 additions & 2 deletions python/cudf/cudf/pandas/module_accelerator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

Expand Down Expand Up @@ -551,6 +551,7 @@ def getattr_real_or_wrapped(
# release the lock after reading this value)
use_real = not loader._use_fast_lib
if not use_real:
CUDF_PANDAS_PATH = __file__.rsplit("/", 1)[0]
# Only need to check the denylist if we're not turned off.
vyasr marked this conversation as resolved.
Show resolved Hide resolved
frame = sys._getframe()
# We cannot possibly be at the top level.
Expand All @@ -559,7 +560,7 @@ def getattr_real_or_wrapped(
use_real = any(
calling_module.is_relative_to(path)
for path in loader._denylist
)
) and not calling_module.is_relative_to(CUDF_PANDAS_PATH)
try:
vyasr marked this conversation as resolved.
Show resolved Hide resolved
if use_real:
return real[name]
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/testing/testing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

shwina marked this conversation as resolved.
Show resolved Hide resolved
from __future__ import annotations

Expand Down
8 changes: 8 additions & 0 deletions python/cudf/cudf_pandas_tests/test_cudf_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,3 +1233,11 @@ def test_concat_fast():
def test_func_namespace():
# note: this test is sensitive to Pandas' internal module layout
assert xpd.concat is xpd.core.reshape.concat.concat


def test_groupby_pickling(dataframe):
pdf, df = dataframe
pgb = pdf.groupby("a")
gb = df.groupby("a")
gb = pickle.loads(pickle.dumps(gb))
tm.assert_equal(pgb.sum(), gb.sum())
Loading