From ee913e6e2d58dfac20f3f06ff306081bd0e48066 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sun, 6 Mar 2016 08:57:01 -0800 Subject: [PATCH] [SPARK-13697] [PYSPARK] Fix the missing module name of TransformFunctionSerializer.loads ## What changes were proposed in this pull request? Set the function's module name to `__main__` if it's missing in `TransformFunctionSerializer.loads`. ## How was this patch tested? Manually test in the shell. Before this patch: ``` >>> from pyspark.streaming import StreamingContext >>> from pyspark.streaming.util import TransformFunction >>> ssc = StreamingContext(sc, 1) >>> func = TransformFunction(sc, lambda x: x, sc.serializer) >>> func.rdd_wrapper(lambda x: x) TransformFunction( at 0x106ac8b18>) >>> bytes = bytearray(ssc._transformerSerializer.serializer.dumps((func.func, func.rdd_wrap_func, func.deserializers))) >>> func2 = ssc._transformerSerializer.loads(bytes) >>> print(func2.func.__module__) None >>> print(func2.rdd_wrap_func.__module__) None >>> ``` After this patch: ``` >>> from pyspark.streaming import StreamingContext >>> from pyspark.streaming.util import TransformFunction >>> ssc = StreamingContext(sc, 1) >>> func = TransformFunction(sc, lambda x: x, sc.serializer) >>> func.rdd_wrapper(lambda x: x) TransformFunction( at 0x108bf1b90>) >>> bytes = bytearray(ssc._transformerSerializer.serializer.dumps((func.func, func.rdd_wrap_func, func.deserializers))) >>> func2 = ssc._transformerSerializer.loads(bytes) >>> print(func2.func.__module__) __main__ >>> print(func2.rdd_wrap_func.__module__) __main__ >>> ``` Author: Shixiong Zhu Closes #11535 from zsxwing/loads-module. --- python/pyspark/cloudpickle.py | 4 +++- python/pyspark/tests.py | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 95b3abc74244b..e56e22a9b920e 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -241,6 +241,7 @@ def save_function_tuple(self, func): save(f_globals) save(defaults) save(dct) + save(func.__module__) write(pickle.TUPLE) write(pickle.REDUCE) # applies _fill_function on the tuple @@ -698,13 +699,14 @@ def _genpartial(func, args, kwds): return partial(func, *args, **kwds) -def _fill_function(func, globals, defaults, dict): +def _fill_function(func, globals, defaults, dict, module): """ Fills in the rest of function data into the skeleton function object that were created via _make_skel_func(). """ func.__globals__.update(globals) func.__defaults__ = defaults func.__dict__ = dict + func.__module__ = module return func diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 23720502a82c8..a5a83c7e38b3c 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -228,6 +228,12 @@ def test_itemgetter(self): getter2 = ser.loads(ser.dumps(getter)) self.assertEqual(getter(d), getter2(d)) + def test_function_module_name(self): + ser = CloudPickleSerializer() + func = lambda x: x + func2 = ser.loads(ser.dumps(func)) + self.assertEqual(func.__module__, func2.__module__) + def test_attrgetter(self): from operator import attrgetter ser = CloudPickleSerializer()