Skip to content

Commit

Permalink
Fixes dotnet#4957
Browse files Browse the repository at this point in the history
Replaces usage of ResilienceKeys.RequestMessage variable with corresponding Get/Set methods
  • Loading branch information
iliar-turdushev committed Sep 30, 2024
1 parent 5631d0d commit 69a412c
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ protected override async ValueTask<Outcome<TResult>> ExecuteCore<TResult, TState
ResilienceContext context,
TState state)
{
if (!context.Properties.TryGetValue(ResilienceKeys.RequestMessage, out var request) || request is null)
HttpRequestMessage? request = context.GetRequestMessage();

if (request is null)
{
Throw.InvalidOperationException("The HTTP request message was not found in the resilience context.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public static IStandardHedgingHandlerBuilder AddStandardHedgingHandler(this IHtt
requestMessage.SetResilienceContext(args.ActionContext);

// replace the request message
args.ActionContext.Properties.Set(ResilienceKeys.RequestMessage, requestMessage);
args.ActionContext.SetRequestMessage(requestMessage);

if (args.PrimaryContext.Properties.TryGetValue(ResilienceKeys.RoutingStrategy, out var routingPipeline))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage

ResilienceContext context = GetOrSetResilienceContext(request, cancellationToken, out bool created);
TrySetRequestMetadata(context, request);
SetRequestMessage(context, request);
context.SetRequestMessage(request);

try
{
Expand Down Expand Up @@ -117,7 +117,7 @@ protected override HttpResponseMessage Send(HttpRequestMessage request, Cancella

ResilienceContext context = GetOrSetResilienceContext(request, cancellationToken, out bool created);
TrySetRequestMetadata(context, request);
SetRequestMessage(context, request);
context.SetRequestMessage(request);

try
{
Expand Down Expand Up @@ -165,11 +165,8 @@ private static void TrySetRequestMetadata(ResilienceContext context, HttpRequest
}
}

private static void SetRequestMessage(ResilienceContext context, HttpRequestMessage request)
=> context.Properties.Set(ResilienceKeys.RequestMessage, request);

private static HttpRequestMessage GetRequestMessage(ResilienceContext context, HttpRequestMessage request)
=> context.Properties.GetValue(ResilienceKeys.RequestMessage, request);
=> context.GetRequestMessage() ?? request;

private static void RestoreResilienceContext(ResilienceContext context, HttpRequestMessage request, bool created)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.Extensions.Http.Resilience.Internal;
using Microsoft.Shared.Diagnostics;
Expand All @@ -26,7 +27,9 @@ protected override async ValueTask<Outcome<TResult>> ExecuteCore<TResult, TState
ResilienceContext context,
TState state)
{
if (!context.Properties.TryGetValue(ResilienceKeys.RequestMessage, out var request))
HttpRequestMessage? request = context.GetRequestMessage();

if (request is null)
{
Throw.InvalidOperationException("The HTTP request message was not found in the resilience context.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public async Task ExecuteAsync_EnsureSnapshotAttached()
var strategy = Create();
var context = ResilienceContextPool.Shared.Get();
using var request = new HttpRequestMessage();
context.Properties.Set(ResilienceKeys.RequestMessage, request);
context.SetRequestMessage(request);

using var response = await strategy.ExecuteAsync(
context =>
Expand All @@ -39,5 +39,15 @@ public void ExecuteAsync_RequestMessageNotFound_Throws()
strategy.Invoking(s => s.Execute(() => { })).Should().Throw<InvalidOperationException>();
}

[Fact]
public void ExecuteAsync_RequestMessageIsNull_Throws()
{
var strategy = Create();
var context = ResilienceContextPool.Shared.Get();
context.SetRequestMessage(null);

strategy.Invoking(s => s.Execute(_ => { }, context)).Should().Throw<InvalidOperationException>();
}

private static ResiliencePipeline Create() => new ResiliencePipelineBuilder().AddStrategy(_ => new RequestMessageSnapshotStrategy(), Mock.Of<ResilienceStrategyOptions>()).Build();
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ public void GetRequestMessage_RequestMessageIsMissing_ReturnsNull()
var context = ResilienceContextPool.Shared.Get();

Assert.Null(context.GetRequestMessage());

ResilienceContextPool.Shared.Return(context);
}

[Fact]
Expand All @@ -35,8 +33,6 @@ public void GetRequestMessage_RequestMessageIsNull_ReturnsNull()
context.Properties.Set(ResilienceKeys.RequestMessage, null);

Assert.Null(context.GetRequestMessage());

ResilienceContextPool.Shared.Return(context);
}

[Fact]
Expand All @@ -47,15 +43,14 @@ public void GetRequestMessage_RequestMessageIsPresent_ReturnsRequestMessage()
context.Properties.Set(ResilienceKeys.RequestMessage, request);

Assert.Same(request, context.GetRequestMessage());

ResilienceContextPool.Shared.Return(context);
}

[Fact]
public void SetRequestMessage_ResilienceContextIsNull_Throws()
{
ResilienceContext context = null!;
using var request = new HttpRequestMessage();

Assert.Throws<ArgumentNullException>(() => context.SetRequestMessage(request));
}

Expand All @@ -67,8 +62,6 @@ public void SetRequestMessage_RequestMessageIsNull_SetsNullRequestMessage()

Assert.True(context.Properties.TryGetValue(ResilienceKeys.RequestMessage, out HttpRequestMessage? request));
Assert.Null(request);

ResilienceContextPool.Shared.Return(context);
}

[Fact]
Expand All @@ -80,7 +73,5 @@ public void SetRequestMessage_RequestMessageIsNotNull_SetsRequestMessage()

Assert.True(context.Properties.TryGetValue(ResilienceKeys.RequestMessage, out HttpRequestMessage? actualRequest));
Assert.Same(request, actualRequest);

ResilienceContextPool.Shared.Return(context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
using System.Threading.Tasks;
using FluentAssertions;
using Microsoft.Extensions.Http.Diagnostics;
using Microsoft.Extensions.Http.Resilience.Internal;
using Microsoft.Extensions.Http.Resilience.Test.Helpers;
using Polly;
using Xunit;
Expand Down Expand Up @@ -108,7 +107,7 @@ public async Task Send_EnsureInvoker(bool executionContextSet, bool asynchronous
handler.InnerHandler = new TestHandlerStub((r, _) =>
{
r.GetResilienceContext().Should().NotBeNull();
r.GetResilienceContext()!.Properties.GetValue(ResilienceKeys.RequestMessage, null!).Should().BeSameAs(r);
r.GetResilienceContext()!.GetRequestMessage().Should().BeSameAs(r);

return Task.FromResult(new HttpResponseMessage { StatusCode = HttpStatusCode.Created });
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System;
using System.Net.Http;
using FluentAssertions;
using Microsoft.Extensions.Http.Resilience.Internal;
using Microsoft.Extensions.Http.Resilience.Routing.Internal;
using Moq;
using Polly;
Expand All @@ -22,14 +21,24 @@ public void NoRequestMessage_Throws()
strategy.Invoking(s => s.Execute(() => { })).Should().Throw<InvalidOperationException>().WithMessage("The HTTP request message was not found in the resilience context.");
}

[Fact]
public void RequestMessageIsNull_Throws()
{
var strategy = Create(() => Mock.Of<RequestRoutingStrategy>());
var context = ResilienceContextPool.Shared.Get();
context.SetRequestMessage(null);

strategy.Invoking(s => s.Execute(_ => { }, context)).Should().Throw<InvalidOperationException>().WithMessage("The HTTP request message was not found in the resilience context.");
}

[Fact]
public void NoRoutingProvider_Ok()
{
using var request = new HttpRequestMessage();

var strategy = Create(null);
var context = ResilienceContextPool.Shared.Get();
context.Properties.Set(ResilienceKeys.RequestMessage, request);
context.SetRequestMessage(request);

strategy.Invoking(s => s.Execute(_ => { }, context)).Should().NotThrow();
}
Expand Down

0 comments on commit 69a412c

Please sign in to comment.