From f111f7ab6d302e9b1e2a568d0e4c574895db6a6e Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 8 Apr 2024 19:09:24 +0200 Subject: [PATCH] Fix `_whichmodule` with `multiprocessing` (#529) Co-authored-by: Olivier Grisel --- CHANGES.md | 3 +++ cloudpickle/cloudpickle.py | 1 + tests/cloudpickle_test.py | 24 ++++++++++++++++++++++++ 3 files changed, 28 insertions(+) diff --git a/CHANGES.md b/CHANGES.md index ab86e7af..73677a96 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,6 +5,9 @@ dynamic functions and classes. ([PR #524](https://github.com/cloudpipe/cloudpickle/pull/524)) +- Fix a problem with the joint usage of cloudpickle's `_whichmodule` and + `multiprocessing`. + ([PR #529](https://github.com/cloudpipe/cloudpickle/pull/529)) 3.0.0 ===== diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index ec61002e..88f9b12a 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -213,6 +213,7 @@ def _whichmodule(obj, name): # sys.modules if ( module_name == "__main__" + or module_name == "__mp_main__" or module is None or not isinstance(module, types.ModuleType) ): diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index ed18285c..5aa4baca 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -29,6 +29,7 @@ import pickle import pytest +from pathlib import Path try: # try importing numpy and scipy. These are not hard dependencies and @@ -1479,6 +1480,29 @@ def __getattr__(self, name): finally: sys.modules.pop("NonModuleObject") + def test_importing_multiprocessing_does_not_impact_whichmodule(self): + # non-regression test for #528 + pytest.importorskip("numpy") + script = textwrap.dedent(""" + import multiprocessing + import cloudpickle + from numpy import exp + + print(cloudpickle.cloudpickle._whichmodule(exp, exp.__name__)) + """) + script_path = Path(self.tmpdir) / "whichmodule_and_multiprocessing.py" + with open(script_path, mode="w") as f: + f.write(script) + + proc = subprocess.Popen( + [sys.executable, str(script_path)], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + out, _ = proc.communicate() + self.assertEqual(proc.wait(), 0) + self.assertEqual(out, b"numpy.core._multiarray_umath\n") + def test_unrelated_faulty_module(self): # Check that pickling a dynamically defined function or class does not # fail when introspecting the currently loaded modules in sys.modules