Skip to content

Commit

Permalink
Jwt client authentication converter detects new key
Browse files Browse the repository at this point in the history
Closes gh-9814
  • Loading branch information
jgrandja committed Jun 16, 2021
1 parent 700bda6 commit 6fbd038
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public final class NimbusJwtClientAuthenticationParametersConverter<T extends Ab

private final Function<ClientRegistration, JWK> jwkResolver;

private final Map<String, NimbusJwsEncoder> jwsEncoders = new ConcurrentHashMap<>();
private final Map<String, JwsEncoderHolder> jwsEncoders = new ConcurrentHashMap<>();

/**
* Constructs a {@code NimbusJwtClientAuthenticationParametersConverter} using the
Expand Down Expand Up @@ -140,12 +140,16 @@ public MultiValueMap<String, String> convert(T authorizationGrantRequest) {
JoseHeader joseHeader = headersBuilder.build();
JwtClaimsSet jwtClaimsSet = claimsBuilder.build();

NimbusJwsEncoder jwsEncoder = this.jwsEncoders.computeIfAbsent(clientRegistration.getRegistrationId(),
(clientRegistrationId) -> {
JwsEncoderHolder jwsEncoderHolder = this.jwsEncoders.compute(clientRegistration.getRegistrationId(),
(clientRegistrationId, currentJwsEncoderHolder) -> {
if (currentJwsEncoderHolder != null && currentJwsEncoderHolder.getJwk().equals(jwk)) {
return currentJwsEncoderHolder;
}
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
return new NimbusJwsEncoder(jwkSource);
return new JwsEncoderHolder(new NimbusJwsEncoder(jwkSource), jwk);
});

NimbusJwsEncoder jwsEncoder = jwsEncoderHolder.getJwsEncoder();
Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet);

MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
Expand Down Expand Up @@ -180,4 +184,25 @@ else if (KeyType.OCT.equals(jwk.getKeyType())) {
return jwsAlgorithm;
}

private static final class JwsEncoderHolder {

private final NimbusJwsEncoder jwsEncoder;

private final JWK jwk;

private JwsEncoderHolder(NimbusJwsEncoder jwsEncoder, JWK jwk) {
this.jwsEncoder = jwsEncoder;
this.jwk = jwk;
}

private NimbusJwsEncoder getJwsEncoder() {
return this.jwsEncoder;
}

private JWK getJwk() {
return this.jwk;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

package org.springframework.security.oauth2.client.endpoint;

import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.util.Collections;
import java.util.UUID;
import java.util.function.Function;

import com.nimbusds.jose.jwk.JWK;
Expand All @@ -42,6 +47,7 @@
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verifyNoInteractions;
Expand Down Expand Up @@ -172,4 +178,54 @@ public void convertWhenClientSecretJwtClientAuthenticationMethodThenCustomized()
assertThat(jws.getExpiresAt()).isNotNull();
}

// gh-9814
@Test
public void convertWhenClientKeyChangesThenNewKeyUsed() throws Exception {
// @formatter:off
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials()
.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
.build();
// @formatter:on

RSAKey rsaJwk1 = TestJwks.DEFAULT_RSA_JWK;
given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk1);

OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
clientRegistration);
MultiValueMap<String, String> parameters = this.converter.convert(clientCredentialsGrantRequest);

String encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk1.toRSAPublicKey()).build();
jwtDecoder.decode(encodedJws);

RSAKey rsaJwk2 = generateRsaJwk();
given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk2);

parameters = this.converter.convert(clientCredentialsGrantRequest);

encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk2.toRSAPublicKey()).build();
jwtDecoder.decode(encodedJws);
}

private static RSAKey generateRsaJwk() {
KeyPair keyPair;
try {
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
keyPairGenerator.initialize(2048);
keyPair = keyPairGenerator.generateKeyPair();
}
catch (Exception ex) {
throw new IllegalStateException(ex);
}
RSAPublicKey publicKey = (RSAPublicKey) keyPair.getPublic();
RSAPrivateKey privateKey = (RSAPrivateKey) keyPair.getPrivate();
// @formatter:off
return new RSAKey.Builder(publicKey)
.privateKey(privateKey)
.keyID(UUID.randomUUID().toString())
.build();
// @formatter:on
}

}

0 comments on commit 6fbd038

Please sign in to comment.