Skip to content

Commit

Permalink
WIP refactor func globals to external delagation possible
Browse files Browse the repository at this point in the history
  • Loading branch information
ogrisel committed Jul 1, 2021
1 parent 0c62ae0 commit 35f4997
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 14 deletions.
61 changes: 61 additions & 0 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,67 @@ def __reduce__(cls):
return cls.__name__


def _noop(x):
"""Identity function"""
return x


class _FuncMetadataGlobals:
"""Extracts the base metadata of func.__globals__
This wrapper makes it possible to customize the serialization of the
function globals by subclassing the pickler and overriding persistend_id /
persistent_load:
https://docs.python.org/3/library/pickle.html#persistence-of-external-objects
"""

_common_func_keys = ("__package__", "__name__", "__path__", "__file__")

def __init__(self, func, shared_namespace):
self.func = func
func_globals = func.__globals__
if not shared_namespace:
# The shared name space is empty, meaning that it's the first time
# this function is pickled by the CloudPickler instance: let's
# populate it by the base globals of the function.
shared_namespace.update({
k: func_globals[k] for k in self._common_func_keys
if k in func_globals
})
self.func_metadata_globals = shared_namespace

def __reduce__(self):
# By default, only pickle the core meta-data information of the globals
# dict of the function. The actual symbols referenced in func.__code__
# are pickled separately in _FilteredFuncGlobals.
return _noop, (self.func_metadata_globals,)


class _FuncCodeGlobals:
"""Extracts entries of func.__globals__ actually referenced in func.__code__
This wrapper makes it possible to customize the serialization of the
globals by subclassing the pickler and overriding persistend_id /
persistent_load:
https://docs.python.org/3/library/pickle.html#persistence-of-external-objects
"""

def __init__(self, func):
self.func = func
code_global_names = _extract_code_globals(func.__code__)
func_globals = func.__globals__
self.func_code_globals = {
k: func_globals[k] for k in code_global_names if k in func_globals
}

def __reduce__(self):
# By default, only pickle the Python ojects actually referenced by the
# code of the function.
return _noop, (self.func_code_globals,)


def _fill_function(*args):
"""Fills in the rest of function data into the skeleton function object
Expand Down
27 changes: 13 additions & 14 deletions cloudpickle/cloudpickle_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
parametrized_type_hint_getinitargs, _create_parametrized_type_hint,
builtin_code_type,
_make_dict_keys, _make_dict_values, _make_dict_items,
_FuncMetadataGlobals, _FuncCodeGlobals,
)


Expand Down Expand Up @@ -153,11 +154,7 @@ def _function_getstate(func):
"__doc__": func.__doc__,
"__closure__": func.__closure__,
}

f_globals_ref = _extract_code_globals(func.__code__)
f_globals = {k: func.__globals__[k] for k in f_globals_ref if k in
func.__globals__}

f_globals = _FuncCodeGlobals(func)
closure_values = (
list(map(_get_cell_contents, func.__closure__))
if func.__closure__ is not None else ()
Expand All @@ -168,7 +165,12 @@ def _function_getstate(func):
# trigger the side effect of importing these modules at unpickling time
# (which is necessary for func to work correctly once depickled)
slotstate["_cloudpickle_submodules"] = _find_imported_submodules(
func.__code__, itertools.chain(f_globals.values(), closure_values))
func.__code__,
itertools.chain(
f_globals.func_code_globals.values(),
closure_values,
),
)
slotstate["__globals__"] = f_globals

state = func.__dict__
Expand Down Expand Up @@ -577,15 +579,12 @@ def _function_getnewargs(self, func):
# same invocation of cloudpickle.dump/cloudpickle.dumps (for example:
# cloudpickle.dumps([f1, f2])). There is no such limitation when using
# CloudPickler.dump, as long as the multiple invocations are bound to
# the same CloudPickler.
# the same CloudPickler instance.
base_globals = self.globals_ref.setdefault(id(func.__globals__), {})

if base_globals == {}:
# Add module attributes used to resolve relative imports
# instructions inside func.
for k in ["__package__", "__name__", "__path__", "__file__"]:
if k in func.__globals__:
base_globals[k] = func.__globals__[k]
base_globals = _FuncMetadataGlobals(
func,
shared_namespace=base_globals,
)

# Do not bind the free variables before the function is created to
# avoid infinite recursion.
Expand Down
40 changes: 40 additions & 0 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from cloudpickle.cloudpickle import _make_empty_cell, cell_set
from cloudpickle.cloudpickle import _extract_class_dict, _whichmodule
from cloudpickle.cloudpickle import _lookup_module_and_qualname
from cloudpickle.cloudpickle import _FuncMetadataGlobals, _FuncCodeGlobals

from .testutils import subprocess_pickle_echo
from .testutils import subprocess_pickle_string
Expand Down Expand Up @@ -2377,6 +2378,45 @@ def func_with_globals():
"Expected a single deterministic payload, got %d/5" % len(vals)
)

def test_externally_managed_function_globals(self):
common_globals = {"a": "foo"}

class CustomPickler(cloudpickle.CloudPickler):
@staticmethod
def persistent_id(obj):
if (
isinstance(obj, _FuncMetadataGlobals)
and obj.func.__globals__ is common_globals
):
return "common_globals"
elif (
isinstance(obj, _FuncCodeGlobals)
and obj.func.__globals__ is common_globals
):
return "empty_dict"

class CustomUnpickler(pickle.Unpickler):
@staticmethod
def persistent_load(pid):
return {
"common_globals": common_globals,
"empty_dict": {}
}[pid]

lookup_a = eval('lambda: a', common_globals)
assert lookup_a() == "foo"

file = io.BytesIO()
CustomPickler(file).dump(lookup_a)
dumped = file.getvalue()
assert b'foo' not in dumped

lookup_a_cloned = CustomUnpickler(io.BytesIO(dumped)).load()
assert lookup_a_cloned() == "foo"
assert lookup_a_cloned.__globals__ is common_globals
common_globals['a'] = 'bar'
assert lookup_a_cloned() == "bar"


class Protocol2CloudPickleTest(CloudPickleTest):

Expand Down

0 comments on commit 35f4997

Please sign in to comment.