diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java index b7c80f17b51..68bba82d5c1 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java @@ -80,7 +80,7 @@ public final class NimbusJwtClientAuthenticationParametersConverter jwkResolver; - private final Map jwsEncoders = new ConcurrentHashMap<>(); + private final Map jwsEncoders = new ConcurrentHashMap<>(); /** * Constructs a {@code NimbusJwtClientAuthenticationParametersConverter} using the @@ -140,12 +140,16 @@ public MultiValueMap convert(T authorizationGrantRequest) { JoseHeader joseHeader = headersBuilder.build(); JwtClaimsSet jwtClaimsSet = claimsBuilder.build(); - NimbusJwsEncoder jwsEncoder = this.jwsEncoders.computeIfAbsent(clientRegistration.getRegistrationId(), - (clientRegistrationId) -> { + JwsEncoderHolder jwsEncoderHolder = this.jwsEncoders.compute(clientRegistration.getRegistrationId(), + (clientRegistrationId, currentJwsEncoderHolder) -> { + if (currentJwsEncoderHolder != null && currentJwsEncoderHolder.getJwk().equals(jwk)) { + return currentJwsEncoderHolder; + } JWKSource jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk)); - return new NimbusJwsEncoder(jwkSource); + return new JwsEncoderHolder(new NimbusJwsEncoder(jwkSource), jwk); }); + NimbusJwsEncoder jwsEncoder = jwsEncoderHolder.getJwsEncoder(); Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet); MultiValueMap parameters = new LinkedMultiValueMap<>(); @@ -180,4 +184,25 @@ else if (KeyType.OCT.equals(jwk.getKeyType())) { return jwsAlgorithm; } + private static final class JwsEncoderHolder { + + private final NimbusJwsEncoder jwsEncoder; + + private final JWK jwk; + + private JwsEncoderHolder(NimbusJwsEncoder jwsEncoder, JWK jwk) { + this.jwsEncoder = jwsEncoder; + this.jwk = jwk; + } + + private NimbusJwsEncoder getJwsEncoder() { + return this.jwsEncoder; + } + + private JWK getJwk() { + return this.jwk; + } + + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java index eabeef00c2f..ae8c41a316b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java @@ -16,7 +16,12 @@ package org.springframework.security.oauth2.client.endpoint; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.interfaces.RSAPrivateKey; +import java.security.interfaces.RSAPublicKey; import java.util.Collections; +import java.util.UUID; import java.util.function.Function; import com.nimbusds.jose.jwk.JWK; @@ -42,6 +47,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verifyNoInteractions; @@ -172,4 +178,54 @@ public void convertWhenClientSecretJwtClientAuthenticationMethodThenCustomized() assertThat(jws.getExpiresAt()).isNotNull(); } + // gh-9814 + @Test + public void convertWhenClientKeyChangesThenNewKeyUsed() throws Exception { + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials() + .clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT) + .build(); + // @formatter:on + + RSAKey rsaJwk1 = TestJwks.DEFAULT_RSA_JWK; + given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk1); + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( + clientRegistration); + MultiValueMap parameters = this.converter.convert(clientCredentialsGrantRequest); + + String encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION); + NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk1.toRSAPublicKey()).build(); + jwtDecoder.decode(encodedJws); + + RSAKey rsaJwk2 = generateRsaJwk(); + given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk2); + + parameters = this.converter.convert(clientCredentialsGrantRequest); + + encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION); + jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk2.toRSAPublicKey()).build(); + jwtDecoder.decode(encodedJws); + } + + private static RSAKey generateRsaJwk() { + KeyPair keyPair; + try { + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + keyPair = keyPairGenerator.generateKeyPair(); + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + RSAPublicKey publicKey = (RSAPublicKey) keyPair.getPublic(); + RSAPrivateKey privateKey = (RSAPrivateKey) keyPair.getPrivate(); + // @formatter:off + return new RSAKey.Builder(publicKey) + .privateKey(privateKey) + .keyID(UUID.randomUUID().toString()) + .build(); + // @formatter:on + } + }