Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make Socket useable after cancellation #99181

Merged
merged 1 commit into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3754,7 +3754,7 @@ internal void UpdateStatusAfterSocketError(SocketError errorCode, bool disconnec

if (disconnectOnFailure && _isConnected && (_handle.IsInvalid || (errorCode != SocketError.WouldBlock &&
errorCode != SocketError.IOPending && errorCode != SocketError.NoBufferSpaceAvailable &&
errorCode != SocketError.TimedOut)))
errorCode != SocketError.TimedOut && errorCode != SocketError.OperationAborted)))
{
// The socket is no longer a valid socket.
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, "Invalidating socket.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ private static async Task RunWithConnectedNetworkStreamsAsync(Func<NetworkStream
await Task.WhenAll(remoteTask, clientConnectTask);

using (TcpClient remote = remoteTask.Result)
using (NetworkStream serverStream = new NetworkStream(remote.Client, serverAccess, ownsSocket:true))
using (NetworkStream serverStream = new NetworkStream(remote.Client, serverAccess, ownsSocket: true))
using (NetworkStream clientStream = new NetworkStream(client.Client, clientAccess, ownsSocket: true))
{
await func(serverStream, clientStream);
Expand All @@ -560,6 +560,77 @@ private static async Task RunWithConnectedNetworkStreamsAsync(Func<NetworkStream
}
}

[Fact]
public async Task NetworkStream_ReadTimeout_RemainUseable()
{
using StreamPair streams = await CreateConnectedStreamsAsync();
NetworkStream readable = (NetworkStream)streams.Stream1;

Assert.True(readable.Socket.Connected);
readable.Socket.ReceiveTimeout = TestSettings.FailingTestTimeout;
var buffer = new byte[100];
int readBytes;
try
{
readBytes = readable.Read(buffer);
}
catch (IOException ex) when (ex.InnerException is SocketException && ((SocketException)ex.InnerException).SocketErrorCode == SocketError.TimedOut)
{
}
Assert.True(readable.Socket.Connected);

try
{
readBytes = readable.Read(buffer);
}
catch (IOException ex) when (ex.InnerException is SocketException && ((SocketException)ex.InnerException).SocketErrorCode == SocketError.TimedOut)
{
}
Assert.True(readable.Socket.Connected);

streams.Stream2.Write(new byte[] { 65 });
readBytes = readable.Read(buffer);
Assert.Equal(1, readBytes);
Assert.True(readable.Socket.Connected);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I've understood correctly, the above test passes before this PR and the below test only passes after this PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. This is because of the existing errorCode != SocketError.TimedOut.



[Fact]
public async Task NetworkStream_ReadAsyncTimeout_RemainUseable()
{
using StreamPair streams = await CreateConnectedStreamsAsync();
NetworkStream readable = (NetworkStream)streams.Stream1;

Assert.True(readable.Socket.Connected);

CancellationTokenSource cts = new CancellationTokenSource(TestSettings.FailingTestTimeout);
var buffer = new byte[100];
int readBytes;
try
{
readBytes = await readable.ReadAsync(buffer, cts.Token);
}
catch (OperationCanceledException)
{
}
Assert.True(readable.Socket.Connected);

try
{
cts = new CancellationTokenSource(TestSettings.FailingTestTimeout);
readBytes = await readable.ReadAsync(buffer, cts.Token);
}
catch (OperationCanceledException)
{
}
Assert.True(readable.Socket.Connected);

await streams.Stream2.WriteAsync(new byte[] { 65 });
readBytes = await readable.ReadAsync(buffer);
Assert.Equal(1, readBytes);
Assert.True(readable.Socket.Connected);
}

private sealed class DerivedNetworkStream : NetworkStream
{
public DerivedNetworkStream(Socket socket) : base(socket) { }
Expand Down