Skip to content

Commit

Permalink
Fix race when receiving HEADERS and RST_STREAM in rapid succession. (#…
Browse files Browse the repository at this point in the history
…72932)

* Fix race when receiving HEADERS and RST_STREAM in rapid succession.

* Improve test
  • Loading branch information
rzikm authored Jul 28, 2022
1 parent 50d69de commit a8789a6
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,19 @@ public async Task WriteFrameAsync(Frame frame, CancellationToken cancellationTok
await _connectionStream.WriteAsync(writeBuffer, 0, writeBuffer.Length, cancellationToken).ConfigureAwait(false);
}

public async Task WriteFramesAsync(Frame[] frames, CancellationToken cancellationToken = default)
{
byte[] writeBuffer = new byte[frames.Sum(frame => Frame.FrameHeaderLength + frame.Length)];

int offset = 0;
foreach (Frame frame in frames)
{
frame.WriteTo(writeBuffer.AsSpan(offset));
offset += Frame.FrameHeaderLength + frame.Length;
}
await _connectionStream.WriteAsync(writeBuffer, 0, writeBuffer.Length, cancellationToken).ConfigureAwait(false);
}

// Read until the buffer is full
// Return false on EOF, throw on partial read
private async Task<bool> FillBufferAsync(Memory<byte> buffer, CancellationToken cancellationToken = default(CancellationToken))
Expand Down Expand Up @@ -567,7 +580,7 @@ public async Task<byte[]> ReadBodyAsync(bool expectEndOfStream = false)
}
else if (frame == null || frame.Type == FrameType.RstStream)
{
throw new IOException( frame == null ? "End of stream" : "Got RST");
throw new IOException(frame == null ? "End of stream" : "Got RST");
}

Assert.Equal(FrameType.Data, frame.Type);
Expand All @@ -586,7 +599,7 @@ public async Task<byte[]> ReadBodyAsync(bool expectEndOfStream = false)

