Skip to content

Commit

Permalink
Add dask serialization of CUDA objects (#3482)
Browse files Browse the repository at this point in the history
* Run `isort` on CUDA protocol `import`s

* Align CuPy serialize/deserialize function names

* Prefix CUDA serializers with `cuda_`

This should make room for Dask serializers to also be specified and
added.

* Add Dask serializers for RMM `DeviceBuffer`s

To make TCP a bit more performant with RMM, provide Dask serializers to
allow going to and from host memory.

* Add Dask serializers for Numba `DeviceNDArray`s

* Add Dask serializers for CuPy `ndarray`s

* Parametrize serializers in CUDA object tests

To make sure that different CUDA objects can use different serialization
protocols, test with each one individual and ensure it completes. In
particular test both "cuda" and "dask". Where supported also test
"pickle", but skip it when it is not (like with Numba).

* Check frames are the expected type

To make sure Dask can handle transmission of the frames serialized, test
they match the type expected by the protocol used. With "cuda" ensure we
get something that supports `__cuda_array_interface__`. With "dask" make
sure we get a `memoryview`.
  • Loading branch information
jakirkham authored Feb 19, 2020
1 parent 5465739 commit b5e95ed
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 19 deletions.
6 changes: 6 additions & 0 deletions distributed/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,24 @@ def _register_torch():

@cuda_serialize.register_lazy("cupy")
@cuda_deserialize.register_lazy("cupy")
@dask_serialize.register_lazy("cupy")
@dask_deserialize.register_lazy("cupy")
def _register_cupy():
from . import cupy


@cuda_serialize.register_lazy("numba")
@cuda_deserialize.register_lazy("numba")
@dask_serialize.register_lazy("numba")
@dask_deserialize.register_lazy("numba")
def _register_numba():
from . import numba


@cuda_serialize.register_lazy("rmm")
@cuda_deserialize.register_lazy("rmm")
@dask_serialize.register_lazy("rmm")
@dask_deserialize.register_lazy("rmm")
def _register_rmm():
from . import rmm

Expand Down
27 changes: 24 additions & 3 deletions distributed/protocol/cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
Efficient serialization GPU arrays.
"""
import cupy
from .cuda import cuda_serialize, cuda_deserialize

from .cuda import cuda_deserialize, cuda_serialize
from .serialize import dask_deserialize, dask_serialize

try:
from .rmm import dask_deserialize_rmm_device_buffer as dask_deserialize_cuda_buffer
except ImportError:
from .numba import dask_deserialize_numba_array as dask_deserialize_cuda_buffer


class PatchedCudaArrayInterface:
Expand Down Expand Up @@ -31,7 +38,7 @@ def __del__(self):


@cuda_serialize.register(cupy.ndarray)
def serialize_cupy_ndarray(x):
def cuda_serialize_cupy_ndarray(x):
# Making sure `x` is behaving
if not (x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"]):
x = cupy.array(x, copy=True)
Expand All @@ -48,7 +55,7 @@ def serialize_cupy_ndarray(x):


@cuda_deserialize.register(cupy.ndarray)
def deserialize_cupy_array(header, frames):
def cuda_deserialize_cupy_ndarray(header, frames):
(frame,) = frames
if not isinstance(frame, cupy.ndarray):
frame = PatchedCudaArrayInterface(frame)
Expand All @@ -59,3 +66,17 @@ def deserialize_cupy_array(header, frames):
strides=header["strides"],
)
return arr


@dask_serialize.register(cupy.ndarray)
def dask_serialize_cupy_ndarray(x):
header, frames = cuda_serialize_cupy_ndarray(x)
frames = [memoryview(cupy.asnumpy(f)) for f in frames]
return header, frames


@dask_deserialize.register(cupy.ndarray)
def dask_deserialize_cupy_ndarray(header, frames):
frames = [dask_deserialize_cuda_buffer(header, frames)]
arr = cuda_deserialize_cupy_ndarray(header, frames)
return arr
33 changes: 29 additions & 4 deletions distributed/protocol/numba.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import numpy as np
import numba.cuda
from .cuda import cuda_serialize, cuda_deserialize
import numpy as np

from .cuda import cuda_deserialize, cuda_serialize
from .serialize import dask_deserialize, dask_serialize

try:
from .rmm import dask_deserialize_rmm_device_buffer
except ImportError:
dask_deserialize_rmm_device_buffer = None


@cuda_serialize.register(numba.cuda.devicearray.DeviceNDArray)
def serialize_numba_ndarray(x):
def cuda_serialize_numba_ndarray(x):
# Making sure `x` is behaving
if not (x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"]):
shape = x.shape
Expand All @@ -24,7 +31,7 @@ def serialize_numba_ndarray(x):


@cuda_deserialize.register(numba.cuda.devicearray.DeviceNDArray)
def deserialize_numba_ndarray(header, frames):
def cuda_deserialize_numba_ndarray(header, frames):
(frame,) = frames
shape = header["shape"]
strides = header["strides"]
Expand All @@ -36,3 +43,21 @@ def deserialize_numba_ndarray(header, frames):
gpu_data=numba.cuda.as_cuda_array(frame).gpu_data,
)
return arr


@dask_serialize.register(numba.cuda.devicearray.DeviceNDArray)
def dask_serialize_numba_ndarray(x):
header, frames = cuda_serialize_numba_ndarray(x)
frames = [memoryview(f.copy_to_host()) for f in frames]
return header, frames


@dask_deserialize.register(numba.cuda.devicearray.DeviceNDArray)
def dask_deserialize_numba_array(header, frames):
if dask_deserialize_rmm_device_buffer:
frames = [dask_deserialize_rmm_device_buffer(header, frames)]
else:
frames = [numba.cuda.to_device(np.asarray(memoryview(f))) for f in frames]

arr = cuda_deserialize_numba_ndarray(header, frames)
return arr
28 changes: 25 additions & 3 deletions distributed/protocol/rmm.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import numba
import numba.cuda
import numpy
import rmm
from .cuda import cuda_serialize, cuda_deserialize

from .cuda import cuda_deserialize, cuda_serialize
from .serialize import dask_deserialize, dask_serialize

# Used for RMM 0.11.0+ otherwise Numba serializers used
if hasattr(rmm, "DeviceBuffer"):

@cuda_serialize.register(rmm.DeviceBuffer)
def serialize_rmm_device_buffer(x):
def cuda_serialize_rmm_device_buffer(x):
header = x.__cuda_array_interface__.copy()
frames = [x]
return header, frames

@cuda_deserialize.register(rmm.DeviceBuffer)
def deserialize_rmm_device_buffer(header, frames):
def cuda_deserialize_rmm_device_buffer(header, frames):
(arr,) = frames

# We should already have `DeviceBuffer`
Expand All @@ -21,3 +25,21 @@ def deserialize_rmm_device_buffer(header, frames):
assert isinstance(arr, rmm.DeviceBuffer)

return arr

@dask_serialize.register(rmm.DeviceBuffer)
def dask_serialize_rmm_device_buffer(x):
header = x.__cuda_array_interface__.copy()
frames = [numba.cuda.as_cuda_array(x).copy_to_host().data]
return header, frames

@dask_deserialize.register(rmm.DeviceBuffer)
def dask_deserialize_rmm_device_buffer(header, frames):
(frame,) = frames

arr = numpy.asarray(memoryview(frame))
ptr = arr.__array_interface__["data"][0]
size = arr.nbytes

buf = rmm.DeviceBuffer(ptr=ptr, size=size)

return buf
12 changes: 9 additions & 3 deletions distributed/protocol/tests/test_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@
@pytest.mark.parametrize("shape", [(0,), (5,), (4, 6), (10, 11), (2, 3, 5)])
@pytest.mark.parametrize("dtype", ["u1", "u4", "u8", "f4"])
@pytest.mark.parametrize("order", ["C", "F"])
def test_serialize_cupy(shape, dtype, order):
@pytest.mark.parametrize("serializers", [("cuda",), ("dask",), ("pickle",)])
def test_serialize_cupy(shape, dtype, order, serializers):
x = cupy.arange(numpy.product(shape), dtype=dtype)
x = cupy.ndarray(shape, dtype=x.dtype, memptr=x.data, order=order)
header, frames = serialize(x, serializers=("cuda", "dask", "pickle"))
y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error"))
header, frames = serialize(x, serializers=serializers)
y = deserialize(header, frames, deserializers=serializers)

if serializers[0] == "cuda":
assert all(hasattr(f, "__cuda_array_interface__") for f in frames)
elif serializers[0] == "dask":
assert all(isinstance(f, memoryview) for f in frames)

assert (x == y).all()

Expand Down
12 changes: 9 additions & 3 deletions distributed/protocol/tests/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,21 @@
@pytest.mark.parametrize("shape", [(0,), (5,), (4, 6), (10, 11), (2, 3, 5)])
@pytest.mark.parametrize("dtype", ["u1", "u4", "u8", "f4"])
@pytest.mark.parametrize("order", ["C", "F"])
def test_serialize_numba(shape, dtype, order):
@pytest.mark.parametrize("serializers", [("cuda",), ("dask",)])
def test_serialize_numba(shape, dtype, order, serializers):
if not cuda.is_available():
pytest.skip("CUDA is not available")

ary = np.arange(np.product(shape), dtype=dtype)
ary = np.ndarray(shape, dtype=ary.dtype, buffer=ary.data, order=order)
x = cuda.to_device(ary)
header, frames = serialize(x, serializers=("cuda", "dask", "pickle"))
y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error"))
header, frames = serialize(x, serializers=serializers)
y = deserialize(header, frames, deserializers=serializers)

if serializers[0] == "cuda":
assert all(hasattr(f, "__cuda_array_interface__") for f in frames)
elif serializers[0] == "dask":
assert all(isinstance(f, memoryview) for f in frames)

hx = np.empty_like(ary)
hy = np.empty_like(ary)
Expand Down
12 changes: 9 additions & 3 deletions distributed/protocol/tests/test_rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,22 @@


@pytest.mark.parametrize("size", [0, 3, 10])
def test_serialize_rmm_device_buffer(size):
@pytest.mark.parametrize("serializers", [("cuda",), ("dask",), ("pickle",)])
def test_serialize_rmm_device_buffer(size, serializers):
if not hasattr(rmm, "DeviceBuffer"):
pytest.skip("RMM pre-0.11.0 does not have DeviceBuffer")

x_np = numpy.arange(size, dtype="u1")
x = rmm.DeviceBuffer(size=size)
cuda.to_device(x_np, to=cuda.as_cuda_array(x))

header, frames = serialize(x, serializers=("cuda", "dask", "pickle"))
y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error"))
header, frames = serialize(x, serializers=serializers)
y = deserialize(header, frames, deserializers=serializers)
y_np = y.copy_to_host()

if serializers[0] == "cuda":
assert all(hasattr(f, "__cuda_array_interface__") for f in frames)
elif serializers[0] == "dask":
assert all(isinstance(f, memoryview) for f in frames)

assert (x_np == y_np).all()

0 comments on commit b5e95ed

Please sign in to comment.