Skip to content

Commit

Permalink
Revert pickle change (#8456)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Jan 12, 2024
1 parent 58e2d99 commit c3aa960
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
13 changes: 4 additions & 9 deletions distributed/protocol/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,14 @@ def dumps(x, *, buffer_callback=None, protocol=HIGHEST_PROTOCOL):
buffers.clear()
pickler.dump(x)
result = f.getvalue()

if not _always_use_pickle_for(x) and (
if b"__main__" in result or (
CLOUDPICKLE_GE_20
and getattr(inspect.getmodule(x), "__name__", None)
in cloudpickle.list_registry_pickle_by_value()
or (
len(result) < 1000
# Do this very last since it's expensive
and b"__main__" in result
)
):
buffers.clear()
result = cloudpickle.dumps(x, **dump_kwargs)
if len(result) < 1000 or not _always_use_pickle_for(x):
buffers.clear()
result = cloudpickle.dumps(x, **dump_kwargs)
except Exception:
try:
buffers.clear()
Expand Down
21 changes: 20 additions & 1 deletion distributed/protocol/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
loads,
)
from distributed.protocol.serialize import dask_deserialize, dask_serialize
from distributed.utils_test import save_sys_modules
from distributed.utils_test import popen, save_sys_modules


class MemoryviewHolder:
Expand Down Expand Up @@ -278,3 +278,22 @@ def test_nopickle_nested():
finally:
del dask_serialize._lookup[NoPickle]
del dask_deserialize._lookup[NoPickle]


@pytest.mark.slow()
def test_pickle_functions_in_main(tmp_path):
script = """
from dask.distributed import Client
if __name__ == "__main__":
with Client(n_workers=1) as client:
def func(df):
return (df + 5)
client.submit(func, 5).result()
print("script successful", flush=True)
"""
with open(tmp_path / "script.py", mode="w") as f:
f.write(script)
with popen([sys.executable, tmp_path / "script.py"], capture_output=True) as proc:
out, _ = proc.communicate(timeout=60)

assert "script successful" in out.decode("utf-8")

0 comments on commit c3aa960

Please sign in to comment.