From 49fcb0f149840507a99a12c6e3eeb8db6264b7dd Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 3 Aug 2020 14:25:27 -0700 Subject: [PATCH 1/4] Allow specifying `protocol` in `dumps` Provide users the option to override `protocol` in our `dumps` function. If it is not specified, default to the `HIGHEST_PROTOCOL` just as before. --- distributed/protocol/pickle.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index fd2343756a4..48b82202b2f 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -33,16 +33,17 @@ def _always_use_pickle_for(x): return False -def dumps(x, *, buffer_callback=None): +def dumps(x, protocol=None, *, buffer_callback=None): """ Manage between cloudpickle and pickle 1. Try pickle 2. If it is short then check if it contains __main__ 3. If it is long, then first check type, then check __main__ """ + protocol = protocol or HIGHEST_PROTOCOL buffers = [] - dump_kwargs = {"protocol": HIGHEST_PROTOCOL} - if HIGHEST_PROTOCOL >= 5 and buffer_callback is not None: + dump_kwargs = {"protocol": protocol} + if protocol >= 5 and buffer_callback is not None: dump_kwargs["buffer_callback"] = buffers.append try: buffers.clear() From c11efbf96230b5dae23f998ca733565dc411ad9f Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 3 Aug 2020 14:30:41 -0700 Subject: [PATCH 2/4] Test pickling with all supported protocols --- distributed/protocol/tests/test_pickle.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index ea2143c5358..8fba8a261b9 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -25,6 +25,13 @@ def test_pickle_data(): assert deserialize(*serialize(d, serializers=("pickle",))) == d +def test_pickle_protocol(): + data = {"int": 1, "float": 2, "unicode": "abc", "bytes": b"def", "set": set()} + for p in range(HIGHEST_PROTOCOL): + assert loads(dumps(data, p)) == data + assert deserialize(*serialize(data, serializers=("pickle",))) == data + + def test_pickle_out_of_band(): class MemoryviewHolder: def __init__(self, mv): From bd907019dcaac8b508cb09a17c263b2d52ca471a Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 3 Aug 2020 14:35:59 -0700 Subject: [PATCH 3/4] Add pickle protocol to the config --- distributed/distributed-schema.yaml | 12 ++++++++++++ distributed/distributed.yaml | 2 ++ 2 files changed, 14 insertions(+) diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index f67cdca84f2..f6dc7b23187 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -759,6 +759,18 @@ properties: type: boolean description: Enter Python Debugger on scheduling error + pickle: + type: object + description: | + Configuration for pickle serialization + properties: + protocol: + type: + - integer + - "null" + description: + The protocol version to use with pickle + rmm: type: object description: | diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index f815fadf830..1ac52211095 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -170,6 +170,8 @@ distributed: log-length: 10000 # default length of logs to keep in memory log-format: '%(name)s - %(levelname)s - %(message)s' pdb-on-err: False # enter debug mode on scheduling error +pickle: + protocol: null # specify the pickle protocol to use rmm: pool-size: null ucx: From 0c1a79d66cc7e038f560ae754ac89cb90b07502f Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 3 Aug 2020 14:47:07 -0700 Subject: [PATCH 4/4] Override `protocol` based on the config value --- distributed/protocol/pickle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index 48b82202b2f..30bc7f42f9c 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -2,6 +2,7 @@ import sys import cloudpickle +import dask if sys.version_info < (3, 8): try: @@ -40,7 +41,7 @@ def dumps(x, protocol=None, *, buffer_callback=None): 2. If it is short then check if it contains __main__ 3. If it is long, then first check type, then check __main__ """ - protocol = protocol or HIGHEST_PROTOCOL + protocol = protocol or dask.config.get("pickle.protocol") or HIGHEST_PROTOCOL buffers = [] dump_kwargs = {"protocol": protocol} if protocol >= 5 and buffer_callback is not None: