diff --git a/src/Libraries/Microsoft.Extensions.Http.Resilience/Hedging/Internals/RequestMessageSnapshotStrategy.cs b/src/Libraries/Microsoft.Extensions.Http.Resilience/Hedging/Internals/RequestMessageSnapshotStrategy.cs index 8f484f1cbcd..9b4fe7e9c8d 100644 --- a/src/Libraries/Microsoft.Extensions.Http.Resilience/Hedging/Internals/RequestMessageSnapshotStrategy.cs +++ b/src/Libraries/Microsoft.Extensions.Http.Resilience/Hedging/Internals/RequestMessageSnapshotStrategy.cs @@ -20,7 +20,9 @@ protected override async ValueTask> ExecuteCore SendAsync(HttpRequestMessage ResilienceContext context = GetOrSetResilienceContext(request, cancellationToken, out bool created); TrySetRequestMetadata(context, request); - SetRequestMessage(context, request); + context.SetRequestMessage(request); try { @@ -117,7 +117,7 @@ protected override HttpResponseMessage Send(HttpRequestMessage request, Cancella ResilienceContext context = GetOrSetResilienceContext(request, cancellationToken, out bool created); TrySetRequestMetadata(context, request); - SetRequestMessage(context, request); + context.SetRequestMessage(request); try { @@ -165,11 +165,8 @@ private static void TrySetRequestMetadata(ResilienceContext context, HttpRequest } } - private static void SetRequestMessage(ResilienceContext context, HttpRequestMessage request) - => context.Properties.Set(ResilienceKeys.RequestMessage, request); - private static HttpRequestMessage GetRequestMessage(ResilienceContext context, HttpRequestMessage request) - => context.Properties.GetValue(ResilienceKeys.RequestMessage, request); + => context.GetRequestMessage() ?? request; private static void RestoreResilienceContext(ResilienceContext context, HttpRequestMessage request, bool created) { diff --git a/src/Libraries/Microsoft.Extensions.Http.Resilience/Routing/Internal/RoutingResilienceStrategy.cs b/src/Libraries/Microsoft.Extensions.Http.Resilience/Routing/Internal/RoutingResilienceStrategy.cs index 0aefc6cf171..2054cbcccbc 100644 --- a/src/Libraries/Microsoft.Extensions.Http.Resilience/Routing/Internal/RoutingResilienceStrategy.cs +++ b/src/Libraries/Microsoft.Extensions.Http.Resilience/Routing/Internal/RoutingResilienceStrategy.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Net.Http; using System.Threading.Tasks; using Microsoft.Extensions.Http.Resilience.Internal; using Microsoft.Shared.Diagnostics; @@ -26,7 +27,9 @@ protected override async ValueTask> ExecuteCore @@ -39,5 +39,15 @@ public void ExecuteAsync_RequestMessageNotFound_Throws() strategy.Invoking(s => s.Execute(() => { })).Should().Throw(); } + [Fact] + public void ExecuteAsync_RequestMessageIsNull_Throws() + { + var strategy = Create(); + var context = ResilienceContextPool.Shared.Get(); + context.SetRequestMessage(null); + + strategy.Invoking(s => s.Execute(_ => { }, context)).Should().Throw(); + } + private static ResiliencePipeline Create() => new ResiliencePipelineBuilder().AddStrategy(_ => new RequestMessageSnapshotStrategy(), Mock.Of()).Build(); } diff --git a/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/HttpResilienceContextExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/HttpResilienceContextExtensionsTests.cs index 3ca02116e16..9cd1255c992 100644 --- a/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/HttpResilienceContextExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/HttpResilienceContextExtensionsTests.cs @@ -24,8 +24,6 @@ public void GetRequestMessage_RequestMessageIsMissing_ReturnsNull() var context = ResilienceContextPool.Shared.Get(); Assert.Null(context.GetRequestMessage()); - - ResilienceContextPool.Shared.Return(context); } [Fact] @@ -35,8 +33,6 @@ public void GetRequestMessage_RequestMessageIsNull_ReturnsNull() context.Properties.Set(ResilienceKeys.RequestMessage, null); Assert.Null(context.GetRequestMessage()); - - ResilienceContextPool.Shared.Return(context); } [Fact] @@ -47,8 +43,6 @@ public void GetRequestMessage_RequestMessageIsPresent_ReturnsRequestMessage() context.Properties.Set(ResilienceKeys.RequestMessage, request); Assert.Same(request, context.GetRequestMessage()); - - ResilienceContextPool.Shared.Return(context); } [Fact] @@ -56,6 +50,7 @@ public void SetRequestMessage_ResilienceContextIsNull_Throws() { ResilienceContext context = null!; using var request = new HttpRequestMessage(); + Assert.Throws(() => context.SetRequestMessage(request)); } @@ -67,8 +62,6 @@ public void SetRequestMessage_RequestMessageIsNull_SetsNullRequestMessage() Assert.True(context.Properties.TryGetValue(ResilienceKeys.RequestMessage, out HttpRequestMessage? request)); Assert.Null(request); - - ResilienceContextPool.Shared.Return(context); } [Fact] @@ -80,7 +73,5 @@ public void SetRequestMessage_RequestMessageIsNotNull_SetsRequestMessage() Assert.True(context.Properties.TryGetValue(ResilienceKeys.RequestMessage, out HttpRequestMessage? actualRequest)); Assert.Same(request, actualRequest); - - ResilienceContextPool.Shared.Return(context); } } diff --git a/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/ResilienceHandlerTest.cs b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/ResilienceHandlerTest.cs index 8a1831d67b8..9de72555401 100644 --- a/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/ResilienceHandlerTest.cs +++ b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Resilience/ResilienceHandlerTest.cs @@ -8,7 +8,6 @@ using System.Threading.Tasks; using FluentAssertions; using Microsoft.Extensions.Http.Diagnostics; -using Microsoft.Extensions.Http.Resilience.Internal; using Microsoft.Extensions.Http.Resilience.Test.Helpers; using Polly; using Xunit; @@ -108,7 +107,7 @@ public async Task Send_EnsureInvoker(bool executionContextSet, bool asynchronous handler.InnerHandler = new TestHandlerStub((r, _) => { r.GetResilienceContext().Should().NotBeNull(); - r.GetResilienceContext()!.Properties.GetValue(ResilienceKeys.RequestMessage, null!).Should().BeSameAs(r); + r.GetResilienceContext()!.GetRequestMessage().Should().BeSameAs(r); return Task.FromResult(new HttpResponseMessage { StatusCode = HttpStatusCode.Created }); }); diff --git a/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Routing/RoutingResilienceStrategyTests.cs b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Routing/RoutingResilienceStrategyTests.cs index 341d32087f8..ed1a6d5c6b2 100644 --- a/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Routing/RoutingResilienceStrategyTests.cs +++ b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Routing/RoutingResilienceStrategyTests.cs @@ -4,7 +4,6 @@ using System; using System.Net.Http; using FluentAssertions; -using Microsoft.Extensions.Http.Resilience.Internal; using Microsoft.Extensions.Http.Resilience.Routing.Internal; using Moq; using Polly; @@ -22,6 +21,16 @@ public void NoRequestMessage_Throws() strategy.Invoking(s => s.Execute(() => { })).Should().Throw().WithMessage("The HTTP request message was not found in the resilience context."); } + [Fact] + public void RequestMessageIsNull_Throws() + { + var strategy = Create(() => Mock.Of()); + var context = ResilienceContextPool.Shared.Get(); + context.SetRequestMessage(null); + + strategy.Invoking(s => s.Execute(_ => { }, context)).Should().Throw().WithMessage("The HTTP request message was not found in the resilience context."); + } + [Fact] public void NoRoutingProvider_Ok() { @@ -29,7 +38,7 @@ public void NoRoutingProvider_Ok() var strategy = Create(null); var context = ResilienceContextPool.Shared.Get(); - context.Properties.Set(ResilienceKeys.RequestMessage, request); + context.SetRequestMessage(request); strategy.Invoking(s => s.Execute(_ => { }, context)).Should().NotThrow(); }