Skip to content

Commit

Permalink
Support customizing Jwt claims and headers
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrandja committed Jan 8, 2021
1 parent 4fe2d52 commit 8368d84
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -86,6 +87,9 @@ public final class NimbusJwsEncoder implements JwtEncoder {

private final Function<JoseHeader, JWK> jwkSelector;

private BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer = (headers, claims) -> {
};

/**
* Constructs a {@code NimbusJwsEncoder} using the provided parameters.
* @param jwkSelector the {@code com.nimbusds.jose.jwk.JWK} selector
Expand All @@ -95,6 +99,17 @@ public NimbusJwsEncoder(Function<JoseHeader, JWK> jwkSelector) {
this.jwkSelector = jwkSelector;
}

/**
* Sets the {@link Jwt} customizer to be provided the {@link JoseHeader.Builder} and
* {@link JwtClaimsSet.Builder} allowing for further customizations.
* @param jwtCustomizer the {@link Jwt} customizer to be provided the
* {@link JoseHeader.Builder} and {@link JwtClaimsSet.Builder}
*/
public void setJwtCustomizer(BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer) {
Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
this.jwtCustomizer = jwtCustomizer;
}

@Override
public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException {
Assert.notNull(headers, "headers cannot be null");
Expand All @@ -121,18 +136,19 @@ else if (!StringUtils.hasText(jwk.getKeyID())) {
});

// @formatter:off
headers = JoseHeader.from(headers)
JoseHeader.Builder headersBuilder = JoseHeader.from(headers)
.type(JOSEObjectType.JWT.getType())
.keyId(jwk.getKeyID())
.build();
.keyId(jwk.getKeyID());
JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.from(claims)
.id(UUID.randomUUID().toString());
// @formatter:on
JWSHeader jwsHeader = JWS_HEADER_CONVERTER.convert(headers);

// @formatter:off
claims = JwtClaimsSet.from(claims)
.id(UUID.randomUUID().toString())
.build();
// @formatter:on
this.jwtCustomizer.accept(headersBuilder, claimsBuilder);

headers = headersBuilder.build();
claims = claimsBuilder.build();

JWSHeader jwsHeader = JWS_HEADER_CONVERTER.convert(headers);
JWTClaimsSet jwtClaimsSet = JWT_CLAIMS_SET_CONVERTER.convert(claims);

SignedJWT signedJwt = new SignedJWT(jwsHeader, jwtClaimsSet);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Supplier;

Expand Down Expand Up @@ -47,6 +48,7 @@
import static org.mockito.BDDMockito.willAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;

/**
* Tests for {@link NimbusJwsEncoder}.
Expand All @@ -71,6 +73,12 @@ public void constructorWhenJwkSelectorNullThenThrowIllegalArgumentException() {
.withMessage("jwkSelector cannot be null");
}

@Test
public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.jwsEncoder.setJwtCustomizer(null))
.withMessage("jwtCustomizer cannot be null");
}

@Test
public void encodeWhenHeadersNullThenThrowIllegalArgumentException() {
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
Expand Down Expand Up @@ -160,6 +168,28 @@ public void encodeWhenSuccessThenDecodes() throws Exception {
jwtDecoder.decode(encodedJws.getTokenValue());
}

@Test
public void encodeWhenCustomizerSetThenCalled() {
// @formatter:off
RSAKey rsaJwk = new RSAKey.Builder(TestKeys.DEFAULT_PUBLIC_KEY)
.privateKey(TestKeys.DEFAULT_PRIVATE_KEY)
.keyID("keyId")
.build();
// @formatter:on

given(this.jwkSelector.apply(any())).willReturn(rsaJwk);

BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer = mock(BiConsumer.class);
this.jwsEncoder.setJwtCustomizer(jwtCustomizer);

JoseHeader joseHeader = TestJoseHeaders.joseHeader().build();
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();

this.jwsEncoder.encode(joseHeader, jwtClaimsSet);

verify(jwtCustomizer).accept(any(JoseHeader.Builder.class), any(JwtClaimsSet.Builder.class));
}

@Test
public void defaultJwkSelectorConstructorWhenJwkSetProviderNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> new NimbusJwsEncoder.DefaultJwkSelector(null))
Expand Down

0 comments on commit 8368d84

Please sign in to comment.