Skip to content

Commit

Permalink
CLN de-duplicate complex utilities functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pierreglaser committed Mar 27, 2019
1 parent f61c8d4 commit a34c11c
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 189 deletions.
171 changes: 92 additions & 79 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,90 @@ def _factory():
object.__new__: _get_object_new,
}

_extract_code_globals_cache = (
weakref.WeakKeyDictionary()
if not hasattr(sys, "pypy_version_info")
else {})


def extract_code_globals(co):
"""
Find all globals names read or written to by codeblock co
"""
out_names = _extract_code_globals_cache.get(co)
if out_names is None:
try:
names = co.co_names
except AttributeError:
# PyPy "builtin-code" object
out_names = set()
else:
out_names = {names[oparg] for _, oparg in _walk_global_ops(co)}

# Declaring a function inside another one using the "def ..."
# syntax generates a constant code object corresonding to the one
# of the nested function's As the nested function may itself need
# global variables, we need to introspect its code, extract its
# globals, (look for code object in it's co_consts attribute..) and
# add the result to code_globals
if co.co_consts:
for const in co.co_consts:
if isinstance(const, types.CodeType):
out_names |= extract_code_globals(const)

_extract_code_globals_cache[co] = out_names

return out_names


def _find_loaded_submodules(code, top_level_dependencies):
"""
Save submodules used by a function but not listed in its globals.
In the example below:
```
import concurrent.futures
import cloudpickle
def func():
x = concurrent.futures.ThreadPoolExecutor
if __name__ == '__main__':
cloudpickle.dumps(func)
```
the globals extracted by cloudpickle in the function's state include
the concurrent module, but not its submodule (here,
concurrent.futures), which is the module used by func.
To ensure that calling the depickled function does not raise an
AttributeError, this function looks for any currently loaded submodule
that the function uses and whose parent is present in the function
globals, and saves it before saving the function.
"""

subimports = []
# check if any known dependency is an imported package
for x in top_level_dependencies:
if (isinstance(x, types.ModuleType) and
hasattr(x, '__package__') and x.__package__):
# check if the package has any currently loaded sub-imports
prefix = x.__name__ + '.'
# A concurrent thread could mutate sys.modules,
# make sure we iterate over a copy to avoid exceptions
for name in list(sys.modules):
# Older versions of pytest will add a "None" module to
# sys.modules.
if name is not None and name.startswith(prefix):
# check whether the function can address the sub-module
tokens = set(name[len(prefix):].split('.'))
if not tokens - set(code.co_names):
subimports.append(sys.modules[name])
return subimports


if sys.version_info < (3, 4): # pragma: no branch
def _walk_global_ops(code):
Expand Down Expand Up @@ -413,53 +497,6 @@ def save_function(self, obj, name=None):

dispatch[types.FunctionType] = save_function

def _save_subimports(self, code, top_level_dependencies):
"""
Save submodules used by a function but not listed in its globals.
In the example below:
```
import concurrent.futures
import cloudpickle
def func():
x = concurrent.futures.ThreadPoolExecutor
if __name__ == '__main__':
cloudpickle.dumps(func)
```
the globals extracted by cloudpickle in the function's state include
the concurrent module, but not its submodule (here,
concurrent.futures), which is the module used by func.
To ensure that calling the depickled function does not raise an
AttributeError, this function looks for any currently loaded submodule
that the function uses and whose parent is present in the function
globals, and saves it before saving the function.
"""

# check if any known dependency is an imported package
for x in top_level_dependencies:
if isinstance(x, types.ModuleType) and hasattr(x, '__package__') and x.__package__:
# check if the package has any currently loaded sub-imports
prefix = x.__name__ + '.'
# A concurrent thread could mutate sys.modules,
# make sure we iterate over a copy to avoid exceptions
for name in list(sys.modules):
# Older versions of pytest will add a "None" module to sys.modules.
if name is not None and name.startswith(prefix):
# check whether the function can address the sub-module
tokens = set(name[len(prefix):].split('.'))
if not tokens - set(code.co_names):
# ensure unpickler executes this import
self.save(sys.modules[name])
# then discards the reference to it
self.write(pickle.POP)

def save_dynamic_class(self, obj):
"""
Save a class that can't be stored as module global.
Expand Down Expand Up @@ -562,10 +599,16 @@ def save_function_tuple(self, func):
save(_fill_function) # skeleton function updater
write(pickle.MARK) # beginning of tuple that _fill_function expects

self._save_subimports(
subimports = _find_loaded_submodules(
code,
itertools.chain(f_globals.values(), closure_values or ()),
)
for s in subimports:
# ensure that subimport s is loaded at unpickling time
self.save(s)
# then discards the reference to it
self.write(pickle.POP)


# create a skeleton function object and memoize it
save(_make_skel_func)
Expand Down Expand Up @@ -595,36 +638,6 @@ def save_function_tuple(self, func):
write(pickle.TUPLE)
write(pickle.REDUCE) # applies _fill_function on the tuple

_extract_code_globals_cache = (
weakref.WeakKeyDictionary()
if not hasattr(sys, "pypy_version_info")
else {})

@classmethod
def extract_code_globals(cls, co):
"""
Find all globals names read or written to by codeblock co
"""
out_names = cls._extract_code_globals_cache.get(co)
if out_names is None:
try:
names = co.co_names
except AttributeError:
# PyPy "builtin-code" object
out_names = set()
else:
out_names = {names[oparg] for _, oparg in _walk_global_ops(co)}

# see if nested function have any global refs
if co.co_consts:
for const in co.co_consts:
if type(const) is types.CodeType:
out_names |= cls.extract_code_globals(const)

cls._extract_code_globals_cache[co] = out_names

return out_names

def extract_func_data(self, func):
"""
Turn the function into a tuple of data necessary to recreate it:
Expand All @@ -633,7 +646,7 @@ def extract_func_data(self, func):
code = func.__code__

