Skip to content

Commit

Permalink
Error handling callback for ReadAuthnRequest
Browse files Browse the repository at this point in the history
  • Loading branch information
AndersAbel committed Feb 3, 2024
1 parent 269ecb6 commit d813c80
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 5 deletions.
5 changes: 5 additions & 0 deletions src/Sustainsys.Saml2.AspNetCore/ServiceResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

namespace Sustainsys.Saml2.AspNetCore;

// TODO: Replace with .NET 8 keyed services
// Resolve services from key (= scheme) first, then fallback to
// normal registration. Add default service implementations as singletons to DI.

/// <summary>
/// The Sustainsys.Saml2 library uses multiple loosely coupled services internally. The
/// default implementation is to not register these in the main dependency injection
Expand Down Expand Up @@ -63,6 +67,7 @@ public class ResolverContext(
public Func<ResolverContext, Saml2Events> CreateEvents { get; set; }
= _ => new Saml2Events();

// TODO: Can this be a shared instance?
/// <summary>
/// Factory for <see cref="ISamlXmlReader"/>
/// </summary>
Expand Down
4 changes: 3 additions & 1 deletion src/Sustainsys.Saml2/Serialization/ISamlXmlReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,7 @@ public interface ISamlXmlReader
/// <param name="source">Xml Traverser to read from</param>
/// <param name="errorInspector">Callback that can inspect and alter errors before throwing</param>
/// <returns><see cref="AuthnRequest"/></returns>
AuthnRequest ReadAuthnRequest(XmlTraverser source, Action<AuthnRequest, IList<Error>>? errorInspector = null);
AuthnRequest ReadAuthnRequest(
XmlTraverser source,
Action<ReadErrorInspectorContext<AuthnRequest>>? errorInspector = null);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Sustainsys.Saml2.Samlp;
using Sustainsys.Saml2.Xml;
using System.Xml;
using static Sustainsys.Saml2.Constants;

namespace Sustainsys.Saml2.Serialization;
Expand All @@ -14,12 +15,14 @@ public partial class SamlXmlReader
//TODO: Convert other reads to follow this pattern with a callback for errors

/// <inheritdoc/>
public virtual AuthnRequest ReadAuthnRequest(
public AuthnRequest ReadAuthnRequest(
XmlTraverser source,
Action<AuthnRequest, IList<Error>>? errorInspector = null)
Action<ReadErrorInspectorContext<AuthnRequest>>? errorInspector = null)
{
var authnRequest = ReadAuthnRequest(source);

CallErrorInspector(errorInspector, authnRequest, source);

source.ThrowOnErrors();

return authnRequest;
Expand Down
19 changes: 19 additions & 0 deletions src/Sustainsys.Saml2/Serialization/SamlXmlReader.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Sustainsys.Saml2.Common;
using Sustainsys.Saml2.Saml;
using Sustainsys.Saml2.Samlp;
using Sustainsys.Saml2.Xml;
using System;
using System.Collections.Generic;
Expand All @@ -8,6 +9,7 @@
using System.Security.Cryptography.Xml;
using System.Text;
using System.Threading.Tasks;
using System.Xml;
using static Sustainsys.Saml2.Constants;

namespace Sustainsys.Saml2.Serialization;
Expand Down Expand Up @@ -79,4 +81,21 @@ protected virtual void ThrowOnErrors(XmlTraverser source)
return (trustedSigningKeys, allowedHashAlgorithms);
}

private void CallErrorInspector<TData>(
Action<ReadErrorInspectorContext<TData>>? errorInspector,
TData data,
XmlTraverser source)
{
if (errorInspector != null)
{
var context = new ReadErrorInspectorContext<TData>()
{
Data = data,
Errors = source.Errors,
XmlSource = source.RootNode
};

errorInspector(context);
}
}
}
31 changes: 31 additions & 0 deletions src/Sustainsys.Saml2/Xml/ReadErrorInspectorContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using Microsoft.Extensions.Configuration.Xml;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Xml;

namespace Sustainsys.Saml2.Xml;

/// <summary>
/// Context for an error inspector.
/// </summary>
/// <typeparam name="TData">Type of the data read</typeparam>
public class ReadErrorInspectorContext<TData>
{
/// <summary>
/// The data read
/// </summary>
public required TData Data { get; set; }

/// <summary>
/// The XML source, if this was a parsing event.
/// </summary>
public required XmlNode? XmlSource { get; set; }

/// <summary>
/// The errors found
/// </summary>
public required IList<Error> Errors { get; set; }
}
3 changes: 3 additions & 0 deletions src/Sustainsys.Saml2/Xml/XmlTraverser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ public class XmlTraverser
/// </summary>
private bool childrenHandled = true;

internal XmlNode? RootNode { get; set; }

/// <summary>
/// The current node being processed.
/// </summary>
Expand All @@ -50,6 +52,7 @@ public class XmlTraverser
/// <param name="rootNode">Root node for this traverser</param>
public XmlTraverser(XmlNode rootNode)
{
RootNode = rootNode;
CurrentNode = rootNode;
Errors = [];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,49 @@ public void ReadAuthnRequest_CanReadOptional()
actual.Should().BeEquivalentTo(expected);
}

// TODO: Test with AssertionConsumerServiceIndex - note mutually exclusive to AcsUrl + Binding
[Fact]
public void ReadAuthnRequest_ErrorCallback()
{
var source = GetXmlTraverser(nameof(ReadAuthnRequest_Error));

var subject = new SamlXmlReader();

bool errorInspectorCalled = false;

void errorInspector(ReadErrorInspectorContext<AuthnRequest> context)
{
context.Data.Id.Should().Be("x123");

var xmlSourceElement = context.XmlSource as XmlElement;
xmlSourceElement.Should().NotBeNull();
xmlSourceElement!.GetAttribute("ID").Should().Be("x123");
context.Errors.Count.Should().Be(1);

var error = context.Errors.Single();
error.Node.Should().BeSameAs(context.XmlSource);
error.LocalName.Should().Be("Version");
error.Reason.Should().Be(ErrorReason.MissingAttribute);
error.Ignore.Should().BeFalse();

error.Ignore = true;

errorInspectorCalled = true;
}

// TODO: Test error callback
var actual = subject.ReadAuthnRequest(source, errorInspector);

errorInspectorCalled.Should().BeTrue();
}

[Fact]
public void ReadAuthnRequest_Error()
{
var source = GetXmlTraverser();

var subject = new SamlXmlReader();

subject.Invoking(s => s.ReadAuthnRequest(source))
.Should().Throw<SamlXmlException>()
.WithMessage("*Version*not found*");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
<saml:AuthnRequest xmlns:saml="urn:oasis:names:tc:SAML:2.0:protocol"
ID="x123" IssueInstant="2023-11-24T22:44:14Z" />

0 comments on commit d813c80

Please sign in to comment.