Skip to content

Commit

Permalink
Support async recvmsg and sendmsg
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasWoodtli committed Feb 27, 2024
1 parent 10fbcd6 commit ee826c3
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 0 deletions.
105 changes: 105 additions & 0 deletions Lib/asyncio/selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .log import logger

_HAS_SENDMSG = hasattr(socket.socket, 'sendmsg')
_HAS_RECVMSG = hasattr(socket.socket, 'recvmsg')

if _HAS_SENDMSG:
try:
Expand Down Expand Up @@ -66,6 +67,12 @@ def __init__(self, selector=None):
self._make_self_pipe()
self._transports = weakref.WeakValueDictionary()

if not _HAS_SENDMSG:
delattr(self, 'sock_sendmsg')

if not _HAS_RECVMSG:
delattr(self, 'sock_recvmsg')

def _make_socket_transport(self, sock, protocol, waiter=None, *,
extra=None, server=None):
self._ensure_fd_no_transport(sock)
Expand Down Expand Up @@ -523,6 +530,52 @@ def _sock_recvfrom_into(self, fut, sock, buf, bufsize):
else:
fut.set_result(result)

async def sock_recvmsg(self, sock, bufsize, ancbufsize=0, flags=0):
"""Receive normal data (up to bufsize bytes) and ancillary data from
the socket (sock). The socket must be non-blocking.
The return value is a tuple of (data, ancdata, msg_flags, address).
data represents the datagram received. ancdata are the ancillary data
(control messages) as a list of tuples (cmsg_level, cmsg_type, cmsg_data),
where cmsg_level and cmsg_type are integers specifying the protocol level
and protocol-specific type respectively, and cmsg_data is a bytes object
holding the associated ancillary data. flags represent various conditions
(bitwise OR) on the received data.
The address is only specified if the receiving socket is unconnected.
Then it is the address of the sending socket.
"""
base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
try:
return sock.recvmsg(bufsize)
except (BlockingIOError, InterruptedError):
pass
fut = self.create_future()
fd = sock.fileno()
self._ensure_fd_no_transport(fd)
handle = self._add_reader(fd, self._sock_recvmsg, fut, sock, bufsize, ancbufsize, flags)
fut.add_done_callback(
functools.partial(self._sock_read_done, fd, handle=handle))
return await fut

def _sock_recvmsg(self, fut, sock, bufsize, ancbufsize, flags):
# _sock_recvmsg() can add itself as an I/O callback if the operation
# can't be done immediately. Don't use it directly, call
# sock_recvmsg().
if fut.done():
return
try:
result = sock.recvmsg(bufsize, ancbufsize, flags)
except (BlockingIOError, InterruptedError):
return # try again next time
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
fut.set_exception(exc)
else:
fut.set_result(result)

async def sock_sendall(self, sock, data):
"""Send data to the socket.
Expand Down Expand Up @@ -576,6 +629,58 @@ def _sock_sendall(self, fut, sock, view, pos):
else:
pos[0] = start

async def sock_sendmsg(self, sock, data, ancdata=[], flags=0, address=None):
"""Send datagram (data) and ancillary data to the socket (sock).
The provided ancillary data is a list of zero or more tuples (data, ancdata,
msg_flags, address). flags represent various conditions and have the same
meaning as for send(). If address is supplied and not None, it sets a destination
address for the message it is the address of the sending socket.
"""
base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
try:
n = sock.sendmsg([data], ancdata, flags, address)
except (BlockingIOError, InterruptedError):
n = 0

if n == len(data):
# all data sent
return

fut = self.create_future()
fd = sock.fileno()
self._ensure_fd_no_transport(fd)
# use a trick with a list in closure to store a mutable state
handle = self._add_writer(fd, self._sock_sendmsg, fut, sock,
memoryview(data), [n], ancdata, flags, address)
fut.add_done_callback(
functools.partial(self._sock_write_done, fd, handle=handle))
return await fut

def _sock_sendmsg(self, fut, sock, view, pos, ancdata, flags, address):
if fut.done():
# Future cancellation can be scheduled on previous loop iteration
return
start = pos[0]
try:
n = sock.sendmsg([view[start:]], ancdata, flags, address)
except (BlockingIOError, InterruptedError):
return
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
fut.set_exception(exc)
return

start += n

if start == len(view):
fut.set_result(None)
else:
pos[0] = start

async def sock_sendto(self, sock, data, address):
"""Send data to the socket.
Expand Down
50 changes: 50 additions & 0 deletions Lib/test/test_asyncio/test_sock_lowlevel.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import ctypes

import socket
import asyncio
import sys
import struct
import unittest

from asyncio import proactor_events
from itertools import cycle, islice

from ipaddress import IPv4Address
from unittest.mock import Mock
from test.test_asyncio import utils as test_utils
from test import support
Expand Down Expand Up @@ -427,6 +432,51 @@ def test_recvfrom_into(self):
self.loop.run_until_complete(
self._basetest_datagram_recvfrom_into(server_address))

async def _basetest_datagram_sendmsg_recvmsg(self, server_address):
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.setblocking(False)
sock.setsockopt(socket.IPPROTO_IP, socket.IP_PKTINFO, 1)
sock.setsockopt(socket.IPPROTO_IP, socket.IP_RECVTOS, 1)
sock.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, 1)

data = b'\x01' * 4096
ancsize = 10240

ancillary_data = [(socket.IPPROTO_IP, socket.IP_TOS, b"\x08")]
await self.loop.sock_sendmsg(sock, data, ancillary_data, address=server_address)
rec_data, ancdata_rcv, msg_flags, address = await self.loop.sock_recvmsg(
sock, len(data), ancsize)

# Sent data is echoed back
self.assertEqual(data, rec_data)

# ancillary data
self.assertEqual(2, len(ancdata_rcv))
self.assertTrue(all(a[0] == socket.IPPROTO_IP for a in ancdata_rcv))
# PKTINFO
ancdata_rcv_pktinfo = [d for d in ancdata_rcv if d[1] == socket.IP_PKTINFO]
self.assertEqual(1, len(ancdata_rcv_pktinfo))
ancdata_rcv_pktinfo = ancdata_rcv_pktinfo[0]
# Not decoding the data. Just assert length as sanity check.
self.assertEqual(12, len(ancdata_rcv_pktinfo[2]))
# IP_RECVTOS
ancdata_rcv_rectos = [d for d in ancdata_rcv if d[1] == socket.IP_TOS]
self.assertEqual(1, len(ancdata_rcv_rectos))
ancdata_rcv_rectos = ancdata_rcv_rectos[0]
tos = int.from_bytes(struct.unpack("c", ancdata_rcv_rectos[2])[0], "big")
# the testing server is sending an empty TOS
self.assertEqual(tos, 0)

self.assertEqual(msg_flags, 0)
self.assertEqual(address[0], '127.0.0.1')

@unittest.skipUnless(hasattr(socket.socket, 'sendmsg') and hasattr(socket.socket, 'recvmsg'),
"The OS does not support sockets with 'sendmsg' or 'recvmsg'")
def test_sendmsg_recvmsg(self):
with test_utils.run_udp_echo_server() as server_address:
self.loop.run_until_complete(
self._basetest_datagram_sendmsg_recvmsg(server_address))

async def _basetest_datagram_sendto_blocking(self, server_address):
# Sad path, sock.sendto() raises BlockingIOError
# This involves patching sock.sendto() to raise BlockingIOError but
Expand Down

0 comments on commit ee826c3

Please sign in to comment.