# extract all global ref's
func_global_refs = self.extract_code_globals(code)
func_global_refs = extract_code_globals(code)

# process all variables referenced by global environment
f_globals = {}
Expand Down
123 changes: 13 additions & 110 deletions cloudpickle/cloudpickle_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import abc
import dis
import io
import itertools
import logging
import opcode
import _pickle
Expand All @@ -19,8 +20,9 @@
from _pickle import Pickler

from .cloudpickle import (
islambda, _is_dynamic, GLOBAL_OPS, _BUILTIN_TYPE_CONSTRUCTORS,
_BUILTIN_TYPE_NAMES, DEFAULT_PROTOCOL
islambda, _is_dynamic, extract_code_globals, GLOBAL_OPS,
_BUILTIN_TYPE_CONSTRUCTORS, _BUILTIN_TYPE_NAMES, DEFAULT_PROTOCOL,
_find_loaded_submodules, _get_cell_contents
)

load, loads = _pickle.load, _pickle.loads
Expand Down Expand Up @@ -60,111 +62,6 @@ def dumps(obj, protocol=None):
file.close()


# Utility functions introspecting objects to extract useful properties about
# them.
def _find_loaded_submodules(globals, closure, co_names):
"""
Find submodules used by a function but not listed in its globals.
In the example below:
```
import xml.etree
import cloudpickle
def func():
x = xml.etree.ElementTree
if __name__ == '__main__':
cloudpickle.dumps(func)
```
the expression xml.etree.ElementTree generates a LOAD_GLOBAL for xml, but
simply LOAD_ATTR for etree and ElementTree - cloudpickle cannot detect
such submodules by bytecode inspection. There is actually no exact way of
detecting them, the method below is simply "good enough". For instance:
import xml.etree
def f():
def g():
return xml.etree
return g
pickling f and trying to call f()() will raise a NameError
"""

referenced_submodules = {}
top_level_dependencies = list(globals.values())
for cell in closure:
try:
top_level_dependencies.append(cell.cell_contents)
except ValueError:
continue

# top_level_dependencies are variables that generated a LOAD_GlOBAL or a
# LOAD_DEREF opcode in code.
for x in top_level_dependencies:
if (
isinstance(x, types.ModuleType)
and getattr(x, "__package__", None) is not None
):
# check if the package has any currently loaded sub-imports
prefix = x.__name__ + "."
# A concurrent thread could mutate sys.modules,
# make sure we iterate over a copy to avoid exceptions
for name in list(sys.modules):
# Older versions of pytest will add a "None" module to
# sys.modules.
if name is not None and name.startswith(prefix):
# check whether the function can address the sub-module
tokens = set(name[len(prefix) :].split("."))
if not tokens - set(co_names):
# ensure unpickler executes this import
referenced_submodules[name] = sys.modules[name]
return referenced_submodules


_extract_code_globals_cache = (
weakref.WeakKeyDictionary()
if not hasattr(sys, "pypy_version_info")
else {}
)


def extract_code_globals(code, globals_):
"""
Find all globals names read or written to by codeblock co
"""
code_globals = _extract_code_globals_cache.get(code)
if code_globals is None:
code_globals = {}
# PyPy "builtin-code" do not have this structure
if hasattr(code, "co_names"):
# first, find, potential submodules that are hard to identify
instructions = dis.get_instructions(code)
for ins in instructions:
varname = ins.argval
if ins.opcode in GLOBAL_OPS and varname in globals_:
code_globals[varname] = globals_[varname]

# Declaring a function inside another one using the "def ..."
# syntax generates a constant code object corresonding to the one
# of the nested function's As the nested function may itself need
# global variables, we need to introspect its code, extract its
# globals, (look for code object in it's co_consts attribute..) and
# add the result to code_globals
if code.co_consts:
for c in code.co_consts:
if isinstance(c, types.CodeType):
code_globals.update(extract_code_globals(c, globals_))

return code_globals


# COLLECTION OF OBJECTS __getnewargs__-LIKE METHODS
# -------------------------------------------------

Expand Down Expand Up @@ -254,12 +151,18 @@ def function_getstate(func):
"__closure__": func.__closure__,
}

f_globals = extract_code_globals(func.__code__, func.__globals__)
f_globals_ref = extract_code_globals(func.__code__)
f_globals = {k: func.__globals__[k] for k in f_globals_ref if k in
func.__globals__}

closure_values = (
list(map(_get_cell_contents, func.__closure__))
if func.__closure__ is not None else ()
)

# extract submodules referenced by attribute lookup (no global opcode)
f_globals["__submodules__"] = _find_loaded_submodules(
f_globals, slotstate["__closure__"] or (), func.__code__.co_names
)
func.__code__, itertools.chain(f_globals.values(), closure_values))
slotstate["__globals__"] = f_globals

state = func.__dict__
Expand Down

0 comments on commit a34c11c

Please sign in to comment.