Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pickling dataclasses #245

Merged
merged 10 commits into from
Feb 6, 2019
Merged
7 changes: 7 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
0.8.0
=====

- Add support for pickling interactively defined dataclasses.
([issue #245](https://github.com/cloudpipe/cloudpickle/pull/245))


0.7.0
=====

Expand Down
28 changes: 17 additions & 11 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL


if sys.version < '3':
if sys.version_info[0] < 3: # pragma: no branch
from pickle import Pickler
try:
from cStringIO import StringIO
Expand Down Expand Up @@ -128,7 +128,7 @@ def inner(value):
# NOTE: we are marking the cell variable as a free variable intentionally
# so that we simulate an inner function instead of the outer function. This
# is what gives us the ``nonlocal`` behavior in a Python 2 compatible way.
if not PY3:
if not PY3: # pragma: no branch
return types.CodeType(
co.co_argcount,
co.co_nlocals,
Expand Down Expand Up @@ -229,14 +229,14 @@ def _factory():
}


if sys.version_info < (3, 4):
if sys.version_info < (3, 4): # pragma: no branch
def _walk_global_ops(code):
"""
Yield (opcode, argument number) tuples for all
global-referencing instructions in *code*.
"""
code = getattr(code, 'co_code', b'')
if not PY3:
if not PY3: # pragma: no branch
code = map(ord, code)

n = len(code)
Expand Down Expand Up @@ -293,7 +293,7 @@ def save_memoryview(self, obj):

dispatch[memoryview] = save_memoryview

if not PY3:
if not PY3: # pragma: no branch
def save_buffer(self, obj):
self.save(str(obj))

Expand All @@ -315,7 +315,7 @@ def save_codeobject(self, obj):
"""
Save a code object
"""
if PY3:
if PY3: # pragma: no branch
args = (
obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, obj.co_varnames,
Expand Down Expand Up @@ -393,7 +393,7 @@ def save_function(self, obj, name=None):
# So we pickle them here using save_reduce; have to do it differently
# for different python versions.
if not hasattr(obj, '__code__'):
if PY3:
if PY3: # pragma: no branch
rv = obj.__reduce_ex__(self.proto)
else:
if hasattr(obj, '__self__'):
Expand Down Expand Up @@ -730,7 +730,7 @@ def save_instancemethod(self, obj):
if obj.__self__ is None:
self.save_reduce(getattr, (obj.im_class, obj.__name__))
else:
if PY3:
if PY3: # pragma: no branch
self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj)
else:
self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__),
Expand Down Expand Up @@ -783,7 +783,7 @@ def save_inst(self, obj):
save(stuff)
write(pickle.BUILD)

if not PY3:
if not PY3: # pragma: no branch
dispatch[types.InstanceType] = save_inst

def save_property(self, obj):
Expand Down Expand Up @@ -883,7 +883,7 @@ def save_not_implemented(self, obj):

try: # Python 2
dispatch[file] = save_file
except NameError: # Python 3
except NameError: # Python 3 # pragma: no branch
dispatch[io.TextIOWrapper] = save_file

dispatch[type(Ellipsis)] = save_ellipsis
Expand All @@ -904,6 +904,12 @@ def save_root_logger(self, obj):

dispatch[logging.RootLogger] = save_root_logger

if hasattr(types, "MappingProxyType"): # pragma: no branch
def save_mappingproxy(self, obj):
self.save_reduce(types.MappingProxyType, (dict(obj),), obj=obj)

dispatch[types.MappingProxyType] = save_mappingproxy

"""Special functions for Add-on libraries"""
def inject_addons(self):
"""Plug in system. Register additional pickling functions if modules already loaded"""
Expand Down Expand Up @@ -1213,7 +1219,7 @@ def _getobject(modname, attribute):

""" Use copy_reg to extend global pickle definitions """

if sys.version_info < (3, 4):
if sys.version_info < (3, 4): # pragma: no branch
method_descriptor = type(str.upper)

def _reduce_method_descriptor(obj):
Expand Down
15 changes: 15 additions & 0 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,21 @@ def __init__(self):
with pytest.raises(AttributeError):
obj.non_registered_attribute = 1

@unittest.skipIf(not hasattr(types, "MappingProxyType"),
"Old versions of Python do not have this type.")
def test_mappingproxy(self):
mp = types.MappingProxyType({"some_key": "some value"})
assert mp == pickle_depickle(mp, protocol=self.protocol)

def test_dataclass(self):
dataclasses = pytest.importorskip("dataclasses")

DataClass = dataclasses.make_dataclass('DataClass', [('x', int)])
data = DataClass(x=42)

pickle_depickle(DataClass, protocol=self.protocol)
assert data.x == pickle_depickle(data, protocol=self.protocol).x == 42


class Protocol2CloudPickleTest(CloudPickleTest):

Expand Down