Skip to content

Commit

Permalink
[SPARK-29341][PYTHON] Upgrade cloudpickle to 1.0.0
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This patch upgrades cloudpickle to 1.0.0 version.

Main changes:

1. cleanup unused functions: cloudpipe/cloudpickle@936f16f
2. Fix relative imports inside function body: cloudpipe/cloudpickle@31ecdd6
3. Write kw only arguments to pickle: cloudpipe/cloudpickle@6cb4718

### Why are the changes needed?

We should include new bug fix like cloudpipe/cloudpickle@6cb4718, because users might use such python function in PySpark.

```python
>>> def f(a, *, b=1):
...   return a + b
...
>>> rdd = sc.parallelize([1, 2, 3])
>>> rdd.map(f).collect()
[Stage 0:>                                                        (0 + 12) / 12]19/10/03 00:42:24 ERROR Executor: Exception in task 3.0 in stage 0.0 (TID 3)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/spark/python/lib/pyspark.zip/pyspark/worker.py", line 598, in main
    process()
  File "/spark/python/lib/pyspark.zip/pyspark/worker.py", line 590, in process
    serializer.dump_stream(out_iter, outfile)
  File "/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 513, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "/spark/python/lib/pyspark.zip/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
TypeError: f() missing 1 required keyword-only argument: 'b'
```

After:

```python
>>> def f(a, *, b=1):
...   return a + b
...
>>> rdd = sc.parallelize([1, 2, 3])
>>> rdd.map(f).collect()
[2, 3, 4]
```

### Does this PR introduce any user-facing change?

Yes. This fixes two bugs when pickling Python functions.

### How was this patch tested?

Existing tests.

Closes #26009 from viirya/upgrade-cloudpickle.

Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: HyukjinKwon <[email protected]>
  • Loading branch information
viirya authored and HyukjinKwon committed Oct 3, 2019
1 parent 858bf76 commit 2bc3fff
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 50 deletions.
3 changes: 1 addition & 2 deletions python/pyspark/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
from tempfile import NamedTemporaryFile
import threading

from pyspark.cloudpickle import print_exec
from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import ChunkedStream, pickle_protocol
from pyspark.util import _exception_message
from pyspark.util import _exception_message, print_exec

if sys.version < '3':
import cPickle as pickle
Expand Down
59 changes: 13 additions & 46 deletions python/pyspark/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,8 @@ def save_function_tuple(self, func):
state['annotations'] = func.__annotations__
if hasattr(func, '__qualname__'):
state['qualname'] = func.__qualname__
if hasattr(func, '__kwdefaults__'):
state['kwdefaults'] = func.__kwdefaults__
save(state)
write(pickle.TUPLE)
write(pickle.REDUCE) # applies _fill_function on the tuple
Expand Down Expand Up @@ -666,6 +668,15 @@ def extract_func_data(self, func):
# multiple invokations are bound to the same Cloudpickler.
base_globals = self.globals_ref.setdefault(id(func.__globals__), {})

if base_globals == {}:
# Add module attributes used to resolve relative imports
# instructions inside func.
for k in ["__package__", "__name__", "__path__", "__file__"]:
# Some built-in functions/methods such as object.__new__ have
# their __globals__ set to None in PyPy
if func.__globals__ is not None and k in func.__globals__:
base_globals[k] = func.__globals__[k]

return (code, f_globals, defaults, closure, dct, base_globals)

def save_builtin_function(self, obj):
Expand Down Expand Up @@ -979,43 +990,6 @@ def _restore_attr(obj, attr):
return obj


def _get_module_builtins():
return pickle.__builtins__


def print_exec(stream):
ei = sys.exc_info()
traceback.print_exception(ei[0], ei[1], ei[2], None, stream)


def _modules_to_main(modList):
"""Force every module in modList to be placed into main"""
if not modList:
return

main = sys.modules['__main__']
for modname in modList:
if type(modname) is str:
try:
mod = __import__(modname)
except Exception:
sys.stderr.write('warning: could not import %s\n. '
'Your function may unexpectedly error due to this import failing;'
'A version mismatch is likely. Specific error was:\n' % modname)
print_exec(sys.stderr)
else:
setattr(main, mod.__name__, mod)


# object generators:
def _genpartial(func, args, kwds):
if not args:
args = ()
if not kwds:
kwds = {}
return partial(func, *args, **kwds)


def _gen_ellipsis():
return Ellipsis

Expand Down Expand Up @@ -1103,6 +1077,8 @@ def _fill_function(*args):
func.__module__ = state['module']
if 'qualname' in state:
func.__qualname__ = state['qualname']
if 'kwdefaults' in state:
func.__kwdefaults__ = state['kwdefaults']

cells = func.__closure__
if cells is not None:
Expand Down Expand Up @@ -1188,15 +1164,6 @@ def _is_dynamic(module):
return False


"""Constructors for 3rd party libraries
Note: These can never be renamed due to client compatibility issues"""


def _getobject(modname, attribute):
mod = __import__(modname, fromlist=[attribute])
return mod.__dict__[attribute]


""" Use copy_reg to extend global pickle definitions """

if sys.version_info < (3, 4): # pragma: no branch
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
pickle_protocol = pickle.HIGHEST_PROTOCOL

from pyspark import cloudpickle
from pyspark.util import _exception_message
from pyspark.util import _exception_message, print_exec


__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"]
Expand Down Expand Up @@ -716,7 +716,7 @@ def dumps(self, obj):
msg = "Object too large to serialize: %s" % emsg
else:
msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg)
cloudpickle.print_exec(sys.stderr)
print_exec(sys.stderr)
raise pickle.PicklingError(msg)


Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import re
import sys
import traceback
import inspect
from py4j.protocol import Py4JJavaError

Expand Down Expand Up @@ -62,6 +63,11 @@ def _get_argspec(f):
return argspec


def print_exec(stream):
ei = sys.exc_info()
traceback.print_exception(ei[0], ei[1], ei[2], None, stream)


class VersionUtils(object):
"""
Provides utility method to determine Spark versions with given input string.
Expand Down

0 comments on commit 2bc3fff

Please sign in to comment.