Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[H/3] Fix code passed into QuicConnection.CloseAsync and QuicStream.Abort #55282

Merged
merged 11 commits into from
May 3, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ private void AbortCore(Exception exception, Http3ErrorCode errorCode)
abortReason = new ConnectionAbortedException(exception.Message, exception);
}

// This has the side-effect of validating the error code, so do it before we consume the error code
_errorCodeFeature.Error = (long)errorCode;

_context.WebTransportSession?.Abort(abortReason, errorCode);

Log.Http3StreamAbort(TraceIdentifier, errorCode, abortReason);
Expand All @@ -181,7 +184,6 @@ private void AbortCore(Exception exception, Http3ErrorCode errorCode)
RequestBodyPipe.Writer.Complete(exception);

// Abort framewriter and underlying transport after stopping output.
_errorCodeFeature.Error = (long)errorCode;
_frameWriter.Abort(abortReason);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ internal sealed partial class QuicConnectionContext : IProtocolErrorCodeFeature,
public long Error
{
get => _error ?? -1;
set => _error = value;
set
{
QuicTransportOptions.ValidateErrorCode(value);
_error = value;
}
}

public X509Certificate2? ClientCertificate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public override async ValueTask DisposeAsync()
{
lock (_shutdownLock)
{
// The DefaultCloseErrorCode setter validates that the error code is within the valid range
_closeTask ??= _connection.CloseAsync(errorCode: _context.Options.DefaultCloseErrorCode).AsTask();
}

Expand All @@ -81,7 +82,7 @@ public override void Abort(ConnectionAbortedException abortReason)
return;
}

