From dcb42e861e3343e7a9825c3eba5d8a8edba4ab65 Mon Sep 17 00:00:00 2001 From: Pieter-Jan Briers Date: Mon, 10 May 2021 12:12:04 +0200 Subject: [PATCH] Adds synchronous span APIs for datagram sockets. (#51956) * Adds synchronous span APIs for datagram sockets. Part of #33418/#938 * Doc comments for new APIs. * Fix review nits. * Add 0-byte send tests. --- .../ref/System.Net.Sockets.cs | 4 + .../src/System/Net/Sockets/Socket.cs | 148 ++++++++++++++++++ .../src/System/Net/Sockets/SocketPal.Unix.cs | 25 +++ .../System/Net/Sockets/SocketPal.Windows.cs | 45 ++---- .../SendReceive/SendReceive.cs | 109 ++++++++++++- .../tests/FunctionalTests/SocketTestHelper.cs | 15 ++ 6 files changed, 312 insertions(+), 34 deletions(-) diff --git a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs index 68b963f24bcb1..5238ff8d8dad9 100644 --- a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs +++ b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs @@ -385,6 +385,8 @@ public void Listen(int backlog) { } public int ReceiveFrom(byte[] buffer, int size, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; } public int ReceiveFrom(byte[] buffer, ref System.Net.EndPoint remoteEP) { throw null; } public int ReceiveFrom(byte[] buffer, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; } + public int ReceiveFrom(System.Span buffer, ref System.Net.EndPoint remoteEP) { throw null; } + public int ReceiveFrom(System.Span buffer, System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP) { throw null; } public System.Threading.Tasks.Task ReceiveFromAsync(System.ArraySegment buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint) { throw null; } public System.Threading.Tasks.ValueTask ReceiveFromAsync(System.Memory buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public bool ReceiveFromAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } @@ -419,6 +421,8 @@ public void SendFile(string? fileName, System.ReadOnlySpan preBuffer, Syst public int SendTo(byte[] buffer, int size, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; } public int SendTo(byte[] buffer, System.Net.EndPoint remoteEP) { throw null; } public int SendTo(byte[] buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; } + public int SendTo(System.ReadOnlySpan buffer, System.Net.EndPoint remoteEP) { throw null; } + public int SendTo(System.ReadOnlySpan buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; } public System.Threading.Tasks.Task SendToAsync(System.ArraySegment buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP) { throw null; } public System.Threading.Tasks.ValueTask SendToAsync(System.ReadOnlyMemory buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEP, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public bool SendToAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index b7cddab344b5c..b75858ddfd347 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -1380,6 +1380,67 @@ public int SendTo(byte[] buffer, EndPoint remoteEP) return SendTo(buffer, 0, buffer != null ? buffer.Length : 0, SocketFlags.None, remoteEP); } + /// + /// Sends data to the specified endpoint. + /// + /// A span of bytes that contains the data to be sent. + /// The that represents the destination for the data. + /// The number of bytes sent. + /// remoteEP is . + /// An error occurred when attempting to access the socket. + /// The has been closed. + public int SendTo(ReadOnlySpan buffer, EndPoint remoteEP) + { + return SendTo(buffer, SocketFlags.None, remoteEP); + } + + /// + /// Sends data to a specific endpoint using the specified . + /// + /// A span of bytes that contains the data to be sent. + /// A bitwise combination of the values. + /// The that represents the destination for the data. + /// The number of bytes sent. + /// remoteEP is . + /// An error occurred when attempting to access the socket. + /// The has been closed. + public int SendTo(ReadOnlySpan buffer, SocketFlags socketFlags, EndPoint remoteEP) + { + ThrowIfDisposed(); + if (remoteEP == null) + { + throw new ArgumentNullException(nameof(remoteEP)); + } + + ValidateBlockingMode(); + + Internals.SocketAddress socketAddress = Serialize(ref remoteEP); + + int bytesTransferred; + SocketError errorCode = SocketPal.SendTo(_handle, buffer, socketFlags, socketAddress.Buffer, socketAddress.Size, out bytesTransferred); + + // Throw an appropriate SocketException if the native call fails. + if (errorCode != SocketError.Success) + { + UpdateSendSocketErrorForDisposed(ref errorCode); + + UpdateStatusAfterSocketErrorAndThrowException(errorCode); + } + else if (SocketsTelemetry.Log.IsEnabled()) + { + SocketsTelemetry.Log.BytesSent(bytesTransferred); + if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramSent(); + } + + if (_rightEndPoint == null) + { + // Save a copy of the EndPoint so we can use it for Create(). + _rightEndPoint = remoteEP; + } + + return bytesTransferred; + } + // Receives data from a connected socket. public int Receive(byte[] buffer, int size, SocketFlags socketFlags) { @@ -1764,6 +1825,93 @@ public int ReceiveFrom(byte[] buffer, ref EndPoint remoteEP) return ReceiveFrom(buffer, 0, buffer != null ? buffer.Length : 0, SocketFlags.None, ref remoteEP); } + /// + /// Receives a datagram into the data buffer and stores the endpoint. + /// + /// A span of bytes that is the storage location for received data. + /// An , passed by reference, that represents the remote server. + /// The number of bytes received. + /// remoteEP is . + /// An error occurred when attempting to access the socket. + /// The has been closed. + public int ReceiveFrom(Span buffer, ref EndPoint remoteEP) + { + return ReceiveFrom(buffer, SocketFlags.None, ref remoteEP); + } + + /// + /// Receives a datagram into the data buffer, using the specified , and stores the endpoint. + /// + /// A span of bytes that is the storage location for received data. + /// A bitwise combination of the values. + /// An , passed by reference, that represents the remote server. + /// The number of bytes received. + /// remoteEP is . + /// An error occurred when attempting to access the socket. + /// The has been closed. + public int ReceiveFrom(Span buffer, SocketFlags socketFlags, ref EndPoint remoteEP) + { + ThrowIfDisposed(); + ValidateReceiveFromEndpointAndState(remoteEP, nameof(remoteEP)); + + SocketPal.CheckDualModeReceiveSupport(this); + + ValidateBlockingMode(); + + // We don't do a CAS demand here because the contents of remoteEP aren't used by + // WSARecvFrom; all that matters is that we generate a unique-to-this-call SocketAddress + // with the right address family. + EndPoint endPointSnapshot = remoteEP; + Internals.SocketAddress socketAddress = Serialize(ref endPointSnapshot); + Internals.SocketAddress socketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot); + + int bytesTransferred; + SocketError errorCode = SocketPal.ReceiveFrom(_handle, buffer, socketFlags, socketAddress.Buffer, ref socketAddress.InternalSize, out bytesTransferred); + + UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred); + // If the native call fails we'll throw a SocketException. + SocketException? socketException = null; + if (errorCode != SocketError.Success) + { + socketException = new SocketException((int)errorCode); + UpdateStatusAfterSocketError(socketException); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, socketException); + + if (socketException.SocketErrorCode != SocketError.MessageSize) + { + throw socketException; + } + } + else if (SocketsTelemetry.Log.IsEnabled()) + { + SocketsTelemetry.Log.BytesReceived(bytesTransferred); + if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramReceived(); + } + + if (!socketAddressOriginal.Equals(socketAddress)) + { + try + { + remoteEP = endPointSnapshot.Create(socketAddress); + } + catch + { + } + if (_rightEndPoint == null) + { + // Save a copy of the EndPoint so we can use it for Create(). + _rightEndPoint = endPointSnapshot; + } + } + + if (socketException != null) + { + throw socketException; + } + + return bytesTransferred; + } + public int IOControl(int ioControlCode, byte[]? optionInValue, byte[]? optionOutValue) { ThrowIfDisposed(); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index 5b44cf408afe9..0e64fc1f16d88 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -1203,6 +1203,19 @@ public static SocketError SendTo(SafeSocketHandle handle, byte[] buffer, int off return errorCode; } + public static SocketError SendTo(SafeSocketHandle handle, ReadOnlySpan buffer, SocketFlags socketFlags, byte[] socketAddress, int socketAddressLen, out int bytesTransferred) + { + if (!handle.IsNonBlocking) + { + return handle.AsyncContext.SendTo(buffer, socketFlags, socketAddress, socketAddressLen, handle.SendTimeout, out bytesTransferred); + } + + bytesTransferred = 0; + SocketError errorCode; + TryCompleteSendTo(handle, buffer, socketFlags, socketAddress, socketAddressLen, ref bytesTransferred, out errorCode); + return errorCode; + } + public static SocketError Receive(SafeSocketHandle handle, IList> buffers, SocketFlags socketFlags, out int bytesTransferred) { SocketError errorCode; @@ -1311,6 +1324,18 @@ public static SocketError ReceiveFrom(SafeSocketHandle handle, byte[] buffer, in return completed ? errorCode : SocketError.WouldBlock; } + public static SocketError ReceiveFrom(SafeSocketHandle handle, Span buffer, SocketFlags socketFlags, byte[] socketAddress, ref int socketAddressLen, out int bytesTransferred) + { + if (!handle.IsNonBlocking) + { + return handle.AsyncContext.ReceiveFrom(buffer, ref socketFlags, socketAddress, ref socketAddressLen, handle.ReceiveTimeout, out bytesTransferred); + } + + SocketError errorCode; + bool completed = TryCompleteReceiveFrom(handle, buffer, socketFlags, socketAddress, ref socketAddressLen, out bytesTransferred, out socketFlags, out errorCode); + return completed ? errorCode : SocketError.WouldBlock; + } + public static SocketError WindowsIoctl(SafeSocketHandle handle, int ioControlCode, byte[]? optionInValue, byte[]? optionOutValue, out int optionLength) { // Three codes are called out in the Winsock IOCTLs documentation as "The following Unix IOCTL codes (commands) are supported." They are diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs index 6e59eda00d6e0..e157c687fd4a1 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs @@ -300,31 +300,15 @@ public static unsafe SocketError SendFile(SafeSocketHandle handle, SafeFileHandl } } - public static unsafe SocketError SendTo(SafeSocketHandle handle, byte[] buffer, int offset, int size, SocketFlags socketFlags, byte[] peerAddress, int peerAddressSize, out int bytesTransferred) + public static SocketError SendTo(SafeSocketHandle handle, byte[] buffer, int offset, int size, SocketFlags socketFlags, byte[] peerAddress, int peerAddressSize, out int bytesTransferred) => + SendTo(handle, buffer.AsSpan(offset, size), socketFlags, peerAddress, peerAddressSize, out bytesTransferred); + + public static unsafe SocketError SendTo(SafeSocketHandle handle, ReadOnlySpan buffer, SocketFlags socketFlags, byte[] peerAddress, int peerAddressSize, out int bytesTransferred) { int bytesSent; - if (buffer.Length == 0) + fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer)) { - bytesSent = Interop.Winsock.sendto( - handle, - null, - 0, - socketFlags, - peerAddress, - peerAddressSize); - } - else - { - fixed (byte* pinnedBuffer = &buffer[0]) - { - bytesSent = Interop.Winsock.sendto( - handle, - pinnedBuffer + offset, - size, - socketFlags, - peerAddress, - peerAddressSize); - } + bytesSent = Interop.Winsock.sendto(handle, bufferPtr, buffer.Length, socketFlags, peerAddress, peerAddressSize); } if (bytesSent == (int)SocketError.SocketError) @@ -528,19 +512,16 @@ public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHan return SocketError.Success; } - public static unsafe SocketError ReceiveFrom(SafeSocketHandle handle, byte[] buffer, int offset, int size, SocketFlags socketFlags, byte[] socketAddress, ref int addressLength, out int bytesTransferred) + public static unsafe SocketError ReceiveFrom(SafeSocketHandle handle, byte[] buffer, int offset, int size, SocketFlags socketFlags, byte[] socketAddress, ref int addressLength, out int bytesTransferred) => + ReceiveFrom(handle, buffer.AsSpan(offset, size), SocketFlags.None, socketAddress, ref addressLength, out bytesTransferred); + + public static unsafe SocketError ReceiveFrom(SafeSocketHandle handle, Span buffer, SocketFlags socketFlags, byte[] socketAddress, ref int addressLength, out int bytesTransferred) { int bytesReceived; - if (buffer.Length == 0) - { - bytesReceived = Interop.Winsock.recvfrom(handle, null, 0, socketFlags, socketAddress, ref addressLength); - } - else + + fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer)) { - fixed (byte* pinnedBuffer = &buffer[0]) - { - bytesReceived = Interop.Winsock.recvfrom(handle, pinnedBuffer + offset, size, socketFlags, socketAddress, ref addressLength); - } + bytesReceived = Interop.Winsock.recvfrom(handle, bufferPtr, buffer.Length, socketFlags, socketAddress, ref addressLength); } if (bytesReceived == (int)SocketError.SocketError) diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive/SendReceive.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive/SendReceive.cs index b06faaf5a7580..e9d1bf1e42418 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive/SendReceive.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendReceive/SendReceive.cs @@ -592,7 +592,7 @@ await Task.WhenAll( { for (int i = 0; i < 3; i++) { - // Zero byte send should be a no-op + // Zero byte send should be a no-op int bytesSent = await SendAsync(client, new ArraySegment(Array.Empty())); Assert.Equal(0, bytesSent); @@ -650,6 +650,41 @@ await Task.WhenAll( } } + [Fact] + public async Task Send_0ByteSendTo_Success() + { + using (Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) + using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) + { + server.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + client.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + + for (int i = 0; i < 3; i++) + { + // Send empty packet then real data. + int bytesSent = await SendToAsync( + client, new ArraySegment(Array.Empty()), server.LocalEndPoint!); + Assert.Equal(0, bytesSent); + + await SendToAsync(client, new byte[] { 99 }, server.LocalEndPoint); + + // Read empty packet + byte[] buffer = new byte[10]; + SocketReceiveFromResult result = await ReceiveFromAsync(server, buffer, new IPEndPoint(IPAddress.Any, 0)); + + Assert.Equal(0, result.ReceivedBytes); + Assert.Equal(client.LocalEndPoint, result.RemoteEndPoint); + + // Read real packet. + result = await ReceiveFromAsync(server, buffer, new IPEndPoint(IPAddress.Any, 0)); + + Assert.Equal(1, result.ReceivedBytes); + Assert.Equal(client.LocalEndPoint, result.RemoteEndPoint); + Assert.Equal(99, buffer[0]); + } + } + } + [Fact] public async Task Receive0ByteReturns_WhenPeerDisconnects() { @@ -1201,7 +1236,7 @@ await Task.WhenAll( { for (int i = 0; i < 3; i++) { - // Zero byte send should be a no-op + // Zero byte send should be a no-op int bytesSent = client.Send(ReadOnlySpan.Empty, SocketFlags.None); Assert.Equal(0, bytesSent); @@ -1215,6 +1250,41 @@ await Task.WhenAll( } } } + + [Fact] + public async Task Send_0ByteSendTo_Span_Success() + { + using (Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) + using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) + { + server.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + client.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + + for (int i = 0; i < 3; i++) + { + // Send empty packet then real data. + int bytesSent = client.SendTo(ReadOnlySpan.Empty, server.LocalEndPoint!); + Assert.Equal(0, bytesSent); + + client.SendTo(new byte[] { 99 }, server.LocalEndPoint); + + // Read empty packet + byte[] buffer = new byte[10]; + SocketReceiveFromResult result = await ReceiveFromAsync(server, buffer, new IPEndPoint(IPAddress.Any, 0)); + + Assert.Equal(0, result.ReceivedBytes); + Assert.Equal(client.LocalEndPoint, result.RemoteEndPoint); + + // Read real packet. + result = await ReceiveFromAsync(server, buffer, new IPEndPoint(IPAddress.Any, 0)); + + Assert.Equal(1, result.ReceivedBytes); + Assert.Equal(client.LocalEndPoint, result.RemoteEndPoint); + Assert.Equal(99, buffer[0]); + } + } + } + } public sealed class SendReceive_SpanSyncForceNonBlocking : SendReceive @@ -1260,6 +1330,41 @@ await Task.WhenAll( } } + [Fact] + public async Task Send_0ByteSendTo_Memory_Success() + { + using (Socket server = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) + using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)) + { + server.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + client.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + + for (int i = 0; i < 3; i++) + { + // Send empty packet then real data. + int bytesSent = await client.SendToAsync( + ReadOnlyMemory.Empty, SocketFlags.None, server.LocalEndPoint!); + Assert.Equal(0, bytesSent); + + await client.SendToAsync(new byte[] { 99 }, SocketFlags.None, server.LocalEndPoint); + + // Read empty packet + byte[] buffer = new byte[10]; + SocketReceiveFromResult result = await ReceiveFromAsync(server, buffer, new IPEndPoint(IPAddress.Any, 0)); + + Assert.Equal(0, result.ReceivedBytes); + Assert.Equal(client.LocalEndPoint, result.RemoteEndPoint); + + // Read real packet. + result = await ReceiveFromAsync(server, buffer, new IPEndPoint(IPAddress.Any, 0)); + + Assert.Equal(1, result.ReceivedBytes); + Assert.Equal(client.LocalEndPoint, result.RemoteEndPoint); + Assert.Equal(99, buffer[0]); + } + } + } + [Fact] public async Task Precanceled_Throws() { diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs index 1c27291c683ba..7c40f8334a055 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs @@ -487,6 +487,20 @@ public override Task ReceiveAsync(Socket s, ArraySegment buffer) => Task.Run(() => s.Receive((Span)buffer, SocketFlags.None)); public override Task SendAsync(Socket s, ArraySegment buffer) => Task.Run(() => s.Send((ReadOnlySpan)buffer, SocketFlags.None)); + public override Task ReceiveFromAsync(Socket s, ArraySegment buffer, + EndPoint endPoint) => + Task.Run(() => + { + SocketFlags socketFlags = SocketFlags.None; + int received = s.ReceiveFrom((Span)buffer, socketFlags, ref endPoint); + return new SocketReceiveFromResult + { + ReceivedBytes = received, + RemoteEndPoint = endPoint, + }; + }); + public override Task SendToAsync(Socket s, ArraySegment buffer, EndPoint endPoint) => + Task.Run(() => s.SendTo((ReadOnlySpan)buffer, endPoint)); public override Task ReceiveMessageFromAsync(Socket s, ArraySegment buffer, EndPoint endPoint) => Task.Run(() => { @@ -501,6 +515,7 @@ public override Task ReceiveMessageFromAsync(Soc PacketInformation = ipPacketInformation }; }); + public override Task SendFileAsync(Socket s, string fileName, ArraySegment preBuffer, ArraySegment postBuffer, TransmitFileOptions flags) => Task.Run(() => s.SendFile(fileName, preBuffer, postBuffer, flags)); public override bool UsesSync => true;