Skip to content

Commit

Permalink
Fix pickling dataclasses (#245)
Browse files Browse the repository at this point in the history
  • Loading branch information
pcmoritz authored and ogrisel committed Feb 6, 2019
1 parent 54463b6 commit 7b31ced
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 11 deletions.
7 changes: 7 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
0.8.0
=====

- Add support for pickling interactively defined dataclasses.
([issue #245](https://github.com/cloudpipe/cloudpickle/pull/245))


0.7.0
=====

Expand Down
28 changes: 17 additions & 11 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL


if sys.version < '3':
if sys.version_info[0] < 3: # pragma: no branch
from pickle import Pickler
try:
from cStringIO import StringIO
Expand Down Expand Up @@ -128,7 +128,7 @@ def inner(value):
# NOTE: we are marking the cell variable as a free variable intentionally
# so that we simulate an inner function instead of the outer function. This
# is what gives us the ``nonlocal`` behavior in a Python 2 compatible way.
if not PY3:
if not PY3: # pragma: no branch
return types.CodeType(
co.co_argcount,
co.co_nlocals,
Expand Down Expand Up @@ -229,14 +229,14 @@ def _factory():
}


if sys.version_info < (3, 4):
if sys.version_info < (3, 4): # pragma: no branch
def _walk_global_ops(code):
"""
Yield (opcode, argument number) tuples for all
global-referencing instructions in *code*.
"""
code = getattr(code, 'co_code', b'')
if not PY3:
if not PY3: # pragma: no branch
code = map(ord, code)

n = len(code)
Expand Down Expand Up @@ -293,7 +293,7 @@ def save_memoryview(self, obj):

dispatch[memoryview] = save_memoryview

if not PY3:
if not PY3: # pragma: no branch
def save_buffer(self, obj):
self.save(str(obj))

Expand All @@ -315,7 +315,7 @@ def save_codeobject(self, obj):
"""
Save a code object
"""
if PY3:
if PY3: # pragma: no branch
args = (
obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, obj.co_varnames,
Expand Down Expand Up @@ -393,7 +393,7 @@ def save_function(self, obj, name=None):
# So we pickle them here using save_reduce; have to do it differently
# for different python versions.
if not hasattr(obj, '__code__'):
if PY3:
if PY3: # pragma: no branch
rv = obj.__reduce_ex__(self.proto)
else:
if hasattr(obj, '__self__'):
Expand Down Expand Up @@ -730,7 +730,7 @@ def save_instancemethod(self, obj):
if obj.__self__ is None:
self.save_reduce(getattr, (obj.im_class, obj.__name__))
else:
if PY3:
if PY3: # pragma: no branch
self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj)
else:
self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__),
Expand Down Expand Up @@ -783,7 +783,7 @@ def save_inst(self, obj):
save(stuff)
write(pickle.BUILD)

if not PY3:
if not PY3: # pragma: no branch
dispatch[types.InstanceType] = save_inst

def save_property(self, obj):
Expand Down Expand Up @@ -883,7 +883,7 @@ def save_not_implemented(self, obj):

try: # Python 2
dispatch[file] = save_file
except NameError: # Python 3
except NameError: # Python 3 # pragma: no branch
dispatch[io.TextIOWrapper] = save_file

dispatch[type(Ellipsis)] = save_ellipsis
Expand All @@ -904,6 +904,12 @@ def save_root_logger(self, obj):

dispatch[logging.RootLogger] = save_root_logger

if hasattr(types, "MappingProxyType"): # pragma: no branch
def save_mappingproxy(self, obj):
self.save_reduce(types.MappingProxyType, (dict(obj),), obj=obj)

dispatch[types.MappingProxyType] = save_mappingproxy

"""Special functions for Add-on libraries"""
def inject_addons(self):
"""Plug in system. Register additional pickling functions if modules already loaded"""
Expand Down Expand Up @@ -1213,7 +1219,7 @@ def _getobject(modname, attribute):

""" Use copy_reg to extend global pickle definitions """

if sys.version_info < (3, 4):
if sys.version_info < (3, 4): # pragma: no branch
method_descriptor = type(str.upper)

def _reduce_method_descriptor(obj):
Expand Down
15 changes: 15 additions & 0 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,21 @@ def __init__(self):
with pytest.raises(AttributeError):
obj.non_registered_attribute = 1

@unittest.skipIf(not hasattr(types, "MappingProxyType"),
"Old versions of Python do not have this type.")
def test_mappingproxy(self):
mp = types.MappingProxyType({"some_key": "some value"})
assert mp == pickle_depickle(mp, protocol=self.protocol)

def test_dataclass(self):
dataclasses = pytest.importorskip("dataclasses")

DataClass = dataclasses.make_dataclass('DataClass', [('x', int)])
data = DataClass(x=42)

pickle_depickle(DataClass, protocol=self.protocol)
assert data.x == pickle_depickle(data, protocol=self.protocol).x == 42


class Protocol2CloudPickleTest(CloudPickleTest):

Expand Down

0 comments on commit 7b31ced

Please sign in to comment.