Skip to content

Commit

Permalink
Enable intermediate proxies to be picklable (#14752)
Browse files Browse the repository at this point in the history
Closes #14738

Enables intermediate proxy types to be pickled, same as final proxy types.

Authors:
  - Ashwin Srinath (https://github.com/shwina)
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #14752
  • Loading branch information
shwina authored Jan 20, 2024
1 parent eb850fa commit 1994280
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 22 deletions.
77 changes: 57 additions & 20 deletions python/cudf/cudf/pandas/fast_slow_proxy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

Expand All @@ -25,6 +25,11 @@

from .annotation import nvtx


def call_operator(fn, args, kwargs):
return fn(*args, **kwargs)


_CUDF_PANDAS_NVTX_COLORS = {
"COPY_SLOW_TO_FAST": 0xCA0020,
"COPY_FAST_TO_SLOW": 0xF4A582,
Expand Down Expand Up @@ -189,22 +194,6 @@ def _fsproxy_state(self) -> _State:
else _State.SLOW
)

def __reduce__(self):
# Need a local import to avoid circular import issues
from .module_accelerator import disable_module_accelerator

with disable_module_accelerator():
pickled_wrapped_obj = pickle.dumps(self._fsproxy_wrapped)
return (_PickleConstructor(type(self)), (), pickled_wrapped_obj)

def __setstate__(self, state):
# Need a local import to avoid circular import issues
from .module_accelerator import disable_module_accelerator

with disable_module_accelerator():
unpickled_wrapped_obj = pickle.loads(state)
self._fsproxy_wrapped = unpickled_wrapped_obj

slow_dir = dir(slow_type)
cls_dict = {
"__init__": __init__,
Expand All @@ -215,9 +204,8 @@ def __setstate__(self, state):
"_fsproxy_slow_to_fast": _fsproxy_slow_to_fast,
"_fsproxy_fast_to_slow": _fsproxy_fast_to_slow,
"_fsproxy_state": _fsproxy_state,
"__reduce__": __reduce__,
"__setstate__": __setstate__,
}

if additional_attributes is None:
additional_attributes = {}
for method in _SPECIAL_METHODS:
Expand Down Expand Up @@ -716,6 +704,27 @@ def _fsproxy_wrap(cls, value, func):
proxy._fsproxy_wrapped = value
return proxy

def __reduce__(self):
"""
In conjunction with `__proxy_setstate__`, this effectively enables
proxy types to be pickled and unpickled by pickling and unpickling
the underlying wrapped types.
"""
# Need a local import to avoid circular import issues
from .module_accelerator import disable_module_accelerator

with disable_module_accelerator():
pickled_wrapped_obj = pickle.dumps(self._fsproxy_wrapped)
return (_PickleConstructor(type(self)), (), pickled_wrapped_obj)

def __setstate__(self, state):
# Need a local import to avoid circular import issues
from .module_accelerator import disable_module_accelerator

with disable_module_accelerator():
unpickled_wrapped_obj = pickle.loads(state)
self._fsproxy_wrapped = unpickled_wrapped_obj


class _IntermediateProxy(_FastSlowProxy):
"""
Expand Down Expand Up @@ -772,6 +781,34 @@ def _fsproxy_fast_to_slow(self) -> Any:
args, kwargs = _slow_arg(args), _slow_arg(kwargs)
return func(*args, **kwargs)

def __reduce__(self):
"""
In conjunction with `__proxy_setstate__`, this effectively enables
proxy types to be pickled and unpickled by pickling and unpickling
the underlying wrapped types.
"""
# Need a local import to avoid circular import issues
from .module_accelerator import disable_module_accelerator

with disable_module_accelerator():
pickled_wrapped_obj = pickle.dumps(self._fsproxy_wrapped)
pickled_method_chain = pickle.dumps(self._method_chain)
return (
_PickleConstructor(type(self)),
(),
(pickled_wrapped_obj, pickled_method_chain),
)

def __setstate__(self, state):
# Need a local import to avoid circular import issues
from .module_accelerator import disable_module_accelerator

with disable_module_accelerator():
unpickled_wrapped_obj = pickle.loads(state[0])
unpickled_method_chain = pickle.loads(state[1])
self._fsproxy_wrapped = unpickled_wrapped_obj
self._method_chain = unpickled_method_chain


class _CallableProxyMixin:
"""
Expand All @@ -788,7 +825,7 @@ def __call__(self, *args, **kwargs) -> Any:
# _fast_slow_function_call) to avoid infinite recursion.
# TODO: When Python 3.11 is the minimum supported Python version
# this can use operator.call
lambda fn, args, kwargs: fn(*args, **kwargs),
call_operator,
self,
args,
kwargs,
Expand Down
7 changes: 5 additions & 2 deletions python/cudf/cudf/pandas/module_accelerator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

Expand Down Expand Up @@ -551,12 +551,15 @@ def getattr_real_or_wrapped(
# release the lock after reading this value)
use_real = not loader._use_fast_lib
if not use_real:
CUDF_PANDAS_PATH = __file__.rsplit("/", 1)[0]
# Only need to check the denylist if we're not turned off.
frame = sys._getframe()
# We cannot possibly be at the top level.
assert frame.f_back
calling_module = pathlib.PurePath(frame.f_back.f_code.co_filename)
use_real = any(
use_real = not calling_module.is_relative_to(
CUDF_PANDAS_PATH
) and any(
calling_module.is_relative_to(path)
for path in loader._denylist
)
Expand Down
8 changes: 8 additions & 0 deletions python/cudf/cudf_pandas_tests/test_cudf_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,6 +1235,14 @@ def test_func_namespace():
assert xpd.concat is xpd.core.reshape.concat.concat


def test_pickle_groupby(dataframe):
pdf, df = dataframe
pgb = pdf.groupby("a")
gb = df.groupby("a")
gb = pickle.loads(pickle.dumps(gb))
tm.assert_equal(pgb.sum(), gb.sum())


def test_isinstance_base_offset():
offset = xpd.tseries.frequencies.to_offset("1s")
assert isinstance(offset, xpd.tseries.offsets.BaseOffset)

0 comments on commit 1994280

Please sign in to comment.