Skip to content

Commit

Permalink
Replace service resolver with keyed services
Browse files Browse the repository at this point in the history
  • Loading branch information
AndersAbel committed Feb 3, 2024
1 parent d813c80 commit d23857e
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 150 deletions.
10 changes: 10 additions & 0 deletions src/Sustainsys.Saml2.AspNetCore/Saml2Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Options;
using Sustainsys.Saml2.AspNetCore;
using Sustainsys.Saml2.Bindings;
using Sustainsys.Saml2.Serialization;

namespace Microsoft.Extensions.DependencyInjection;

Expand Down Expand Up @@ -63,6 +65,14 @@ public static AuthenticationBuilder AddSaml2(
ServiceDescriptor.Singleton<IPostConfigureOptions<Saml2Options>,
Saml2PostConfigureOptions>());

builder.Services.TryAddSingleton<ISamlXmlReader, SamlXmlReader>();
builder.Services.TryAddSingleton<ISamlXmlWriter, SamlXmlWriter>();

builder.Services.TryAddEnumerable(
ServiceDescriptor.Singleton<IFrontChannelBinding, HttpRedirectBinding>());
builder.Services.TryAddEnumerable(
ServiceDescriptor.Singleton<IFrontChannelBinding, HttpPostBinding>());

return builder.AddRemoteScheme<Saml2Options, Saml2Handler>(authenticationScheme, displayName, configureOptions);
}
}
36 changes: 22 additions & 14 deletions src/Sustainsys.Saml2.AspNetCore/Saml2Handler.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
using Microsoft.AspNetCore.Authentication;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Sustainsys.Saml2.AspNetCore.Events;
using Sustainsys.Saml2.Bindings;
using Sustainsys.Saml2.Samlp;
using Sustainsys.Saml2.Serialization;
using Sustainsys.Saml2.Xml;
using System.Runtime.Serialization;
using System.Text.Encodings.Web;

