diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 763e9d6f..239c2b51 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -637,6 +637,67 @@ def __reduce__(cls): return cls.__name__ +def _noop(x): + """Identity function""" + return x + + +class _FuncMetadataGlobals: + """Extracts the base metadata of func.__globals__ + + This wrapper makes it possible to customize the serialization of the + function globals by subclassing the pickler and overriding persistend_id / + persistent_load: + + https://docs.python.org/3/library/pickle.html#persistence-of-external-objects + """ + + _common_func_keys = ("__package__", "__name__", "__path__", "__file__") + + def __init__(self, func, shared_namespace): + self.func = func + func_globals = func.__globals__ + if not shared_namespace: + # The shared name space is empty, meaning that it's the first time + # this function is pickled by the CloudPickler instance: let's + # populate it by the base globals of the function. + shared_namespace.update({ + k: func_globals[k] for k in self._common_func_keys + if k in func_globals + }) + self.func_metadata_globals = shared_namespace + + def __reduce__(self): + # By default, only pickle the core meta-data information of the globals + # dict of the function. The actual symbols referenced in func.__code__ + # are pickled separately in _FilteredFuncGlobals. + return _noop, (self.func_metadata_globals,) + + +class _FuncCodeGlobals: + """Extracts entries of func.__globals__ actually referenced in func.__code__ + + This wrapper makes it possible to customize the serialization of the + globals by subclassing the pickler and overriding persistend_id / + persistent_load: + + https://docs.python.org/3/library/pickle.html#persistence-of-external-objects + """ + + def __init__(self, func): + self.func = func + code_global_names = _extract_code_globals(func.__code__) + func_globals = func.__globals__ + self.func_code_globals = { + k: func_globals[k] for k in code_global_names if k in func_globals + } + + def __reduce__(self): + # By default, only pickle the Python ojects actually referenced by the + # code of the function. + return _noop, (self.func_code_globals,) + + def _fill_function(*args): """Fills in the rest of function data into the skeleton function object diff --git a/cloudpickle/cloudpickle_fast.py b/cloudpickle/cloudpickle_fast.py index 10ceef1b..b4c7b576 100644 --- a/cloudpickle/cloudpickle_fast.py +++ b/cloudpickle/cloudpickle_fast.py @@ -36,6 +36,7 @@ parametrized_type_hint_getinitargs, _create_parametrized_type_hint, builtin_code_type, _make_dict_keys, _make_dict_values, _make_dict_items, + _FuncMetadataGlobals, _FuncCodeGlobals, ) @@ -153,11 +154,7 @@ def _function_getstate(func): "__doc__": func.__doc__, "__closure__": func.__closure__, } - - f_globals_ref = _extract_code_globals(func.__code__) - f_globals = {k: func.__globals__[k] for k in f_globals_ref if k in - func.__globals__} - + f_globals = _FuncCodeGlobals(func) closure_values = ( list(map(_get_cell_contents, func.__closure__)) if func.__closure__ is not None else () @@ -168,7 +165,12 @@ def _function_getstate(func): # trigger the side effect of importing these modules at unpickling time # (which is necessary for func to work correctly once depickled) slotstate["_cloudpickle_submodules"] = _find_imported_submodules( - func.__code__, itertools.chain(f_globals.values(), closure_values)) + func.__code__, + itertools.chain( + f_globals.func_code_globals.values(), + closure_values, + ), + ) slotstate["__globals__"] = f_globals state = func.__dict__ @@ -577,15 +579,9 @@ def _function_getnewargs(self, func): # same invocation of cloudpickle.dump/cloudpickle.dumps (for example: # cloudpickle.dumps([f1, f2])). There is no such limitation when using # CloudPickler.dump, as long as the multiple invocations are bound to - # the same CloudPickler. - base_globals = self.globals_ref.setdefault(id(func.__globals__), {}) - - if base_globals == {}: - # Add module attributes used to resolve relative imports - # instructions inside func. - for k in ["__package__", "__name__", "__path__", "__file__"]: - if k in func.__globals__: - base_globals[k] = func.__globals__[k] + # the same CloudPickler instance. + shared_ns = self.globals_ref.setdefault(id(func.__globals__), {}) + base_globals = _FuncMetadataGlobals(func, shared_namespace=shared_ns) # Do not bind the free variables before the function is created to # avoid infinite recursion. diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index baca23cc..00d529a8 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -49,6 +49,7 @@ from cloudpickle.cloudpickle import _make_empty_cell, cell_set from cloudpickle.cloudpickle import _extract_class_dict, _whichmodule from cloudpickle.cloudpickle import _lookup_module_and_qualname +from cloudpickle.cloudpickle import _FuncMetadataGlobals, _FuncCodeGlobals from .testutils import subprocess_pickle_echo from .testutils import subprocess_pickle_string @@ -2377,6 +2378,45 @@ def func_with_globals(): "Expected a single deterministic payload, got %d/5" % len(vals) ) + def test_externally_managed_function_globals(self): + common_globals = {"a": "foo"} + + class CustomPickler(cloudpickle.CloudPickler): + @staticmethod + def persistent_id(obj): + if ( + isinstance(obj, _FuncMetadataGlobals) + and obj.func.__globals__ is common_globals + ): + return "common_globals" + elif ( + isinstance(obj, _FuncCodeGlobals) + and obj.func.__globals__ is common_globals + ): + return "empty_dict" + + class CustomUnpickler(pickle.Unpickler): + @staticmethod + def persistent_load(pid): + return { + "common_globals": common_globals, + "empty_dict": {} + }[pid] + + lookup_a = eval('lambda: a', common_globals) + assert lookup_a() == "foo" + + file = io.BytesIO() + CustomPickler(file).dump(lookup_a) + dumped = file.getvalue() + assert b'foo' not in dumped + + lookup_a_cloned = CustomUnpickler(io.BytesIO(dumped)).load() + assert lookup_a_cloned() == "foo" + assert lookup_a_cloned.__globals__ is common_globals + common_globals['a'] = 'bar' + assert lookup_a_cloned() == "bar" + class Protocol2CloudPickleTest(CloudPickleTest):