diff --git a/eng/Versions.props b/eng/Versions.props
index 57ccfbd0c1fa7..e67e223bb4188 100644
--- a/eng/Versions.props
+++ b/eng/Versions.props
@@ -219,7 +219,7 @@
8.0.0-rtm.23523.2
- 2.2.3
+ 2.2.5-ci.444313
8.0.0-alpha.1.23527.1
16.0.5-alpha.1.23566.1
diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs
index e9cce1c24d34d..14c9685a4f511 100644
--- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs
+++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs
@@ -1446,6 +1446,14 @@ private int WriteHeaderCollection(HttpRequestMessage request, HttpHeaders header
continue;
}
+ // Extended connect requests will use the response content stream for bidirectional communication.
+ // We will ignore any content set for such requests in Http2Stream.SendRequestBodyAsync, as it has no defined semantics.
+ // Drop the Content-Length header as well in the unlikely case it was set.
+ if (knownHeader == KnownHeaders.ContentLength && request.IsExtendedConnectRequest)
+ {
+ continue;
+ }
+
// For all other known headers, send them via their pre-encoded name and the associated value.
WriteBytes(knownHeader.Http2EncodedName, ref headerBuffer);
string? separator = null;
diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs
index de66b7cfa103d..d834679274b4a 100644
--- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs
+++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs
@@ -105,7 +105,9 @@ public Http2Stream(HttpRequestMessage request, Http2Connection connection)
_headerBudgetRemaining = connection._pool.Settings.MaxResponseHeadersByteLength;
- if (_request.Content == null)
+ // Extended connect requests will use the response content stream for bidirectional communication.
+ // We will ignore any content set for such requests in SendRequestBodyAsync, as it has no defined semantics.
+ if (_request.Content == null || _request.IsExtendedConnectRequest)
{
_requestCompletionState = StreamCompletionState.Completed;
if (_request.IsExtendedConnectRequest)
@@ -173,7 +175,9 @@ public HttpResponseMessage GetAndClearResponse()
public async Task SendRequestBodyAsync(CancellationToken cancellationToken)
{
- if (_request.Content == null)
+ // Extended connect requests will use the response content stream for bidirectional communication.
+ // Ignore any content set for such requests, as it has no defined semantics.
+ if (_request.Content == null || _request.IsExtendedConnectRequest)
{
Debug.Assert(_requestCompletionState == StreamCompletionState.Completed);
return;
@@ -250,6 +254,7 @@ public async Task SendRequestBodyAsync(CancellationToken cancellationToken)
// and we also don't want to propagate any error to the caller, in particular for non-duplex scenarios.
Debug.Assert(_responseCompletionState == StreamCompletionState.Completed);
_requestCompletionState = StreamCompletionState.Completed;
+ Debug.Assert(!ConnectProtocolEstablished);
Complete();
return;
}
@@ -261,6 +266,7 @@ public async Task SendRequestBodyAsync(CancellationToken cancellationToken)
_requestCompletionState = StreamCompletionState.Failed;
SendReset();
+ Debug.Assert(!ConnectProtocolEstablished);
Complete();
}
@@ -313,6 +319,7 @@ public async Task SendRequestBodyAsync(CancellationToken cancellationToken)
if (complete)
{
+ Debug.Assert(!ConnectProtocolEstablished);
Complete();
}
}
@@ -420,7 +427,17 @@ private void Cancel()
if (sendReset)
{
SendReset();
- Complete();
+
+ // Extended CONNECT notes:
+ //
+ // To prevent from calling it *twice*, Extended CONNECT stream's Complete() is only
+ // called from CloseResponseBody(), as CloseResponseBody() is *always* called
+ // from Extended CONNECT stream's Dispose().
+
+ if (!ConnectProtocolEstablished)
+ {
+ Complete();
+ }
}
}
@@ -810,7 +827,20 @@ public void OnHeadersComplete(bool endStream)
Debug.Assert(_responseCompletionState == StreamCompletionState.InProgress, $"Response already completed with state={_responseCompletionState}");
_responseCompletionState = StreamCompletionState.Completed;
- if (_requestCompletionState == StreamCompletionState.Completed)
+
+ // Extended CONNECT notes:
+ //
+ // To prevent from calling it *prematurely*, Extended CONNECT stream's Complete() is only
+ // called from CloseResponseBody(), as CloseResponseBody() is *only* called
+ // from Extended CONNECT stream's Dispose().
+ //
+ // Due to bidirectional streaming nature of the Extended CONNECT request,
+ // the *write side* of the stream can only be completed by calling Dispose().
+ //
+ // The streaming in both ways happens over the single "response" stream instance, which makes
+ // _requestCompletionState *not indicative* of the actual state of the write side of the stream.
+
+ if (_requestCompletionState == StreamCompletionState.Completed && !ConnectProtocolEstablished)
{
Complete();
}
@@ -871,7 +901,20 @@ public void OnResponseData(ReadOnlySpan buffer, bool endStream)
Debug.Assert(_responseCompletionState == StreamCompletionState.InProgress, $"Response already completed with state={_responseCompletionState}");
_responseCompletionState = StreamCompletionState.Completed;
- if (_requestCompletionState == StreamCompletionState.Completed)
+
+ // Extended CONNECT notes:
+ //
+ // To prevent from calling it *prematurely*, Extended CONNECT stream's Complete() is only
+ // called from CloseResponseBody(), as CloseResponseBody() is *only* called
+ // from Extended CONNECT stream's Dispose().
+ //
+ // Due to bidirectional streaming nature of the Extended CONNECT request,
+ // the *write side* of the stream can only be completed by calling Dispose().
+ //
+ // The streaming in both ways happens over the single "response" stream instance, which makes
+ // _requestCompletionState *not indicative* of the actual state of the write side of the stream.
+
+ if (_requestCompletionState == StreamCompletionState.Completed && !ConnectProtocolEstablished)
{
Complete();
}
@@ -1036,17 +1079,17 @@ public async Task ReadResponseHeadersAsync(CancellationToken cancellationToken)
Debug.Assert(_response != null && _response.Content != null);
// Start to process the response body.
var responseContent = (HttpConnectionResponseContent)_response.Content;
- if (emptyResponse)
+ if (ConnectProtocolEstablished)
+ {
+ responseContent.SetStream(new Http2ReadWriteStream(this, closeResponseBodyOnDispose: true));
+ }
+ else if (emptyResponse)
{
// If there are any trailers, copy them over to the response. Normally this would be handled by
// the response stream hitting EOF, but if there is no response body, we do it here.
MoveTrailersToResponseMessage(_response);
responseContent.SetStream(EmptyReadStream.Instance);
}
- else if (ConnectProtocolEstablished)
- {
- responseContent.SetStream(new Http2ReadWriteStream(this));
- }
else
{
responseContent.SetStream(new Http2ReadStream(this));
@@ -1309,8 +1352,25 @@ private async ValueTask SendDataAsync(ReadOnlyMemory buffer, CancellationT
}
}
+ // This method should only be called from Http2ReadWriteStream.Dispose()
private void CloseResponseBody()
{
+ // Extended CONNECT notes:
+ //
+ // Due to bidirectional streaming nature of the Extended CONNECT request,
+ // the *write side* of the stream can only be completed by calling Dispose()
+ // (which, for Extended CONNECT case, will in turn call CloseResponseBody())
+ //
+ // Similarly to QuicStream, disposal *gracefully* closes the write side of the stream
+ // (unless we've received RST_STREAM before) and *abortively* closes the read side
+ // of the stream (unless we've received EOS before).
+
+ if (ConnectProtocolEstablished && _resetException is null)
+ {
+ // Gracefully close the write side of the Extended CONNECT stream
+ _connection.LogExceptions(_connection.SendEndStreamAsync(StreamId));
+ }
+
// Check if the response body has been fully consumed.
bool fullyConsumed = false;
Debug.Assert(!Monitor.IsEntered(SyncObject));
@@ -1323,6 +1383,7 @@ private void CloseResponseBody()
}
// If the response body isn't completed, cancel it now.
+ // This includes aborting the read side of the Extended CONNECT stream.
if (!fullyConsumed)
{
Cancel();
@@ -1337,6 +1398,12 @@ private void CloseResponseBody()
lock (SyncObject)
{
+ if (ConnectProtocolEstablished)
+ {
+ // This should be the only place where Extended Connect stream is completed
+ Complete();
+ }
+
_responseBuffer.Dispose();
}
}
@@ -1430,10 +1497,7 @@ private enum StreamCompletionState : byte
private sealed class Http2ReadStream : Http2ReadWriteStream
{
- public Http2ReadStream(Http2Stream http2Stream) : base(http2Stream)
- {
- base.CloseResponseBodyOnDispose = true;
- }
+ public Http2ReadStream(Http2Stream http2Stream) : base(http2Stream, closeResponseBodyOnDispose: true) { }
public override bool CanWrite => false;
@@ -1482,12 +1546,13 @@ public class Http2ReadWriteStream : HttpBaseStream
private Http2Stream? _http2Stream;
private readonly HttpResponseMessage _responseMessage;
- public Http2ReadWriteStream(Http2Stream http2Stream)
+ public Http2ReadWriteStream(Http2Stream http2Stream, bool closeResponseBodyOnDispose = false)
{
Debug.Assert(http2Stream != null);
Debug.Assert(http2Stream._response != null);
_http2Stream = http2Stream;
_responseMessage = _http2Stream._response;
+ CloseResponseBodyOnDispose = closeResponseBodyOnDispose;
}
~Http2ReadWriteStream()
@@ -1503,7 +1568,7 @@ public Http2ReadWriteStream(Http2Stream http2Stream)
}
}
- protected bool CloseResponseBodyOnDispose { get; set; }
+ protected bool CloseResponseBodyOnDispose { get; private init; }
protected override void Dispose(bool disposing)
{
diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Http2ExtendedConnect.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Http2ExtendedConnect.cs
index 0ea6ae9e13f60..cb1a15df14e77 100644
--- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Http2ExtendedConnect.cs
+++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Http2ExtendedConnect.cs
@@ -2,8 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Generic;
+using System.Diagnostics;
using System.IO;
using System.Net.Test.Common;
+using System.Threading;
using System.Threading.Tasks;
using Xunit;
using Xunit.Abstractions;
@@ -31,6 +33,7 @@ public static IEnumerable UseSsl_MemberData()
[MemberData(nameof(UseSsl_MemberData))]
public async Task Connect_ReadWriteResponseStream(bool useSsl)
{
+ const int MessageCount = 3;
byte[] clientMessage = new byte[] { 1, 2, 3 };
byte[] serverMessage = new byte[] { 4, 5, 6, 7 };
@@ -43,19 +46,39 @@ await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync(async uri
HttpRequestMessage request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true);
request.Headers.Protocol = "foo";
+ bool readFromContentStream = false;
+
+ // We won't send the content bytes, but we will send content headers.
+ // Since we're dropping the content, we'll also drop the Content-Length header.
+ request.Content = new StreamContent(new DelegateStream(
+ readAsyncFunc: (_, _, _, _) =>
+ {
+ readFromContentStream = true;
+ throw new UnreachableException();
+ }));
+
+ request.Headers.Add("User-Agent", "foo");
+ request.Content.Headers.Add("Content-Language", "bar");
+ request.Content.Headers.ContentLength = 42;
+
using HttpResponseMessage response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
using Stream responseStream = await response.Content.ReadAsStreamAsync();
- await responseStream.WriteAsync(clientMessage);
- await responseStream.FlushAsync();
+ for (int i = 0; i < MessageCount; i++)
+ {
+ await responseStream.WriteAsync(clientMessage);
+ await responseStream.FlushAsync();
- byte[] readBuffer = new byte[serverMessage.Length];
- await responseStream.ReadExactlyAsync(readBuffer);
- Assert.Equal(serverMessage, readBuffer);
+ byte[] readBuffer = new byte[serverMessage.Length];
+ await responseStream.ReadExactlyAsync(readBuffer);
+ Assert.Equal(serverMessage, readBuffer);
+ }
// Receive server's EOS
- Assert.Equal(0, await responseStream.ReadAsync(readBuffer));
+ Assert.Equal(0, await responseStream.ReadAsync(new byte[1]));
+
+ Assert.False(readFromContentStream);
clientCompleted.SetResult();
},
@@ -63,14 +86,21 @@ await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync(async uri
{
await using Http2LoopbackConnection connection = await ((Http2LoopbackServer)server).EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 });
- (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
+ (int streamId, HttpRequestData request) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
+
+ Assert.Equal("foo", request.GetSingleHeaderValue("User-Agent"));
+ Assert.Equal("bar", request.GetSingleHeaderValue("Content-Language"));
+ Assert.Equal(0, request.GetHeaderValueCount("Content-Length"));
await connection.SendResponseHeadersAsync(streamId, endStream: false).ConfigureAwait(false);
- DataFrame dataFrame = await connection.ReadDataFrameAsync();
- Assert.Equal(clientMessage, dataFrame.Data.ToArray());
+ for (int i = 0; i < MessageCount; i++)
+ {
+ DataFrame dataFrame = await connection.ReadDataFrameAsync();
+ Assert.Equal(clientMessage, dataFrame.Data.ToArray());
- await connection.SendResponseDataAsync(streamId, serverMessage, endStream: true);
+ await connection.SendResponseDataAsync(streamId, serverMessage, endStream: i == MessageCount - 1);
+ }
await clientCompleted.Task.WaitAsync(TestHelper.PassingTestTimeout);
}, options: new GenericLoopbackOptions { UseSsl = useSsl });
@@ -163,5 +193,112 @@ await server.AcceptConnectionAsync(async connection =>
await new[] { serverTask, clientTask }.WhenAllOrAnyFailed().WaitAsync(TestHelper.PassingTestTimeout);
}
+
+ [Theory]
+ [MemberData(nameof(UseSsl_MemberData))]
+ public async Task Connect_ServerSideEOS_ReceivedByClient(bool useSsl)
+ {
+ var timeoutTcs = new CancellationTokenSource(TestHelper.PassingTestTimeout);
+ var serverReceivedEOS = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
+ await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync(
+ clientFunc: async uri =>
+ {
+ var client = CreateHttpClient();
+ var request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true);
+ request.Headers.Protocol = "foo";
+
+ var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutTcs.Token);
+ var responseStream = await response.Content.ReadAsStreamAsync(timeoutTcs.Token);
+
+ // receive server's EOS
+ Assert.Equal(0, await responseStream.ReadAsync(new byte[1], timeoutTcs.Token));
+
+ // send client's EOS
+ responseStream.Dispose();
+
+ // wait for "ack" from server
+ await serverReceivedEOS.Task.WaitAsync(timeoutTcs.Token);
+
+ // can dispose handler now
+ client.Dispose();
+ },
+ serverFunc: async server =>
+ {
+ await using var connection = await ((Http2LoopbackServer)server).EstablishConnectionAsync(
+ new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 });
+
+ (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
+ await connection.SendResponseHeadersAsync(streamId, endStream: false);
+
+ // send server's EOS
+ await connection.SendResponseDataAsync(streamId, Array.Empty(), endStream: true);
+
+ // receive client's EOS "in response" to server's EOS
+ var eosFrame = Assert.IsType(await connection.ReadFrameAsync(timeoutTcs.Token));
+ Assert.Equal(streamId, eosFrame.StreamId);
+ Assert.Equal(0, eosFrame.Data.Length);
+ Assert.True(eosFrame.EndStreamFlag);
+
+ serverReceivedEOS.SetResult();
+
+ // on handler dispose, client should shutdown the connection without sending additional frames
+ await connection.WaitForClientDisconnectAsync().WaitAsync(timeoutTcs.Token);
+ },
+ options: new GenericLoopbackOptions { UseSsl = useSsl });
+ }
+
+ [Theory]
+ [MemberData(nameof(UseSsl_MemberData))]
+ public async Task Connect_ClientSideEOS_ReceivedByServer(bool useSsl)
+ {
+ var timeoutTcs = new CancellationTokenSource(TestHelper.PassingTestTimeout);
+ var serverReceivedRst = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
+ await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync(
+ clientFunc: async uri =>
+ {
+ var client = CreateHttpClient();
+ var request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true);
+ request.Headers.Protocol = "foo";
+
+ var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutTcs.Token);
+ var responseStream = await response.Content.ReadAsStreamAsync(timeoutTcs.Token);
+
+ // send client's EOS
+ // this will also send RST_STREAM as we didn't receive server's EOS before
+ responseStream.Dispose();
+
+ // wait for "ack" from server
+ await serverReceivedRst.Task.WaitAsync(timeoutTcs.Token);
+
+ // can dispose handler now
+ client.Dispose();
+ },
+ serverFunc: async server =>
+ {
+ await using var connection = await ((Http2LoopbackServer)server).EstablishConnectionAsync(
+ new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 });
+
+ (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
+ await connection.SendResponseHeadersAsync(streamId, endStream: false);
+
+ // receive client's EOS
+ var eosFrame = Assert.IsType(await connection.ReadFrameAsync(timeoutTcs.Token));
+ Assert.Equal(streamId, eosFrame.StreamId);
+ Assert.Equal(0, eosFrame.Data.Length);
+ Assert.True(eosFrame.EndStreamFlag);
+
+ // receive client's RST_STREAM as we didn't send server's EOS before
+ var rstFrame = Assert.IsType(await connection.ReadFrameAsync(timeoutTcs.Token));
+ Assert.Equal(streamId, rstFrame.StreamId);
+
+ serverReceivedRst.SetResult();
+
+ // on handler dispose, client should shutdown the connection without sending additional frames
+ await connection.WaitForClientDisconnectAsync().WaitAsync(timeoutTcs.Token);
+ },
+ options: new GenericLoopbackOptions { UseSsl = useSsl });
+ }
}
}
diff --git a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs
new file mode 100644
index 0000000000000..0aa83697a9de7
--- /dev/null
+++ b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs
@@ -0,0 +1,246 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace System.Net.WebSockets.Client.Tests
+{
+ [ConditionalClass(typeof(ClientWebSocketTestBase), nameof(WebSocketsSupported))]
+ [SkipOnPlatform(TestPlatforms.Browser, "System.Net.Sockets are not supported on browser")]
+ public abstract class AbortTest_Loopback : ClientWebSocketTestBase
+ {
+ public AbortTest_Loopback(ITestOutputHelper output) : base(output) { }
+
+ protected virtual Version HttpVersion => Net.HttpVersion.Version11;
+
+ [Theory]
+ [MemberData(nameof(AbortClient_MemberData))]
+ public Task AbortClient_ServerGetsCorrectException(AbortType abortType, bool useSsl, bool verifySendReceive)
+ {
+ var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 };
+ var serverMsg = new byte[] { 42 };
+ var clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
+ var timeoutCts = new CancellationTokenSource(TimeOutMilliseconds);
+
+ return LoopbackWebSocketServer.RunAsync(
+ async (clientWebSocket, token) =>
+ {
+ if (verifySendReceive)
+ {
+ await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token);
+ }
+
+ switch (abortType)
+ {
+ case AbortType.Abort:
+ clientWebSocket.Abort();
+ break;
+ case AbortType.Dispose:
+ clientWebSocket.Dispose();
+ break;
+ }
+ },
+ async (serverWebSocket, token) =>
+ {
+ if (verifySendReceive)
+ {
+ await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token);
+ }
+
+ var readBuffer = new byte[1];
+ var exception = await Assert.ThrowsAsync(async () =>
+ await serverWebSocket.ReceiveAsync(readBuffer, token));
+
+ Assert.Equal(WebSocketError.ConnectionClosedPrematurely, exception.WebSocketErrorCode);
+ Assert.Equal(WebSocketState.Aborted, serverWebSocket.State);
+ },
+ new LoopbackWebSocketServer.Options(HttpVersion, useSsl, GetInvoker()),
+ timeoutCts.Token);
+ }
+
+ [Theory]
+ [MemberData(nameof(ServerPrematureEos_MemberData))]
+ public Task ServerPrematureEos_ClientGetsCorrectException(ServerEosType serverEosType, bool useSsl)
+ {
+ var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 };
+ var serverMsg = new byte[] { 42 };
+ var clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
+ var timeoutCts = new CancellationTokenSource(TimeOutMilliseconds);
+
+ var globalOptions = new LoopbackWebSocketServer.Options(HttpVersion, useSsl, HttpInvoker: null)
+ {
+ DisposeServerWebSocket = false,
+ ManualServerHandshakeResponse = true
+ };
+
+ var serverReceivedEosTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientReceivedEosTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
+ return LoopbackWebSocketServer.RunAsync(
+ async uri =>
+ {
+ var token = timeoutCts.Token;
+ var clientOptions = globalOptions with { HttpInvoker = GetInvoker() };
+ var clientWebSocket = await LoopbackWebSocketServer.GetConnectedClientAsync(uri, clientOptions, token).ConfigureAwait(false);
+
+ if (serverEosType == ServerEosType.AfterSomeData)
+ {
+ await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token).ConfigureAwait(false);
+ }
+
+ // only one side of the stream was closed. the other should work
+ await clientWebSocket.SendAsync(clientMsg, WebSocketMessageType.Binary, endOfMessage: true, token).ConfigureAwait(false);
+
+ var exception = await Assert.ThrowsAsync(() => clientWebSocket.ReceiveAsync(new byte[1], token));
+ Assert.Equal(WebSocketError.ConnectionClosedPrematurely, exception.WebSocketErrorCode);
+
+ clientReceivedEosTcs.SetResult();
+ clientWebSocket.Dispose();
+ },
+ async (requestData, token) =>
+ {
+ WebSocket serverWebSocket = null!;
+ await SendServerResponseAndEosAsync(
+ requestData,
+ serverEosType,
+ (wsData, ct) =>
+ {
+ var wsOptions = new WebSocketCreationOptions { IsServer = true };
+ serverWebSocket = WebSocket.CreateFromStream(wsData.WebSocketStream, wsOptions);
+
+ return serverEosType == ServerEosType.AfterSomeData
+ ? VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, ct)
+ : Task.CompletedTask;
+ },
+ token);
+
+ Assert.NotNull(serverWebSocket);
+
+ // only one side of the stream was closed. the other should work
+ var readBuffer = new byte[clientMsg.Length];
+ var result = await serverWebSocket.ReceiveAsync(readBuffer, token);
+ Assert.Equal(WebSocketMessageType.Binary, result.MessageType);
+ Assert.Equal(clientMsg.Length, result.Count);
+ Assert.True(result.EndOfMessage);
+ Assert.Equal(clientMsg, readBuffer);
+
+ await clientReceivedEosTcs.Task.WaitAsync(token).ConfigureAwait(false);
+
+ var exception = await Assert.ThrowsAsync(() => serverWebSocket.ReceiveAsync(readBuffer, token));
+ Assert.Equal(WebSocketError.ConnectionClosedPrematurely, exception.WebSocketErrorCode);
+
+ serverWebSocket.Dispose();
+ },
+ globalOptions,
+ timeoutCts.Token);
+ }
+
+ protected virtual Task SendServerResponseAndEosAsync(WebSocketRequestData requestData, ServerEosType serverEosType, Func serverFunc, CancellationToken cancellationToken)
+ => WebSocketHandshakeHelper.SendHttp11ServerResponseAndEosAsync(requestData, serverFunc, cancellationToken); // override for HTTP/2
+
+ private static readonly bool[] Bool_Values = new[] { false, true };
+ private static readonly bool[] UseSsl_Values = PlatformDetection.SupportsAlpn ? Bool_Values : new[] { false };
+
+ public static IEnumerable AbortClient_MemberData()
+ {
+ foreach (var abortType in Enum.GetValues())
+ {
+ foreach (var useSsl in UseSsl_Values)
+ {
+ foreach (var verifySendReceive in Bool_Values)
+ {
+ yield return new object[] { abortType, useSsl, verifySendReceive };
+ }
+ }
+ }
+ }
+
+ public static IEnumerable ServerPrematureEos_MemberData()
+ {
+ foreach (var serverEosType in Enum.GetValues())
+ {
+ foreach (var useSsl in UseSsl_Values)
+ {
+ yield return new object[] { serverEosType, useSsl };
+ }
+ }
+ }
+
+ public enum AbortType
+ {
+ Abort,
+ Dispose
+ }
+
+ public enum ServerEosType
+ {
+ WithHeaders,
+ RightAfterHeaders,
+ AfterSomeData
+ }
+
+ private static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, byte[] remoteMsg,
+ TaskCompletionSource localAckTcs, Task remoteAck, CancellationToken cancellationToken)
+ {
+ var sendTask = ws.SendAsync(localMsg, WebSocketMessageType.Binary, endOfMessage: true, cancellationToken);
+
+ var recvBuf = new byte[remoteMsg.Length * 2];
+ var recvResult = await ws.ReceiveAsync(recvBuf, cancellationToken).ConfigureAwait(false);
+
+ Assert.Equal(WebSocketMessageType.Binary, recvResult.MessageType);
+ Assert.Equal(remoteMsg.Length, recvResult.Count);
+ Assert.True(recvResult.EndOfMessage);
+ Assert.Equal(remoteMsg, recvBuf[..recvResult.Count]);
+
+ localAckTcs.SetResult();
+
+ await sendTask.ConfigureAwait(false);
+ await remoteAck.WaitAsync(cancellationToken).ConfigureAwait(false);
+ }
+ }
+
+ // --- HTTP/1.1 WebSocket loopback tests ---
+
+ public class AbortTest_Invoker_Loopback : AbortTest_Loopback
+ {
+ public AbortTest_Invoker_Loopback(ITestOutputHelper output) : base(output) { }
+ protected override bool UseCustomInvoker => true;
+ }
+
+ public class AbortTest_HttpClient_Loopback : AbortTest_Loopback
+ {
+ public AbortTest_HttpClient_Loopback(ITestOutputHelper output) : base(output) { }
+ protected override bool UseHttpClient => true;
+ }
+
+ public class AbortTest_SharedHandler_Loopback : AbortTest_Loopback
+ {
+ public AbortTest_SharedHandler_Loopback(ITestOutputHelper output) : base(output) { }
+ }
+
+ // --- HTTP/2 WebSocket loopback tests ---
+
+ public class AbortTest_Invoker_Http2 : AbortTest_Invoker_Loopback
+ {
+ public AbortTest_Invoker_Http2(ITestOutputHelper output) : base(output) { }
+ protected override Version HttpVersion => Net.HttpVersion.Version20;
+ protected override Task SendServerResponseAndEosAsync(WebSocketRequestData rd, ServerEosType eos, Func callback, CancellationToken ct)
+ => WebSocketHandshakeHelper.SendHttp2ServerResponseAndEosAsync(rd, eosInHeadersFrame: eos == ServerEosType.WithHeaders, callback, ct);
+ }
+
+ public class AbortTest_HttpClient_Http2 : AbortTest_HttpClient_Loopback
+ {
+ public AbortTest_HttpClient_Http2(ITestOutputHelper output) : base(output) { }
+ protected override Version HttpVersion => Net.HttpVersion.Version20;
+ protected override Task SendServerResponseAndEosAsync(WebSocketRequestData rd, ServerEosType eos, Func callback, CancellationToken ct)
+ => WebSocketHandshakeHelper.SendHttp2ServerResponseAndEosAsync(rd, eosInHeadersFrame: eos == ServerEosType.WithHeaders, callback, ct);
+ }
+}
diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs
index 48d167b072f78..cee509ee06846 100644
--- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs
+++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs
@@ -28,14 +28,7 @@ public static async Task> WebSocketHandshakeAsync(Loo
if (headerName == "Sec-WebSocket-Key")
{
string headerValue = tokens[1].Trim();
- string responseSecurityAcceptValue = ComputeWebSocketHandshakeSecurityAcceptValue(headerValue);
- serverResponse =
- "HTTP/1.1 101 Switching Protocols\r\n" +
- "Content-Length: 0\r\n" +
- "Upgrade: websocket\r\n" +
- "Connection: Upgrade\r\n" +
- (extensions is null ? null : $"Sec-WebSocket-Extensions: {extensions}\r\n") +
- "Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n";
+ serverResponse = GetServerResponseString(headerValue, extensions);
}
}
}
@@ -50,6 +43,18 @@ public static async Task> WebSocketHandshakeAsync(Loo
return null;
}
+ public static string GetServerResponseString(string secWebSocketKey, string? extensions = null)
+ {
+ var responseSecurityAcceptValue = ComputeWebSocketHandshakeSecurityAcceptValue(secWebSocketKey);
+ return
+ "HTTP/1.1 101 Switching Protocols\r\n" +
+ "Content-Length: 0\r\n" +
+ "Upgrade: websocket\r\n" +
+ "Connection: Upgrade\r\n" +
+ (extensions is null ? null : $"Sec-WebSocket-Extensions: {extensions}\r\n") +
+ "Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n";
+ }
+
private static string ComputeWebSocketHandshakeSecurityAcceptValue(string secWebSocketKey)
{
// GUID specified by RFC 6455.
diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs
new file mode 100644
index 0000000000000..1b3b51840ec99
--- /dev/null
+++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs
@@ -0,0 +1,100 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.IO;
+using System.Net.Sockets;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace System.Net.Test.Common
+{
+ public class Http2LoopbackStream : Stream
+ {
+ private readonly Http2LoopbackConnection _connection;
+ private readonly int _streamId;
+ private bool _readEnded;
+ private ReadOnlyMemory _leftoverReadData;
+
+ public override bool CanRead => true;
+ public override bool CanSeek => false;
+ public override bool CanWrite => true;
+
+ public Http2LoopbackConnection Connection => _connection;
+ public int StreamId => _streamId;
+
+ public Http2LoopbackStream(Http2LoopbackConnection connection, int streamId)
+ {
+ _connection = connection;
+ _streamId = streamId;
+ }
+
+ public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default)
+ {
+ if (!_leftoverReadData.IsEmpty)
+ {
+ int read = Math.Min(buffer.Length, _leftoverReadData.Length);
+ _leftoverReadData.Span.Slice(0, read).CopyTo(buffer.Span);
+ _leftoverReadData = _leftoverReadData.Slice(read);
+ return read;
+ }
+
+ if (_readEnded)
+ {
+ return 0;
+ }
+
+ DataFrame dataFrame = (DataFrame)await _connection.ReadFrameAsync(cancellationToken);
+ Assert.Equal(_streamId, dataFrame.StreamId);
+ _leftoverReadData = dataFrame.Data;
+ _readEnded = dataFrame.EndStreamFlag;
+
+ return await ReadAsync(buffer, cancellationToken);
+ }
+
+ public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
+ ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
+
+ public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default)
+ {
+ await _connection.SendResponseDataAsync(_streamId, buffer, endStream: false);
+ }
+
+ public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
+ WriteAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
+
+ protected override void Dispose(bool disposing) => DisposeAsync().GetAwaiter().GetResult();
+
+ public override async ValueTask DisposeAsync()
+ {
+ try
+ {
+ await _connection.SendResponseDataAsync(_streamId, Memory.Empty, endStream: true).ConfigureAwait(false);
+
+ if (!_readEnded)
+ {
+ var rstFrame = new RstStreamFrame(FrameFlags.None, (int)ProtocolErrors.NO_ERROR, _streamId);
+ await _connection.WriteFrameAsync(rstFrame).ConfigureAwait(false);
+ }
+ }
+ catch (IOException)
+ {
+ // Ignore connection errors
+ }
+ catch (SocketException)
+ {
+ // Ignore connection errors
+ }
+ }
+
+ public override void Flush() { }
+ public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;
+
+ public override int Read(byte[] buffer, int offset, int count) => throw new NotImplementedException();
+ public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException();
+ public override void SetLength(long value) => throw new NotImplementedException();
+ public override void Write(byte[] buffer, int offset, int count) => throw new NotImplementedException();
+ public override long Length => throw new NotImplementedException();
+ public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
+ }
+}
diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs
new file mode 100644
index 0000000000000..b24e2e20d40df
--- /dev/null
+++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs
@@ -0,0 +1,148 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Net.Http;
+using System.Net.Test.Common;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace System.Net.WebSockets.Client.Tests
+{
+ public static class LoopbackWebSocketServer
+ {
+ public static Task RunAsync(
+ Func clientWebSocketFunc,
+ Func serverWebSocketFunc,
+ Options options,
+ CancellationToken cancellationToken)
+ {
+ Assert.False(options.ManualServerHandshakeResponse, "Not supported in this overload");
+
+ return RunAsyncPrivate(
+ uri => RunClientAsync(uri, clientWebSocketFunc, options, cancellationToken),
+ (requestData, token) => RunServerAsync(requestData, serverWebSocketFunc, options, token),
+ options,
+ cancellationToken);
+ }
+
+ public static Task RunAsync(
+ Func loopbackClientFunc,
+ Func loopbackServerFunc,
+ Options options,
+ CancellationToken cancellationToken)
+ {
+ Assert.False(options.DisposeClientWebSocket, "Not supported in this overload");
+ Assert.False(options.DisposeServerWebSocket, "Not supported in this overload");
+ Assert.False(options.DisposeHttpInvoker, "Not supported in this overload");
+ Assert.Null(options.HttpInvoker); // Not supported in this overload
+
+ return RunAsyncPrivate(loopbackClientFunc, loopbackServerFunc, options, cancellationToken);
+ }
+
+ private static Task RunAsyncPrivate(
+ Func loopbackClientFunc,
+ Func loopbackServerFunc,
+ Options options,
+ CancellationToken cancellationToken)
+ {
+ bool sendDefaultServerHandshakeResponse = !options.ManualServerHandshakeResponse;
+ if (options.HttpVersion == HttpVersion.Version11)
+ {
+ return LoopbackServer.CreateClientAndServerAsync(
+ loopbackClientFunc,
+ async server =>
+ {
+ await server.AcceptConnectionAsync(async connection =>
+ {
+ var requestData = await WebSocketHandshakeHelper.ProcessHttp11RequestAsync(connection, sendDefaultServerHandshakeResponse, cancellationToken).ConfigureAwait(false);
+ await loopbackServerFunc(requestData, cancellationToken).ConfigureAwait(false);
+ });
+ },
+ new LoopbackServer.Options { WebSocketEndpoint = true, UseSsl = options.UseSsl });
+ }
+ else if (options.HttpVersion == HttpVersion.Version20)
+ {
+ return Http2LoopbackServer.CreateClientAndServerAsync(
+ loopbackClientFunc,
+ async server =>
+ {
+ var requestData = await WebSocketHandshakeHelper.ProcessHttp2RequestAsync(server, sendDefaultServerHandshakeResponse, cancellationToken).ConfigureAwait(false);
+ var http2Connection = requestData.Http2Connection!;
+ var http2StreamId = requestData.Http2StreamId.Value;
+
+ await loopbackServerFunc(requestData, cancellationToken).ConfigureAwait(false);
+
+ await http2Connection.DisposeAsync().ConfigureAwait(false);
+ },
+ new Http2Options { WebSocketEndpoint = true, UseSsl = options.UseSsl });
+ }
+ else
+ {
+ throw new ArgumentException(nameof(options.HttpVersion));
+ }
+ }
+
+ private static async Task RunServerAsync(
+ WebSocketRequestData requestData,
+ Func serverWebSocketFunc,
+ Options options,
+ CancellationToken cancellationToken)
+ {
+ var wsOptions = new WebSocketCreationOptions { IsServer = true };
+ var serverWebSocket = WebSocket.CreateFromStream(requestData.WebSocketStream, wsOptions);
+
+ await serverWebSocketFunc(serverWebSocket, cancellationToken).ConfigureAwait(false);
+
+ if (options.DisposeServerWebSocket)
+ {
+ serverWebSocket.Dispose();
+ }
+ }
+
+ private static async Task RunClientAsync(
+ Uri uri,
+ Func clientWebSocketFunc,
+ Options options,
+ CancellationToken cancellationToken)
+ {
+ var clientWebSocket = await GetConnectedClientAsync(uri, options, cancellationToken).ConfigureAwait(false);
+
+ await clientWebSocketFunc(clientWebSocket, cancellationToken).ConfigureAwait(false);
+
+ if (options.DisposeClientWebSocket)
+ {
+ clientWebSocket.Dispose();
+ }
+
+ if (options.DisposeHttpInvoker)
+ {
+ options.HttpInvoker?.Dispose();
+ }
+ }
+
+ public static async Task GetConnectedClientAsync(Uri uri, Options options, CancellationToken cancellationToken)
+ {
+ var clientWebSocket = new ClientWebSocket();
+ clientWebSocket.Options.HttpVersion = options.HttpVersion;
+ clientWebSocket.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact;
+
+ if (options.UseSsl && options.HttpInvoker is null)
+ {
+ clientWebSocket.Options.RemoteCertificateValidationCallback = delegate { return true; };
+ }
+
+ await clientWebSocket.ConnectAsync(uri, options.HttpInvoker, cancellationToken).ConfigureAwait(false);
+
+ return clientWebSocket;
+ }
+
+ public record class Options(Version HttpVersion, bool UseSsl, HttpMessageInvoker? HttpInvoker)
+ {
+ public bool DisposeServerWebSocket { get; set; } = true;
+ public bool DisposeClientWebSocket { get; set; }
+ public bool DisposeHttpInvoker { get; set; }
+ public bool ManualServerHandshakeResponse { get; set; }
+ }
+ }
+}
diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs
new file mode 100644
index 0000000000000..f4d2f42f5edbb
--- /dev/null
+++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs
@@ -0,0 +1,134 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System.Linq;
+using System.Net.Http;
+using System.Net.Sockets;
+using System.Net.Test.Common;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace System.Net.WebSockets.Client.Tests
+{
+ public static class WebSocketHandshakeHelper
+ {
+ public static async Task ProcessHttp11RequestAsync(LoopbackServer.Connection connection, bool sendServerResponse = true, CancellationToken cancellationToken = default)
+ {
+ List headers = await connection.ReadRequestHeaderAsync().WaitAsync(cancellationToken).ConfigureAwait(false);
+
+ var data = new WebSocketRequestData()
+ {
+ HttpVersion = HttpVersion.Version11,
+ Http11Connection = connection
+ };
+
+ foreach (string header in headers.Skip(1))
+ {
+ string[] tokens = header.Split(new char[] { ':' }, StringSplitOptions.RemoveEmptyEntries);
+ if (tokens.Length is 1 or 2)
+ {
+ data.Headers.Add(
+ tokens[0].Trim(),
+ tokens.Length == 2 ? tokens[1].Trim() : null);
+ }
+ }
+
+ var isValidOpeningHandshake = data.Headers.TryGetValue("Sec-WebSocket-Key", out var secWebSocketKey);
+ Assert.True(isValidOpeningHandshake);
+
+ if (sendServerResponse)
+ {
+ await SendHttp11ServerResponseAsync(connection, secWebSocketKey, cancellationToken).ConfigureAwait(false);
+ }
+
+ data.WebSocketStream = connection.Stream;
+ return data;
+ }
+
+ private static async Task SendHttp11ServerResponseAsync(LoopbackServer.Connection connection, string secWebSocketKey, CancellationToken cancellationToken)
+ {
+ var serverResponse = LoopbackHelper.GetServerResponseString(secWebSocketKey);
+ await connection.WriteStringAsync(serverResponse).WaitAsync(cancellationToken).ConfigureAwait(false);
+ }
+
+ public static async Task ProcessHttp2RequestAsync(Http2LoopbackServer server, bool sendServerResponse = true, CancellationToken cancellationToken = default)
+ {
+ var connection = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 })
+ .WaitAsync(cancellationToken).ConfigureAwait(false);
+
+ (int streamId, var httpRequestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false)
+ .WaitAsync(cancellationToken).ConfigureAwait(false);
+
+ var data = new WebSocketRequestData
+ {
+ HttpVersion = HttpVersion.Version20,
+ Http2Connection = connection,
+ Http2StreamId = streamId
+ };
+
+ foreach (var header in httpRequestData.Headers)
+ {
+ Assert.NotNull(header.Name);
+ data.Headers.Add(header.Name, header.Value);
+ }
+
+ var isValidOpeningHandshake = httpRequestData.Method == HttpMethod.Connect.ToString() && data.Headers.ContainsKey(":protocol");
+ Assert.True(isValidOpeningHandshake);
+
+ if (sendServerResponse)
+ {
+ await SendHttp2ServerResponseAsync(connection, streamId, cancellationToken: cancellationToken).ConfigureAwait(false);
+ }
+
+ data.WebSocketStream = new Http2LoopbackStream(connection, streamId);
+ return data;
+ }
+
+ private static async Task SendHttp2ServerResponseAsync(Http2LoopbackConnection connection, int streamId, bool endStream = false, CancellationToken cancellationToken = default)
+ {
+ // send status 200 OK to establish websocket
+ // we don't need to send anything additional as Sec-WebSocket-Key is not used for HTTP/2
+ // note: endStream=true is abnormal and used for testing premature EOS scenarios only
+ await connection.SendResponseHeadersAsync(streamId, endStream: endStream).WaitAsync(cancellationToken).ConfigureAwait(false);
+ }
+
+ public static async Task SendHttp11ServerResponseAndEosAsync(WebSocketRequestData requestData, Func? requestDataCallback, CancellationToken cancellationToken)
+ {
+ Assert.Equal(HttpVersion.Version11, requestData.HttpVersion);
+
+ // sending default handshake response
+ await SendHttp11ServerResponseAsync(requestData.Http11Connection!, requestData.Headers["Sec-WebSocket-Key"], cancellationToken).ConfigureAwait(false);
+
+ if (requestDataCallback is not null)
+ {
+ await requestDataCallback(requestData, cancellationToken).ConfigureAwait(false);
+ }
+
+ // send server EOS (half-closing from server side)
+ requestData.Http11Connection!.Socket.Shutdown(SocketShutdown.Send);
+ }
+
+ public static async Task SendHttp2ServerResponseAndEosAsync(WebSocketRequestData requestData, bool eosInHeadersFrame, Func? requestDataCallback, CancellationToken cancellationToken)
+ {
+ Assert.Equal(HttpVersion.Version20, requestData.HttpVersion);
+
+ var connection = requestData.Http2Connection!;
+ var streamId = requestData.Http2StreamId!.Value;
+
+ await SendHttp2ServerResponseAsync(connection, streamId, endStream: eosInHeadersFrame, cancellationToken).ConfigureAwait(false);
+
+ if (requestDataCallback is not null)
+ {
+ await requestDataCallback(requestData, cancellationToken).ConfigureAwait(false);
+ }
+
+ if (!eosInHeadersFrame)
+ {
+ // send server EOS (half-closing from server side)
+ await connection.SendResponseDataAsync(streamId, Array.Empty(), endStream: true).ConfigureAwait(false);
+ }
+ }
+ }
+}
diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketRequestData.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketRequestData.cs
new file mode 100644
index 0000000000000..799157a370f07
--- /dev/null
+++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketRequestData.cs
@@ -0,0 +1,20 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System.IO;
+using System.Net.Test.Common;
+
+namespace System.Net.WebSockets.Client.Tests
+{
+ public class WebSocketRequestData
+ {
+ public Dictionary Headers { get; set; } = new Dictionary();
+ public Stream? WebSocketStream { get; set; }
+
+ public Version HttpVersion { get; set; }
+ public LoopbackServer.Connection? Http11Connection { get; set; }
+ public Http2LoopbackConnection? Http2Connection { get; set; }
+ public int? Http2StreamId { get; set; }
+ }
+}
diff --git a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj
index 8f23e7925a451..a4f20d03002e1 100644
--- a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj
+++ b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj
@@ -55,6 +55,7 @@
+
@@ -64,6 +65,10 @@
+
+
+
+