From 5bac442dee9ea9b138909f3a0d8f79e96947a388 Mon Sep 17 00:00:00 2001 From: dhirschf Date: Fri, 13 Jul 2018 12:04:08 +1000 Subject: [PATCH] Add custom serialization support for pyarrow Closes #2103 --- distributed/protocol/__init__.py | 5 +++ distributed/protocol/arrow.py | 53 ++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 distributed/protocol/arrow.py diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index a6a9afaf324..565eda7fef2 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -37,3 +37,8 @@ def _register_keras(): @partial(register_serialization_lazy, "sparse") def _register_sparse(): from . import sparse + + +@partial(register_serialization_lazy, "arrow") +def _register_arrow(): + from . import arrow diff --git a/distributed/protocol/arrow.py b/distributed/protocol/arrow.py new file mode 100644 index 00000000000..2378d7d7feb --- /dev/null +++ b/distributed/protocol/arrow.py @@ -0,0 +1,53 @@ +from __future__ import print_function, division, absolute_import + +from .serialize import register_serialization + + +def serialize_batch(batch): + import pyarrow as pa + sink = pa.BufferOutputStream() + writer = pa.RecordBatchStreamWriter(sink, batch.schema) + writer.write_batch(batch) + writer.close() + buf = sink.get_result() + header = {} + frames = [buf.to_pybytes()] + return header, frames + + +def deserialize_batch(header, frames): + import pyarrow as pa + blob = frames[0] + reader = pa.RecordBatchStreamReader(pa.BufferReader(blob)) + return reader.read_next_batch() + + +def serialize_table(tbl): + import pyarrow as pa + sink = pa.BufferOutputStream() + writer = pa.RecordBatchStreamWriter(sink, tbl.schema) + writer.write_table(tbl) + writer.close() + buf = sink.get_result() + header = {} + frames = [buf.to_pybytes()] + return header, frames + + +def deserialize_table(header, frames): + import pyarrow as pa + blob = frames[0] + reader = pa.RecordBatchStreamReader(pa.BufferReader(blob)) + return reader.read_all() + + +register_serialization( + 'pyarrow.lib.RecordBatch', + serialize_batch, + deserialize_batch +) +register_serialization( + 'pyarrow.lib.Table', + serialize_batch, + deserialize_table +)