diff --git a/src/Sustainsys.Saml2.AspNetCore/Saml2Extensions.cs b/src/Sustainsys.Saml2.AspNetCore/Saml2Extensions.cs index 344a76a75..2078a1b81 100644 --- a/src/Sustainsys.Saml2.AspNetCore/Saml2Extensions.cs +++ b/src/Sustainsys.Saml2.AspNetCore/Saml2Extensions.cs @@ -5,6 +5,8 @@ using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Options; using Sustainsys.Saml2.AspNetCore; +using Sustainsys.Saml2.Bindings; +using Sustainsys.Saml2.Serialization; namespace Microsoft.Extensions.DependencyInjection; @@ -63,6 +65,14 @@ public static AuthenticationBuilder AddSaml2( ServiceDescriptor.Singleton, Saml2PostConfigureOptions>()); + builder.Services.TryAddSingleton(); + builder.Services.TryAddSingleton(); + + builder.Services.TryAddEnumerable( + ServiceDescriptor.Singleton()); + builder.Services.TryAddEnumerable( + ServiceDescriptor.Singleton()); + return builder.AddRemoteScheme(authenticationScheme, displayName, configureOptions); } } diff --git a/src/Sustainsys.Saml2.AspNetCore/Saml2Handler.cs b/src/Sustainsys.Saml2.AspNetCore/Saml2Handler.cs index 13a02a833..452b6d126 100644 --- a/src/Sustainsys.Saml2.AspNetCore/Saml2Handler.cs +++ b/src/Sustainsys.Saml2.AspNetCore/Saml2Handler.cs @@ -1,10 +1,13 @@ using Microsoft.AspNetCore.Authentication; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Sustainsys.Saml2.AspNetCore.Events; using Sustainsys.Saml2.Bindings; using Sustainsys.Saml2.Samlp; +using Sustainsys.Saml2.Serialization; using Sustainsys.Saml2.Xml; +using System.Runtime.Serialization; using System.Text.Encodings.Web; namespace Sustainsys.Saml2.AspNetCore; @@ -17,24 +20,30 @@ namespace Sustainsys.Saml2.AspNetCore; /// Options /// Logger factory /// Url encoder +/// Keyed service provider used to resolve services public class Saml2Handler( IOptionsMonitor options, ILoggerFactory logger, - UrlEncoder encoder) + UrlEncoder encoder, + IKeyedServiceProvider keyedServiceProvider) : RemoteAuthenticationHandler(options, logger, encoder) { + private TService GetRequiredService() where TService : notnull => + keyedServiceProvider.GetKeyedService(Scheme.Name) ?? + keyedServiceProvider.GetRequiredService(); + + private IFrontChannelBinding GetFrontChannelBinding(string uri) => + GetRequiredService>() + .Single(b => b.Identifier == uri); /// - /// Create events by invoking Options.ServiceResolver.CreateEventsAsync() + /// Resolves events as keyed service from DI, falls back to non-keyed service and + /// finally falls back to creating an events instance. /// /// Saml2 events instance - protected override Task CreateEventsAsync() => Task.FromResult(GetService(s => s.CreateEvents)); - - private T GetService( - Func> factorySelector, - AuthenticationProperties? authenticationProperties = null) => - factorySelector(Options.ServiceResolver) - (new ServiceResolver.ResolverContext(Context, Options, Scheme, authenticationProperties)); + protected override Task CreateEventsAsync() => Task.FromResult( + keyedServiceProvider.GetKeyedService(Scheme.Name) ?? + new Saml2Events()); /// /// Events represents the easiest and most straight forward way to customize the @@ -52,7 +61,7 @@ private T GetService( /// protected override async Task HandleRemoteAuthenticateAsync() { - var bindings = GetService(sr => sr.GetAllBindings); + var bindings = GetRequiredService>(); var binding = bindings.SingleOrDefault(b => b.CanUnbind(Context.Request)); @@ -64,7 +73,7 @@ protected override async Task HandleRemoteAuthenticateAsync var samlMessage = await binding.UnbindAsync(Context.Request, str => Task.FromResult(Options.IdentityProvider!)); var source = XmlHelpers.GetXmlTraverser(samlMessage.Xml); - var reader = GetService(sr => sr.GetSamlXmlReader); + var reader = GetRequiredService(); var samlResponse = reader.ReadSamlResponse(source); // For now, to make half-baked test pass. @@ -88,7 +97,7 @@ protected override async Task HandleChallengeAsync(AuthenticationProperties prop var authnRequestGeneratedContext = new AuthnRequestGeneratedContext(Context, Scheme, Options, properties, authnRequest); await Events.AuthnRequestGeneratedAsync(authnRequestGeneratedContext); - var xmlDoc = GetService(sr => sr.GetSamlXmlWriter, properties).Write(authnRequest); + var xmlDoc = GetRequiredService().Write(authnRequest); var message = new Saml2Message { @@ -97,8 +106,7 @@ protected override async Task HandleChallengeAsync(AuthenticationProperties prop Xml = xmlDoc.DocumentElement!, }; - var binding = Options.ServiceResolver.GetBinding( - new(Context, Options, Scheme, properties, Options.IdentityProvider.SsoServiceBinding!)); + var binding = GetFrontChannelBinding(Options.IdentityProvider.SsoServiceBinding!); await binding.BindAsync(Response, message); } diff --git a/src/Sustainsys.Saml2.AspNetCore/Saml2Options.cs b/src/Sustainsys.Saml2.AspNetCore/Saml2Options.cs index d18bd7002..8f9f466de 100644 --- a/src/Sustainsys.Saml2.AspNetCore/Saml2Options.cs +++ b/src/Sustainsys.Saml2.AspNetCore/Saml2Options.cs @@ -17,14 +17,11 @@ public Saml2Options() CallbackPath = new PathString("/Saml2/Acs"); } - /// - /// The service resolver can - /// - public ServiceResolver ServiceResolver { get; set; } = new ServiceResolver(); - /// /// Events can be used to override behaviour. Setting this property is the easy way. - /// To resolve the events form DI, use + /// To resolve from DI register Saml2Events as a keyed service with the scheme name + /// as the key, or to use the same events for all schemes register as an normal + /// service /// public new Saml2Events Events { diff --git a/src/Sustainsys.Saml2.AspNetCore/ServiceResolver.cs b/src/Sustainsys.Saml2.AspNetCore/ServiceResolver.cs deleted file mode 100644 index 498f1af5d..000000000 --- a/src/Sustainsys.Saml2.AspNetCore/ServiceResolver.cs +++ /dev/null @@ -1,129 +0,0 @@ -using Microsoft.AspNetCore.Authentication; -using Microsoft.AspNetCore.Http; -using Sustainsys.Saml2.Bindings; -using Sustainsys.Saml2.Samlp; -using Sustainsys.Saml2.Serialization; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; - -namespace Sustainsys.Saml2.AspNetCore; - -// TODO: Replace with .NET 8 keyed services -// Resolve services from key (= scheme) first, then fallback to -// normal registration. Add default service implementations as singletons to DI. - -/// -/// The Sustainsys.Saml2 library uses multiple loosely coupled services internally. The -/// default implementation is to not register these in the main dependency injection -/// system to avoid clutter and to allow per scheme registrations. All services are -/// resolved using the service resolver. To override services, override the factory -/// method here. The resolver context always contains the HttpContext, which can be -/// used to resolve services from DI. -/// -public class ServiceResolver -{ - /// - /// Context for service resolver - /// - /// Current http context - /// Current options - /// Current authentication scheme - /// Authentication properties, if available - public class ResolverContext( - HttpContext context, - Saml2Options options, - AuthenticationScheme scheme, - AuthenticationProperties? authenticationProperties) - { - /// - /// Current HttpContext - /// - public HttpContext Context { get; } = context; - - /// - /// Current options - /// - public Saml2Options Options { get; } = options; - - /// - /// Current authentication scheme - /// - public AuthenticationScheme Scheme { get; } = scheme; - - /// - /// Current AuthenticationProperties, if available - /// - public AuthenticationProperties? AuthenticationProperties { get; } = authenticationProperties; - } - - /// - /// Factory for the events class. Defaults to returning a new Saml2Events instance. It's usually - /// easier to just set the Events property on the options than to use this. If you want to - /// resolve the events from DI, this is the best place to do it. - /// - public Func CreateEvents { get; set; } - = _ => new Saml2Events(); - - // TODO: Can this be a shared instance? - /// - /// Factory for - /// - public Func GetSamlXmlReader { get; set; } - = _ => new SamlXmlReader(); - - // TODO: Can this be a static instance? - /// - /// Factory for - /// - public Func GetSamlXmlWriter { get; set; } - = _ => new SamlXmlWriter(); - - // TODO: Can this be a static instance? - /// - /// Factory for collection of front channel bindings. - /// - public Func> GetAllBindings { get; set; } - = _ => new IFrontChannelBinding[] { new HttpRedirectBinding(), new HttpPostBinding() }; - - /// - /// Context for resolving binding - /// - /// Current http context - /// Current options - /// Current authentication scheme - /// Authentication properties, if available - /// Uri for requested binding - public class BindingResolverContext( - HttpContext context, - Saml2Options options, - AuthenticationScheme scheme, - AuthenticationProperties? authenticationProperties, - string binding) - : ResolverContext(context, options, scheme, authenticationProperties) - { - - /// - /// Requested binding - /// - public string Binding { get; } = binding; - } - - /// - /// Factory for front channel bindings - /// - public Func GetBinding { get; set; } - = ctx => - { - if(string.IsNullOrEmpty(ctx.Binding)) - { - throw new ArgumentNullException("Binding property must have value to get binding"); - } - - return ctx.Options.ServiceResolver.GetAllBindings(ctx) - .SingleOrDefault(b => b.Identifier == ctx.Binding) - ?? throw new NotImplementedException($"Unknown binding {ctx.Binding} requested"); - }; -} \ No newline at end of file diff --git a/src/Tests/Sustainsys.Saml2.AspNetCore.Tests/Saml2HandlerTests.cs b/src/Tests/Sustainsys.Saml2.AspNetCore.Tests/Saml2HandlerTests.cs index e417f9987..62418f3e5 100644 --- a/src/Tests/Sustainsys.Saml2.AspNetCore.Tests/Saml2HandlerTests.cs +++ b/src/Tests/Sustainsys.Saml2.AspNetCore.Tests/Saml2HandlerTests.cs @@ -13,6 +13,7 @@ using Sustainsys.Saml2.Tests.Helpers; using System.Text; using Sustainsys.Saml2.Serialization; +using Microsoft.Extensions.DependencyInjection; namespace Sustainsys.Saml2.AspNetCore.Tests; public class Saml2HandlerTests @@ -26,10 +27,21 @@ public class Saml2HandlerTests var loggerFactory = Substitute.For(); + var keyedServiceProvider = Substitute.For(); + keyedServiceProvider.GetService(typeof(ISamlXmlReader)).Returns(new SamlXmlReader()); + keyedServiceProvider.GetService(typeof(ISamlXmlWriter)).Returns(new SamlXmlWriter()); + keyedServiceProvider.GetService(typeof(IEnumerable)).Returns( + new IFrontChannelBinding[] + { + new HttpRedirectBinding(), + new HttpPostBinding() + }); + var handler = new Saml2Handler( optionsMonitor, loggerFactory, - UrlEncoder.Default); + UrlEncoder.Default, + keyedServiceProvider); var scheme = new AuthenticationScheme("Saml2", "Saml2", typeof(Saml2Handler));