var resolvedErrorCode = _error ?? 0;
var resolvedErrorCode = _error ?? 0; // Only valid error codes are assigned to _error
_abortReason = ExceptionDispatchInfo.Capture(abortReason);
QuicLog.ConnectionAbort(_log, this, resolvedErrorCode, abortReason.Message);
_closeTask = _connection.CloseAsync(errorCode: resolvedErrorCode).AsTask();
Expand Down Expand Up @@ -130,7 +131,7 @@ public override void Abort(ConnectionAbortedException abortReason)
catch (QuicException ex) when (ex.QuicError == QuicError.ConnectionAborted)
{
// Shutdown initiated by peer, abortive.
_error = ex.ApplicationErrorCode;
_error = ex.ApplicationErrorCode; // Trust Quic to provide us a valid error code
QuicLog.ConnectionAborted(_log, this, ex.ApplicationErrorCode.GetValueOrDefault(), ex);

ThreadPool.UnsafeQueueUserWorkItem(state =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ public OnCloseRegistration(Action<object?> callback, object? state)
public long Error
{
get => _error ?? -1;
set => _error = value;
set
{
QuicTransportOptions.ValidateErrorCode(value);
_error = value;
}
}

public long StreamId { get; private set; }
Expand All @@ -54,6 +58,8 @@ public long Error

public void AbortRead(long errorCode, ConnectionAbortedException abortReason)
{
QuicTransportOptions.ValidateErrorCode(errorCode);

lock (_shutdownLock)
{
if (_stream != null)
Expand All @@ -74,6 +80,8 @@ public void AbortRead(long errorCode, ConnectionAbortedException abortReason)

public void AbortWrite(long errorCode, ConnectionAbortedException abortReason)
{
QuicTransportOptions.ValidateErrorCode(errorCode);

lock (_shutdownLock)
{
if (_stream != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ private async ValueTask DoReceiveAsync()
catch (QuicException ex) when (ex.QuicError is QuicError.StreamAborted or QuicError.ConnectionAborted)
{
// Abort from peer.
_error = ex.ApplicationErrorCode;
_error = ex.ApplicationErrorCode; // Trust Quic to provide us a valid error code
QuicLog.StreamAbortedRead(_log, this, ex.ApplicationErrorCode.GetValueOrDefault());

// This could be ignored if _shutdownReason is already set.
Expand Down Expand Up @@ -434,7 +434,7 @@ private async ValueTask DoSendAsync()
catch (QuicException ex) when (ex.QuicError is QuicError.StreamAborted or QuicError.ConnectionAborted)
{
// Abort from peer.
_error = ex.ApplicationErrorCode;
_error = ex.ApplicationErrorCode; // Trust Quic to provide us a valid error code
QuicLog.StreamAbortedWrite(_log, this, ex.ApplicationErrorCode.GetValueOrDefault());

// This could be ignored if _shutdownReason is already set.
Expand Down Expand Up @@ -501,7 +501,7 @@ public override void Abort(ConnectionAbortedException abortReason)
_shutdownReason = abortReason;
}

var resolvedErrorCode = _error ?? 0;
var resolvedErrorCode = _error ?? 0; // _error is validated on assignment
QuicLog.StreamAbort(_log, this, resolvedErrorCode, abortReason.Message);

if (stream.CanRead)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,15 @@ public long DefaultCloseErrorCode
}
}

private static void ValidateErrorCode(long errorCode)
internal static void ValidateErrorCode(long errorCode)
{
const long MinErrorCode = 0;
const long MaxErrorCode = (1L << 62) - 1;

if (errorCode < MinErrorCode || errorCode > MaxErrorCode)
{
throw new ArgumentOutOfRangeException(nameof(errorCode), errorCode, $"A value between {MinErrorCode} and {MaxErrorCode} is required.");
// Print the values in hex since the max is unintelligible in decimal
throw new ArgumentOutOfRangeException(nameof(errorCode), errorCode, $"A value between 0x{MinErrorCode:x} and 0x{MaxErrorCode:x} is required.");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,33 @@ public async Task PersistentState_StreamsReused_StatePersisted()
Assert.Equal(true, state);
}

[ConditionalTheory]
[MsQuicSupported]
[InlineData(-1L)] // Too small
[InlineData(1L << 62)] // Too big
public async Task IProtocolErrorFeature_InvalidErrorCode(long errorCode)
{
// Arrange
await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory);

var options = QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint);
await using var clientConnection = await QuicConnection.ConnectAsync(options);

await using var serverConnection = await connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout();

// Act
var clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
await clientStream.WriteAsync(TestData).DefaultTimeout();

var serverStream = await serverConnection.AcceptAsync().DefaultTimeout();

var protocolErrorCodeFeature = serverConnection.Features.Get<IProtocolErrorCodeFeature>();

// Assert
Assert.IsType<QuicConnectionContext>(protocolErrorCodeFeature);
Assert.Throws<ArgumentOutOfRangeException>(() => protocolErrorCodeFeature.Error = errorCode);
}

private record RequestState(
QuicConnection QuicConnection,
MultiplexedConnectionContext ServerConnection,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,4 +526,59 @@ public async Task StreamAbortFeature_AbortWrite_ClientReceivesAbort()
var serverEx = await Assert.ThrowsAsync<ConnectionAbortedException>(() => serverReadTask).DefaultTimeout();
Assert.Equal("Test reason", serverEx.Message);
}

[ConditionalTheory]
[MsQuicSupported]
[InlineData(-1L)] // Too small
[InlineData(1L << 62)] // Too big
public async Task IProtocolErrorFeature_InvalidErrorCode(long errorCode)
{
// Arrange
await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory);

var options = QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint);
await using var clientConnection = await QuicConnection.ConnectAsync(options);

await using var serverConnection = await connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout();

// Act
var clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
await clientStream.WriteAsync(TestData).DefaultTimeout();

var serverStream = await serverConnection.AcceptAsync().DefaultTimeout();

var protocolErrorCodeFeature = serverStream.Features.Get<IProtocolErrorCodeFeature>();

// Assert
Assert.IsType<QuicStreamContext>(protocolErrorCodeFeature);
Assert.Throws<ArgumentOutOfRangeException>(() => protocolErrorCodeFeature.Error = errorCode);
}

[ConditionalTheory]
[MsQuicSupported]
[InlineData(-1L)] // Too small
[InlineData(1L << 62)] // Too big
public async Task IStreamAbortFeature_InvalidErrorCode(long errorCode)
{
// Arrange
await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory);

var options = QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint);
await using var clientConnection = await QuicConnection.ConnectAsync(options);

await using var serverConnection = await connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout();

// Act
var clientStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
await clientStream.WriteAsync(TestData).DefaultTimeout();

var serverStream = await serverConnection.AcceptAsync().DefaultTimeout();

var protocolErrorCodeFeature = serverStream.Features.Get<IStreamAbortFeature>();

// Assert
Assert.IsType<QuicStreamContext>(protocolErrorCodeFeature);
Assert.Throws<ArgumentOutOfRangeException>(() => protocolErrorCodeFeature.AbortRead(errorCode, new ConnectionAbortedException()));
Assert.Throws<ArgumentOutOfRangeException>(() => protocolErrorCodeFeature.AbortWrite(errorCode, new ConnectionAbortedException()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Internal;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.AspNetCore.Server.Kestrel.Core;
using Microsoft.AspNetCore.Server.Kestrel.Https;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.AspNetCore.Server.Kestrel.Transport.Quic;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Diagnostics.Metrics;
using Microsoft.Extensions.Diagnostics.Metrics.Testing;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
using Microsoft.Extensions.Primitives;
using Xunit;

namespace Interop.FunctionalTests.Http3;

Expand Down Expand Up @@ -2031,6 +2030,33 @@ public async Task GET_GracefulServerShutdown_RequestCompleteSuccessfullyInsideHo
}
}

[Fact]
amcasey marked this conversation as resolved.
Show resolved Hide resolved
public async Task ServerReset_InvalidErrorCode()
{
var ranHandler = false;
var hostBuilder = CreateHostBuilder(context =>
{
ranHandler = true;
// Can't test a too-large value since it's bigger than int
//Assert.Throws<ArgumentOutOfRangeException>(() => context.Features.Get<IHttpResetFeature>().Reset(-1)); // Invalid negative value
context.Features.Get<IHttpResetFeature>().Reset(-1);
return Task.CompletedTask;
});

using var host = await hostBuilder.StartAsync().DefaultTimeout();
using var client = HttpHelpers.CreateClient();

var request = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/");
request.Version = GetProtocol(HttpProtocols.Http3);
request.VersionPolicy = HttpVersionPolicy.RequestVersionExact;

var response = await client.SendAsync(request, CancellationToken.None).DefaultTimeout();
await host.StopAsync().DefaultTimeout();

Assert.True(ranHandler);
Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode);
}

private IHostBuilder CreateHostBuilder(RequestDelegate requestDelegate, HttpProtocols? protocol = null, Action<KestrelServerOptions> configureKestrel = null)
{
return HttpHelpers.CreateHostBuilder(AddTestLogging, requestDelegate, protocol, configureKestrel);
Expand Down
Loading