diff --git a/tests/test_send_recv.py b/tests/test_send_recv.py index cd7accc20..c3e749072 100644 --- a/tests/test_send_recv.py +++ b/tests/test_send_recv.py @@ -98,13 +98,25 @@ async def test_send_recv_cupy(size, dtype, blocking_progress_mode): ucp.init(blocking_progress_mode=blocking_progress_mode) cupy = pytest.importorskip("cupy") - msg = cupy.arange(size, dtype=dtype) - msg_size = np.array([msg.nbytes], dtype=np.uint64) - listener = ucp.create_listener( make_echo_server(lambda n: cupy.empty((n,), dtype=np.uint8)) ) client = await ucp.create_endpoint(ucp.get_address(), listener.port) + + msg = cupy.arange(size, dtype=dtype) + msg_size = np.array([msg.nbytes], dtype=np.uint64) + + await client.send(msg_size) + await client.send(msg) + resp = cupy.empty_like(msg) + await client.recv(resp) + np.testing.assert_array_equal(cupy.asnumpy(resp), cupy.asnumpy(msg)) + + msg = cupy.concatenate( + [cupy.arange(0, dtype=dtype), cupy.arange(size, dtype=dtype)] + ) + msg_size = np.array([msg.nbytes], dtype=np.uint64) + await client.send(msg_size) await client.send(msg) resp = cupy.empty_like(msg)