diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java index 72d799ece53..6ea9aed12b6 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java @@ -25,6 +25,7 @@ import java.util.Set; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.BiConsumer; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -86,6 +87,9 @@ public final class NimbusJwsEncoder implements JwtEncoder { private final Function jwkSelector; + private BiConsumer jwtCustomizer = (headers, claims) -> { + }; + /** * Constructs a {@code NimbusJwsEncoder} using the provided parameters. * @param jwkSelector the {@code com.nimbusds.jose.jwk.JWK} selector @@ -95,6 +99,17 @@ public NimbusJwsEncoder(Function jwkSelector) { this.jwkSelector = jwkSelector; } + /** + * Sets the {@link Jwt} customizer to be provided the {@link JoseHeader.Builder} and + * {@link JwtClaimsSet.Builder} allowing for further customizations. + * @param jwtCustomizer the {@link Jwt} customizer to be provided the + * {@link JoseHeader.Builder} and {@link JwtClaimsSet.Builder} + */ + public void setJwtCustomizer(BiConsumer jwtCustomizer) { + Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null"); + this.jwtCustomizer = jwtCustomizer; + } + @Override public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException { Assert.notNull(headers, "headers cannot be null"); @@ -121,18 +136,19 @@ else if (!StringUtils.hasText(jwk.getKeyID())) { }); // @formatter:off - headers = JoseHeader.from(headers) + JoseHeader.Builder headersBuilder = JoseHeader.from(headers) .type(JOSEObjectType.JWT.getType()) - .keyId(jwk.getKeyID()) - .build(); + .keyId(jwk.getKeyID()); + JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.from(claims) + .id(UUID.randomUUID().toString()); // @formatter:on - JWSHeader jwsHeader = JWS_HEADER_CONVERTER.convert(headers); - // @formatter:off - claims = JwtClaimsSet.from(claims) - .id(UUID.randomUUID().toString()) - .build(); - // @formatter:on + this.jwtCustomizer.accept(headersBuilder, claimsBuilder); + + headers = headersBuilder.build(); + claims = claimsBuilder.build(); + + JWSHeader jwsHeader = JWS_HEADER_CONVERTER.convert(headers); JWTClaimsSet jwtClaimsSet = JWT_CLAIMS_SET_CONVERTER.convert(claims); SignedJWT signedJwt = new SignedJWT(jwsHeader, jwtClaimsSet); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java index 666b192e161..9eda10a0101 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java @@ -20,6 +20,7 @@ import java.util.Arrays; import java.util.LinkedHashSet; import java.util.Set; +import java.util.function.BiConsumer; import java.util.function.Function; import java.util.function.Supplier; @@ -47,6 +48,7 @@ import static org.mockito.BDDMockito.willAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; /** * Tests for {@link NimbusJwsEncoder}. @@ -71,6 +73,12 @@ public void constructorWhenJwkSelectorNullThenThrowIllegalArgumentException() { .withMessage("jwkSelector cannot be null"); } + @Test + public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.jwsEncoder.setJwtCustomizer(null)) + .withMessage("jwtCustomizer cannot be null"); + } + @Test public void encodeWhenHeadersNullThenThrowIllegalArgumentException() { JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); @@ -160,6 +168,28 @@ public void encodeWhenSuccessThenDecodes() throws Exception { jwtDecoder.decode(encodedJws.getTokenValue()); } + @Test + public void encodeWhenCustomizerSetThenCalled() { + // @formatter:off + RSAKey rsaJwk = new RSAKey.Builder(TestKeys.DEFAULT_PUBLIC_KEY) + .privateKey(TestKeys.DEFAULT_PRIVATE_KEY) + .keyID("keyId") + .build(); + // @formatter:on + + given(this.jwkSelector.apply(any())).willReturn(rsaJwk); + + BiConsumer jwtCustomizer = mock(BiConsumer.class); + this.jwsEncoder.setJwtCustomizer(jwtCustomizer); + + JoseHeader joseHeader = TestJoseHeaders.joseHeader().build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + this.jwsEncoder.encode(joseHeader, jwtClaimsSet); + + verify(jwtCustomizer).accept(any(JoseHeader.Builder.class), any(JwtClaimsSet.Builder.class)); + } + @Test public void defaultJwkSelectorConstructorWhenJwkSetProviderNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException().isThrownBy(() -> new NimbusJwsEncoder.DefaultJwkSelector(null))