namespace Sustainsys.Saml2.AspNetCore;
Expand All @@ -17,24 +20,30 @@ namespace Sustainsys.Saml2.AspNetCore;
/// <param name="options">Options</param>
/// <param name="logger">Logger factory</param>
/// <param name="encoder">Url encoder</param>
/// <param name="keyedServiceProvider">Keyed service provider used to resolve services</param>
public class Saml2Handler(
IOptionsMonitor<Saml2Options> options,
ILoggerFactory logger,
UrlEncoder encoder)
UrlEncoder encoder,
IKeyedServiceProvider keyedServiceProvider)
: RemoteAuthenticationHandler<Saml2Options>(options, logger, encoder)
{
private TService GetRequiredService<TService>() where TService : notnull =>
keyedServiceProvider.GetKeyedService<TService>(Scheme.Name) ??
keyedServiceProvider.GetRequiredService<TService>();

private IFrontChannelBinding GetFrontChannelBinding(string uri) =>
GetRequiredService<IEnumerable<IFrontChannelBinding>>()
.Single(b => b.Identifier == uri);

/// <summary>
/// Create events by invoking Options.ServiceResolver.CreateEventsAsync()
/// Resolves events as keyed service from DI, falls back to non-keyed service and
/// finally falls back to creating an events instance.
/// </summary>
/// <returns><see cref="Saml2Events"/>Saml2 events instance</returns>
protected override Task<object> CreateEventsAsync() => Task.FromResult<object>(GetService(s => s.CreateEvents));

private T GetService<T>(
Func<ServiceResolver, Func<ServiceResolver.ResolverContext, T>> factorySelector,
AuthenticationProperties? authenticationProperties = null) =>
factorySelector(Options.ServiceResolver)
(new ServiceResolver.ResolverContext(Context, Options, Scheme, authenticationProperties));
protected override Task<object> CreateEventsAsync() => Task.FromResult<object>(
keyedServiceProvider.GetKeyedService<Saml2Events>(Scheme.Name) ??
new Saml2Events());

/// <summary>
/// Events represents the easiest and most straight forward way to customize the
Expand All @@ -52,7 +61,7 @@ private T GetService<T>(
/// <returns></returns>
protected override async Task<HandleRequestResult> HandleRemoteAuthenticateAsync()
{
var bindings = GetService(sr => sr.GetAllBindings);
var bindings = GetRequiredService<IEnumerable<IFrontChannelBinding>>();

var binding = bindings.SingleOrDefault(b => b.CanUnbind(Context.Request));

Expand All @@ -64,7 +73,7 @@ protected override async Task<HandleRequestResult> HandleRemoteAuthenticateAsync
var samlMessage = await binding.UnbindAsync(Context.Request, str => Task.FromResult<Saml2Entity>(Options.IdentityProvider!));

var source = XmlHelpers.GetXmlTraverser(samlMessage.Xml);
var reader = GetService(sr => sr.GetSamlXmlReader);
var reader = GetRequiredService<ISamlXmlReader>();
var samlResponse = reader.ReadSamlResponse(source);

// For now, to make half-baked test pass.
Expand All @@ -88,7 +97,7 @@ protected override async Task HandleChallengeAsync(AuthenticationProperties prop
var authnRequestGeneratedContext = new AuthnRequestGeneratedContext(Context, Scheme, Options, properties, authnRequest);
await Events.AuthnRequestGeneratedAsync(authnRequestGeneratedContext);

var xmlDoc = GetService(sr => sr.GetSamlXmlWriter, properties).Write(authnRequest);
var xmlDoc = GetRequiredService<ISamlXmlWriter>().Write(authnRequest);

var message = new Saml2Message
{
Expand All @@ -97,8 +106,7 @@ protected override async Task HandleChallengeAsync(AuthenticationProperties prop
Xml = xmlDoc.DocumentElement!,
};

var binding = Options.ServiceResolver.GetBinding(
new(Context, Options, Scheme, properties, Options.IdentityProvider.SsoServiceBinding!));
var binding = GetFrontChannelBinding(Options.IdentityProvider.SsoServiceBinding!);

await binding.BindAsync(Response, message);
}
Expand Down
9 changes: 3 additions & 6 deletions src/Sustainsys.Saml2.AspNetCore/Saml2Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@ public Saml2Options()
CallbackPath = new PathString("/Saml2/Acs");
}

/// <summary>
/// The service resolver can
/// </summary>
public ServiceResolver ServiceResolver { get; set; } = new ServiceResolver();

/// <summary>
/// Events can be used to override behaviour. Setting this property is the easy way.
/// To resolve the events form DI, use <see cref="ServiceResolver.CreateEvents"/>
/// To resolve from DI register Saml2Events as a keyed service with the scheme name
/// as the key, or to use the same events for all schemes register as an normal
/// service
/// </summary>
public new Saml2Events Events
{
Expand Down
129 changes: 0 additions & 129 deletions src/Sustainsys.Saml2.AspNetCore/ServiceResolver.cs

This file was deleted.

14 changes: 13 additions & 1 deletion src/Tests/Sustainsys.Saml2.AspNetCore.Tests/Saml2HandlerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Sustainsys.Saml2.Tests.Helpers;
using System.Text;
using Sustainsys.Saml2.Serialization;
using Microsoft.Extensions.DependencyInjection;

namespace Sustainsys.Saml2.AspNetCore.Tests;
public class Saml2HandlerTests
Expand All @@ -26,10 +27,21 @@ public class Saml2HandlerTests

var loggerFactory = Substitute.For<ILoggerFactory>();

var keyedServiceProvider = Substitute.For<IKeyedServiceProvider>();
keyedServiceProvider.GetService(typeof(ISamlXmlReader)).Returns(new SamlXmlReader());
keyedServiceProvider.GetService(typeof(ISamlXmlWriter)).Returns(new SamlXmlWriter());
keyedServiceProvider.GetService(typeof(IEnumerable<IFrontChannelBinding>)).Returns(
new IFrontChannelBinding[]
{
new HttpRedirectBinding(),
new HttpPostBinding()
});

var handler = new Saml2Handler(
optionsMonitor,
loggerFactory,
UrlEncoder.Default);
UrlEncoder.Default,
keyedServiceProvider);

var scheme = new AuthenticationScheme("Saml2", "Saml2", typeof(Saml2Handler));

Expand Down

0 comments on commit d23857e

Please sign in to comment.