From 7422a1134a7f8d0cabd97b081a2e2c0fb5d2f50d Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Fri, 31 May 2024 14:41:05 -0600 Subject: [PATCH] Allow logout+jwt JWT type Closes gh-15003 --- ...ckChannelLogoutAuthenticationProvider.java | 32 +++++++++++++++++-- .../client/OidcLogoutConfigurerTests.java | 12 ++++--- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcBackChannelLogoutAuthenticationProvider.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcBackChannelLogoutAuthenticationProvider.java index d8a217f2632..16731f19561 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcBackChannelLogoutAuthenticationProvider.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcBackChannelLogoutAuthenticationProvider.java @@ -16,6 +16,11 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.client; +import com.nimbusds.jose.JOSEObjectType; +import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier; +import com.nimbusds.jose.proc.JOSEObjectTypeVerifier; +import com.nimbusds.jose.proc.SecurityContext; + import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.core.Authentication; @@ -26,11 +31,14 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.converter.ClaimTypeConverter; import org.springframework.security.oauth2.jwt.BadJwtException; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtDecoderFactory; +import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** * An {@link AuthenticationProvider} that authenticates an OIDC Logout Token; namely @@ -56,9 +64,27 @@ final class OidcBackChannelLogoutAuthenticationProvider implements Authenticatio * Construct an {@link OidcBackChannelLogoutAuthenticationProvider} */ OidcBackChannelLogoutAuthenticationProvider() { - OidcIdTokenDecoderFactory logoutTokenDecoderFactory = new OidcIdTokenDecoderFactory(); - logoutTokenDecoderFactory.setJwtValidatorFactory(new DefaultOidcLogoutTokenValidatorFactory()); - this.logoutTokenDecoderFactory = logoutTokenDecoderFactory; + DefaultOidcLogoutTokenValidatorFactory jwtValidator = new DefaultOidcLogoutTokenValidatorFactory(); + this.logoutTokenDecoderFactory = (clientRegistration) -> { + String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri(); + if (!StringUtils.hasText(jwkSetUri)) { + OAuth2Error oauth2Error = new OAuth2Error("missing_signature_verifier", + "Failed to find a Signature Verifier for Client Registration: '" + + clientRegistration.getRegistrationId() + + "'. Check to ensure you have configured the JwkSet URI.", + null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + JOSEObjectTypeVerifier typeVerifier = new DefaultJOSEObjectTypeVerifier<>(null, + JOSEObjectType.JWT, new JOSEObjectType("logout+jwt")); + NimbusJwtDecoder decoder = NimbusJwtDecoder.withJwkSetUri(jwkSetUri) + .jwtProcessorCustomizer((processor) -> processor.setJWSTypeVerifier(typeVerifier)) + .build(); + decoder.setJwtValidator(jwtValidator.apply(clientRegistration)); + decoder.setClaimSetConverter( + new ClaimTypeConverter(OidcIdTokenDecoderFactory.createDefaultClaimTypeConverters())); + return decoder; + }; } /** diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurerTests.java index 934c8bede91..eccf675f3df 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurerTests.java @@ -73,6 +73,8 @@ import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens; import org.springframework.security.oauth2.core.oidc.user.OidcUser; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JwsHeader; import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.JwtEncoderParameters; @@ -513,8 +515,9 @@ String jwks() { String logoutToken(@AuthenticationPrincipal OidcUser user) { OidcLogoutToken token = TestOidcLogoutTokens.withUser(user) .audience(List.of(this.registration.getClientId())).build(); - JwtEncoderParameters parameters = JwtEncoderParameters - .from(JwtClaimsSet.builder().claims((claims) -> claims.putAll(token.getClaims())).build()); + JwsHeader header = JwsHeader.with(SignatureAlgorithm.RS256).type("logout+jwt").build(); + JwtClaimsSet claims = JwtClaimsSet.builder().claims((c) -> c.putAll(token.getClaims())).build(); + JwtEncoderParameters parameters = JwtEncoderParameters.from(header, claims); return this.encoder.encode(parameters).getTokenValue(); } @@ -523,8 +526,9 @@ String logoutTokenAll(@AuthenticationPrincipal OidcUser user) { OidcLogoutToken token = TestOidcLogoutTokens.withUser(user) .audience(List.of(this.registration.getClientId())) .claims((claims) -> claims.remove(LogoutTokenClaimNames.SID)).build(); - JwtEncoderParameters parameters = JwtEncoderParameters - .from(JwtClaimsSet.builder().claims((claims) -> claims.putAll(token.getClaims())).build()); + JwsHeader header = JwsHeader.with(SignatureAlgorithm.RS256).type("JWT").build(); + JwtClaimsSet claims = JwtClaimsSet.builder().claims((c) -> c.putAll(token.getClaims())).build(); + JwtEncoderParameters parameters = JwtEncoderParameters.from(header, claims); return this.encoder.encode(parameters).getTokenValue(); } }