body.CopyTo(newBuffer, 0);
dataFrame.Data.Span.CopyTo(newBuffer.AsSpan(body.Length));
body= newBuffer;
body = newBuffer;
}
}
}
Expand Down Expand Up @@ -947,11 +960,11 @@ public override async Task<HttpRequestData> HandleRequestAsync(HttpStatusCode st

if (string.IsNullOrEmpty(content))
{
await SendResponseHeadersAsync(streamId, endStream: true, statusCode, isTrailingHeader: false, headers : headers).ConfigureAwait(false);
await SendResponseHeadersAsync(streamId, endStream: true, statusCode, isTrailingHeader: false, headers: headers).ConfigureAwait(false);
}
else
{
await SendResponseHeadersAsync(streamId, endStream: false, statusCode, isTrailingHeader: false, headers : headers).ConfigureAwait(false);
await SendResponseHeadersAsync(streamId, endStream: false, statusCode, isTrailingHeader: false, headers: headers).ConfigureAwait(false);
await SendResponseBodyAsync(streamId, Encoding.ASCII.GetBytes(content)).ConfigureAwait(false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ private sealed class Http2Stream : IValueTaskSource, IHttpStreamHeadersHandler,
private StreamCompletionState _requestCompletionState;
private StreamCompletionState _responseCompletionState;
private ResponseProtocolState _responseProtocolState;
private bool _responseHeadersReceived;
private bool _webSocketEstablished;

// If this is not null, then we have received a reset from the server
Expand Down Expand Up @@ -775,6 +776,7 @@ public void OnHeadersComplete(bool endStream)

case ResponseProtocolState.ExpectingHeaders:
_responseProtocolState = endStream ? ResponseProtocolState.Complete : ResponseProtocolState.ExpectingData;
_responseHeadersReceived = true;
break;

case ResponseProtocolState.ExpectingTrailingHeaders:
Expand Down Expand Up @@ -988,24 +990,16 @@ private void CheckResponseBodyState()
Debug.Assert(!Monitor.IsEntered(SyncObject));
lock (SyncObject)
{
CheckResponseBodyState();

if (_responseProtocolState == ResponseProtocolState.ExpectingHeaders || _responseProtocolState == ResponseProtocolState.ExpectingIgnoredHeaders || _responseProtocolState == ResponseProtocolState.ExpectingStatus)
if (!_responseHeadersReceived)
{
CheckResponseBodyState();
Debug.Assert(!_hasWaiter);
_hasWaiter = true;
_waitSource.Reset();
return (true, false);
}
else if (_responseProtocolState == ResponseProtocolState.ExpectingData || _responseProtocolState == ResponseProtocolState.ExpectingTrailingHeaders)
{
return (false, false);
}
else
{
Debug.Assert(_responseProtocolState == ResponseProtocolState.Complete);
return (false, _responseBuffer.IsEmpty);
}

return (false, _responseProtocolState == ResponseProtocolState.Complete && _responseBuffer.IsEmpty);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,31 @@ public async Task Http2_StreamResetByServerAfterHeadersSent_RequestFails()
}
}

[ConditionalFact(nameof(SupportsAlpn))]
public async Task Http2_StreamResetByServerAfterHeadersSent_ResponseHeadersRead_ContentThrows()
{
using (Http2LoopbackServer server = Http2LoopbackServer.CreateServer())
using (HttpClient client = CreateHttpClient())
{
Task<HttpResponseMessage> sendTask = client.GetAsync(server.Address, HttpCompletionOption.ResponseHeadersRead);

Http2LoopbackConnection connection = await server.EstablishConnectionAsync();
int streamId = await connection.ReadRequestHeaderAsync();

// Send response headers and RST_STREAM combined
await connection.WriteFramesAsync(new Frame[] {
new HeadersFrame(new byte[] { 0x88 /* :status: 200 */}, FrameFlags.EndHeaders, 0, 0, 0, streamId),
new RstStreamFrame(FrameFlags.None, (int)ProtocolErrors.NO_ERROR, streamId)
});

// Headers should be received successfully
HttpResponseMessage response = await sendTask;

// Reading the actual content should throw
await AssertHttpProtocolException((await response.Content.ReadAsStreamAsync()).ReadAsync(new byte[10]).AsTask(), ProtocolErrors.NO_ERROR);
}
}

[ConditionalFact(nameof(SupportsAlpn))]
public async Task Http2_StreamResetByServerAfterPartialBodySent_RequestFails()
{
Expand Down Expand Up @@ -2021,7 +2046,7 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async url =>
using (HttpClient client = CreateHttpClient())
{
var request = new HttpRequestMessage(HttpMethod.Post, url);
request.Version = new Version(2,0);
request.Version = new Version(2, 0);
request.Content = new CustomContent(stream);

await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await client.SendAsync(request, cts.Token));
Expand All @@ -2031,7 +2056,7 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async url =>

// Send another request to verify that connection is still functional.
request = new HttpRequestMessage(HttpMethod.Get, url);
request.Version = new Version(2,0);
request.Version = new Version(2, 0);

await client.SendAsync(request);
}
Expand All @@ -2040,20 +2065,21 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async url =>
{
Http2LoopbackConnection connection = await server.EstablishConnectionAsync();

(int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody : false);
(int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
int frameCount = 0;
Frame frame;
do
{
if (frameCount == (waitForData ? 1 : 0)) {
if (frameCount == (waitForData ? 1 : 0))
{
// Cancel client after receiving Headers or part of request body.
cts.Cancel();
}
frame = await connection.ReadFrameAsync(TestHelper.PassingTestTimeout);
Assert.NotNull(frame); // We should get Rst before closing connection.
Assert.Equal(0, (int)(frame.Flags & FrameFlags.EndStream));
frameCount++;
} while (frame.Type != FrameType.RstStream);
} while (frame.Type != FrameType.RstStream);

Assert.Equal(1, frame.StreamId);

Expand Down Expand Up @@ -2089,7 +2115,7 @@ public async Task Http2_PendingReceive_SendsReset(bool doRead)
await Http2LoopbackServer.CreateClientAndServerAsync(async url =>
{
var request = new HttpRequestMessage(HttpMethod.Get, url);
request.Version = new Version(2,0);
request.Version = new Version(2, 0);

response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cts.Token);
using (Stream stream = await response.Content.ReadAsStreamAsync())
Expand Down Expand Up @@ -2123,7 +2149,7 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async url =>
{
Http2LoopbackConnection connection = await server.EstablishConnectionAsync();

(int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody : false);
(int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
_output.WriteLine($"{DateTime.Now} Connection established");

await connection.SendResponseHeadersAsync(streamId, endStream: false, HttpStatusCode.OK);
Expand Down Expand Up @@ -2213,7 +2239,7 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async url =>
using (HttpClient client = CreateHttpClient())
{
var request = new HttpRequestMessage(HttpMethod.Post, url);
request.Version = new Version(2,0);
request.Version = new Version(2, 0);
request.Content = new StringContent(new string('*', 3000));
request.Headers.ExpectContinue = true;
request.Headers.Add("x-test", $"PostAsyncExpect100Continue_SendRequest_Ok({send100Continue}");
Expand All @@ -2226,7 +2252,7 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async url =>
{
Http2LoopbackConnection connection = await server.EstablishConnectionAsync();

(int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody : false);
(int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
Assert.Equal("100-continue", requestData.GetSingleHeaderValue("Expect"));

if (send100Continue)
Expand Down Expand Up @@ -2254,7 +2280,7 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async url =>
handler.Expect100ContinueTimeout = TimeSpan.FromSeconds(300);

var request = new HttpRequestMessage(HttpMethod.Post, url);
request.Version = new Version(2,0);
request.Version = new Version(2, 0);
request.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
request.Content = new StringContent(new string('*', 3000));
request.Headers.ExpectContinue = true;
Expand All @@ -2269,15 +2295,15 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async url =>
{
Http2LoopbackConnection connection = await server.EstablishConnectionAsync();

(int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody : false);
(int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
Assert.Equal("100-continue", requestData.GetSingleHeaderValue("Expect"));

// Reject content with 403.
await connection.SendResponseHeadersAsync(streamId, endStream: false, HttpStatusCode.Forbidden);
await connection.SendResponseBodyAsync(streamId, Encoding.ASCII.GetBytes(responseContent));

// Client should send empty request body
byte[] requestBody = await connection.ReadBodyAsync(expectEndOfStream:true);
byte[] requestBody = await connection.ReadBodyAsync(expectEndOfStream: true);
Assert.Null(requestBody);

await connection.ShutdownIgnoringErrorsAsync(streamId);
Expand Down Expand Up @@ -3136,12 +3162,12 @@ public async Task SendAsync_ConcurentSendReceive_Ok(bool shouldWaitForRequestBod
Task<HttpResponseMessage> responseTask = client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
connection = await server.EstablishConnectionAsync();

// Client should have sent the request headers, and the request stream should now be available
// Client should have sent the request headers, and the request stream should now be available
Stream requestStream = await duplexContent.WaitForStreamAsync();
// Flush the content stream. Otherwise, the request headers are not guaranteed to be sent.
await requestStream.FlushAsync();

(int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody : false);
(int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);

// Client finished sending request headers and we received them.
// Send request body.
Expand Down Expand Up @@ -3170,7 +3196,7 @@ public async Task SendAsync_ConcurentSendReceive_Ok(bool shouldWaitForRequestBod

// Send trailing headers for good measure and close stream.
var headers = new HttpHeaderData[] { new HttpHeaderData("x-last", "done") };
await connection.SendResponseHeadersAsync(streamId, endStream: true, isTrailingHeader : true, headers: headers);
await connection.SendResponseHeadersAsync(streamId, endStream: true, isTrailingHeader: true, headers: headers);

// Finish reading response body and verify it for all cases.
string responseBody = await response.Content.ReadAsStringAsync();
Expand All @@ -3189,15 +3215,15 @@ public async Task SendAsync_ConcurentSendReceive_Fail()
TaskCompletionSource<bool> tsc = new TaskCompletionSource<bool>();
string requestContent = new string('*', 300);
const string responseContent = "SendAsync_ConcurentSendReceive_Fail";
var stream = new CustomContent.SlowTestStream(Encoding.UTF8.GetBytes(requestContent), tsc, trigger : 1, count : 50);
var stream = new CustomContent.SlowTestStream(Encoding.UTF8.GetBytes(requestContent), tsc, trigger: 1, count: 50);
bool stopSending = false;

await Http2LoopbackServer.CreateClientAndServerAsync(async url =>
{
using (HttpClient client = CreateHttpClient())
{
var request = new HttpRequestMessage(HttpMethod.Post, url);
request.Version = new Version(2,0);
request.Version = new Version(2, 0);
request.Content = new CustomContent(stream);

// This should fail either while getting response headers or while reading response body.
Expand All @@ -3217,7 +3243,7 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async url =>
{
Http2LoopbackConnection connection = await server.EstablishConnectionAsync();

(int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody : false);
(int streamId, HttpRequestData requestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false);
await connection.SendResponseHeadersAsync(streamId, endStream: false, HttpStatusCode.OK);

// Wait for client so start sending body.
Expand All @@ -3226,7 +3252,7 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async url =>
int maxCount = 120;
while (!stopSending && maxCount != 0)
{
try
try
{
await connection.SendResponseDataAsync(streamId, Encoding.ASCII.GetBytes(responseContent), endStream: false);
}
Expand All @@ -3237,7 +3263,7 @@ await Http2LoopbackServer.CreateClientAndServerAsync(async url =>
break;
}
await Task.Delay(500);
maxCount --;
maxCount--;
}
// We should not reach retry limit without failing.
Assert.NotEqual(0, maxCount);
Expand Down Expand Up @@ -3349,7 +3375,7 @@ public async Task Http2GetAsync_MultipleStatusHeaders_Throws()

Http2LoopbackConnection connection = await server.EstablishConnectionAsync();
int streamId = await connection.ReadRequestHeaderAsync();
await connection.SendResponseHeadersAsync(streamId, endStream : true, headers: headers);
await connection.SendResponseHeadersAsync(streamId, endStream: true, headers: headers);
await Assert.ThrowsAsync<HttpRequestException>(() => sendTask);
}
}
Expand All @@ -3366,7 +3392,7 @@ public async Task Http2GetAsync_StatusHeaderNotFirst_Throws()

Http2LoopbackConnection connection = await server.EstablishConnectionAsync();
int streamId = await connection.ReadRequestHeaderAsync();
await connection.SendResponseHeadersAsync(streamId, endStream : true, isTrailingHeader : true, headers: headers);
await connection.SendResponseHeadersAsync(streamId, endStream: true, isTrailingHeader: true, headers: headers);

await Assert.ThrowsAsync<HttpRequestException>(() => sendTask);
}
Expand All @@ -3386,7 +3412,7 @@ public async Task Http2GetAsync_TrailigPseudo_Throw()
int streamId = await connection.ReadRequestHeaderAsync();
await connection.SendDefaultResponseHeadersAsync(streamId);
await connection.SendResponseDataAsync(streamId, "hello"u8.ToArray(), endStream: false);
await connection.SendResponseHeadersAsync(streamId, endStream : true, isTrailingHeader : true, headers: headers);
await connection.SendResponseHeadersAsync(streamId, endStream: true, isTrailingHeader: true, headers: headers);

await Assert.ThrowsAsync<HttpRequestException>(() => sendTask);
}
Expand Down

0 comments on commit a8789a6

Please sign in to comment.