Skip to content

Commit

Permalink
fix zero byte Send on linux (#51473)
Browse files Browse the repository at this point in the history
* fix zero byte send on linux and update tests

Co-authored-by: Geoffrey Kizer <[email protected]>
  • Loading branch information
geoffkizer and Geoffrey Kizer authored Apr 20, 2021
1 parent a0811a1 commit 333a6c7
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.Win32.SafeHandles;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Win32.SafeHandles;

namespace System.Net.Sockets
{
Expand Down Expand Up @@ -197,10 +197,17 @@ private static unsafe int SysWrite(SafeSocketHandle handle, ReadOnlySpan<byte> b
return sent;
}

// The Linux kernel doesn't like it if we pass a null reference for buffer pointers, even if the length is 0.
// Replace any null pointer (e.g. from Memory<byte>.Empty) with a valid pointer.
private static ReadOnlySpan<byte> AvoidNullReference(ReadOnlySpan<byte> buffer) =>
Unsafe.IsNullRef(ref MemoryMarshal.GetReference(buffer)) ? Array.Empty<byte>() : buffer;

private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, ReadOnlySpan<byte> buffer, ref int offset, ref int count, out Interop.Error errno)
{
Debug.Assert(socket.IsSocket);

buffer = AvoidNullReference(buffer);

int sent;
fixed (byte* b = &MemoryMarshal.GetReference(buffer))
{
Expand All @@ -226,6 +233,8 @@ private static unsafe int SysSend(SafeSocketHandle socket, SocketFlags flags, Re
{
Debug.Assert(socket.IsSocket);

buffer = AvoidNullReference(buffer);

int sent;
fixed (byte* sockAddr = socketAddress)
fixed (byte* b = &MemoryMarshal.GetReference(buffer))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,39 @@ public async Task SendRecvPollSync_TcpListener_Socket(IPAddress listenAt, bool p
}
}

[Fact]
public async Task Send_0ByteSend_Success()
{
using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);

Task<Socket> acceptTask = AcceptAsync(listener);
await Task.WhenAll(
acceptTask,
ConnectAsync(client, new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)listener.LocalEndPoint).Port)));

using (Socket server = await acceptTask)
{
for (int i = 0; i < 3; i++)
{
// Zero byte send should be a no-op
int bytesSent = await SendAsync(client, new ArraySegment<byte>(Array.Empty<byte>()));
Assert.Equal(0, bytesSent);

// Socket should still be usable
await SendAsync(client, new byte[] { 99 });
byte[] buffer = new byte[10];
int bytesReceived = await ReceiveAsync(server, buffer);
Assert.Equal(1, bytesReceived);
Assert.Equal(99, buffer[0]);
}
}
}
}

[Fact]
public async Task SendRecv_0ByteReceive_Success()
{
Expand Down Expand Up @@ -1149,6 +1182,39 @@ public SendReceive_Eap(ITestOutputHelper output) : base(output) {}
public sealed class SendReceive_SpanSync : SendReceive<SocketHelperSpanSync>
{
public SendReceive_SpanSync(ITestOutputHelper output) : base(output) { }

[Fact]
public async Task Send_0ByteSend_Span_Success()
{
using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);

Task<Socket> acceptTask = AcceptAsync(listener);
await Task.WhenAll(
acceptTask,
ConnectAsync(client, new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)listener.LocalEndPoint).Port)));

using (Socket server = await acceptTask)
{
for (int i = 0; i < 3; i++)
{
// Zero byte send should be a no-op
int bytesSent = client.Send(ReadOnlySpan<byte>.Empty, SocketFlags.None);
Assert.Equal(0, bytesSent);

// Socket should still be usable
await SendAsync(client, new byte[] { 99 });
byte[] buffer = new byte[10];
int bytesReceived = await ReceiveAsync(server, buffer);
Assert.Equal(1, bytesReceived);
Assert.Equal(99, buffer[0]);
}
}
}
}
}

public sealed class SendReceive_SpanSyncForceNonBlocking : SendReceive<SocketHelperSpanSyncForceNonBlocking>
Expand All @@ -1160,6 +1226,40 @@ public sealed class SendReceive_MemoryArrayTask : SendReceive<SocketHelperMemory
{
public SendReceive_MemoryArrayTask(ITestOutputHelper output) : base(output) { }

[Fact]
public async Task Send_0ByteSend_Memory_Success()
{
using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);

Task<Socket> acceptTask = AcceptAsync(listener);
await Task.WhenAll(
acceptTask,
ConnectAsync(client, new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)listener.LocalEndPoint).Port)));

using (Socket server = await acceptTask)
{
for (int i = 0; i < 3; i++)
{
// Zero byte send should be a no-op and complete immediately
Task<int> sendTask = client.SendAsync(ReadOnlyMemory<byte>.Empty, SocketFlags.None).AsTask();
Assert.True(sendTask.IsCompleted);
Assert.Equal(0, await sendTask);

// Socket should still be usable
await SendAsync(client, new byte[] { 99 });
byte[] buffer = new byte[10];
int bytesReceived = await ReceiveAsync(server, buffer);
Assert.Equal(1, bytesReceived);
Assert.Equal(99, buffer[0]);
}
}
}
}

[Fact]
public async Task Precanceled_Throws()
{
Expand Down

0 comments on commit 333a6c7

Please sign in to comment.