From 7f37d18987b6e09f6b2f6e95296fdad1dbd56483 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 16 Nov 2020 19:56:45 -0500 Subject: [PATCH 01/20] Introduce JwtEncoder Closes gh-9208 --- ...ientAuthenticationParametersConverter.java | 12 +- ...uthenticationParametersConverterTests.java | 1 + .../security/oauth2/jwt}/JoseHeader.java | 84 ++- .../security/oauth2/jwt}/JoseHeaderNames.java | 43 +- .../security/oauth2/jwt}/JwtClaimsSet.java | 52 +- .../security/oauth2/jwt/JwtEncoder.java | 59 ++ .../oauth2/jwt}/JwtEncodingException.java | 27 +- .../oauth2/jwt}/NimbusJwsEncoder.java | 37 +- .../security/oauth2/jwt}/JoseHeaderTests.java | 17 +- .../oauth2/jwt}/JwtClaimsSetTests.java | 17 +- .../oauth2/jwt/NimbusJweEncoderTests.java | 543 ++++++++++++++++++ .../oauth2/jwt}/NimbusJwsEncoderTests.java | 19 +- .../security/oauth2/jwt}/TestJoseHeaders.java | 23 +- .../oauth2/jwt}/TestJwtClaimsSets.java | 21 +- 14 files changed, 702 insertions(+), 253 deletions(-) rename oauth2/{oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint => oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt}/JoseHeader.java (84%) rename oauth2/{oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint => oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt}/JoseHeaderNames.java (75%) rename oauth2/{oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint => oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt}/JwtClaimsSet.java (77%) create mode 100644 oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java rename oauth2/{oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint => oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt}/JwtEncodingException.java (55%) rename oauth2/{oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint => oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt}/NimbusJwsEncoder.java (89%) rename oauth2/{oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint => oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt}/JoseHeaderTests.java (88%) rename oauth2/{oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint => oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt}/JwtClaimsSetTests.java (85%) create mode 100644 oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java rename oauth2/{oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint => oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt}/NimbusJwsEncoderTests.java (94%) rename oauth2/{oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint => oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt}/TestJoseHeaders.java (67%) rename oauth2/{oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint => oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt}/TestJwtClaimsSets.java (63%) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java index 68bba82d5c1..65b5ffcfb79 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java @@ -40,7 +40,11 @@ import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JoseHeader; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.jwt.NimbusJwsEncoder; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -149,7 +153,7 @@ public MultiValueMap convert(T authorizationGrantRequest) { return new JwsEncoderHolder(new NimbusJwsEncoder(jwkSource), jwk); }); - NimbusJwsEncoder jwsEncoder = jwsEncoderHolder.getJwsEncoder(); + JwtEncoder jwsEncoder = jwsEncoderHolder.getJwsEncoder(); Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet); MultiValueMap parameters = new LinkedMultiValueMap<>(); @@ -186,16 +190,16 @@ else if (KeyType.OCT.equals(jwk.getKeyType())) { private static final class JwsEncoderHolder { - private final NimbusJwsEncoder jwsEncoder; + private final JwtEncoder jwsEncoder; private final JWK jwk; - private JwsEncoderHolder(NimbusJwsEncoder jwsEncoder, JWK jwk) { + private JwsEncoderHolder(JwtEncoder jwsEncoder, JWK jwk) { this.jwsEncoder = jwsEncoder; this.jwk = jwk; } - private NimbusJwsEncoder getJwsEncoder() { + private JwtEncoder getJwsEncoder() { return this.jwsEncoder; } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java index 42fb1dc9f84..d916424582b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java @@ -38,6 +38,7 @@ import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JoseHeaderNames; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtClaimNames; import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeader.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java similarity index 84% rename from oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeader.java rename to oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java index b148e670a3e..a386d0feb54 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeader.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.security.oauth2.client.endpoint; +package org.springframework.security.oauth2.jwt; import java.net.URL; import java.util.Collections; @@ -26,24 +26,8 @@ import org.springframework.security.oauth2.core.converter.ClaimConversionService; import org.springframework.security.oauth2.jose.JwaAlgorithm; -import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.util.Assert; -/* - * NOTE: - * This originated in gh-9208 (JwtEncoder), - * which is required to realize the feature in gh-8175 (JWT Client Authentication). - * However, we decided not to merge gh-9208 as part of the 5.5.0 release - * and instead packaged it up privately with the gh-8175 feature. - * We MAY merge gh-9208 in a later release but that is yet to be determined. - * - * gh-9208 Introduce JwtEncoder - * https://github.com/spring-projects/spring-security/pull/9208 - * - * gh-8175 Support JWT for Client Authentication - * https://github.com/spring-projects/spring-security/issues/8175 - */ - /** * The JOSE header is a JSON object representing the header parameters of a JSON Web * Token, whether the JWT is a JWS or JWE, that describe the cryptographic operations @@ -51,7 +35,7 @@ * * @author Anoop Garlapati * @author Joe Grandja - * @since 5.5 + * @since 5.6 * @see Jwt * @see JWT JOSE * Header @@ -60,7 +44,7 @@ * @see JWE JOSE * Header */ -final class JoseHeader { +public final class JoseHeader { private final Map headers; @@ -74,7 +58,7 @@ private JoseHeader(Map headers) { * @return the {@link JwaAlgorithm} */ @SuppressWarnings("unchecked") - T getAlgorithm() { + public T getAlgorithm() { return (T) getHeader(JoseHeaderNames.ALG); } @@ -84,7 +68,7 @@ T getAlgorithm() { * the JWE. * @return the JWK Set URL */ - URL getJwkSetUrl() { + public URL getJwkSetUrl() { return getHeader(JoseHeaderNames.JKU); } @@ -93,7 +77,7 @@ URL getJwkSetUrl() { * to digitally sign the JWS or encrypt the JWE. * @return the JSON Web Key */ - Map getJwk() { + public Map getJwk() { return getHeader(JoseHeaderNames.JWK); } @@ -102,7 +86,7 @@ Map getJwk() { * or JWE. * @return the key ID */ - String getKeyId() { + public String getKeyId() { return getHeader(JoseHeaderNames.KID); } @@ -112,7 +96,7 @@ String getKeyId() { * the JWS or encrypt the JWE. * @return the X.509 URL */ - URL getX509Url() { + public URL getX509Url() { return getHeader(JoseHeaderNames.X5U); } @@ -124,7 +108,7 @@ URL getX509Url() { * {@code List} is a Base64-encoded DER PKIX certificate value. * @return the X.509 certificate chain */ - List getX509CertificateChain() { + public List getX509CertificateChain() { return getHeader(JoseHeaderNames.X5C); } @@ -134,7 +118,7 @@ List getX509CertificateChain() { * corresponding to the key used to digitally sign the JWS or encrypt the JWE. * @return the X.509 certificate SHA-1 thumbprint */ - String getX509SHA1Thumbprint() { + public String getX509SHA1Thumbprint() { return getHeader(JoseHeaderNames.X5T); } @@ -144,7 +128,7 @@ String getX509SHA1Thumbprint() { * corresponding to the key used to digitally sign the JWS or encrypt the JWE. * @return the X.509 certificate SHA-256 thumbprint */ - String getX509SHA256Thumbprint() { + public String getX509SHA256Thumbprint() { return getHeader(JoseHeaderNames.X5T_S256); } @@ -152,7 +136,7 @@ String getX509SHA256Thumbprint() { * Returns the type header that declares the media type of the JWS/JWE. * @return the type header */ - String getType() { + public String getType() { return getHeader(JoseHeaderNames.TYP); } @@ -161,7 +145,7 @@ String getType() { * (the payload). * @return the content type header */ - String getContentType() { + public String getContentType() { return getHeader(JoseHeaderNames.CTY); } @@ -170,7 +154,7 @@ String getContentType() { * specifications are being used that MUST be understood and processed. * @return the critical headers */ - Set getCritical() { + public Set getCritical() { return getHeader(JoseHeaderNames.CRIT); } @@ -178,7 +162,7 @@ Set getCritical() { * Returns the headers. * @return the headers */ - Map getHeaders() { + public Map getHeaders() { return this.headers; } @@ -189,7 +173,7 @@ Map getHeaders() { * @return the header value */ @SuppressWarnings("unchecked") - T getHeader(String name) { + public T getHeader(String name) { Assert.hasText(name, "name cannot be empty"); return (T) getHeaders().get(name); } @@ -199,7 +183,7 @@ T getHeader(String name) { * @param jwaAlgorithm the {@link JwaAlgorithm} * @return the {@link Builder} */ - static Builder withAlgorithm(JwaAlgorithm jwaAlgorithm) { + public static Builder withAlgorithm(JwaAlgorithm jwaAlgorithm) { return new Builder(jwaAlgorithm); } @@ -208,16 +192,16 @@ static Builder withAlgorithm(JwaAlgorithm jwaAlgorithm) { * @param headers the headers * @return the {@link Builder} */ - static Builder from(JoseHeader headers) { + public static Builder from(JoseHeader headers) { return new Builder(headers); } /** * A builder for {@link JoseHeader}. */ - static final class Builder { + public static final class Builder { - final Map headers = new HashMap<>(); + private final Map headers = new HashMap<>(); private Builder(JwaAlgorithm jwaAlgorithm) { algorithm(jwaAlgorithm); @@ -234,7 +218,7 @@ private Builder(JoseHeader headers) { * @param jwaAlgorithm the {@link JwaAlgorithm} * @return the {@link Builder} */ - Builder algorithm(JwaAlgorithm jwaAlgorithm) { + public Builder algorithm(JwaAlgorithm jwaAlgorithm) { Assert.notNull(jwaAlgorithm, "jwaAlgorithm cannot be null"); return header(JoseHeaderNames.ALG, jwaAlgorithm); } @@ -246,7 +230,7 @@ Builder algorithm(JwaAlgorithm jwaAlgorithm) { * @param jwkSetUrl the JWK Set URL * @return the {@link Builder} */ - Builder jwkSetUrl(String jwkSetUrl) { + public Builder jwkSetUrl(String jwkSetUrl) { return header(JoseHeaderNames.JKU, convertAsURL(JoseHeaderNames.JKU, jwkSetUrl)); } @@ -256,7 +240,7 @@ Builder jwkSetUrl(String jwkSetUrl) { * @param jwk the JSON Web Key * @return the {@link Builder} */ - Builder jwk(Map jwk) { + public Builder jwk(Map jwk) { return header(JoseHeaderNames.JWK, jwk); } @@ -266,7 +250,7 @@ Builder jwk(Map jwk) { * @param keyId the key ID * @return the {@link Builder} */ - Builder keyId(String keyId) { + public Builder keyId(String keyId) { return header(JoseHeaderNames.KID, keyId); } @@ -277,7 +261,7 @@ Builder keyId(String keyId) { * @param x509Url the X.509 URL * @return the {@link Builder} */ - Builder x509Url(String x509Url) { + public Builder x509Url(String x509Url) { return header(JoseHeaderNames.X5U, convertAsURL(JoseHeaderNames.X5U, x509Url)); } @@ -290,7 +274,7 @@ Builder x509Url(String x509Url) { * @param x509CertificateChain the X.509 certificate chain * @return the {@link Builder} */ - Builder x509CertificateChain(List x509CertificateChain) { + public Builder x509CertificateChain(List x509CertificateChain) { return header(JoseHeaderNames.X5C, x509CertificateChain); } @@ -301,7 +285,7 @@ Builder x509CertificateChain(List x509CertificateChain) { * @param x509SHA1Thumbprint the X.509 certificate SHA-1 thumbprint * @return the {@link Builder} */ - Builder x509SHA1Thumbprint(String x509SHA1Thumbprint) { + public Builder x509SHA1Thumbprint(String x509SHA1Thumbprint) { return header(JoseHeaderNames.X5T, x509SHA1Thumbprint); } @@ -312,7 +296,7 @@ Builder x509SHA1Thumbprint(String x509SHA1Thumbprint) { * @param x509SHA256Thumbprint the X.509 certificate SHA-256 thumbprint * @return the {@link Builder} */ - Builder x509SHA256Thumbprint(String x509SHA256Thumbprint) { + public Builder x509SHA256Thumbprint(String x509SHA256Thumbprint) { return header(JoseHeaderNames.X5T_S256, x509SHA256Thumbprint); } @@ -321,7 +305,7 @@ Builder x509SHA256Thumbprint(String x509SHA256Thumbprint) { * @param type the type header * @return the {@link Builder} */ - Builder type(String type) { + public Builder type(String type) { return header(JoseHeaderNames.TYP, type); } @@ -331,7 +315,7 @@ Builder type(String type) { * @param contentType the content type header * @return the {@link Builder} */ - Builder contentType(String contentType) { + public Builder contentType(String contentType) { return header(JoseHeaderNames.CTY, contentType); } @@ -341,7 +325,7 @@ Builder contentType(String contentType) { * @param headerNames the critical header names * @return the {@link Builder} */ - Builder critical(Set headerNames) { + public Builder critical(Set headerNames) { return header(JoseHeaderNames.CRIT, headerNames); } @@ -351,7 +335,7 @@ Builder critical(Set headerNames) { * @param value the header value * @return the {@link Builder} */ - Builder header(String name, Object value) { + public Builder header(String name, Object value) { Assert.hasText(name, "name cannot be empty"); Assert.notNull(value, "value cannot be null"); this.headers.put(name, value); @@ -364,7 +348,7 @@ Builder header(String name, Object value) { * @param headersConsumer a {@code Consumer} of the headers * @return the {@link Builder} */ - Builder headers(Consumer> headersConsumer) { + public Builder headers(Consumer> headersConsumer) { headersConsumer.accept(this.headers); return this; } @@ -373,7 +357,7 @@ Builder headers(Consumer> headersConsumer) { * Builds a new {@link JoseHeader}. * @return a {@link JoseHeader} */ - JoseHeader build() { + public JoseHeader build() { Assert.notEmpty(this.headers, "headers cannot be empty"); return new JoseHeader(this.headers); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeaderNames.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeaderNames.java similarity index 75% rename from oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeaderNames.java rename to oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeaderNames.java index 41abd7eeba0..9e5f04f8673 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JoseHeaderNames.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeaderNames.java @@ -14,22 +14,7 @@ * limitations under the License. */ -package org.springframework.security.oauth2.client.endpoint; - -/* - * NOTE: - * This originated in gh-9208 (JwtEncoder), - * which is required to realize the feature in gh-8175 (JWT Client Authentication). - * However, we decided not to merge gh-9208 as part of the 5.5.0 release - * and instead packaged it up privately with the gh-8175 feature. - * We MAY merge gh-9208 in a later release but that is yet to be determined. - * - * gh-9208 Introduce JwtEncoder - * https://github.com/spring-projects/spring-security/pull/9208 - * - * gh-8175 Support JWT for Client Authentication - * https://github.com/spring-projects/spring-security/issues/8175 - */ +package org.springframework.security.oauth2.jwt; /** * The Registered Header Parameter Names defined by the JSON Web Token (JWT), JSON Web @@ -38,7 +23,7 @@ * * @author Anoop Garlapati * @author Joe Grandja - * @since 5.5 + * @since 5.6 * @see JoseHeader * @see JWT JOSE * Header @@ -47,53 +32,53 @@ * @see JWE JOSE * Header */ -final class JoseHeaderNames { +public final class JoseHeaderNames { /** * {@code alg} - the algorithm header identifies the cryptographic algorithm used to * secure a JWS or JWE */ - static final String ALG = "alg"; + public static final String ALG = "alg"; /** * {@code jku} - the JWK Set URL header is a URI that refers to a resource for a set * of JSON-encoded public keys, one of which corresponds to the key used to digitally * sign a JWS or encrypt a JWE */ - static final String JKU = "jku"; + public static final String JKU = "jku"; /** * {@code jwk} - the JSON Web Key header is the public key that corresponds to the key * used to digitally sign a JWS or encrypt a JWE */ - static final String JWK = "jwk"; + public static final String JWK = "jwk"; /** * {@code kid} - the key ID header is a hint indicating which key was used to secure a * JWS or JWE */ - static final String KID = "kid"; + public static final String KID = "kid"; /** * {@code x5u} - the X.509 URL header is a URI that refers to a resource for the X.509 * public key certificate or certificate chain corresponding to the key used to * digitally sign a JWS or encrypt a JWE */ - static final String X5U = "x5u"; + public static final String X5U = "x5u"; /** * {@code x5c} - the X.509 certificate chain header contains the X.509 public key * certificate or certificate chain corresponding to the key used to digitally sign a * JWS or encrypt a JWE */ - static final String X5C = "x5c"; + public static final String X5C = "x5c"; /** * {@code x5t} - the X.509 certificate SHA-1 thumbprint header is a base64url-encoded * SHA-1 thumbprint (a.k.a. digest) of the DER encoding of the X.509 certificate * corresponding to the key used to digitally sign a JWS or encrypt a JWE */ - static final String X5T = "x5t"; + public static final String X5T = "x5t"; /** * {@code x5t#S256} - the X.509 certificate SHA-256 thumbprint header is a @@ -101,25 +86,25 @@ final class JoseHeaderNames { * X.509 certificate corresponding to the key used to digitally sign a JWS or encrypt * a JWE */ - static final String X5T_S256 = "x5t#S256"; + public static final String X5T_S256 = "x5t#S256"; /** * {@code typ} - the type header is used by JWS/JWE applications to declare the media * type of a JWS/JWE */ - static final String TYP = "typ"; + public static final String TYP = "typ"; /** * {@code cty} - the content type header is used by JWS/JWE applications to declare * the media type of the secured content (the payload) */ - static final String CTY = "cty"; + public static final String CTY = "cty"; /** * {@code crit} - the critical header indicates that extensions to the JWS/JWE/JWA * specifications are being used that MUST be understood and processed */ - static final String CRIT = "crit"; + public static final String CRIT = "crit"; private JoseHeaderNames() { } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSet.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimsSet.java similarity index 77% rename from oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSet.java rename to oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimsSet.java index d383c04b661..eb70bf60f65 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSet.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimsSet.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.security.oauth2.client.endpoint; +package org.springframework.security.oauth2.jwt; import java.net.URL; import java.time.Instant; @@ -25,39 +25,21 @@ import java.util.function.Consumer; import org.springframework.security.oauth2.core.converter.ClaimConversionService; -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtClaimAccessor; -import org.springframework.security.oauth2.jwt.JwtClaimNames; import org.springframework.util.Assert; -/* - * NOTE: - * This originated in gh-9208 (JwtEncoder), - * which is required to realize the feature in gh-8175 (JWT Client Authentication). - * However, we decided not to merge gh-9208 as part of the 5.5.0 release - * and instead packaged it up privately with the gh-8175 feature. - * We MAY merge gh-9208 in a later release but that is yet to be determined. - * - * gh-9208 Introduce JwtEncoder - * https://github.com/spring-projects/spring-security/pull/9208 - * - * gh-8175 Support JWT for Client Authentication - * https://github.com/spring-projects/spring-security/issues/8175 - */ - /** * The {@link Jwt JWT} Claims Set is a JSON object representing the claims conveyed by a * JSON Web Token. * * @author Anoop Garlapati * @author Joe Grandja - * @since 5.5 + * @since 5.6 * @see Jwt * @see JwtClaimAccessor * @see JWT Claims * Set */ -final class JwtClaimsSet implements JwtClaimAccessor { +public final class JwtClaimsSet implements JwtClaimAccessor { private final Map claims; @@ -74,7 +56,7 @@ public Map getClaims() { * Returns a new {@link Builder}. * @return the {@link Builder} */ - static Builder builder() { + public static Builder builder() { return new Builder(); } @@ -83,16 +65,16 @@ static Builder builder() { * @param claims a JWT claims set * @return the {@link Builder} */ - static Builder from(JwtClaimsSet claims) { + public static Builder from(JwtClaimsSet claims) { return new Builder(claims); } /** * A builder for {@link JwtClaimsSet}. */ - static final class Builder { + public static final class Builder { - final Map claims = new HashMap<>(); + private final Map claims = new HashMap<>(); private Builder() { } @@ -108,7 +90,7 @@ private Builder(JwtClaimsSet claims) { * @param issuer the issuer identifier * @return the {@link Builder} */ - Builder issuer(String issuer) { + public Builder issuer(String issuer) { return claim(JwtClaimNames.ISS, issuer); } @@ -118,7 +100,7 @@ Builder issuer(String issuer) { * @param subject the subject identifier * @return the {@link Builder} */ - Builder subject(String subject) { + public Builder subject(String subject) { return claim(JwtClaimNames.SUB, subject); } @@ -128,7 +110,7 @@ Builder subject(String subject) { * @param audience the audience that this JWT is intended for * @return the {@link Builder} */ - Builder audience(List audience) { + public Builder audience(List audience) { return claim(JwtClaimNames.AUD, audience); } @@ -139,7 +121,7 @@ Builder audience(List audience) { * processing * @return the {@link Builder} */ - Builder expiresAt(Instant expiresAt) { + public Builder expiresAt(Instant expiresAt) { return claim(JwtClaimNames.EXP, expiresAt); } @@ -150,7 +132,7 @@ Builder expiresAt(Instant expiresAt) { * processing * @return the {@link Builder} */ - Builder notBefore(Instant notBefore) { + public Builder notBefore(Instant notBefore) { return claim(JwtClaimNames.NBF, notBefore); } @@ -160,7 +142,7 @@ Builder notBefore(Instant notBefore) { * @param issuedAt the time at which the JWT was issued * @return the {@link Builder} */ - Builder issuedAt(Instant issuedAt) { + public Builder issuedAt(Instant issuedAt) { return claim(JwtClaimNames.IAT, issuedAt); } @@ -170,7 +152,7 @@ Builder issuedAt(Instant issuedAt) { * @param jti the unique identifier for the JWT * @return the {@link Builder} */ - Builder id(String jti) { + public Builder id(String jti) { return claim(JwtClaimNames.JTI, jti); } @@ -180,7 +162,7 @@ Builder id(String jti) { * @param value the claim value * @return the {@link Builder} */ - Builder claim(String name, Object value) { + public Builder claim(String name, Object value) { Assert.hasText(name, "name cannot be empty"); Assert.notNull(value, "value cannot be null"); this.claims.put(name, value); @@ -192,7 +174,7 @@ Builder claim(String name, Object value) { * add, replace, or remove. * @param claimsConsumer a {@code Consumer} of the claims */ - Builder claims(Consumer> claimsConsumer) { + public Builder claims(Consumer> claimsConsumer) { claimsConsumer.accept(this.claims); return this; } @@ -201,7 +183,7 @@ Builder claims(Consumer> claimsConsumer) { * Builds a new {@link JwtClaimsSet}. * @return a {@link JwtClaimsSet} */ - JwtClaimsSet build() { + public JwtClaimsSet build() { Assert.notEmpty(this.claims, "claims cannot be empty"); // The value of the 'iss' claim is a String or URL (StringOrURI). diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java new file mode 100644 index 00000000000..44d6a14a88a --- /dev/null +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java @@ -0,0 +1,59 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.jwt; + +/** + * Implementations of this interface are responsible for encoding a JSON Web Token (JWT) + * to it's compact claims representation format. + * + *

+ * JWTs may be represented using the JWS Compact Serialization format for a JSON Web + * Signature (JWS) structure or JWE Compact Serialization format for a JSON Web Encryption + * (JWE) structure. Therefore, implementors are responsible for signing a JWS and/or + * encrypting a JWE. + * + * @author Anoop Garlapati + * @author Joe Grandja + * @since 5.6 + * @see Jwt + * @see JoseHeader + * @see JwtClaimsSet + * @see JwtDecoder + * @see JSON Web Token + * (JWT) + * @see JSON Web Signature + * (JWS) + * @see JSON Web Encryption + * (JWE) + * @see JWS + * Compact Serialization + * @see JWE + * Compact Serialization + */ +@FunctionalInterface +public interface JwtEncoder { + + /** + * Encode the JWT to it's compact claims representation format. + * @param headers the JOSE header + * @param claims the JWT Claims Set + * @return a {@link Jwt} + * @throws JwtEncodingException if an error occurs while attempting to encode the JWT + */ + Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException; + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtEncodingException.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncodingException.java similarity index 55% rename from oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtEncodingException.java rename to oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncodingException.java index 53c82b13bd5..9b48f5c4a2d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/JwtEncodingException.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncodingException.java @@ -14,39 +14,22 @@ * limitations under the License. */ -package org.springframework.security.oauth2.client.endpoint; - -import org.springframework.security.oauth2.jwt.JwtException; - -/* - * NOTE: - * This originated in gh-9208 (JwtEncoder), - * which is required to realize the feature in gh-8175 (JWT Client Authentication). - * However, we decided not to merge gh-9208 as part of the 5.5.0 release - * and instead packaged it up privately with the gh-8175 feature. - * We MAY merge gh-9208 in a later release but that is yet to be determined. - * - * gh-9208 Introduce JwtEncoder - * https://github.com/spring-projects/spring-security/pull/9208 - * - * gh-8175 Support JWT for Client Authentication - * https://github.com/spring-projects/spring-security/issues/8175 - */ +package org.springframework.security.oauth2.jwt; /** * This exception is thrown when an error occurs while attempting to encode a JSON Web * Token (JWT). * * @author Joe Grandja - * @since 5.5 + * @since 5.6 */ -class JwtEncodingException extends JwtException { +public class JwtEncodingException extends JwtException { /** * Constructs a {@code JwtEncodingException} using the provided parameters. * @param message the detail message */ - JwtEncodingException(String message) { + public JwtEncodingException(String message) { super(message); } @@ -55,7 +38,7 @@ class JwtEncodingException extends JwtException { * @param message the detail message * @param cause the root cause */ - JwtEncodingException(String message, Throwable cause) { + public JwtEncodingException(String message, Throwable cause) { super(message, cause); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java similarity index 89% rename from oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoder.java rename to oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java index c6d681a6232..f0d01fb8c18 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.security.oauth2.client.endpoint; +package org.springframework.security.oauth2.jwt; import java.net.URI; import java.net.URL; @@ -46,38 +46,22 @@ import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtClaimNames; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -/* - * NOTE: - * This originated in gh-9208 (JwtEncoder), - * which is required to realize the feature in gh-8175 (JWT Client Authentication). - * However, we decided not to merge gh-9208 as part of the 5.5.0 release - * and instead packaged it up privately with the gh-8175 feature. - * We MAY merge gh-9208 in a later release but that is yet to be determined. - * - * gh-9208 Introduce JwtEncoder - * https://github.com/spring-projects/spring-security/pull/9208 - * - * gh-8175 Support JWT for Client Authentication - * https://github.com/spring-projects/spring-security/issues/8175 - */ - /** - * A JWT encoder that encodes a JSON Web Token (JWT) using the JSON Web Signature (JWS) - * Compact Serialization format. The private/secret key used for signing the JWS is - * supplied by the {@code com.nimbusds.jose.jwk.source.JWKSource} provided via the - * constructor. + * An implementation of a {@link JwtEncoder} that encodes a JSON Web Token (JWT) using the + * JSON Web Signature (JWS) Compact Serialization format. The private/secret key used for + * signing the JWS is supplied by the {@code com.nimbusds.jose.jwk.source.JWKSource} + * provided via the constructor. * *

* NOTE: This implementation uses the Nimbus JOSE + JWT SDK. * * @author Joe Grandja - * @since 5.5 + * @since 5.6 + * @see JwtEncoder * @see com.nimbusds.jose.jwk.source.JWKSource * @see com.nimbusds.jose.jwk.JWK * @see JSON Web Token @@ -89,7 +73,7 @@ * @see Nimbus * JOSE + JWT SDK */ -final class NimbusJwsEncoder { +public final class NimbusJwsEncoder implements JwtEncoder { private static final String ENCODING_ERROR_MESSAGE_TEMPLATE = "An error occurred while attempting to encode the Jwt: %s"; @@ -103,12 +87,13 @@ final class NimbusJwsEncoder { * Constructs a {@code NimbusJwsEncoder} using the provided parameters. * @param jwkSource the {@code com.nimbusds.jose.jwk.source.JWKSource} */ - NimbusJwsEncoder(JWKSource jwkSource) { + public NimbusJwsEncoder(JWKSource jwkSource) { Assert.notNull(jwkSource, "jwkSource cannot be null"); this.jwkSource = jwkSource; } - Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException { + @Override + public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException { Assert.notNull(headers, "headers cannot be null"); Assert.notNull(claims, "claims cannot be null"); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/JoseHeaderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java similarity index 88% rename from oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/JoseHeaderTests.java rename to oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java index 2d69831c374..4aaeb773b77 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/JoseHeaderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.security.oauth2.client.endpoint; +package org.springframework.security.oauth2.jwt; import org.junit.jupiter.api.Test; @@ -24,21 +24,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -/* - * NOTE: - * This originated in gh-9208 (JwtEncoder), - * which is required to realize the feature in gh-8175 (JWT Client Authentication). - * However, we decided not to merge gh-9208 as part of the 5.5.0 release - * and instead packaged it up privately with the gh-8175 feature. - * We MAY merge gh-9208 in a later release but that is yet to be determined. - * - * gh-9208 Introduce JwtEncoder - * https://github.com/spring-projects/spring-security/pull/9208 - * - * gh-8175 Support JWT for Client Authentication - * https://github.com/spring-projects/spring-security/issues/8175 - */ - /** * Tests for {@link JoseHeader}. * diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSetTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtClaimsSetTests.java similarity index 85% rename from oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSetTests.java rename to oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtClaimsSetTests.java index 9a23a2fb528..00e2784d9eb 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/JwtClaimsSetTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtClaimsSetTests.java @@ -14,28 +14,13 @@ * limitations under the License. */ -package org.springframework.security.oauth2.client.endpoint; +package org.springframework.security.oauth2.jwt; import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -/* - * NOTE: - * This originated in gh-9208 (JwtEncoder), - * which is required to realize the feature in gh-8175 (JWT Client Authentication). - * However, we decided not to merge gh-9208 as part of the 5.5.0 release - * and instead packaged it up privately with the gh-8175 feature. - * We MAY merge gh-9208 in a later release but that is yet to be determined. - * - * gh-9208 Introduce JwtEncoder - * https://github.com/spring-projects/spring-security/pull/9208 - * - * gh-8175 Support JWT for Client Authentication - * https://github.com/spring-projects/spring-security/issues/8175 - */ - /** * Tests for {@link JwtClaimsSet}. * diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java new file mode 100644 index 00000000000..10fedcb44ad --- /dev/null +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java @@ -0,0 +1,543 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.jwt; + +import java.net.URL; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Date; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +import com.nimbusds.jose.EncryptionMethod; +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JOSEObjectType; +import com.nimbusds.jose.JWEAlgorithm; +import com.nimbusds.jose.JWEHeader; +import com.nimbusds.jose.JWEObject; +import com.nimbusds.jose.KeySourceException; +import com.nimbusds.jose.Payload; +import com.nimbusds.jose.crypto.RSAEncrypter; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.JWKMatcher; +import com.nimbusds.jose.jwk.JWKSelector; +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jose.jwk.source.JWKSource; +import com.nimbusds.jose.proc.SecurityContext; +import com.nimbusds.jose.util.Base64; +import com.nimbusds.jose.util.Base64URL; +import com.nimbusds.jwt.JWTClaimsSet; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; +import org.springframework.security.oauth2.jose.JwaAlgorithm; +import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link NimbusJweEncoder} (future support for JWE). + * + * @author Joe Grandja + */ +public class NimbusJweEncoderTests { + + private List jwkList; + + private JWKSource jwkSource; + + private NimbusJweEncoder jweEncoder; + + private NimbusJwsEncoder jwsEncoder; + + @BeforeEach + public void setUp() { + this.jwkList = new ArrayList<>(); + this.jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(new JWKSet(this.jwkList)); + this.jweEncoder = new NimbusJweEncoder(this.jwkSource); + this.jwsEncoder = new NimbusJwsEncoder(this.jwkSource); + } + + @Test + public void encodeWhenJwtClaimsSetThenEncodes() { + RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + this.jwkList.add(rsaJwk); + + // @formatter:off + JoseHeader jweHeader = JoseHeader.withAlgorithm(JweAlgorithm.RSA_OAEP_256) + .header("enc", EncryptionMethod.A256GCM.getName()) + .build(); + // @formatter:on + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + Jwt encodedJwe = this.jweEncoder.encode(jweHeader, jwtClaimsSet); + + assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(jweHeader.getAlgorithm()); + assertThat(encodedJwe.getHeaders().get("enc")).isEqualTo(jweHeader.getHeader("enc")); + assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.JKU)).isNull(); + assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.JWK)).isNull(); + assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(rsaJwk.getKeyID()); + assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.X5U)).isNull(); + assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.X5C)).isNull(); + assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.X5T)).isNull(); + assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.X5T_S256)).isNull(); + assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.TYP)).isNull(); + assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.CTY)).isNull(); + assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.CRIT)).isNull(); + + assertThat(encodedJwe.getIssuer()).isEqualTo(jwtClaimsSet.getIssuer()); + assertThat(encodedJwe.getSubject()).isEqualTo(jwtClaimsSet.getSubject()); + assertThat(encodedJwe.getAudience()).isEqualTo(jwtClaimsSet.getAudience()); + assertThat(encodedJwe.getExpiresAt()).isEqualTo(jwtClaimsSet.getExpiresAt()); + assertThat(encodedJwe.getNotBefore()).isEqualTo(jwtClaimsSet.getNotBefore()); + assertThat(encodedJwe.getIssuedAt()).isEqualTo(jwtClaimsSet.getIssuedAt()); + assertThat(encodedJwe.getId()).isEqualTo(jwtClaimsSet.getId()); + assertThat(encodedJwe.getClaim("custom-claim-name")).isEqualTo("custom-claim-value"); + + assertThat(encodedJwe.getTokenValue()).isNotNull(); + } + + @Test + public void encodeWhenNestedJwsThenEncodes() { + // See Nimbus example -> Nested signed and encrypted JWT + // https://connect2id.com/products/nimbus-jose-jwt/examples/signed-and-encrypted-jwt + + RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + this.jwkList.add(rsaJwk); + + JoseHeader jwsHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + Jwt encodedJws = this.jwsEncoder.encode(jwsHeader, jwtClaimsSet); + + // @formatter:off + JoseHeader jweHeader = JoseHeader.withAlgorithm(JweAlgorithm.RSA_OAEP_256) + .header("enc", EncryptionMethod.A256GCM.getName()) + .contentType("JWT") // Indicates Nested JWT (REQUIRED) + .build(); + // @formatter:on + + JoseToken encodedJweNestedJws = this.jweEncoder.encode(jweHeader, + new JosePayload<>(encodedJws.getTokenValue())); + + assertThat(encodedJweNestedJws.getHeaders().getAlgorithm()).isEqualTo(jweHeader.getAlgorithm()); + assertThat(encodedJweNestedJws.getHeaders().getHeader("enc")).isEqualTo(jweHeader.getHeader("enc")); + assertThat(encodedJweNestedJws.getHeaders().getJwkSetUrl()).isNull(); + assertThat(encodedJweNestedJws.getHeaders().getJwk()).isNull(); + assertThat(encodedJweNestedJws.getHeaders().getKeyId()).isEqualTo(rsaJwk.getKeyID()); + assertThat(encodedJweNestedJws.getHeaders().getX509Url()).isNull(); + assertThat(encodedJweNestedJws.getHeaders().getX509CertificateChain()).isNull(); + assertThat(encodedJweNestedJws.getHeaders().getX509SHA1Thumbprint()).isNull(); + assertThat(encodedJweNestedJws.getHeaders().getX509SHA256Thumbprint()).isNull(); + assertThat(encodedJweNestedJws.getHeaders().getType()).isNull(); + assertThat(encodedJweNestedJws.getHeaders().getContentType()).isEqualTo("JWT"); + assertThat(encodedJweNestedJws.getHeaders().getCritical()).isNull(); + + assertThat(encodedJweNestedJws.getTokenValue()).isNotNull(); + } + + enum JweAlgorithm implements JwaAlgorithm { + + RSA_OAEP_256("RSA-OAEP-256"); + + private final String name; + + JweAlgorithm(String name) { + this.name = name; + } + + @Override + public String getName() { + return this.name; + } + + } + + private static final class NimbusJweEncoder implements JwtEncoder, JoseEncoder { + + private static final String ENCODING_ERROR_MESSAGE_TEMPLATE = "An error occurred while attempting to encode the Jwt: %s"; + + private static final Converter JWE_HEADER_CONVERTER = new JweHeaderConverter(); + + private static final Converter JWT_CLAIMS_SET_CONVERTER = new JwtClaimsSetConverter(); + + private final JWKSource jwkSource; + + private NimbusJweEncoder(JWKSource jwkSource) { + Assert.notNull(jwkSource, "jwkSource cannot be null"); + this.jwkSource = jwkSource; + } + + @Override + public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException { + Assert.notNull(headers, "headers cannot be null"); + Assert.notNull(claims, "claims cannot be null"); + + JWTClaimsSet jwtClaimsSet = JWT_CLAIMS_SET_CONVERTER.convert(claims); + + JoseToken joseToken = encode(headers, new JosePayload<>(jwtClaimsSet.toString())); + + return new Jwt(joseToken.getTokenValue(), claims.getIssuedAt(), claims.getExpiresAt(), + joseToken.getHeaders().getHeaders(), claims.getClaims()); + } + + @Override + public JoseToken encode(JoseHeader headers, JosePayload payload) throws JwtEncodingException { + Assert.notNull(headers, "headers cannot be null"); + Assert.notNull(payload, "payload cannot be null"); + + JWEHeader jweHeader; + try { + jweHeader = JWE_HEADER_CONVERTER.convert(headers); + } + catch (Exception ex) { + throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); + } + + JWK jwk = selectJwk(jweHeader); + if (jwk == null) { + throw new JwtEncodingException( + String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK encryption key")); + } + + jweHeader = addKeyIdentifierHeadersIfNecessary(jweHeader, jwk); + headers = syncKeyIdentifierHeadersIfNecessary(headers, jweHeader); + + // FIXME + // Resolve type of JosePayload.content + // For now, assuming String type + String payloadContent = (String) payload.getContent(); + + JWEObject jweObject = new JWEObject(jweHeader, new Payload(payloadContent)); + try { + // FIXME + // Resolve type of JWEEncrypter using the JWK key type + // For now, assuming RSA key type + jweObject.encrypt(new RSAEncrypter(jwk.toRSAKey())); + } + catch (JOSEException ex) { + throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, + "Failed to encrypt the JWT -> " + ex.getMessage()), ex); + } + String jwe = jweObject.serialize(); + + return new JoseToken(jwe, null, null, headers, payload); + } + + private JWK selectJwk(JWEHeader jweHeader) { + JWKSelector jwkSelector = new JWKSelector(JWKMatcher.forJWEHeader(jweHeader)); + + List jwks; + try { + jwks = this.jwkSource.get(jwkSelector, null); + } + catch (KeySourceException ex) { + throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, + "Failed to select a JWK encryption key -> " + ex.getMessage()), ex); + } + + if (jwks.size() > 1) { + throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, + "Found multiple JWK encryption keys for algorithm '" + jweHeader.getAlgorithm().getName() + + "'")); + } + + return !jwks.isEmpty() ? jwks.get(0) : null; + } + + private static JWEHeader addKeyIdentifierHeadersIfNecessary(JWEHeader jweHeader, JWK jwk) { + // Check if headers have already been added + if (StringUtils.hasText(jweHeader.getKeyID()) && jweHeader.getX509CertSHA256Thumbprint() != null) { + return jweHeader; + } + // Check if headers can be added from JWK + if (!StringUtils.hasText(jwk.getKeyID()) && jwk.getX509CertSHA256Thumbprint() == null) { + return jweHeader; + } + + JWEHeader.Builder headerBuilder = new JWEHeader.Builder(jweHeader); + if (!StringUtils.hasText(jweHeader.getKeyID()) && StringUtils.hasText(jwk.getKeyID())) { + headerBuilder.keyID(jwk.getKeyID()); + } + if (jweHeader.getX509CertSHA256Thumbprint() == null && jwk.getX509CertSHA256Thumbprint() != null) { + headerBuilder.x509CertSHA256Thumbprint(jwk.getX509CertSHA256Thumbprint()); + } + + return headerBuilder.build(); + } + + private static JoseHeader syncKeyIdentifierHeadersIfNecessary(JoseHeader joseHeader, JWEHeader jweHeader) { + String jweHeaderX509SHA256Thumbprint = null; + if (jweHeader.getX509CertSHA256Thumbprint() != null) { + jweHeaderX509SHA256Thumbprint = jweHeader.getX509CertSHA256Thumbprint().toString(); + } + if (Objects.equals(joseHeader.getKeyId(), jweHeader.getKeyID()) + && Objects.equals(joseHeader.getX509SHA256Thumbprint(), jweHeaderX509SHA256Thumbprint)) { + return joseHeader; + } + + JoseHeader.Builder headerBuilder = JoseHeader.from(joseHeader); + if (!Objects.equals(joseHeader.getKeyId(), jweHeader.getKeyID())) { + headerBuilder.keyId(jweHeader.getKeyID()); + } + if (!Objects.equals(joseHeader.getX509SHA256Thumbprint(), jweHeaderX509SHA256Thumbprint)) { + headerBuilder.x509SHA256Thumbprint(jweHeaderX509SHA256Thumbprint); + } + + return headerBuilder.build(); + } + + } + + private static class JweHeaderConverter implements Converter { + + @Override + public JWEHeader convert(JoseHeader headers) { + JWEAlgorithm jweAlgorithm = JWEAlgorithm.parse(headers.getAlgorithm().getName()); + EncryptionMethod encryptionMethod = EncryptionMethod.parse(headers.getHeader("enc")); + JWEHeader.Builder builder = new JWEHeader.Builder(jweAlgorithm, encryptionMethod); + + URL jwkSetUri = headers.getJwkSetUrl(); + if (jwkSetUri != null) { + try { + builder.jwkURL(jwkSetUri.toURI()); + } + catch (Exception ex) { + throw new IllegalArgumentException( + "Unable to convert '" + JoseHeaderNames.JKU + "' JOSE header to a URI", ex); + } + } + + Map jwk = headers.getJwk(); + if (!CollectionUtils.isEmpty(jwk)) { + try { + builder.jwk(JWK.parse(jwk)); + } + catch (Exception ex) { + throw new IllegalArgumentException("Unable to convert '" + JoseHeaderNames.JWK + "' JOSE header", + ex); + } + } + + String keyId = headers.getKeyId(); + if (StringUtils.hasText(keyId)) { + builder.keyID(keyId); + } + + URL x509Uri = headers.getX509Url(); + if (x509Uri != null) { + try { + builder.x509CertURL(x509Uri.toURI()); + } + catch (Exception ex) { + throw new IllegalArgumentException( + "Unable to convert '" + JoseHeaderNames.X5U + "' JOSE header to a URI", ex); + } + } + + List x509CertificateChain = headers.getX509CertificateChain(); + if (!CollectionUtils.isEmpty(x509CertificateChain)) { + builder.x509CertChain(x509CertificateChain.stream().map(Base64::new).collect(Collectors.toList())); + } + + String x509SHA1Thumbprint = headers.getX509SHA1Thumbprint(); + if (StringUtils.hasText(x509SHA1Thumbprint)) { + builder.x509CertThumbprint(new Base64URL(x509SHA1Thumbprint)); + } + + String x509SHA256Thumbprint = headers.getX509SHA256Thumbprint(); + if (StringUtils.hasText(x509SHA256Thumbprint)) { + builder.x509CertSHA256Thumbprint(new Base64URL(x509SHA256Thumbprint)); + } + + String type = headers.getType(); + if (StringUtils.hasText(type)) { + builder.type(new JOSEObjectType(type)); + } + + String contentType = headers.getContentType(); + if (StringUtils.hasText(contentType)) { + builder.contentType(contentType); + } + + Set critical = headers.getCritical(); + if (!CollectionUtils.isEmpty(critical)) { + builder.criticalParams(critical); + } + + Map customHeaders = headers.getHeaders().entrySet().stream() + .filter((header) -> !JWEHeader.getRegisteredParameterNames().contains(header.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + if (!CollectionUtils.isEmpty(customHeaders)) { + builder.customParams(customHeaders); + } + + return builder.build(); + } + + } + + private static class JwtClaimsSetConverter implements Converter { + + @Override + public JWTClaimsSet convert(JwtClaimsSet claims) { + JWTClaimsSet.Builder builder = new JWTClaimsSet.Builder(); + + // NOTE: The value of the 'iss' claim is a String or URL (StringOrURI). + Object issuer = claims.getClaim(JwtClaimNames.ISS); + if (issuer != null) { + builder.issuer(issuer.toString()); + } + + String subject = claims.getSubject(); + if (StringUtils.hasText(subject)) { + builder.subject(subject); + } + + List audience = claims.getAudience(); + if (!CollectionUtils.isEmpty(audience)) { + builder.audience(audience); + } + + Instant expiresAt = claims.getExpiresAt(); + if (expiresAt != null) { + builder.expirationTime(Date.from(expiresAt)); + } + + Instant notBefore = claims.getNotBefore(); + if (notBefore != null) { + builder.notBeforeTime(Date.from(notBefore)); + } + + Instant issuedAt = claims.getIssuedAt(); + if (issuedAt != null) { + builder.issueTime(Date.from(issuedAt)); + } + + String jwtId = claims.getId(); + if (StringUtils.hasText(jwtId)) { + builder.jwtID(jwtId); + } + + Map customClaims = new HashMap<>(); + claims.getClaims().forEach((name, value) -> { + if (!JWTClaimsSet.getRegisteredNames().contains(name)) { + customClaims.put(name, value); + } + }); + if (!customClaims.isEmpty()) { + customClaims.forEach(builder::claim); + } + + return builder.build(); + } + + } + + static class JoseToken extends AbstractOAuth2Token { + + private final JoseHeader headers; + + private final JosePayload payload; + + JoseToken(String tokenValue, Instant issuedAt, Instant expiresAt, JoseHeader headers, JosePayload payload) { + super(tokenValue, issuedAt, expiresAt); + this.headers = headers; + this.payload = payload; + } + + JoseHeader getHeaders() { + return this.headers; + } + + JosePayload getPayload() { + return this.payload; + } + + } + + static class JosePayload { + + private final T content; + + JosePayload(T content) { + this.content = content; + } + + T getContent() { + return this.content; + } + + } + + // @formatter:off + /* + * IMPORTANT DESIGN DECISION + * ------------------------- + * + * This API is needed in order to support "Nested JWT". + * + * See section 2. Terminology + * https://tools.ietf.org/html/rfc7519#section-2 + * + * Nested JWT + * A JWT in which nested signing and/or encryption are employed. + * In Nested JWTs, a JWT is used as the payload or plaintext value of an + * enclosing JWS or JWE structure, respectively. + * + * See section 3. JSON Web Token (JWT) Overview + * https://tools.ietf.org/html/rfc7519#section-3 + * + * JWTs represent a set of claims as a JSON object that is encoded in a + * JWS and/or JWE structure. This JSON object is the JWT Claims Set. + * + * The contents of the JOSE Header describe the cryptographic operations + * applied to the JWT Claims Set. If the JOSE Header is for a JWS, the + * JWT is represented as a JWS and the claims are digitally signed or + * MACed, with the JWT Claims Set being the JWS Payload. If the JOSE + * Header is for a JWE, the JWT is represented as a JWE and the claims + * are encrypted, with the JWT Claims Set being the plaintext encrypted + * by the JWE. A JWT may be enclosed in another JWE or JWS structure to + * create a Nested JWT, enabling nested signing and encryption to be + * performed. + * + * ----------------------- + * + * In summary, the `JwtEncoder` API is designed for signing (JWS) and encrypting (JWE) a JWT Claims Set. + * Whereas, the `JoseEncoder` API is a higher level of abstraction that can be used for Nested JWT (signing and encryption). + * NOTE: The `JosePayload` type provides the flexibility to support any data type, + * e.g. JWT/JWS, JwtClaimsSet, String, Map, byte[], etc. + */ + interface JoseEncoder { + + JoseToken encode(JoseHeader headers, JosePayload payload) throws JwtEncodingException; + + } + // @formatter:on + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java similarity index 94% rename from oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoderTests.java rename to oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java index b97d638b073..53964d8485a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwsEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.security.oauth2.client.endpoint; +package org.springframework.security.oauth2.jwt; import java.security.interfaces.ECPrivateKey; import java.security.interfaces.ECPublicKey; @@ -42,8 +42,6 @@ import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.jose.TestKeys; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -54,21 +52,6 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; -/* - * NOTE: - * This originated in gh-9208 (JwtEncoder), - * which is required to realize the feature in gh-8175 (JWT Client Authentication). - * However, we decided not to merge gh-9208 as part of the 5.5.0 release - * and instead packaged it up privately with the gh-8175 feature. - * We MAY merge gh-9208 in a later release but that is yet to be determined. - * - * gh-9208 Introduce JwtEncoder - * https://github.com/spring-projects/spring-security/pull/9208 - * - * gh-8175 Support JWT for Client Authentication - * https://github.com/spring-projects/spring-security/issues/8175 - */ - /** * Tests for {@link NimbusJwsEncoder}. * diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TestJoseHeaders.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJoseHeaders.java similarity index 67% rename from oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TestJoseHeaders.java rename to oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJoseHeaders.java index ffda877694f..5d30f58b778 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TestJoseHeaders.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJoseHeaders.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.security.oauth2.client.endpoint; +package org.springframework.security.oauth2.jwt; import java.util.Arrays; import java.util.HashMap; @@ -22,34 +22,19 @@ import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; -/* - * NOTE: - * This originated in gh-9208 (JwtEncoder), - * which is required to realize the feature in gh-8175 (JWT Client Authentication). - * However, we decided not to merge gh-9208 as part of the 5.5.0 release - * and instead packaged it up privately with the gh-8175 feature. - * We MAY merge gh-9208 in a later release but that is yet to be determined. - * - * gh-9208 Introduce JwtEncoder - * https://github.com/spring-projects/spring-security/pull/9208 - * - * gh-8175 Support JWT for Client Authentication - * https://github.com/spring-projects/spring-security/issues/8175 - */ - /** * @author Joe Grandja */ -final class TestJoseHeaders { +public final class TestJoseHeaders { private TestJoseHeaders() { } - static JoseHeader.Builder joseHeader() { + public static JoseHeader.Builder joseHeader() { return joseHeader(SignatureAlgorithm.RS256); } - static JoseHeader.Builder joseHeader(SignatureAlgorithm signatureAlgorithm) { + public static JoseHeader.Builder joseHeader(SignatureAlgorithm signatureAlgorithm) { // @formatter:off return JoseHeader.withAlgorithm(signatureAlgorithm) .jwkSetUrl("https://provider.com/oauth2/jwks") diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TestJwtClaimsSets.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJwtClaimsSets.java similarity index 63% rename from oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TestJwtClaimsSets.java rename to oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJwtClaimsSets.java index 1b311979457..4cb79f6192f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TestJwtClaimsSets.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJwtClaimsSets.java @@ -14,36 +14,21 @@ * limitations under the License. */ -package org.springframework.security.oauth2.client.endpoint; +package org.springframework.security.oauth2.jwt; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Collections; -/* - * NOTE: - * This originated in gh-9208 (JwtEncoder), - * which is required to realize the feature in gh-8175 (JWT Client Authentication). - * However, we decided not to merge gh-9208 as part of the 5.5.0 release - * and instead packaged it up privately with the gh-8175 feature. - * We MAY merge gh-9208 in a later release but that is yet to be determined. - * - * gh-9208 Introduce JwtEncoder - * https://github.com/spring-projects/spring-security/pull/9208 - * - * gh-8175 Support JWT for Client Authentication - * https://github.com/spring-projects/spring-security/issues/8175 - */ - /** * @author Joe Grandja */ -final class TestJwtClaimsSets { +public final class TestJwtClaimsSets { private TestJwtClaimsSets() { } - static JwtClaimsSet.Builder jwtClaimsSet() { + public static JwtClaimsSet.Builder jwtClaimsSet() { String issuer = "https://provider.com"; Instant issuedAt = Instant.now(); Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); From 9b7b085ca443dbcd3b426458d632072ef53ba7ab Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 16 Sep 2021 04:30:33 -0400 Subject: [PATCH 02/20] Introduce JwtEncoderParameters --- ...ientAuthenticationParametersConverter.java | 3 +- .../security/oauth2/jwt/JwtEncoder.java | 6 +- .../oauth2/jwt/JwtEncoderParameters.java | 68 +++++++++++++++++++ .../security/oauth2/jwt/NimbusJwsEncoder.java | 8 ++- .../oauth2/jwt/NimbusJweEncoderTests.java | 12 ++-- .../oauth2/jwt/NimbusJwsEncoderTests.java | 31 ++++++--- 6 files changed, 105 insertions(+), 23 deletions(-) create mode 100644 oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java index 65b5ffcfb79..04aa62f2101 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java @@ -44,6 +44,7 @@ import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.jwt.JwtEncoderParameters; import org.springframework.security.oauth2.jwt.NimbusJwsEncoder; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; @@ -154,7 +155,7 @@ public MultiValueMap convert(T authorizationGrantRequest) { }); JwtEncoder jwsEncoder = jwsEncoderHolder.getJwsEncoder(); - Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet); + Jwt jws = jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet)); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.set(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, CLIENT_ASSERTION_TYPE_VALUE); diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java index 44d6a14a88a..14e426ddd59 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java @@ -30,6 +30,7 @@ * @author Joe Grandja * @since 5.6 * @see Jwt + * @see JwtEncoderParameters * @see JoseHeader * @see JwtClaimsSet * @see JwtDecoder @@ -49,11 +50,10 @@ public interface JwtEncoder { /** * Encode the JWT to it's compact claims representation format. - * @param headers the JOSE header - * @param claims the JWT Claims Set + * @param parameters the parameters containing the JOSE header and JWT Claims Set * @return a {@link Jwt} * @throws JwtEncodingException if an error occurs while attempting to encode the JWT */ - Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException; + Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException; } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java new file mode 100644 index 00000000000..0ecdc7200fd --- /dev/null +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.jwt; + +import org.springframework.util.Assert; + +/** + * A holder of parameters containing the JOSE header and JWT Claims Set. + * + * @author Joe Grandja + * @since 5.6 + * @see JwtEncoder + */ +public final class JwtEncoderParameters { + + private final JoseHeader headers; + + private final JwtClaimsSet claims; + + private JwtEncoderParameters(JoseHeader headers, JwtClaimsSet claims) { + Assert.notNull(headers, "headers cannot be null"); + Assert.notNull(claims, "claims cannot be null"); + this.headers = headers; + this.claims = claims; + } + + /** + * Returns a new {@link JwtEncoderParameters}, initialized with the provided + * {@link JoseHeader} and {@link JwtClaimsSet}. + * @param headers the {@link JoseHeader} + * @param claims the {@link JwtClaimsSet} + * @return the {@link JwtEncoderParameters} + */ + public static JwtEncoderParameters with(JoseHeader headers, JwtClaimsSet claims) { + return new JwtEncoderParameters(headers, claims); + } + + /** + * Returns the {@link JoseHeader headers}. + * @return the {@link JoseHeader} + */ + public JoseHeader getHeaders() { + return this.headers; + } + + /** + * Returns the {@link JwtClaimsSet claims}. + * @return the {@link JwtClaimsSet} + */ + public JwtClaimsSet getClaims() { + return this.claims; + } + +} diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java index f0d01fb8c18..a469bef666b 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java @@ -93,9 +93,11 @@ public NimbusJwsEncoder(JWKSource jwkSource) { } @Override - public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException { - Assert.notNull(headers, "headers cannot be null"); - Assert.notNull(claims, "claims cannot be null"); + public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { + Assert.notNull(parameters, "parameters cannot be null"); + + JoseHeader headers = parameters.getHeaders(); + JwtClaimsSet claims = parameters.getClaims(); JWK jwk = selectJwk(headers); headers = addKeyIdentifierHeadersIfNecessary(headers, jwk); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java index 10fedcb44ad..4d02bb3becd 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java @@ -95,7 +95,7 @@ public void encodeWhenJwtClaimsSetThenEncodes() { // @formatter:on JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJwe = this.jweEncoder.encode(jweHeader, jwtClaimsSet); + Jwt encodedJwe = this.jweEncoder.encode(JwtEncoderParameters.with(jweHeader, jwtClaimsSet)); assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(jweHeader.getAlgorithm()); assertThat(encodedJwe.getHeaders().get("enc")).isEqualTo(jweHeader.getHeader("enc")); @@ -133,7 +133,7 @@ public void encodeWhenNestedJwsThenEncodes() { JoseHeader jwsHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJws = this.jwsEncoder.encode(jwsHeader, jwtClaimsSet); + Jwt encodedJws = this.jwsEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); // @formatter:off JoseHeader jweHeader = JoseHeader.withAlgorithm(JweAlgorithm.RSA_OAEP_256) @@ -194,9 +194,11 @@ private NimbusJweEncoder(JWKSource jwkSource) { } @Override - public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException { - Assert.notNull(headers, "headers cannot be null"); - Assert.notNull(claims, "claims cannot be null"); + public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { + Assert.notNull(parameters, "parameters cannot be null"); + + JoseHeader headers = parameters.getHeaders(); + JwtClaimsSet claims = parameters.getClaims(); JWTClaimsSet jwtClaimsSet = JWT_CLAIMS_SET_CONVERTER.convert(claims); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java index 53964d8485a..0ed6c436738 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java @@ -78,11 +78,18 @@ public void constructorWhenJwkSourceNullThenThrowIllegalArgumentException() { .withMessage("jwkSource cannot be null"); } + @Test + public void encodeWhenParametersNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.jwsEncoder.encode(null)) + .withMessage("parameters cannot be null"); + } + @Test public void encodeWhenHeadersNullThenThrowIllegalArgumentException() { JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - assertThatIllegalArgumentException().isThrownBy(() -> this.jwsEncoder.encode(null, jwtClaimsSet)) + assertThatIllegalArgumentException() + .isThrownBy(() -> this.jwsEncoder.encode(JwtEncoderParameters.with(null, jwtClaimsSet))) .withMessage("headers cannot be null"); } @@ -90,7 +97,8 @@ public void encodeWhenHeadersNullThenThrowIllegalArgumentException() { public void encodeWhenClaimsNullThenThrowIllegalArgumentException() { JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); - assertThatIllegalArgumentException().isThrownBy(() -> this.jwsEncoder.encode(joseHeader, null)) + assertThatIllegalArgumentException() + .isThrownBy(() -> this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, null))) .withMessage("claims cannot be null"); } @@ -104,7 +112,7 @@ public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exce JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) - .isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet)) + .isThrownBy(() -> this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet))) .withMessageContaining("Failed to select a JWK signing key -> key source error"); } @@ -118,7 +126,7 @@ public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) - .isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet)) + .isThrownBy(() -> this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet))) .withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'"); } @@ -128,7 +136,7 @@ public void encodeWhenJwkSelectEmptyThenThrowJwtEncodingException() { JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) - .isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet)) + .isThrownBy(() -> this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet))) .withMessageContaining("Failed to select a JWK signing key"); } @@ -148,7 +156,7 @@ public void encodeWhenJwkSelectWithProvidedKidThenSelected() { JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).keyId(rsaJwk2.getKeyID()).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJws = this.jwsEncoder.encode(joseHeader, jwtClaimsSet); + Jwt encodedJws = this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet)); assertThat(encodedJws.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(rsaJwk2.getKeyID()); } @@ -172,7 +180,7 @@ public void encodeWhenJwkSelectWithProvidedX5TS256ThenSelected() { .x509SHA256Thumbprint(rsaJwk1.getX509CertSHA256Thumbprint().toString()).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJws = this.jwsEncoder.encode(joseHeader, jwtClaimsSet); + Jwt encodedJws = this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet)); assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5T_S256)) .isEqualTo(rsaJwk1.getX509CertSHA256Thumbprint().toString()); @@ -195,7 +203,8 @@ public void encodeWhenJwkUseEncryptionThenThrowJwtEncodingException() throws Exc JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) - .isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet)).withMessageContaining( + .isThrownBy(() -> this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet))) + .withMessageContaining( "Failed to create a JWS Signer -> The JWK use must be sig (signature) or unspecified"); } @@ -212,7 +221,7 @@ public void encodeWhenSuccessThenDecodes() throws Exception { JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJws = this.jwsEncoder.encode(joseHeader, jwtClaimsSet); + Jwt encodedJws = this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet)); assertThat(encodedJws.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(joseHeader.getAlgorithm()); assertThat(encodedJws.getHeaders().get(JoseHeaderNames.JKU)).isNull(); @@ -257,7 +266,7 @@ public List get(JWKSelector jwkSelector, SecurityContext context) { JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJws = jwsEncoder.encode(joseHeader, jwtClaimsSet); + Jwt encodedJws = jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet)); JWK jwk1 = jwkListResultCaptor.getResult().get(0); NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(((RSAKey) jwk1).toRSAPublicKey()).build(); @@ -265,7 +274,7 @@ public List get(JWKSelector jwkSelector, SecurityContext context) { jwkSource.rotate(); // Simulate key rotation - encodedJws = jwsEncoder.encode(joseHeader, jwtClaimsSet); + encodedJws = jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet)); JWK jwk2 = jwkListResultCaptor.getResult().get(0); jwtDecoder = NimbusJwtDecoder.withPublicKey(((RSAKey) jwk2).toRSAPublicKey()).build(); From 8c21bdaf8fa48b8ddcbdda717cdf19caae31a288 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 16 Sep 2021 05:58:40 -0400 Subject: [PATCH 03/20] Rename NimbusJwsEncoder -> NimbusJwtEncoder --- ...tClientAuthenticationParametersConverter.java | 4 ++-- ...mbusJwsEncoder.java => NimbusJwtEncoder.java} | 8 ++++---- .../oauth2/jwt/NimbusJweEncoderTests.java | 4 ++-- ...oderTests.java => NimbusJwtEncoderTests.java} | 16 ++++++++-------- 4 files changed, 16 insertions(+), 16 deletions(-) rename oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/{NimbusJwsEncoder.java => NimbusJwtEncoder.java} (98%) rename oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/{NimbusJwsEncoderTests.java => NimbusJwtEncoderTests.java} (97%) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java index 04aa62f2101..330be84561e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java @@ -45,7 +45,7 @@ import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.JwtEncoderParameters; -import org.springframework.security.oauth2.jwt.NimbusJwsEncoder; +import org.springframework.security.oauth2.jwt.NimbusJwtEncoder; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -151,7 +151,7 @@ public MultiValueMap convert(T authorizationGrantRequest) { return currentJwsEncoderHolder; } JWKSource jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk)); - return new JwsEncoderHolder(new NimbusJwsEncoder(jwkSource), jwk); + return new JwsEncoderHolder(new NimbusJwtEncoder(jwkSource), jwk); }); JwtEncoder jwsEncoder = jwsEncoderHolder.getJwsEncoder(); diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java similarity index 98% rename from oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java rename to oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java index a469bef666b..ed944f719d4 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java @@ -73,7 +73,7 @@ * @see Nimbus * JOSE + JWT SDK */ -public final class NimbusJwsEncoder implements JwtEncoder { +public final class NimbusJwtEncoder implements JwtEncoder { private static final String ENCODING_ERROR_MESSAGE_TEMPLATE = "An error occurred while attempting to encode the Jwt: %s"; @@ -84,10 +84,10 @@ public final class NimbusJwsEncoder implements JwtEncoder { private final JWKSource jwkSource; /** - * Constructs a {@code NimbusJwsEncoder} using the provided parameters. + * Constructs a {@code NimbusJwtEncoder} using the provided parameters. * @param jwkSource the {@code com.nimbusds.jose.jwk.source.JWKSource} */ - public NimbusJwsEncoder(JWKSource jwkSource) { + public NimbusJwtEncoder(JWKSource jwkSource) { Assert.notNull(jwkSource, "jwkSource cannot be null"); this.jwkSource = jwkSource; } @@ -135,7 +135,7 @@ private String serialize(JoseHeader headers, JwtClaimsSet claims, JWK jwk) { JWSHeader jwsHeader = convert(headers); JWTClaimsSet jwtClaimsSet = convert(claims); - JWSSigner jwsSigner = this.jwsSigners.computeIfAbsent(jwk, NimbusJwsEncoder::createSigner); + JWSSigner jwsSigner = this.jwsSigners.computeIfAbsent(jwk, NimbusJwtEncoder::createSigner); SignedJWT signedJwt = new SignedJWT(jwsHeader, jwtClaimsSet); try { diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java index 4d02bb3becd..0c7643140cf 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java @@ -73,14 +73,14 @@ public class NimbusJweEncoderTests { private NimbusJweEncoder jweEncoder; - private NimbusJwsEncoder jwsEncoder; + private NimbusJwtEncoder jwsEncoder; @BeforeEach public void setUp() { this.jwkList = new ArrayList<>(); this.jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(new JWKSet(this.jwkList)); this.jweEncoder = new NimbusJweEncoder(this.jwkSource); - this.jwsEncoder = new NimbusJwsEncoder(this.jwkSource); + this.jwsEncoder = new NimbusJwtEncoder(this.jwkSource); } @Test diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java similarity index 97% rename from oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java rename to oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java index 0ed6c436738..5ee04b3e3c1 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java @@ -53,28 +53,28 @@ import static org.mockito.Mockito.spy; /** - * Tests for {@link NimbusJwsEncoder}. + * Tests for {@link NimbusJwtEncoder}. * * @author Joe Grandja */ -public class NimbusJwsEncoderTests { +public class NimbusJwtEncoderTests { private List jwkList; private JWKSource jwkSource; - private NimbusJwsEncoder jwsEncoder; + private NimbusJwtEncoder jwsEncoder; @BeforeEach public void setUp() { this.jwkList = new ArrayList<>(); this.jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(new JWKSet(this.jwkList)); - this.jwsEncoder = new NimbusJwsEncoder(this.jwkSource); + this.jwsEncoder = new NimbusJwtEncoder(this.jwkSource); } @Test public void constructorWhenJwkSourceNullThenThrowIllegalArgumentException() { - assertThatIllegalArgumentException().isThrownBy(() -> new NimbusJwsEncoder(null)) + assertThatIllegalArgumentException().isThrownBy(() -> new NimbusJwtEncoder(null)) .withMessage("jwkSource cannot be null"); } @@ -105,7 +105,7 @@ public void encodeWhenClaimsNullThenThrowIllegalArgumentException() { @Test public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exception { this.jwkSource = mock(JWKSource.class); - this.jwsEncoder = new NimbusJwsEncoder(this.jwkSource); + this.jwsEncoder = new NimbusJwtEncoder(this.jwkSource); given(this.jwkSource.get(any(), any())).willThrow(new KeySourceException("key source error")); JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); @@ -196,7 +196,7 @@ public void encodeWhenJwkUseEncryptionThenThrowJwtEncodingException() throws Exc // @formatter:on this.jwkSource = mock(JWKSource.class); - this.jwsEncoder = new NimbusJwsEncoder(this.jwkSource); + this.jwsEncoder = new NimbusJwtEncoder(this.jwkSource); given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(rsaJwk)); JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); @@ -258,7 +258,7 @@ public List get(JWKSelector jwkSelector, SecurityContext context) { return jwkSource.get(jwkSelector, context); } }); - NimbusJwsEncoder jwsEncoder = new NimbusJwsEncoder(jwkSourceDelegate); + NimbusJwtEncoder jwsEncoder = new NimbusJwtEncoder(jwkSourceDelegate); JwkListResultCaptor jwkListResultCaptor = new JwkListResultCaptor(); willAnswer(jwkListResultCaptor).given(jwkSourceDelegate).get(any(), any()); From f977ae61592a209a04aa46845acba1b71ed41f55 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 16 Sep 2021 06:38:54 -0400 Subject: [PATCH 04/20] Validate critical headers --- .../security/oauth2/jwt/JoseHeader.java | 12 ++++++++++++ .../security/oauth2/jwt/JoseHeaderTests.java | 10 ++++++++++ 2 files changed, 22 insertions(+) diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java index a386d0feb54..58c0881e0b9 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java @@ -359,9 +359,21 @@ public Builder headers(Consumer> headersConsumer) { */ public JoseHeader build() { Assert.notEmpty(this.headers, "headers cannot be empty"); + validateCriticalHeaders(); return new JoseHeader(this.headers); } + @SuppressWarnings("unchecked") + private void validateCriticalHeaders() { + Set criticalHeaderNames = (Set) this.headers.get(JoseHeaderNames.CRIT); + if (criticalHeaderNames == null) { + return; + } + criticalHeaderNames + .forEach((criticalHeaderName) -> Assert.state(this.headers.containsKey(criticalHeaderName), + "Missing critical (crit) header '" + criticalHeaderName + "'.")); + } + private static URL convertAsURL(String header, String value) { URL convertedValue = ClaimConversionService.getSharedInstance().convert(value, URL.class); Assert.isTrue(convertedValue != null, diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java index 4aaeb773b77..96332d450a4 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java @@ -16,6 +16,8 @@ package org.springframework.security.oauth2.jwt; +import java.util.Collections; + import org.junit.jupiter.api.Test; import org.springframework.security.oauth2.jose.JwaAlgorithm; @@ -70,6 +72,14 @@ public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { assertThat(joseHeader.getHeaders()).isEqualTo(expectedJoseHeader.getHeaders()); } + @Test + public void buildWhenMissingCriticalHeaderThenThrowIllegalStateException() { + // @formatter:off + assertThatExceptionOfType(IllegalStateException.class).isThrownBy(() -> + TestJoseHeaders.joseHeader().critical(Collections.singleton("critical-header-name")).build()) + .withMessage("Missing critical (crit) header 'critical-header-name'."); + } + @Test public void fromWhenNullThenThrowIllegalArgumentException() { assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JoseHeader.from(null)) From c88b3de512ee9015f1d475e5f8ac03103b80347c Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 16 Sep 2021 06:59:26 -0400 Subject: [PATCH 05/20] Polish tests --- .../security/oauth2/jwt/JoseHeaderTests.java | 6 +++--- .../security/oauth2/jwt/JwtClaimsSetTests.java | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java index 96332d450a4..f777ff3d1cd 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java @@ -36,7 +36,7 @@ public class JoseHeaderTests { @Test public void withAlgorithmWhenNullThenThrowIllegalArgumentException() { assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JoseHeader.withAlgorithm(null)) - .isInstanceOf(IllegalArgumentException.class).withMessage("jwaAlgorithm cannot be null"); + .withMessage("jwaAlgorithm cannot be null"); } @Test @@ -83,7 +83,7 @@ public void buildWhenMissingCriticalHeaderThenThrowIllegalStateException() { @Test public void fromWhenNullThenThrowIllegalArgumentException() { assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JoseHeader.from(null)) - .isInstanceOf(IllegalArgumentException.class).withMessage("headers cannot be null"); + .withMessage("headers cannot be null"); } @Test @@ -112,7 +112,7 @@ public void getHeaderWhenNullThenThrowIllegalArgumentException() { JoseHeader joseHeader = TestJoseHeaders.joseHeader().build(); assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> joseHeader.getHeader(null)) - .isInstanceOf(IllegalArgumentException.class).withMessage("name cannot be empty"); + .withMessage("name cannot be empty"); } } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtClaimsSetTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtClaimsSetTests.java index 00e2784d9eb..3e0c8028424 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtClaimsSetTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtClaimsSetTests.java @@ -31,7 +31,7 @@ public class JwtClaimsSetTests { @Test public void buildWhenClaimsEmptyThenThrowIllegalArgumentException() { assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JwtClaimsSet.builder().build()) - .isInstanceOf(IllegalArgumentException.class).withMessage("claims cannot be empty"); + .withMessage("claims cannot be empty"); } @Test @@ -65,7 +65,7 @@ public void buildWhenAllClaimsProvidedThenAllClaimsAreSet() { @Test public void fromWhenNullThenThrowIllegalArgumentException() { assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JwtClaimsSet.from(null)) - .isInstanceOf(IllegalArgumentException.class).withMessage("claims cannot be null"); + .withMessage("claims cannot be null"); } @Test From 48b63cde5ac3c87a40bebb260e5a85cf85743dac Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 17 Sep 2021 04:07:43 -0400 Subject: [PATCH 06/20] Fix typo --- .../org/springframework/security/oauth2/jwt/JoseHeaderTests.java | 1 - 1 file changed, 1 deletion(-) diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java index f777ff3d1cd..78fcc6e6255 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java @@ -74,7 +74,6 @@ public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { @Test public void buildWhenMissingCriticalHeaderThenThrowIllegalStateException() { - // @formatter:off assertThatExceptionOfType(IllegalStateException.class).isThrownBy(() -> TestJoseHeaders.joseHeader().critical(Collections.singleton("critical-header-name")).build()) .withMessage("Missing critical (crit) header 'critical-header-name'."); From f5b3c51c0855d1a904363fefcff94ee892a53ebd Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 17 Sep 2021 06:16:26 -0400 Subject: [PATCH 07/20] Allow extension of JoseHeader --- ...ientAuthenticationParametersConverter.java | 2 +- .../security/oauth2/jwt/JoseHeader.java | 112 +++++++++++------- .../security/oauth2/jwt/JoseHeaderTests.java | 14 +-- .../oauth2/jwt/NimbusJweEncoderTests.java | 6 +- .../oauth2/jwt/NimbusJwtEncoderTests.java | 18 +-- .../security/oauth2/jwt/TestJoseHeaders.java | 2 +- 6 files changed, 91 insertions(+), 63 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java index 330be84561e..204dc07e161 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java @@ -127,7 +127,7 @@ public MultiValueMap convert(T authorizationGrantRequest) { throw new OAuth2AuthorizationException(oauth2Error); } - JoseHeader.Builder headersBuilder = JoseHeader.withAlgorithm(jwsAlgorithm); + JoseHeader.Builder headersBuilder = JoseHeader.with(jwsAlgorithm); Instant issuedAt = Instant.now(); Instant expiresAt = issuedAt.plus(Duration.ofSeconds(60)); diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java index 58c0881e0b9..1f20fbdb496 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java @@ -44,11 +44,12 @@ * @see JWE JOSE * Header */ -public final class JoseHeader { +public class JoseHeader { private final Map headers; - private JoseHeader(Map headers) { + protected JoseHeader(Map headers) { + Assert.notEmpty(headers, "headers cannot be empty"); this.headers = Collections.unmodifiableMap(new HashMap<>(headers)); } @@ -183,7 +184,7 @@ public T getHeader(String name) { * @param jwaAlgorithm the {@link JwaAlgorithm} * @return the {@link Builder} */ - public static Builder withAlgorithm(JwaAlgorithm jwaAlgorithm) { + public static Builder with(JwaAlgorithm jwaAlgorithm) { return new Builder(jwaAlgorithm); } @@ -199,27 +200,58 @@ public static Builder from(JoseHeader headers) { /** * A builder for {@link JoseHeader}. */ - public static final class Builder { - - private final Map headers = new HashMap<>(); + public static final class Builder extends AbstractBuilder { private Builder(JwaAlgorithm jwaAlgorithm) { + Assert.notNull(jwaAlgorithm, "jwaAlgorithm cannot be null"); algorithm(jwaAlgorithm); } private Builder(JoseHeader headers) { Assert.notNull(headers, "headers cannot be null"); - this.headers.putAll(headers.getHeaders()); + Consumer> headersConsumer = (h) -> h.putAll(headers.getHeaders()); + headers(headersConsumer); + } + + /** + * Builds a new {@link JoseHeader}. + * @return a {@link JoseHeader} + */ + @Override + public JoseHeader build() { + validate(); + return new JoseHeader(getHeaders()); + } + + } + + /** + * A builder for {@link JoseHeader} and subclasses. + */ + protected abstract static class AbstractBuilder> { + + private final Map headers = new HashMap<>(); + + protected AbstractBuilder() { + } + + protected Map getHeaders() { + return this.headers; + } + + @SuppressWarnings("unchecked") + protected final B getThis() { + return (B) this; // avoid unchecked casts in subclasses by using "getThis()" + // instead of "(B) this" } /** * Sets the {@link JwaAlgorithm JWA algorithm} used to digitally sign the JWS or * encrypt the JWE. * @param jwaAlgorithm the {@link JwaAlgorithm} - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder algorithm(JwaAlgorithm jwaAlgorithm) { - Assert.notNull(jwaAlgorithm, "jwaAlgorithm cannot be null"); + public B algorithm(JwaAlgorithm jwaAlgorithm) { return header(JoseHeaderNames.ALG, jwaAlgorithm); } @@ -228,9 +260,9 @@ public Builder algorithm(JwaAlgorithm jwaAlgorithm) { * public keys, one of which corresponds to the key used to digitally sign the JWS * or encrypt the JWE. * @param jwkSetUrl the JWK Set URL - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder jwkSetUrl(String jwkSetUrl) { + public B jwkSetUrl(String jwkSetUrl) { return header(JoseHeaderNames.JKU, convertAsURL(JoseHeaderNames.JKU, jwkSetUrl)); } @@ -238,9 +270,9 @@ public Builder jwkSetUrl(String jwkSetUrl) { * Sets the JSON Web Key which is the public key that corresponds to the key used * to digitally sign the JWS or encrypt the JWE. * @param jwk the JSON Web Key - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder jwk(Map jwk) { + public B jwk(Map jwk) { return header(JoseHeaderNames.JWK, jwk); } @@ -248,9 +280,9 @@ public Builder jwk(Map jwk) { * Sets the key ID that is a hint indicating which key was used to secure the JWS * or JWE. * @param keyId the key ID - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder keyId(String keyId) { + public B keyId(String keyId) { return header(JoseHeaderNames.KID, keyId); } @@ -259,9 +291,9 @@ public Builder keyId(String keyId) { * certificate or certificate chain corresponding to the key used to digitally * sign the JWS or encrypt the JWE. * @param x509Url the X.509 URL - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder x509Url(String x509Url) { + public B x509Url(String x509Url) { return header(JoseHeaderNames.X5U, convertAsURL(JoseHeaderNames.X5U, x509Url)); } @@ -272,9 +304,9 @@ public Builder x509Url(String x509Url) { * {@code List} of certificate value {@code String}s. Each {@code String} in the * {@code List} is a Base64-encoded DER PKIX certificate value. * @param x509CertificateChain the X.509 certificate chain - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder x509CertificateChain(List x509CertificateChain) { + public B x509CertificateChain(List x509CertificateChain) { return header(JoseHeaderNames.X5C, x509CertificateChain); } @@ -283,9 +315,9 @@ public Builder x509CertificateChain(List x509CertificateChain) { * thumbprint (a.k.a. digest) of the DER encoding of the X.509 certificate * corresponding to the key used to digitally sign the JWS or encrypt the JWE. * @param x509SHA1Thumbprint the X.509 certificate SHA-1 thumbprint - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder x509SHA1Thumbprint(String x509SHA1Thumbprint) { + public B x509SHA1Thumbprint(String x509SHA1Thumbprint) { return header(JoseHeaderNames.X5T, x509SHA1Thumbprint); } @@ -294,18 +326,18 @@ public Builder x509SHA1Thumbprint(String x509SHA1Thumbprint) { * SHA-256 thumbprint (a.k.a. digest) of the DER encoding of the X.509 certificate * corresponding to the key used to digitally sign the JWS or encrypt the JWE. * @param x509SHA256Thumbprint the X.509 certificate SHA-256 thumbprint - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder x509SHA256Thumbprint(String x509SHA256Thumbprint) { + public B x509SHA256Thumbprint(String x509SHA256Thumbprint) { return header(JoseHeaderNames.X5T_S256, x509SHA256Thumbprint); } /** * Sets the type header that declares the media type of the JWS/JWE. * @param type the type header - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder type(String type) { + public B type(String type) { return header(JoseHeaderNames.TYP, type); } @@ -313,9 +345,9 @@ public Builder type(String type) { * Sets the content type header that declares the media type of the secured * content (the payload). * @param contentType the content type header - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder contentType(String contentType) { + public B contentType(String contentType) { return header(JoseHeaderNames.CTY, contentType); } @@ -323,9 +355,9 @@ public Builder contentType(String contentType) { * Sets the critical headers that indicates which extensions to the JWS/JWE/JWA * specifications are being used that MUST be understood and processed. * @param headerNames the critical header names - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder critical(Set headerNames) { + public B critical(Set headerNames) { return header(JoseHeaderNames.CRIT, headerNames); } @@ -333,38 +365,34 @@ public Builder critical(Set headerNames) { * Sets the header. * @param name the header name * @param value the header value - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder header(String name, Object value) { + public B header(String name, Object value) { Assert.hasText(name, "name cannot be empty"); Assert.notNull(value, "value cannot be null"); this.headers.put(name, value); - return this; + return getThis(); } /** * A {@code Consumer} to be provided access to the headers allowing the ability to * add, replace, or remove. * @param headersConsumer a {@code Consumer} of the headers - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder headers(Consumer> headersConsumer) { + public B headers(Consumer> headersConsumer) { headersConsumer.accept(this.headers); - return this; + return getThis(); } /** * Builds a new {@link JoseHeader}. * @return a {@link JoseHeader} */ - public JoseHeader build() { - Assert.notEmpty(this.headers, "headers cannot be empty"); - validateCriticalHeaders(); - return new JoseHeader(this.headers); - } + public abstract T build(); @SuppressWarnings("unchecked") - private void validateCriticalHeaders() { + protected void validate() { Set criticalHeaderNames = (Set) this.headers.get(JoseHeaderNames.CRIT); if (criticalHeaderNames == null) { return; diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java index 78fcc6e6255..89701ad508a 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java @@ -34,8 +34,8 @@ public class JoseHeaderTests { @Test - public void withAlgorithmWhenNullThenThrowIllegalArgumentException() { - assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JoseHeader.withAlgorithm(null)) + public void withWhenNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JoseHeader.with(null)) .withMessage("jwaAlgorithm cannot be null"); } @@ -44,7 +44,7 @@ public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { JoseHeader expectedJoseHeader = TestJoseHeaders.joseHeader().build(); // @formatter:off - JoseHeader joseHeader = JoseHeader.withAlgorithm(expectedJoseHeader.getAlgorithm()) + JoseHeader joseHeader = JoseHeader.with(expectedJoseHeader.getAlgorithm()) .jwkSetUrl(expectedJoseHeader.getJwkSetUrl().toExternalForm()) .jwk(expectedJoseHeader.getJwk()) .keyId(expectedJoseHeader.getKeyId()) @@ -74,8 +74,8 @@ public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { @Test public void buildWhenMissingCriticalHeaderThenThrowIllegalStateException() { - assertThatExceptionOfType(IllegalStateException.class).isThrownBy(() -> - TestJoseHeaders.joseHeader().critical(Collections.singleton("critical-header-name")).build()) + assertThatExceptionOfType(IllegalStateException.class).isThrownBy( + () -> TestJoseHeaders.joseHeader().critical(Collections.singleton("critical-header-name")).build()) .withMessage("Missing critical (crit) header 'critical-header-name'."); } @@ -95,14 +95,14 @@ public void fromWhenHeadersProvidedThenCopied() { @Test public void headerWhenNameNullThenThrowIllegalArgumentException() { assertThatExceptionOfType(IllegalArgumentException.class) - .isThrownBy(() -> JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).header(null, "value")) + .isThrownBy(() -> JoseHeader.with(SignatureAlgorithm.RS256).header(null, "value")) .withMessage("name cannot be empty"); } @Test public void headerWhenValueNullThenThrowIllegalArgumentException() { assertThatExceptionOfType(IllegalArgumentException.class) - .isThrownBy(() -> JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).header("name", null)) + .isThrownBy(() -> JoseHeader.with(SignatureAlgorithm.RS256).header("name", null)) .withMessage("value cannot be null"); } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java index 0c7643140cf..188808acc90 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java @@ -89,7 +89,7 @@ public void encodeWhenJwtClaimsSetThenEncodes() { this.jwkList.add(rsaJwk); // @formatter:off - JoseHeader jweHeader = JoseHeader.withAlgorithm(JweAlgorithm.RSA_OAEP_256) + JoseHeader jweHeader = JoseHeader.with(JweAlgorithm.RSA_OAEP_256) .header("enc", EncryptionMethod.A256GCM.getName()) .build(); // @formatter:on @@ -130,13 +130,13 @@ public void encodeWhenNestedJwsThenEncodes() { RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; this.jwkList.add(rsaJwk); - JoseHeader jwsHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + JoseHeader jwsHeader = JoseHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); Jwt encodedJws = this.jwsEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); // @formatter:off - JoseHeader jweHeader = JoseHeader.withAlgorithm(JweAlgorithm.RSA_OAEP_256) + JoseHeader jweHeader = JoseHeader.with(JweAlgorithm.RSA_OAEP_256) .header("enc", EncryptionMethod.A256GCM.getName()) .contentType("JWT") // Indicates Nested JWT (REQUIRED) .build(); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java index 5ee04b3e3c1..0dbc4925ad3 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java @@ -95,7 +95,7 @@ public void encodeWhenHeadersNullThenThrowIllegalArgumentException() { @Test public void encodeWhenClaimsNullThenThrowIllegalArgumentException() { - JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256).build(); assertThatIllegalArgumentException() .isThrownBy(() -> this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, null))) @@ -108,7 +108,7 @@ public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exce this.jwsEncoder = new NimbusJwtEncoder(this.jwkSource); given(this.jwkSource.get(any(), any())).willThrow(new KeySourceException("key source error")); - JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) @@ -122,7 +122,7 @@ public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws this.jwkList.add(rsaJwk); this.jwkList.add(rsaJwk); - JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) @@ -132,7 +132,7 @@ public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws @Test public void encodeWhenJwkSelectEmptyThenThrowJwtEncodingException() { - JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) @@ -153,7 +153,7 @@ public void encodeWhenJwkSelectWithProvidedKidThenSelected() { this.jwkList.add(rsaJwk2); // @formatter:on - JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).keyId(rsaJwk2.getKeyID()).build(); + JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256).keyId(rsaJwk2.getKeyID()).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); Jwt encodedJws = this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet)); @@ -176,7 +176,7 @@ public void encodeWhenJwkSelectWithProvidedX5TS256ThenSelected() { this.jwkList.add(rsaJwk2); // @formatter:on - JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256) + JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256) .x509SHA256Thumbprint(rsaJwk1.getX509CertSHA256Thumbprint().toString()).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); @@ -199,7 +199,7 @@ public void encodeWhenJwkUseEncryptionThenThrowJwtEncodingException() throws Exc this.jwsEncoder = new NimbusJwtEncoder(this.jwkSource); given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(rsaJwk)); - JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) @@ -218,7 +218,7 @@ public void encodeWhenSuccessThenDecodes() throws Exception { this.jwkList.add(rsaJwk); // @formatter:on - JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); Jwt encodedJws = this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet)); @@ -263,7 +263,7 @@ public List get(JWKSelector jwkSelector, SecurityContext context) { JwkListResultCaptor jwkListResultCaptor = new JwkListResultCaptor(); willAnswer(jwkListResultCaptor).given(jwkSourceDelegate).get(any(), any()); - JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build(); + JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); Jwt encodedJws = jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet)); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJoseHeaders.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJoseHeaders.java index 5d30f58b778..12dce7f53c1 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJoseHeaders.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJoseHeaders.java @@ -36,7 +36,7 @@ public static JoseHeader.Builder joseHeader() { public static JoseHeader.Builder joseHeader(SignatureAlgorithm signatureAlgorithm) { // @formatter:off - return JoseHeader.withAlgorithm(signatureAlgorithm) + return JoseHeader.with(signatureAlgorithm) .jwkSetUrl("https://provider.com/oauth2/jwks") .jwk(rsaJwk()) .keyId("keyId") From 19a79d6b28e34883e9001fa0236525d7fc3befc9 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 17 Sep 2021 09:28:59 -0400 Subject: [PATCH 08/20] Introduce JwsHeader --- ...ientAuthenticationParametersConverter.java | 8 +- .../security/oauth2/jwt/JwsHeader.java | 94 +++++++++++++++++++ .../security/oauth2/jwt/JwtEncoder.java | 2 - .../oauth2/jwt/JwtEncoderParameters.java | 28 +++--- .../security/oauth2/jwt/NimbusJwtEncoder.java | 14 +-- .../security/oauth2/jwt/JoseHeaderTests.java | 8 +- .../oauth2/jwt/NimbusJwtEncoderTests.java | 56 +++++------ .../security/oauth2/jwt/TestJoseHeaders.java | 8 +- 8 files changed, 156 insertions(+), 62 deletions(-) create mode 100644 oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java index 204dc07e161..ef333201c18 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java @@ -40,7 +40,7 @@ import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; -import org.springframework.security.oauth2.jwt.JoseHeader; +import org.springframework.security.oauth2.jwt.JwsHeader; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; @@ -127,7 +127,7 @@ public MultiValueMap convert(T authorizationGrantRequest) { throw new OAuth2AuthorizationException(oauth2Error); } - JoseHeader.Builder headersBuilder = JoseHeader.with(jwsAlgorithm); + JwsHeader.Builder headersBuilder = JwsHeader.with(jwsAlgorithm); Instant issuedAt = Instant.now(); Instant expiresAt = issuedAt.plus(Duration.ofSeconds(60)); @@ -142,7 +142,7 @@ public MultiValueMap convert(T authorizationGrantRequest) { .expiresAt(expiresAt); // @formatter:on - JoseHeader joseHeader = headersBuilder.build(); + JwsHeader jwsHeader = headersBuilder.build(); JwtClaimsSet jwtClaimsSet = claimsBuilder.build(); JwsEncoderHolder jwsEncoderHolder = this.jwsEncoders.compute(clientRegistration.getRegistrationId(), @@ -155,7 +155,7 @@ public MultiValueMap convert(T authorizationGrantRequest) { }); JwtEncoder jwsEncoder = jwsEncoderHolder.getJwsEncoder(); - Jwt jws = jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet)); + Jwt jws = jwsEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.set(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, CLIENT_ASSERTION_TYPE_VALUE); diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java new file mode 100644 index 00000000000..0354253233e --- /dev/null +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.jwt; + +import java.util.Map; +import java.util.function.Consumer; + +import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; +import org.springframework.util.Assert; + +/** + * The JSON Web Signature (JWS) header is a JSON object representing the header parameters + * of a JSON Web Token, that describe the cryptographic operations used to digitally sign + * or create a MAC of the contents of the JWS Protected Header and JWS Payload. + * + * @author Joe Grandja + * @since 5.6 + * @see JoseHeader + * @see JWS JOSE + * Header + */ +public final class JwsHeader extends JoseHeader { + + private JwsHeader(Map headers) { + super(headers); + } + + @SuppressWarnings("unchecked") + @Override + public JwsAlgorithm getAlgorithm() { + return super.getAlgorithm(); + } + + /** + * Returns a new {@link Builder}, initialized with the provided {@link JwsAlgorithm}. + * @param jwsAlgorithm the {@link JwsAlgorithm} + * @return the {@link Builder} + */ + public static Builder with(JwsAlgorithm jwsAlgorithm) { + return new Builder(jwsAlgorithm); + } + + /** + * Returns a new {@link Builder}, initialized with the provided {@code headers}. + * @param headers the headers + * @return the {@link Builder} + */ + public static Builder from(JwsHeader headers) { + return new Builder(headers); + } + + /** + * A builder for {@link JwsHeader}. + */ + public static final class Builder extends AbstractBuilder { + + private Builder(JwsAlgorithm jwsAlgorithm) { + Assert.notNull(jwsAlgorithm, "jwsAlgorithm cannot be null"); + algorithm(jwsAlgorithm); + } + + private Builder(JwsHeader headers) { + Assert.notNull(headers, "headers cannot be null"); + Consumer> headersConsumer = (h) -> h.putAll(headers.getHeaders()); + headers(headersConsumer); + } + + /** + * Builds a new {@link JwsHeader}. + * @return a {@link JwsHeader} + */ + @Override + public JwsHeader build() { + validate(); + return new JwsHeader(getHeaders()); + } + + } + +} diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java index 14e426ddd59..4799fbe50c0 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java @@ -31,8 +31,6 @@ * @since 5.6 * @see Jwt * @see JwtEncoderParameters - * @see JoseHeader - * @see JwtClaimsSet * @see JwtDecoder * @see JSON Web Token * (JWT) diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java index 0ecdc7200fd..2c14e6012b6 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java @@ -19,42 +19,44 @@ import org.springframework.util.Assert; /** - * A holder of parameters containing the JOSE header and JWT Claims Set. + * A holder of parameters containing the JWS headers and JWT Claims Set. * * @author Joe Grandja * @since 5.6 + * @see JwsHeader + * @see JwtClaimsSet * @see JwtEncoder */ public final class JwtEncoderParameters { - private final JoseHeader headers; + private final JwsHeader jwsHeader; private final JwtClaimsSet claims; - private JwtEncoderParameters(JoseHeader headers, JwtClaimsSet claims) { - Assert.notNull(headers, "headers cannot be null"); + private JwtEncoderParameters(JwsHeader jwsHeader, JwtClaimsSet claims) { + Assert.notNull(jwsHeader, "jwsHeader cannot be null"); Assert.notNull(claims, "claims cannot be null"); - this.headers = headers; + this.jwsHeader = jwsHeader; this.claims = claims; } /** * Returns a new {@link JwtEncoderParameters}, initialized with the provided - * {@link JoseHeader} and {@link JwtClaimsSet}. - * @param headers the {@link JoseHeader} + * {@link JwsHeader} and {@link JwtClaimsSet}. + * @param jwsHeader the {@link JwsHeader} * @param claims the {@link JwtClaimsSet} * @return the {@link JwtEncoderParameters} */ - public static JwtEncoderParameters with(JoseHeader headers, JwtClaimsSet claims) { - return new JwtEncoderParameters(headers, claims); + public static JwtEncoderParameters with(JwsHeader jwsHeader, JwtClaimsSet claims) { + return new JwtEncoderParameters(jwsHeader, claims); } /** - * Returns the {@link JoseHeader headers}. - * @return the {@link JoseHeader} + * Returns the {@link JwsHeader JWS headers}. + * @return the {@link JwsHeader} */ - public JoseHeader getHeaders() { - return this.headers; + public JwsHeader getJwsHeader() { + return this.jwsHeader; } /** diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java index ed944f719d4..5f2ebbca56d 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java @@ -96,7 +96,7 @@ public NimbusJwtEncoder(JWKSource jwkSource) { public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { Assert.notNull(parameters, "parameters cannot be null"); - JoseHeader headers = parameters.getHeaders(); + JwsHeader headers = parameters.getJwsHeader(); JwtClaimsSet claims = parameters.getClaims(); JWK jwk = selectJwk(headers); @@ -107,7 +107,7 @@ public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { return new Jwt(jws, claims.getIssuedAt(), claims.getExpiresAt(), headers.getHeaders(), claims.getClaims()); } - private JWK selectJwk(JoseHeader headers) { + private JWK selectJwk(JwsHeader headers) { List jwks; try { JWKSelector jwkSelector = new JWKSelector(createJwkMatcher(headers)); @@ -131,7 +131,7 @@ private JWK selectJwk(JoseHeader headers) { return jwks.get(0); } - private String serialize(JoseHeader headers, JwtClaimsSet claims, JWK jwk) { + private String serialize(JwsHeader headers, JwtClaimsSet claims, JWK jwk) { JWSHeader jwsHeader = convert(headers); JWTClaimsSet jwtClaimsSet = convert(claims); @@ -148,7 +148,7 @@ private String serialize(JoseHeader headers, JwtClaimsSet claims, JWK jwk) { return signedJwt.serialize(); } - private static JWKMatcher createJwkMatcher(JoseHeader headers) { + private static JWKMatcher createJwkMatcher(JwsHeader headers) { JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(headers.getAlgorithm().getName()); if (JWSAlgorithm.Family.RSA.contains(jwsAlgorithm) || JWSAlgorithm.Family.EC.contains(jwsAlgorithm)) { @@ -176,7 +176,7 @@ else if (JWSAlgorithm.Family.HMAC_SHA.contains(jwsAlgorithm)) { return null; } - private static JoseHeader addKeyIdentifierHeadersIfNecessary(JoseHeader headers, JWK jwk) { + private static JwsHeader addKeyIdentifierHeadersIfNecessary(JwsHeader headers, JWK jwk) { // Check if headers have already been added if (StringUtils.hasText(headers.getKeyId()) && StringUtils.hasText(headers.getX509SHA256Thumbprint())) { return headers; @@ -186,7 +186,7 @@ private static JoseHeader addKeyIdentifierHeadersIfNecessary(JoseHeader headers, return headers; } - JoseHeader.Builder headersBuilder = JoseHeader.from(headers); + JwsHeader.Builder headersBuilder = JwsHeader.from(headers); if (!StringUtils.hasText(headers.getKeyId()) && StringUtils.hasText(jwk.getKeyID())) { headersBuilder.keyId(jwk.getKeyID()); } @@ -207,7 +207,7 @@ private static JWSSigner createSigner(JWK jwk) { } } - private static JWSHeader convert(JoseHeader headers) { + private static JWSHeader convert(JwsHeader headers) { JWSHeader.Builder builder = new JWSHeader.Builder(JWSAlgorithm.parse(headers.getAlgorithm().getName())); if (headers.getJwkSetUrl() != null) { diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java index 89701ad508a..570e513b9e9 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java @@ -41,7 +41,7 @@ public void withWhenNullThenThrowIllegalArgumentException() { @Test public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { - JoseHeader expectedJoseHeader = TestJoseHeaders.joseHeader().build(); + JoseHeader expectedJoseHeader = TestJoseHeaders.jwsHeader().build(); // @formatter:off JoseHeader joseHeader = JoseHeader.with(expectedJoseHeader.getAlgorithm()) @@ -75,7 +75,7 @@ public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { @Test public void buildWhenMissingCriticalHeaderThenThrowIllegalStateException() { assertThatExceptionOfType(IllegalStateException.class).isThrownBy( - () -> TestJoseHeaders.joseHeader().critical(Collections.singleton("critical-header-name")).build()) + () -> TestJoseHeaders.jwsHeader().critical(Collections.singleton("critical-header-name")).build()) .withMessage("Missing critical (crit) header 'critical-header-name'."); } @@ -87,7 +87,7 @@ public void fromWhenNullThenThrowIllegalArgumentException() { @Test public void fromWhenHeadersProvidedThenCopied() { - JoseHeader expectedJoseHeader = TestJoseHeaders.joseHeader().build(); + JoseHeader expectedJoseHeader = TestJoseHeaders.jwsHeader().build(); JoseHeader joseHeader = JoseHeader.from(expectedJoseHeader).build(); assertThat(joseHeader.getHeaders()).isEqualTo(expectedJoseHeader.getHeaders()); } @@ -108,7 +108,7 @@ public void headerWhenValueNullThenThrowIllegalArgumentException() { @Test public void getHeaderWhenNullThenThrowIllegalArgumentException() { - JoseHeader joseHeader = TestJoseHeaders.joseHeader().build(); + JoseHeader joseHeader = TestJoseHeaders.jwsHeader().build(); assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> joseHeader.getHeader(null)) .withMessage("name cannot be empty"); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java index 0dbc4925ad3..13e8f001736 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java @@ -63,13 +63,13 @@ public class NimbusJwtEncoderTests { private JWKSource jwkSource; - private NimbusJwtEncoder jwsEncoder; + private NimbusJwtEncoder jwtEncoder; @BeforeEach public void setUp() { this.jwkList = new ArrayList<>(); this.jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(new JWKSet(this.jwkList)); - this.jwsEncoder = new NimbusJwtEncoder(this.jwkSource); + this.jwtEncoder = new NimbusJwtEncoder(this.jwkSource); } @Test @@ -80,7 +80,7 @@ public void constructorWhenJwkSourceNullThenThrowIllegalArgumentException() { @Test public void encodeWhenParametersNullThenThrowIllegalArgumentException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.jwsEncoder.encode(null)) + assertThatIllegalArgumentException().isThrownBy(() -> this.jwtEncoder.encode(null)) .withMessage("parameters cannot be null"); } @@ -89,30 +89,30 @@ public void encodeWhenHeadersNullThenThrowIllegalArgumentException() { JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatIllegalArgumentException() - .isThrownBy(() -> this.jwsEncoder.encode(JwtEncoderParameters.with(null, jwtClaimsSet))) - .withMessage("headers cannot be null"); + .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.with(null, jwtClaimsSet))) + .withMessage("jwsHeader cannot be null"); } @Test public void encodeWhenClaimsNullThenThrowIllegalArgumentException() { - JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256).build(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build(); assertThatIllegalArgumentException() - .isThrownBy(() -> this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, null))) + .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, null))) .withMessage("claims cannot be null"); } @Test public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exception { this.jwkSource = mock(JWKSource.class); - this.jwsEncoder = new NimbusJwtEncoder(this.jwkSource); + this.jwtEncoder = new NimbusJwtEncoder(this.jwkSource); given(this.jwkSource.get(any(), any())).willThrow(new KeySourceException("key source error")); - JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256).build(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) - .isThrownBy(() -> this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet))) + .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet))) .withMessageContaining("Failed to select a JWK signing key -> key source error"); } @@ -122,21 +122,21 @@ public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws this.jwkList.add(rsaJwk); this.jwkList.add(rsaJwk); - JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256).build(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) - .isThrownBy(() -> this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet))) + .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet))) .withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'"); } @Test public void encodeWhenJwkSelectEmptyThenThrowJwtEncodingException() { - JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256).build(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) - .isThrownBy(() -> this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet))) + .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet))) .withMessageContaining("Failed to select a JWK signing key"); } @@ -153,10 +153,10 @@ public void encodeWhenJwkSelectWithProvidedKidThenSelected() { this.jwkList.add(rsaJwk2); // @formatter:on - JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256).keyId(rsaJwk2.getKeyID()).build(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).keyId(rsaJwk2.getKeyID()).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJws = this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet)); + Jwt encodedJws = this.jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); assertThat(encodedJws.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(rsaJwk2.getKeyID()); } @@ -176,11 +176,11 @@ public void encodeWhenJwkSelectWithProvidedX5TS256ThenSelected() { this.jwkList.add(rsaJwk2); // @formatter:on - JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256) + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256) .x509SHA256Thumbprint(rsaJwk1.getX509CertSHA256Thumbprint().toString()).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJws = this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet)); + Jwt encodedJws = this.jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5T_S256)) .isEqualTo(rsaJwk1.getX509CertSHA256Thumbprint().toString()); @@ -196,14 +196,14 @@ public void encodeWhenJwkUseEncryptionThenThrowJwtEncodingException() throws Exc // @formatter:on this.jwkSource = mock(JWKSource.class); - this.jwsEncoder = new NimbusJwtEncoder(this.jwkSource); + this.jwtEncoder = new NimbusJwtEncoder(this.jwkSource); given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(rsaJwk)); - JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256).build(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) - .isThrownBy(() -> this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet))) + .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet))) .withMessageContaining( "Failed to create a JWS Signer -> The JWK use must be sig (signature) or unspecified"); } @@ -218,12 +218,12 @@ public void encodeWhenSuccessThenDecodes() throws Exception { this.jwkList.add(rsaJwk); // @formatter:on - JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256).build(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJws = this.jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet)); + Jwt encodedJws = this.jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); - assertThat(encodedJws.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(joseHeader.getAlgorithm()); + assertThat(encodedJws.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(jwsHeader.getAlgorithm()); assertThat(encodedJws.getHeaders().get(JoseHeaderNames.JKU)).isNull(); assertThat(encodedJws.getHeaders().get(JoseHeaderNames.JWK)).isNull(); assertThat(encodedJws.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(rsaJwk.getKeyID()); @@ -258,15 +258,15 @@ public List get(JWKSelector jwkSelector, SecurityContext context) { return jwkSource.get(jwkSelector, context); } }); - NimbusJwtEncoder jwsEncoder = new NimbusJwtEncoder(jwkSourceDelegate); + NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSourceDelegate); JwkListResultCaptor jwkListResultCaptor = new JwkListResultCaptor(); willAnswer(jwkListResultCaptor).given(jwkSourceDelegate).get(any(), any()); - JoseHeader joseHeader = JoseHeader.with(SignatureAlgorithm.RS256).build(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJws = jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet)); + Jwt encodedJws = jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); JWK jwk1 = jwkListResultCaptor.getResult().get(0); NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(((RSAKey) jwk1).toRSAPublicKey()).build(); @@ -274,7 +274,7 @@ public List get(JWKSelector jwkSelector, SecurityContext context) { jwkSource.rotate(); // Simulate key rotation - encodedJws = jwsEncoder.encode(JwtEncoderParameters.with(joseHeader, jwtClaimsSet)); + encodedJws = jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); JWK jwk2 = jwkListResultCaptor.getResult().get(0); jwtDecoder = NimbusJwtDecoder.withPublicKey(((RSAKey) jwk2).toRSAPublicKey()).build(); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJoseHeaders.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJoseHeaders.java index 12dce7f53c1..7ca19be877d 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJoseHeaders.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJoseHeaders.java @@ -30,13 +30,13 @@ public final class TestJoseHeaders { private TestJoseHeaders() { } - public static JoseHeader.Builder joseHeader() { - return joseHeader(SignatureAlgorithm.RS256); + public static JwsHeader.Builder jwsHeader() { + return jwsHeader(SignatureAlgorithm.RS256); } - public static JoseHeader.Builder joseHeader(SignatureAlgorithm signatureAlgorithm) { + public static JwsHeader.Builder jwsHeader(SignatureAlgorithm signatureAlgorithm) { // @formatter:off - return JoseHeader.with(signatureAlgorithm) + return JwsHeader.with(signatureAlgorithm) .jwkSetUrl("https://provider.com/oauth2/jwks") .jwk(rsaJwk()) .keyId("keyId") From 57625d35becf4b6338e39849310324163e6cdf29 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 17 Sep 2021 12:10:24 -0400 Subject: [PATCH 09/20] Disable JWE tests --- .../security/oauth2/jwt/NimbusJweEncoderTests.java | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java index 188808acc90..7aa4425456b 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java @@ -83,7 +83,7 @@ public void setUp() { this.jwsEncoder = new NimbusJwtEncoder(this.jwkSource); } - @Test + // @Test public void encodeWhenJwtClaimsSetThenEncodes() { RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; this.jwkList.add(rsaJwk); @@ -95,7 +95,10 @@ public void encodeWhenJwtClaimsSetThenEncodes() { // @formatter:on JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJwe = this.jweEncoder.encode(JwtEncoderParameters.with(jweHeader, jwtClaimsSet)); + // FIXME + // Jwt encodedJwe = this.jweEncoder.encode(JwtEncoderParameters.with(jweHeader, + // jwtClaimsSet)); + Jwt encodedJwe = null; assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(jweHeader.getAlgorithm()); assertThat(encodedJwe.getHeaders().get("enc")).isEqualTo(jweHeader.getHeader("enc")); @@ -130,7 +133,7 @@ public void encodeWhenNestedJwsThenEncodes() { RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; this.jwkList.add(rsaJwk); - JoseHeader jwsHeader = JoseHeader.with(SignatureAlgorithm.RS256).build(); + JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); Jwt encodedJws = this.jwsEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); @@ -197,7 +200,7 @@ private NimbusJweEncoder(JWKSource jwkSource) { public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { Assert.notNull(parameters, "parameters cannot be null"); - JoseHeader headers = parameters.getHeaders(); + JwsHeader headers = parameters.getJwsHeader(); JwtClaimsSet claims = parameters.getClaims(); JWTClaimsSet jwtClaimsSet = JWT_CLAIMS_SET_CONVERTER.convert(claims); From 0f4e593e84ce18b7016085e0600446cf4532617a Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 17 Sep 2021 12:38:52 -0400 Subject: [PATCH 10/20] Allow headers to be optional --- .../oauth2/jwt/JwtEncoderParameters.java | 17 ++++++++++--- .../security/oauth2/jwt/NimbusJwtEncoder.java | 6 +++++ .../oauth2/jwt/NimbusJwtEncoderTests.java | 25 ++++++++++++------- 3 files changed, 36 insertions(+), 12 deletions(-) diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java index 2c14e6012b6..f6504c6718d 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.jwt; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -34,12 +35,21 @@ public final class JwtEncoderParameters { private final JwtClaimsSet claims; private JwtEncoderParameters(JwsHeader jwsHeader, JwtClaimsSet claims) { - Assert.notNull(jwsHeader, "jwsHeader cannot be null"); Assert.notNull(claims, "claims cannot be null"); this.jwsHeader = jwsHeader; this.claims = claims; } + /** + * Returns a new {@link JwtEncoderParameters}, initialized with the provided + * {@link JwtClaimsSet}. + * @param claims the {@link JwtClaimsSet} + * @return the {@link JwtEncoderParameters} + */ + public static JwtEncoderParameters with(JwtClaimsSet claims) { + return with(null, claims); + } + /** * Returns a new {@link JwtEncoderParameters}, initialized with the provided * {@link JwsHeader} and {@link JwtClaimsSet}. @@ -47,14 +57,15 @@ private JwtEncoderParameters(JwsHeader jwsHeader, JwtClaimsSet claims) { * @param claims the {@link JwtClaimsSet} * @return the {@link JwtEncoderParameters} */ - public static JwtEncoderParameters with(JwsHeader jwsHeader, JwtClaimsSet claims) { + public static JwtEncoderParameters with(@Nullable JwsHeader jwsHeader, JwtClaimsSet claims) { return new JwtEncoderParameters(jwsHeader, claims); } /** * Returns the {@link JwsHeader JWS headers}. - * @return the {@link JwsHeader} + * @return the {@link JwsHeader}, or {@code null} if not available */ + @Nullable public JwsHeader getJwsHeader() { return this.jwsHeader; } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java index 5f2ebbca56d..2de3e64a815 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java @@ -46,6 +46,7 @@ import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -77,6 +78,8 @@ public final class NimbusJwtEncoder implements JwtEncoder { private static final String ENCODING_ERROR_MESSAGE_TEMPLATE = "An error occurred while attempting to encode the Jwt: %s"; + private static final JwsHeader DEFAULT_JWS_HEADER = JwsHeader.with(SignatureAlgorithm.RS256).build(); + private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory(); private final Map jwsSigners = new ConcurrentHashMap<>(); @@ -97,6 +100,9 @@ public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { Assert.notNull(parameters, "parameters cannot be null"); JwsHeader headers = parameters.getJwsHeader(); + if (headers == null) { + headers = DEFAULT_JWS_HEADER; + } JwtClaimsSet claims = parameters.getClaims(); JWK jwk = selectJwk(headers); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java index 13e8f001736..079eac5b9c1 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java @@ -84,15 +84,6 @@ public void encodeWhenParametersNullThenThrowIllegalArgumentException() { .withMessage("parameters cannot be null"); } - @Test - public void encodeWhenHeadersNullThenThrowIllegalArgumentException() { - JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - - assertThatIllegalArgumentException() - .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.with(null, jwtClaimsSet))) - .withMessage("jwsHeader cannot be null"); - } - @Test public void encodeWhenClaimsNullThenThrowIllegalArgumentException() { JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build(); @@ -140,6 +131,22 @@ public void encodeWhenJwkSelectEmptyThenThrowJwtEncodingException() { .withMessageContaining("Failed to select a JWK signing key"); } + @Test + public void encodeWhenHeadersNotProvidedThenDefaulted() { + // @formatter:off + RSAKey rsaJwk = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY) + .keyID("rsa-jwk-1") + .build(); + this.jwkList.add(rsaJwk); + // @formatter:on + + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + + Jwt encodedJws = this.jwtEncoder.encode(JwtEncoderParameters.with(jwtClaimsSet)); + + assertThat(encodedJws.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(SignatureAlgorithm.RS256); + } + @Test public void encodeWhenJwkSelectWithProvidedKidThenSelected() { // @formatter:off From 0d68314116d77912730c851d0f44d732b283f5a7 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 17 Sep 2021 14:53:40 -0400 Subject: [PATCH 11/20] Update JWE tests --- .../oauth2/jwt/NimbusJweEncoderTests.java | 356 ++++++++---------- 1 file changed, 160 insertions(+), 196 deletions(-) diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java index 7aa4425456b..d2ed41363d3 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java @@ -23,8 +23,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Set; +import java.util.function.Consumer; import java.util.stream.Collectors; import com.nimbusds.jose.EncryptionMethod; @@ -33,13 +33,14 @@ import com.nimbusds.jose.JWEAlgorithm; import com.nimbusds.jose.JWEHeader; import com.nimbusds.jose.JWEObject; -import com.nimbusds.jose.KeySourceException; import com.nimbusds.jose.Payload; import com.nimbusds.jose.crypto.RSAEncrypter; import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWKMatcher; import com.nimbusds.jose.jwk.JWKSelector; import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.KeyType; +import com.nimbusds.jose.jwk.KeyUse; import com.nimbusds.jose.jwk.RSAKey; import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.SecurityContext; @@ -50,7 +51,6 @@ import org.junit.jupiter.api.Test; import org.springframework.core.convert.converter.Converter; -import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.security.oauth2.jose.JwaAlgorithm; import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; @@ -67,6 +67,11 @@ */ public class NimbusJweEncoderTests { + // @formatter:off + private static final JweHeader DEFAULT_JWE_HEADER = + JweHeader.with(JweAlgorithm.RSA_OAEP_256, EncryptionMethod.A256GCM.getName()).build(); + // @formatter:on + private List jwkList; private JWKSource jwkSource; @@ -79,29 +84,25 @@ public class NimbusJweEncoderTests { public void setUp() { this.jwkList = new ArrayList<>(); this.jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(new JWKSet(this.jwkList)); - this.jweEncoder = new NimbusJweEncoder(this.jwkSource); this.jwsEncoder = new NimbusJwtEncoder(this.jwkSource); + this.jweEncoder = new NimbusJweEncoder(this.jwkSource, this.jwsEncoder); } - // @Test + @Test public void encodeWhenJwtClaimsSetThenEncodes() { RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; this.jwkList.add(rsaJwk); - // @formatter:off - JoseHeader jweHeader = JoseHeader.with(JweAlgorithm.RSA_OAEP_256) - .header("enc", EncryptionMethod.A256GCM.getName()) - .build(); - // @formatter:on JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - // FIXME - // Jwt encodedJwe = this.jweEncoder.encode(JwtEncoderParameters.with(jweHeader, - // jwtClaimsSet)); - Jwt encodedJwe = null; + // ********************** + // Assume future API: + // JwtEncoderParameters.with(JweHeader jweHeader, JwtClaimsSet claims) + // ********************** + Jwt encodedJwe = this.jweEncoder.encode(JwtEncoderParameters.with(jwtClaimsSet)); - assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(jweHeader.getAlgorithm()); - assertThat(encodedJwe.getHeaders().get("enc")).isEqualTo(jweHeader.getHeader("enc")); + assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(DEFAULT_JWE_HEADER.getAlgorithm()); + assertThat(encodedJwe.getHeaders().get("enc")).isEqualTo(DEFAULT_JWE_HEADER.getHeader("enc")); assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.JKU)).isNull(); assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.JWK)).isNull(); assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(rsaJwk.getKeyID()); @@ -136,30 +137,35 @@ public void encodeWhenNestedJwsThenEncodes() { JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJws = this.jwsEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); - - // @formatter:off - JoseHeader jweHeader = JoseHeader.with(JweAlgorithm.RSA_OAEP_256) - .header("enc", EncryptionMethod.A256GCM.getName()) - .contentType("JWT") // Indicates Nested JWT (REQUIRED) - .build(); - // @formatter:on - - JoseToken encodedJweNestedJws = this.jweEncoder.encode(jweHeader, - new JosePayload<>(encodedJws.getTokenValue())); - - assertThat(encodedJweNestedJws.getHeaders().getAlgorithm()).isEqualTo(jweHeader.getAlgorithm()); - assertThat(encodedJweNestedJws.getHeaders().getHeader("enc")).isEqualTo(jweHeader.getHeader("enc")); - assertThat(encodedJweNestedJws.getHeaders().getJwkSetUrl()).isNull(); - assertThat(encodedJweNestedJws.getHeaders().getJwk()).isNull(); - assertThat(encodedJweNestedJws.getHeaders().getKeyId()).isEqualTo(rsaJwk.getKeyID()); - assertThat(encodedJweNestedJws.getHeaders().getX509Url()).isNull(); - assertThat(encodedJweNestedJws.getHeaders().getX509CertificateChain()).isNull(); - assertThat(encodedJweNestedJws.getHeaders().getX509SHA1Thumbprint()).isNull(); - assertThat(encodedJweNestedJws.getHeaders().getX509SHA256Thumbprint()).isNull(); - assertThat(encodedJweNestedJws.getHeaders().getType()).isNull(); - assertThat(encodedJweNestedJws.getHeaders().getContentType()).isEqualTo("JWT"); - assertThat(encodedJweNestedJws.getHeaders().getCritical()).isNull(); + // ********************** + // Assume future API: + // JwtEncoderParameters.with(JweHeader jweHeader, JwsHeader jwsHeader, + // JwtClaimsSet claims) + // ********************** + Jwt encodedJweNestedJws = this.jweEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); + + assertThat(encodedJweNestedJws.getHeaders().get(JoseHeaderNames.ALG)) + .isEqualTo(DEFAULT_JWE_HEADER.getAlgorithm()); + assertThat(encodedJweNestedJws.getHeaders().get("enc")).isEqualTo(DEFAULT_JWE_HEADER.getHeader("enc")); + assertThat(encodedJweNestedJws.getHeaders().get(JoseHeaderNames.JKU)).isNull(); + assertThat(encodedJweNestedJws.getHeaders().get(JoseHeaderNames.JWK)).isNull(); + assertThat(encodedJweNestedJws.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(rsaJwk.getKeyID()); + assertThat(encodedJweNestedJws.getHeaders().get(JoseHeaderNames.X5U)).isNull(); + assertThat(encodedJweNestedJws.getHeaders().get(JoseHeaderNames.X5C)).isNull(); + assertThat(encodedJweNestedJws.getHeaders().get(JoseHeaderNames.X5T)).isNull(); + assertThat(encodedJweNestedJws.getHeaders().get(JoseHeaderNames.X5T_S256)).isNull(); + assertThat(encodedJweNestedJws.getHeaders().get(JoseHeaderNames.TYP)).isNull(); + assertThat(encodedJweNestedJws.getHeaders().get(JoseHeaderNames.CTY)).isEqualTo("JWT"); + assertThat(encodedJweNestedJws.getHeaders().get(JoseHeaderNames.CRIT)).isNull(); + + assertThat(encodedJweNestedJws.getIssuer()).isEqualTo(jwtClaimsSet.getIssuer()); + assertThat(encodedJweNestedJws.getSubject()).isEqualTo(jwtClaimsSet.getSubject()); + assertThat(encodedJweNestedJws.getAudience()).isEqualTo(jwtClaimsSet.getAudience()); + assertThat(encodedJweNestedJws.getExpiresAt()).isEqualTo(jwtClaimsSet.getExpiresAt()); + assertThat(encodedJweNestedJws.getNotBefore()).isEqualTo(jwtClaimsSet.getNotBefore()); + assertThat(encodedJweNestedJws.getIssuedAt()).isEqualTo(jwtClaimsSet.getIssuedAt()); + assertThat(encodedJweNestedJws.getId()).isEqualTo(jwtClaimsSet.getId()); + assertThat(encodedJweNestedJws.getClaim("custom-claim-name")).isEqualTo("custom-claim-value"); assertThat(encodedJweNestedJws.getTokenValue()).isNotNull(); } @@ -181,64 +187,106 @@ public String getName() { } - private static final class NimbusJweEncoder implements JwtEncoder, JoseEncoder { + private static final class JweHeader extends JoseHeader { + + private JweHeader(Map headers) { + super(headers); + } + + @SuppressWarnings("unchecked") + @Override + public JweAlgorithm getAlgorithm() { + return super.getAlgorithm(); + } + + private static Builder with(JweAlgorithm jweAlgorithm, String enc) { + return new Builder(jweAlgorithm, enc); + } + + private static Builder from(JweHeader headers) { + return new Builder(headers); + } + + private static final class Builder extends AbstractBuilder { + + private Builder(JweAlgorithm jweAlgorithm, String enc) { + Assert.notNull(jweAlgorithm, "jweAlgorithm cannot be null"); + Assert.hasText(enc, "enc cannot be empty"); + algorithm(jweAlgorithm); + header("enc", enc); + } + + private Builder(JweHeader headers) { + Assert.notNull(headers, "headers cannot be null"); + Consumer> headersConsumer = (h) -> h.putAll(headers.getHeaders()); + headers(headersConsumer); + } + + @Override + public JweHeader build() { + validate(); + return new JweHeader(getHeaders()); + } + + } + + } + + private static final class NimbusJweEncoder implements JwtEncoder { private static final String ENCODING_ERROR_MESSAGE_TEMPLATE = "An error occurred while attempting to encode the Jwt: %s"; - private static final Converter JWE_HEADER_CONVERTER = new JweHeaderConverter(); + private static final Converter JWE_HEADER_CONVERTER = new JweHeaderConverter(); private static final Converter JWT_CLAIMS_SET_CONVERTER = new JwtClaimsSetConverter(); private final JWKSource jwkSource; - private NimbusJweEncoder(JWKSource jwkSource) { + private final JwtEncoder jwsEncoder; + + private NimbusJweEncoder(JWKSource jwkSource, JwtEncoder jwsEncoder) { Assert.notNull(jwkSource, "jwkSource cannot be null"); + Assert.notNull(jwsEncoder, "jwsEncoder cannot be null"); this.jwkSource = jwkSource; + this.jwsEncoder = jwsEncoder; } @Override public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { Assert.notNull(parameters, "parameters cannot be null"); - JwsHeader headers = parameters.getJwsHeader(); + // ********************** + // Assume future API: + // JwtEncoderParameters.getJweHeader() + // ********************** + JweHeader jweHeader = DEFAULT_JWE_HEADER; // Assume this is accessed via + // JwtEncoderParameters.getJweHeader() + + JwsHeader jwsHeader = parameters.getJwsHeader(); JwtClaimsSet claims = parameters.getClaims(); - JWTClaimsSet jwtClaimsSet = JWT_CLAIMS_SET_CONVERTER.convert(claims); + JWK jwk = selectJwk(jweHeader); + jweHeader = addKeyIdentifierHeadersIfNecessary(jweHeader, jwk); - JoseToken joseToken = encode(headers, new JosePayload<>(jwtClaimsSet.toString())); + JWEHeader jweHeader2 = JWE_HEADER_CONVERTER.convert(jweHeader); + JWTClaimsSet jwtClaimsSet = JWT_CLAIMS_SET_CONVERTER.convert(claims); - return new Jwt(joseToken.getTokenValue(), claims.getIssuedAt(), claims.getExpiresAt(), - joseToken.getHeaders().getHeaders(), claims.getClaims()); - } + String payload; + if (jwsHeader != null) { + Jwt jws = this.jwsEncoder.encode(JwtEncoderParameters.with(jwsHeader, claims)); + payload = jws.getTokenValue(); - @Override - public JoseToken encode(JoseHeader headers, JosePayload payload) throws JwtEncodingException { - Assert.notNull(headers, "headers cannot be null"); - Assert.notNull(payload, "payload cannot be null"); - - JWEHeader jweHeader; - try { - jweHeader = JWE_HEADER_CONVERTER.convert(headers); + // @formatter:off + jweHeader = JweHeader.from(jweHeader) + .contentType("JWT") // Indicates Nested JWT (REQUIRED) + .build(); + // @formatter:on } - catch (Exception ex) { - throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, ex.getMessage()), ex); - } - - JWK jwk = selectJwk(jweHeader); - if (jwk == null) { - throw new JwtEncodingException( - String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK encryption key")); + else { + payload = jwtClaimsSet.toString(); } - jweHeader = addKeyIdentifierHeadersIfNecessary(jweHeader, jwk); - headers = syncKeyIdentifierHeadersIfNecessary(headers, jweHeader); - - // FIXME - // Resolve type of JosePayload.content - // For now, assuming String type - String payloadContent = (String) payload.getContent(); - - JWEObject jweObject = new JWEObject(jweHeader, new Payload(payloadContent)); + JWEObject jweObject = new JWEObject(jweHeader2, new Payload(payload)); try { // FIXME // Resolve type of JWEEncrypter using the JWK key type @@ -251,78 +299,75 @@ public JoseToken encode(JoseHeader headers, JosePayload payload) throws JwtEn } String jwe = jweObject.serialize(); - return new JoseToken(jwe, null, null, headers, payload); + return new Jwt(jwe, claims.getIssuedAt(), claims.getExpiresAt(), jweHeader.getHeaders(), + claims.getClaims()); } - private JWK selectJwk(JWEHeader jweHeader) { - JWKSelector jwkSelector = new JWKSelector(JWKMatcher.forJWEHeader(jweHeader)); - + private JWK selectJwk(JweHeader headers) { List jwks; try { + JWKSelector jwkSelector = new JWKSelector(createJwkMatcher(headers)); jwks = this.jwkSource.get(jwkSelector, null); } - catch (KeySourceException ex) { + catch (Exception ex) { throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK encryption key -> " + ex.getMessage()), ex); } if (jwks.size() > 1) { throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, - "Found multiple JWK encryption keys for algorithm '" + jweHeader.getAlgorithm().getName() - + "'")); + "Found multiple JWK encryption keys for algorithm '" + headers.getAlgorithm().getName() + "'")); } - return !jwks.isEmpty() ? jwks.get(0) : null; - } - - private static JWEHeader addKeyIdentifierHeadersIfNecessary(JWEHeader jweHeader, JWK jwk) { - // Check if headers have already been added - if (StringUtils.hasText(jweHeader.getKeyID()) && jweHeader.getX509CertSHA256Thumbprint() != null) { - return jweHeader; - } - // Check if headers can be added from JWK - if (!StringUtils.hasText(jwk.getKeyID()) && jwk.getX509CertSHA256Thumbprint() == null) { - return jweHeader; + if (jwks.isEmpty()) { + throw new JwtEncodingException( + String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK encryption key")); } - JWEHeader.Builder headerBuilder = new JWEHeader.Builder(jweHeader); - if (!StringUtils.hasText(jweHeader.getKeyID()) && StringUtils.hasText(jwk.getKeyID())) { - headerBuilder.keyID(jwk.getKeyID()); - } - if (jweHeader.getX509CertSHA256Thumbprint() == null && jwk.getX509CertSHA256Thumbprint() != null) { - headerBuilder.x509CertSHA256Thumbprint(jwk.getX509CertSHA256Thumbprint()); - } + return jwks.get(0); + } - return headerBuilder.build(); + private static JWKMatcher createJwkMatcher(JweHeader headers) { + JWEAlgorithm jweAlgorithm = JWEAlgorithm.parse(headers.getAlgorithm().getName()); + + // @formatter:off + return new JWKMatcher.Builder() + .keyType(KeyType.forAlgorithm(jweAlgorithm)) + .keyID(headers.getKeyId()) + .keyUses(KeyUse.ENCRYPTION, null) + .algorithms(jweAlgorithm, null) + .x509CertSHA256Thumbprint(Base64URL.from(headers.getX509SHA256Thumbprint())) + .build(); + // @formatter:on } - private static JoseHeader syncKeyIdentifierHeadersIfNecessary(JoseHeader joseHeader, JWEHeader jweHeader) { - String jweHeaderX509SHA256Thumbprint = null; - if (jweHeader.getX509CertSHA256Thumbprint() != null) { - jweHeaderX509SHA256Thumbprint = jweHeader.getX509CertSHA256Thumbprint().toString(); + private static JweHeader addKeyIdentifierHeadersIfNecessary(JweHeader headers, JWK jwk) { + // Check if headers have already been added + if (StringUtils.hasText(headers.getKeyId()) && StringUtils.hasText(headers.getX509SHA256Thumbprint())) { + return headers; } - if (Objects.equals(joseHeader.getKeyId(), jweHeader.getKeyID()) - && Objects.equals(joseHeader.getX509SHA256Thumbprint(), jweHeaderX509SHA256Thumbprint)) { - return joseHeader; + // Check if headers can be added from JWK + if (!StringUtils.hasText(jwk.getKeyID()) && jwk.getX509CertSHA256Thumbprint() == null) { + return headers; } - JoseHeader.Builder headerBuilder = JoseHeader.from(joseHeader); - if (!Objects.equals(joseHeader.getKeyId(), jweHeader.getKeyID())) { - headerBuilder.keyId(jweHeader.getKeyID()); + JweHeader.Builder headersBuilder = JweHeader.from(headers); + if (!StringUtils.hasText(headers.getKeyId()) && StringUtils.hasText(jwk.getKeyID())) { + headersBuilder.keyId(jwk.getKeyID()); } - if (!Objects.equals(joseHeader.getX509SHA256Thumbprint(), jweHeaderX509SHA256Thumbprint)) { - headerBuilder.x509SHA256Thumbprint(jweHeaderX509SHA256Thumbprint); + if (!StringUtils.hasText(headers.getX509SHA256Thumbprint()) && jwk.getX509CertSHA256Thumbprint() != null) { + headersBuilder.x509SHA256Thumbprint(jwk.getX509CertSHA256Thumbprint().toString()); } - return headerBuilder.build(); + return headersBuilder.build(); } } - private static class JweHeaderConverter implements Converter { + private static class JweHeaderConverter implements Converter { @Override - public JWEHeader convert(JoseHeader headers) { + public JWEHeader convert(JweHeader headers) { JWEAlgorithm jweAlgorithm = JWEAlgorithm.parse(headers.getAlgorithm().getName()); EncryptionMethod encryptionMethod = EncryptionMethod.parse(headers.getHeader("enc")); JWEHeader.Builder builder = new JWEHeader.Builder(jweAlgorithm, encryptionMethod); @@ -464,85 +509,4 @@ public JWTClaimsSet convert(JwtClaimsSet claims) { } - static class JoseToken extends AbstractOAuth2Token { - - private final JoseHeader headers; - - private final JosePayload payload; - - JoseToken(String tokenValue, Instant issuedAt, Instant expiresAt, JoseHeader headers, JosePayload payload) { - super(tokenValue, issuedAt, expiresAt); - this.headers = headers; - this.payload = payload; - } - - JoseHeader getHeaders() { - return this.headers; - } - - JosePayload getPayload() { - return this.payload; - } - - } - - static class JosePayload { - - private final T content; - - JosePayload(T content) { - this.content = content; - } - - T getContent() { - return this.content; - } - - } - - // @formatter:off - /* - * IMPORTANT DESIGN DECISION - * ------------------------- - * - * This API is needed in order to support "Nested JWT". - * - * See section 2. Terminology - * https://tools.ietf.org/html/rfc7519#section-2 - * - * Nested JWT - * A JWT in which nested signing and/or encryption are employed. - * In Nested JWTs, a JWT is used as the payload or plaintext value of an - * enclosing JWS or JWE structure, respectively. - * - * See section 3. JSON Web Token (JWT) Overview - * https://tools.ietf.org/html/rfc7519#section-3 - * - * JWTs represent a set of claims as a JSON object that is encoded in a - * JWS and/or JWE structure. This JSON object is the JWT Claims Set. - * - * The contents of the JOSE Header describe the cryptographic operations - * applied to the JWT Claims Set. If the JOSE Header is for a JWS, the - * JWT is represented as a JWS and the claims are digitally signed or - * MACed, with the JWT Claims Set being the JWS Payload. If the JOSE - * Header is for a JWE, the JWT is represented as a JWE and the claims - * are encrypted, with the JWT Claims Set being the plaintext encrypted - * by the JWE. A JWT may be enclosed in another JWE or JWS structure to - * create a Nested JWT, enabling nested signing and encryption to be - * performed. - * - * ----------------------- - * - * In summary, the `JwtEncoder` API is designed for signing (JWS) and encrypting (JWE) a JWT Claims Set. - * Whereas, the `JoseEncoder` API is a higher level of abstraction that can be used for Nested JWT (signing and encryption). - * NOTE: The `JosePayload` type provides the flexibility to support any data type, - * e.g. JWT/JWS, JwtClaimsSet, String, Map, byte[], etc. - */ - interface JoseEncoder { - - JoseToken encode(JoseHeader headers, JosePayload payload) throws JwtEncodingException; - - } - // @formatter:on - } From 4ccef3845f8f8e3e1bf766c9fa0a07fb2d2955b9 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 20 Sep 2021 07:18:58 -0400 Subject: [PATCH 12/20] Polish NimbusJweEncoderTests --- .../oauth2/jwt/NimbusJweEncoderTests.java | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java index d2ed41363d3..61b3389218a 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java @@ -61,7 +61,7 @@ import static org.assertj.core.api.Assertions.assertThat; /** - * Tests for {@link NimbusJweEncoder} (future support for JWE). + * Tests for proofing out future support of JWE. * * @author Joe Grandja */ @@ -78,14 +78,11 @@ public class NimbusJweEncoderTests { private NimbusJweEncoder jweEncoder; - private NimbusJwtEncoder jwsEncoder; - @BeforeEach public void setUp() { this.jwkList = new ArrayList<>(); this.jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(new JWKSet(this.jwkList)); - this.jwsEncoder = new NimbusJwtEncoder(this.jwkSource); - this.jweEncoder = new NimbusJweEncoder(this.jwkSource, this.jwsEncoder); + this.jweEncoder = new NimbusJweEncoder(this.jwkSource); } @Test @@ -95,10 +92,12 @@ public void encodeWhenJwtClaimsSetThenEncodes() { JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + // @formatter:off // ********************** // Assume future API: - // JwtEncoderParameters.with(JweHeader jweHeader, JwtClaimsSet claims) + // JwtEncoderParameters.with(JweHeader jweHeader, JwtClaimsSet claims) // ********************** + // @formatter:on Jwt encodedJwe = this.jweEncoder.encode(JwtEncoderParameters.with(jwtClaimsSet)); assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(DEFAULT_JWE_HEADER.getAlgorithm()); @@ -137,11 +136,12 @@ public void encodeWhenNestedJwsThenEncodes() { JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); + // @formatter:off // ********************** // Assume future API: - // JwtEncoderParameters.with(JweHeader jweHeader, JwsHeader jwsHeader, - // JwtClaimsSet claims) + // JwtEncoderParameters.with(JwsHeader jwsHeader, JweHeader jweHeader, JwtClaimsSet claims) // ********************** + // @formatter:on Jwt encodedJweNestedJws = this.jweEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); assertThat(encodedJweNestedJws.getHeaders().get(JoseHeaderNames.ALG)) @@ -244,21 +244,22 @@ private static final class NimbusJweEncoder implements JwtEncoder { private final JwtEncoder jwsEncoder; - private NimbusJweEncoder(JWKSource jwkSource, JwtEncoder jwsEncoder) { + private NimbusJweEncoder(JWKSource jwkSource) { Assert.notNull(jwkSource, "jwkSource cannot be null"); - Assert.notNull(jwsEncoder, "jwsEncoder cannot be null"); this.jwkSource = jwkSource; - this.jwsEncoder = jwsEncoder; + this.jwsEncoder = new NimbusJwtEncoder(jwkSource); } @Override public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { Assert.notNull(parameters, "parameters cannot be null"); + // @formatter:off // ********************** // Assume future API: - // JwtEncoderParameters.getJweHeader() + // JwtEncoderParameters.getJweHeader() // ********************** + // @formatter:on JweHeader jweHeader = DEFAULT_JWE_HEADER; // Assume this is accessed via // JwtEncoderParameters.getJweHeader() @@ -273,6 +274,7 @@ public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { String payload; if (jwsHeader != null) { + // Sign then encrypt Jwt jws = this.jwsEncoder.encode(JwtEncoderParameters.with(jwsHeader, claims)); payload = jws.getTokenValue(); @@ -283,6 +285,7 @@ public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { // @formatter:on } else { + // Encrypt only payload = jwtClaimsSet.toString(); } @@ -299,6 +302,10 @@ public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { } String jwe = jweObject.serialize(); + // NOTE: + // For the Nested JWS use case, we lose access to the JWS Header in the + // returned JWT. + // If this is needed, we can simply add the new method Jwt.getNestedHeaders(). return new Jwt(jwe, claims.getIssuedAt(), claims.getExpiresAt(), jweHeader.getHeaders(), claims.getClaims()); } From 553f890a4b1e41023c7b528f3e1a0423522af67d Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 20 Sep 2021 07:28:42 -0400 Subject: [PATCH 13/20] Add tests --- .../security/oauth2/jwt/JwsHeaderTests.java | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java new file mode 100644 index 00000000000..18265241562 --- /dev/null +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java @@ -0,0 +1,83 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.jwt; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +/** + * Tests for {@link JwsHeader}. + * + * @author Joe Grandja + */ +public class JwsHeaderTests { + + @Test + public void withWhenNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JwsHeader.with(null)) + .withMessage("jwsAlgorithm cannot be null"); + } + + @Test + public void fromWhenNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JwsHeader.from(null)) + .withMessage("headers cannot be null"); + } + + @Test + public void fromWhenHeadersProvidedThenCopied() { + JwsHeader expectedJwsHeader = TestJoseHeaders.jwsHeader().build(); + JwsHeader jwsHeader = JwsHeader.from(expectedJwsHeader).build(); + assertThat(jwsHeader.getHeaders()).isEqualTo(expectedJwsHeader.getHeaders()); + } + + @Test + public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { + JwsHeader expectedJwsHeader = TestJoseHeaders.jwsHeader().build(); + + // @formatter:off + JwsHeader jwsHeader = JwsHeader.with(expectedJwsHeader.getAlgorithm()) + .jwkSetUrl(expectedJwsHeader.getJwkSetUrl().toExternalForm()) + .jwk(expectedJwsHeader.getJwk()) + .keyId(expectedJwsHeader.getKeyId()) + .x509Url(expectedJwsHeader.getX509Url().toExternalForm()) + .x509CertificateChain(expectedJwsHeader.getX509CertificateChain()) + .x509SHA1Thumbprint(expectedJwsHeader.getX509SHA1Thumbprint()) + .x509SHA256Thumbprint(expectedJwsHeader.getX509SHA256Thumbprint()) + .type(expectedJwsHeader.getType()) + .contentType(expectedJwsHeader.getContentType()) + .headers((headers) -> headers.put("custom-header-name", "custom-header-value")) + .build(); + // @formatter:on + + assertThat(jwsHeader.getAlgorithm()).isEqualTo(expectedJwsHeader.getAlgorithm()); + assertThat(jwsHeader.getJwkSetUrl()).isEqualTo(expectedJwsHeader.getJwkSetUrl()); + assertThat(jwsHeader.getJwk()).isEqualTo(expectedJwsHeader.getJwk()); + assertThat(jwsHeader.getKeyId()).isEqualTo(expectedJwsHeader.getKeyId()); + assertThat(jwsHeader.getX509Url()).isEqualTo(expectedJwsHeader.getX509Url()); + assertThat(jwsHeader.getX509CertificateChain()).isEqualTo(expectedJwsHeader.getX509CertificateChain()); + assertThat(jwsHeader.getX509SHA1Thumbprint()).isEqualTo(expectedJwsHeader.getX509SHA1Thumbprint()); + assertThat(jwsHeader.getX509SHA256Thumbprint()).isEqualTo(expectedJwsHeader.getX509SHA256Thumbprint()); + assertThat(jwsHeader.getType()).isEqualTo(expectedJwsHeader.getType()); + assertThat(jwsHeader.getContentType()).isEqualTo(expectedJwsHeader.getContentType()); + assertThat(jwsHeader.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + assertThat(jwsHeader.getHeaders()).isEqualTo(expectedJwsHeader.getHeaders()); + } + +} From f6f96d27240b25cf488f6544726b8729df9cb055 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 20 Sep 2021 09:32:58 -0400 Subject: [PATCH 14/20] Make JoseHeader package-private --- etc/checkstyle/checkstyle-suppressions.xml | 1 + .../security/oauth2/jwt/JoseHeader.java | 52 +------- .../security/oauth2/jwt/JoseHeaderNames.java | 1 - .../security/oauth2/jwt/JwsHeader.java | 1 - .../security/oauth2/jwt/JoseHeaderTests.java | 117 ------------------ .../security/oauth2/jwt/JwsHeaderTests.java | 33 +++++ 6 files changed, 37 insertions(+), 168 deletions(-) delete mode 100644 oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java diff --git a/etc/checkstyle/checkstyle-suppressions.xml b/etc/checkstyle/checkstyle-suppressions.xml index cd90b3cd049..956a3c074e9 100644 --- a/etc/checkstyle/checkstyle-suppressions.xml +++ b/etc/checkstyle/checkstyle-suppressions.xml @@ -51,4 +51,5 @@ + diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java index 1f20fbdb496..ec53883e473 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java @@ -44,7 +44,7 @@ * @see JWE JOSE * Header */ -public class JoseHeader { +class JoseHeader { private final Map headers; @@ -180,55 +180,9 @@ public T getHeader(String name) { } /** - * Returns a new {@link Builder}, initialized with the provided {@link JwaAlgorithm}. - * @param jwaAlgorithm the {@link JwaAlgorithm} - * @return the {@link Builder} + * A builder for subclasses of {@link JoseHeader}. */ - public static Builder with(JwaAlgorithm jwaAlgorithm) { - return new Builder(jwaAlgorithm); - } - - /** - * Returns a new {@link Builder}, initialized with the provided {@code headers}. - * @param headers the headers - * @return the {@link Builder} - */ - public static Builder from(JoseHeader headers) { - return new Builder(headers); - } - - /** - * A builder for {@link JoseHeader}. - */ - public static final class Builder extends AbstractBuilder { - - private Builder(JwaAlgorithm jwaAlgorithm) { - Assert.notNull(jwaAlgorithm, "jwaAlgorithm cannot be null"); - algorithm(jwaAlgorithm); - } - - private Builder(JoseHeader headers) { - Assert.notNull(headers, "headers cannot be null"); - Consumer> headersConsumer = (h) -> h.putAll(headers.getHeaders()); - headers(headersConsumer); - } - - /** - * Builds a new {@link JoseHeader}. - * @return a {@link JoseHeader} - */ - @Override - public JoseHeader build() { - validate(); - return new JoseHeader(getHeaders()); - } - - } - - /** - * A builder for {@link JoseHeader} and subclasses. - */ - protected abstract static class AbstractBuilder> { + abstract static class AbstractBuilder> { private final Map headers = new HashMap<>(); diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeaderNames.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeaderNames.java index 9e5f04f8673..a53318584fc 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeaderNames.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeaderNames.java @@ -24,7 +24,6 @@ * @author Anoop Garlapati * @author Joe Grandja * @since 5.6 - * @see JoseHeader * @see JWT JOSE * Header * @see JWS JOSE diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java index 0354253233e..e6d30be6568 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java @@ -29,7 +29,6 @@ * * @author Joe Grandja * @since 5.6 - * @see JoseHeader * @see JWS JOSE * Header */ diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java deleted file mode 100644 index 570e513b9e9..00000000000 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Copyright 2002-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.oauth2.jwt; - -import java.util.Collections; - -import org.junit.jupiter.api.Test; - -import org.springframework.security.oauth2.jose.JwaAlgorithm; -import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; - -/** - * Tests for {@link JoseHeader}. - * - * @author Joe Grandja - */ -public class JoseHeaderTests { - - @Test - public void withWhenNullThenThrowIllegalArgumentException() { - assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JoseHeader.with(null)) - .withMessage("jwaAlgorithm cannot be null"); - } - - @Test - public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { - JoseHeader expectedJoseHeader = TestJoseHeaders.jwsHeader().build(); - - // @formatter:off - JoseHeader joseHeader = JoseHeader.with(expectedJoseHeader.getAlgorithm()) - .jwkSetUrl(expectedJoseHeader.getJwkSetUrl().toExternalForm()) - .jwk(expectedJoseHeader.getJwk()) - .keyId(expectedJoseHeader.getKeyId()) - .x509Url(expectedJoseHeader.getX509Url().toExternalForm()) - .x509CertificateChain(expectedJoseHeader.getX509CertificateChain()) - .x509SHA1Thumbprint(expectedJoseHeader.getX509SHA1Thumbprint()) - .x509SHA256Thumbprint(expectedJoseHeader.getX509SHA256Thumbprint()) - .type(expectedJoseHeader.getType()) - .contentType(expectedJoseHeader.getContentType()) - .headers((headers) -> headers.put("custom-header-name", "custom-header-value")) - .build(); - // @formatter:on - - assertThat(joseHeader.getAlgorithm()).isEqualTo(expectedJoseHeader.getAlgorithm()); - assertThat(joseHeader.getJwkSetUrl()).isEqualTo(expectedJoseHeader.getJwkSetUrl()); - assertThat(joseHeader.getJwk()).isEqualTo(expectedJoseHeader.getJwk()); - assertThat(joseHeader.getKeyId()).isEqualTo(expectedJoseHeader.getKeyId()); - assertThat(joseHeader.getX509Url()).isEqualTo(expectedJoseHeader.getX509Url()); - assertThat(joseHeader.getX509CertificateChain()).isEqualTo(expectedJoseHeader.getX509CertificateChain()); - assertThat(joseHeader.getX509SHA1Thumbprint()).isEqualTo(expectedJoseHeader.getX509SHA1Thumbprint()); - assertThat(joseHeader.getX509SHA256Thumbprint()).isEqualTo(expectedJoseHeader.getX509SHA256Thumbprint()); - assertThat(joseHeader.getType()).isEqualTo(expectedJoseHeader.getType()); - assertThat(joseHeader.getContentType()).isEqualTo(expectedJoseHeader.getContentType()); - assertThat(joseHeader.getHeader("custom-header-name")).isEqualTo("custom-header-value"); - assertThat(joseHeader.getHeaders()).isEqualTo(expectedJoseHeader.getHeaders()); - } - - @Test - public void buildWhenMissingCriticalHeaderThenThrowIllegalStateException() { - assertThatExceptionOfType(IllegalStateException.class).isThrownBy( - () -> TestJoseHeaders.jwsHeader().critical(Collections.singleton("critical-header-name")).build()) - .withMessage("Missing critical (crit) header 'critical-header-name'."); - } - - @Test - public void fromWhenNullThenThrowIllegalArgumentException() { - assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JoseHeader.from(null)) - .withMessage("headers cannot be null"); - } - - @Test - public void fromWhenHeadersProvidedThenCopied() { - JoseHeader expectedJoseHeader = TestJoseHeaders.jwsHeader().build(); - JoseHeader joseHeader = JoseHeader.from(expectedJoseHeader).build(); - assertThat(joseHeader.getHeaders()).isEqualTo(expectedJoseHeader.getHeaders()); - } - - @Test - public void headerWhenNameNullThenThrowIllegalArgumentException() { - assertThatExceptionOfType(IllegalArgumentException.class) - .isThrownBy(() -> JoseHeader.with(SignatureAlgorithm.RS256).header(null, "value")) - .withMessage("name cannot be empty"); - } - - @Test - public void headerWhenValueNullThenThrowIllegalArgumentException() { - assertThatExceptionOfType(IllegalArgumentException.class) - .isThrownBy(() -> JoseHeader.with(SignatureAlgorithm.RS256).header("name", null)) - .withMessage("value cannot be null"); - } - - @Test - public void getHeaderWhenNullThenThrowIllegalArgumentException() { - JoseHeader joseHeader = TestJoseHeaders.jwsHeader().build(); - - assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> joseHeader.getHeader(null)) - .withMessage("name cannot be empty"); - } - -} diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java index 18265241562..d05e5e586ec 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java @@ -16,8 +16,12 @@ package org.springframework.security.oauth2.jwt; +import java.util.Collections; + import org.junit.jupiter.api.Test; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -80,4 +84,33 @@ public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { assertThat(jwsHeader.getHeaders()).isEqualTo(expectedJwsHeader.getHeaders()); } + @Test + public void buildWhenMissingCriticalHeaderThenThrowIllegalStateException() { + assertThatExceptionOfType(IllegalStateException.class).isThrownBy( + () -> TestJoseHeaders.jwsHeader().critical(Collections.singleton("critical-header-name")).build()) + .withMessage("Missing critical (crit) header 'critical-header-name'."); + } + + @Test + public void headerWhenNameNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> JwsHeader.with(SignatureAlgorithm.RS256).header(null, "value")) + .withMessage("name cannot be empty"); + } + + @Test + public void headerWhenValueNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> JwsHeader.with(SignatureAlgorithm.RS256).header("name", null)) + .withMessage("value cannot be null"); + } + + @Test + public void getHeaderWhenNullThenThrowIllegalArgumentException() { + JwsHeader jwsHeader = TestJoseHeaders.jwsHeader().build(); + + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> jwsHeader.getHeader(null)) + .withMessage("name cannot be empty"); + } + } From ee142b08536aa555920f4c2d76aca3d1da12783b Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 20 Sep 2021 10:19:10 -0400 Subject: [PATCH 15/20] Rename TestJoseHeaders to TestJwsHeaders --- .../security/oauth2/jwt/JwsHeaderTests.java | 8 ++++---- .../jwt/{TestJoseHeaders.java => TestJwsHeaders.java} | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) rename oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/{TestJoseHeaders.java => TestJwsHeaders.java} (96%) diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java index d05e5e586ec..518c3179a24 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java @@ -46,14 +46,14 @@ public void fromWhenNullThenThrowIllegalArgumentException() { @Test public void fromWhenHeadersProvidedThenCopied() { - JwsHeader expectedJwsHeader = TestJoseHeaders.jwsHeader().build(); + JwsHeader expectedJwsHeader = TestJwsHeaders.jwsHeader().build(); JwsHeader jwsHeader = JwsHeader.from(expectedJwsHeader).build(); assertThat(jwsHeader.getHeaders()).isEqualTo(expectedJwsHeader.getHeaders()); } @Test public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { - JwsHeader expectedJwsHeader = TestJoseHeaders.jwsHeader().build(); + JwsHeader expectedJwsHeader = TestJwsHeaders.jwsHeader().build(); // @formatter:off JwsHeader jwsHeader = JwsHeader.with(expectedJwsHeader.getAlgorithm()) @@ -87,7 +87,7 @@ public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { @Test public void buildWhenMissingCriticalHeaderThenThrowIllegalStateException() { assertThatExceptionOfType(IllegalStateException.class).isThrownBy( - () -> TestJoseHeaders.jwsHeader().critical(Collections.singleton("critical-header-name")).build()) + () -> TestJwsHeaders.jwsHeader().critical(Collections.singleton("critical-header-name")).build()) .withMessage("Missing critical (crit) header 'critical-header-name'."); } @@ -107,7 +107,7 @@ public void headerWhenValueNullThenThrowIllegalArgumentException() { @Test public void getHeaderWhenNullThenThrowIllegalArgumentException() { - JwsHeader jwsHeader = TestJoseHeaders.jwsHeader().build(); + JwsHeader jwsHeader = TestJwsHeaders.jwsHeader().build(); assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> jwsHeader.getHeader(null)) .withMessage("name cannot be empty"); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJoseHeaders.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJwsHeaders.java similarity index 96% rename from oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJoseHeaders.java rename to oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJwsHeaders.java index 7ca19be877d..6cbd78edcd3 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJoseHeaders.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/TestJwsHeaders.java @@ -25,9 +25,9 @@ /** * @author Joe Grandja */ -public final class TestJoseHeaders { +public final class TestJwsHeaders { - private TestJoseHeaders() { + private TestJwsHeaders() { } public static JwsHeader.Builder jwsHeader() { From 74f15f55a103b931d8180cf49bcc2a9dfbbc0aa9 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 20 Sep 2021 10:41:49 -0400 Subject: [PATCH 16/20] JwsHeader parameter should be required --- .../security/oauth2/jwt/JwtEncoderParameters.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java index f6504c6718d..ec8a8305104 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java @@ -35,7 +35,6 @@ public final class JwtEncoderParameters { private final JwtClaimsSet claims; private JwtEncoderParameters(JwsHeader jwsHeader, JwtClaimsSet claims) { - Assert.notNull(claims, "claims cannot be null"); this.jwsHeader = jwsHeader; this.claims = claims; } @@ -47,7 +46,8 @@ private JwtEncoderParameters(JwsHeader jwsHeader, JwtClaimsSet claims) { * @return the {@link JwtEncoderParameters} */ public static JwtEncoderParameters with(JwtClaimsSet claims) { - return with(null, claims); + Assert.notNull(claims, "claims cannot be null"); + return new JwtEncoderParameters(null, claims); } /** @@ -57,13 +57,15 @@ public static JwtEncoderParameters with(JwtClaimsSet claims) { * @param claims the {@link JwtClaimsSet} * @return the {@link JwtEncoderParameters} */ - public static JwtEncoderParameters with(@Nullable JwsHeader jwsHeader, JwtClaimsSet claims) { + public static JwtEncoderParameters with(JwsHeader jwsHeader, JwtClaimsSet claims) { + Assert.notNull(jwsHeader, "jwsHeader cannot be null"); + Assert.notNull(claims, "claims cannot be null"); return new JwtEncoderParameters(jwsHeader, claims); } /** * Returns the {@link JwsHeader JWS headers}. - * @return the {@link JwsHeader}, or {@code null} if not available + * @return the {@link JwsHeader}, or {@code null} if not specified */ @Nullable public JwsHeader getJwsHeader() { From a5b54fd1ae01ef003ad8196ce8e9d6348929d2b0 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 20 Sep 2021 11:32:02 -0400 Subject: [PATCH 17/20] Rename JwtEncoderParameters.with() to from() --- ...ientAuthenticationParametersConverter.java | 2 +- .../oauth2/jwt/JwtEncoderParameters.java | 4 ++-- .../oauth2/jwt/NimbusJweEncoderTests.java | 6 ++--- .../oauth2/jwt/NimbusJwtEncoderTests.java | 22 +++++++++---------- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java index ef333201c18..03c8611af30 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java @@ -155,7 +155,7 @@ public MultiValueMap convert(T authorizationGrantRequest) { }); JwtEncoder jwsEncoder = jwsEncoderHolder.getJwsEncoder(); - Jwt jws = jwsEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); + Jwt jws = jwsEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet)); MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.set(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, CLIENT_ASSERTION_TYPE_VALUE); diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java index ec8a8305104..03bfc6ac96e 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoderParameters.java @@ -45,7 +45,7 @@ private JwtEncoderParameters(JwsHeader jwsHeader, JwtClaimsSet claims) { * @param claims the {@link JwtClaimsSet} * @return the {@link JwtEncoderParameters} */ - public static JwtEncoderParameters with(JwtClaimsSet claims) { + public static JwtEncoderParameters from(JwtClaimsSet claims) { Assert.notNull(claims, "claims cannot be null"); return new JwtEncoderParameters(null, claims); } @@ -57,7 +57,7 @@ public static JwtEncoderParameters with(JwtClaimsSet claims) { * @param claims the {@link JwtClaimsSet} * @return the {@link JwtEncoderParameters} */ - public static JwtEncoderParameters with(JwsHeader jwsHeader, JwtClaimsSet claims) { + public static JwtEncoderParameters from(JwsHeader jwsHeader, JwtClaimsSet claims) { Assert.notNull(jwsHeader, "jwsHeader cannot be null"); Assert.notNull(claims, "claims cannot be null"); return new JwtEncoderParameters(jwsHeader, claims); diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java index 61b3389218a..55d652f036e 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java @@ -98,7 +98,7 @@ public void encodeWhenJwtClaimsSetThenEncodes() { // JwtEncoderParameters.with(JweHeader jweHeader, JwtClaimsSet claims) // ********************** // @formatter:on - Jwt encodedJwe = this.jweEncoder.encode(JwtEncoderParameters.with(jwtClaimsSet)); + Jwt encodedJwe = this.jweEncoder.encode(JwtEncoderParameters.from(jwtClaimsSet)); assertThat(encodedJwe.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(DEFAULT_JWE_HEADER.getAlgorithm()); assertThat(encodedJwe.getHeaders().get("enc")).isEqualTo(DEFAULT_JWE_HEADER.getHeader("enc")); @@ -142,7 +142,7 @@ public void encodeWhenNestedJwsThenEncodes() { // JwtEncoderParameters.with(JwsHeader jwsHeader, JweHeader jweHeader, JwtClaimsSet claims) // ********************** // @formatter:on - Jwt encodedJweNestedJws = this.jweEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); + Jwt encodedJweNestedJws = this.jweEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet)); assertThat(encodedJweNestedJws.getHeaders().get(JoseHeaderNames.ALG)) .isEqualTo(DEFAULT_JWE_HEADER.getAlgorithm()); @@ -275,7 +275,7 @@ public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { String payload; if (jwsHeader != null) { // Sign then encrypt - Jwt jws = this.jwsEncoder.encode(JwtEncoderParameters.with(jwsHeader, claims)); + Jwt jws = this.jwsEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); payload = jws.getTokenValue(); // @formatter:off diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java index 079eac5b9c1..60c99b13ce9 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java @@ -89,7 +89,7 @@ public void encodeWhenClaimsNullThenThrowIllegalArgumentException() { JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build(); assertThatIllegalArgumentException() - .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, null))) + .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, null))) .withMessage("claims cannot be null"); } @@ -103,7 +103,7 @@ public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exce JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) - .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet))) + .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet))) .withMessageContaining("Failed to select a JWK signing key -> key source error"); } @@ -117,7 +117,7 @@ public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) - .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet))) + .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet))) .withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'"); } @@ -127,7 +127,7 @@ public void encodeWhenJwkSelectEmptyThenThrowJwtEncodingException() { JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) - .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet))) + .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet))) .withMessageContaining("Failed to select a JWK signing key"); } @@ -142,7 +142,7 @@ public void encodeWhenHeadersNotProvidedThenDefaulted() { JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJws = this.jwtEncoder.encode(JwtEncoderParameters.with(jwtClaimsSet)); + Jwt encodedJws = this.jwtEncoder.encode(JwtEncoderParameters.from(jwtClaimsSet)); assertThat(encodedJws.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(SignatureAlgorithm.RS256); } @@ -163,7 +163,7 @@ public void encodeWhenJwkSelectWithProvidedKidThenSelected() { JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).keyId(rsaJwk2.getKeyID()).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJws = this.jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); + Jwt encodedJws = this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet)); assertThat(encodedJws.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(rsaJwk2.getKeyID()); } @@ -187,7 +187,7 @@ public void encodeWhenJwkSelectWithProvidedX5TS256ThenSelected() { .x509SHA256Thumbprint(rsaJwk1.getX509CertSHA256Thumbprint().toString()).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJws = this.jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); + Jwt encodedJws = this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet)); assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5T_S256)) .isEqualTo(rsaJwk1.getX509CertSHA256Thumbprint().toString()); @@ -210,7 +210,7 @@ public void encodeWhenJwkUseEncryptionThenThrowJwtEncodingException() throws Exc JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); assertThatExceptionOfType(JwtEncodingException.class) - .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet))) + .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet))) .withMessageContaining( "Failed to create a JWS Signer -> The JWK use must be sig (signature) or unspecified"); } @@ -228,7 +228,7 @@ public void encodeWhenSuccessThenDecodes() throws Exception { JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJws = this.jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); + Jwt encodedJws = this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet)); assertThat(encodedJws.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(jwsHeader.getAlgorithm()); assertThat(encodedJws.getHeaders().get(JoseHeaderNames.JKU)).isNull(); @@ -273,7 +273,7 @@ public List get(JWKSelector jwkSelector, SecurityContext context) { JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.RS256).build(); JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build(); - Jwt encodedJws = jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); + Jwt encodedJws = jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet)); JWK jwk1 = jwkListResultCaptor.getResult().get(0); NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(((RSAKey) jwk1).toRSAPublicKey()).build(); @@ -281,7 +281,7 @@ public List get(JWKSelector jwkSelector, SecurityContext context) { jwkSource.rotate(); // Simulate key rotation - encodedJws = jwtEncoder.encode(JwtEncoderParameters.with(jwsHeader, jwtClaimsSet)); + encodedJws = jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet)); JWK jwk2 = jwkListResultCaptor.getResult().get(0); jwtDecoder = NimbusJwtDecoder.withPublicKey(((RSAKey) jwk2).toRSAPublicKey()).build(); From 8dbae128b62fd5898bf59191beaabfdace2b3689 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 20 Sep 2021 16:15:08 -0400 Subject: [PATCH 18/20] Add criticalHeader() convenience methods --- .../security/oauth2/jwt/JoseHeader.java | 51 +++++++++++++------ .../security/oauth2/jwt/JwsHeader.java | 3 +- .../security/oauth2/jwt/JwsHeaderTests.java | 20 ++++---- .../oauth2/jwt/NimbusJweEncoderTests.java | 3 +- 4 files changed, 48 insertions(+), 29 deletions(-) diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java index ec53883e473..4be71daf6d0 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java @@ -186,6 +186,8 @@ abstract static class AbstractBuilder headers = new HashMap<>(); + private final Map criticalHeaders = new HashMap<>(); + protected AbstractBuilder() { } @@ -193,6 +195,21 @@ protected Map getHeaders() { return this.headers; } + protected Map getCriticalHeaders() { + return this.criticalHeaders; + } + + protected Map getMergedHeaders() { + if (getCriticalHeaders().isEmpty()) { + return getHeaders(); + } + Map mergedHeaders = new HashMap<>(getHeaders()); + Set crit = getCriticalHeaders().keySet(); + mergedHeaders.put(JoseHeaderNames.CRIT, crit); + mergedHeaders.putAll(getCriticalHeaders()); + return mergedHeaders; + } + @SuppressWarnings("unchecked") protected final B getThis() { return (B) this; // avoid unchecked casts in subclasses by using "getThis()" @@ -306,13 +323,28 @@ public B contentType(String contentType) { } /** - * Sets the critical headers that indicates which extensions to the JWS/JWE/JWA + * Sets the critical header that indicates which extensions to the JWS/JWE/JWA * specifications are being used that MUST be understood and processed. - * @param headerNames the critical header names + * @param name the critical header name + * @param value the critical header value * @return the {@link AbstractBuilder} */ - public B critical(Set headerNames) { - return header(JoseHeaderNames.CRIT, headerNames); + public B criticalHeader(String name, Object value) { + Assert.hasText(name, "name cannot be empty"); + Assert.notNull(value, "value cannot be null"); + this.criticalHeaders.put(name, value); + return getThis(); + } + + /** + * A {@code Consumer} to be provided access to the critical headers allowing the + * ability to add, replace, or remove. + * @param headersConsumer a {@code Consumer} of the critical headers + * @return the {@link AbstractBuilder} + */ + public B criticalHeaders(Consumer> headersConsumer) { + headersConsumer.accept(this.criticalHeaders); + return getThis(); } /** @@ -345,17 +377,6 @@ public B headers(Consumer> headersConsumer) { */ public abstract T build(); - @SuppressWarnings("unchecked") - protected void validate() { - Set criticalHeaderNames = (Set) this.headers.get(JoseHeaderNames.CRIT); - if (criticalHeaderNames == null) { - return; - } - criticalHeaderNames - .forEach((criticalHeaderName) -> Assert.state(this.headers.containsKey(criticalHeaderName), - "Missing critical (crit) header '" + criticalHeaderName + "'.")); - } - private static URL convertAsURL(String header, String value) { URL convertedValue = ClaimConversionService.getSharedInstance().convert(value, URL.class); Assert.isTrue(convertedValue != null, diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java index e6d30be6568..955a6e0bb90 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java @@ -84,8 +84,7 @@ private Builder(JwsHeader headers) { */ @Override public JwsHeader build() { - validate(); - return new JwsHeader(getHeaders()); + return new JwsHeader(getMergedHeaders()); } } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java index 518c3179a24..eae43f54715 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java @@ -16,8 +16,6 @@ package org.springframework.security.oauth2.jwt; -import java.util.Collections; - import org.junit.jupiter.api.Test; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; @@ -66,6 +64,11 @@ public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { .x509SHA256Thumbprint(expectedJwsHeader.getX509SHA256Thumbprint()) .type(expectedJwsHeader.getType()) .contentType(expectedJwsHeader.getContentType()) + .criticalHeader("critical-header1-name", "critical-header1-value") + .criticalHeaders((criticalHeaders) -> { + criticalHeaders.put("critical-header2-name", "critical-header2-value"); + criticalHeaders.put("critical-header3-name", "critical-header3-value"); + }) .headers((headers) -> headers.put("custom-header-name", "custom-header-value")) .build(); // @formatter:on @@ -80,15 +83,12 @@ public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { assertThat(jwsHeader.getX509SHA256Thumbprint()).isEqualTo(expectedJwsHeader.getX509SHA256Thumbprint()); assertThat(jwsHeader.getType()).isEqualTo(expectedJwsHeader.getType()); assertThat(jwsHeader.getContentType()).isEqualTo(expectedJwsHeader.getContentType()); + assertThat(jwsHeader.getCritical()).containsExactlyInAnyOrder("critical-header1-name", "critical-header2-name", + "critical-header3-name"); + assertThat(jwsHeader.getHeader("critical-header1-name")).isEqualTo("critical-header1-value"); + assertThat(jwsHeader.getHeader("critical-header2-name")).isEqualTo("critical-header2-value"); + assertThat(jwsHeader.getHeader("critical-header3-name")).isEqualTo("critical-header3-value"); assertThat(jwsHeader.getHeader("custom-header-name")).isEqualTo("custom-header-value"); - assertThat(jwsHeader.getHeaders()).isEqualTo(expectedJwsHeader.getHeaders()); - } - - @Test - public void buildWhenMissingCriticalHeaderThenThrowIllegalStateException() { - assertThatExceptionOfType(IllegalStateException.class).isThrownBy( - () -> TestJwsHeaders.jwsHeader().critical(Collections.singleton("critical-header-name")).build()) - .withMessage("Missing critical (crit) header 'critical-header-name'."); } @Test diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java index 55d652f036e..d3871989860 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java @@ -224,8 +224,7 @@ private Builder(JweHeader headers) { @Override public JweHeader build() { - validate(); - return new JweHeader(getHeaders()); + return new JweHeader(getMergedHeaders()); } } From 7a6218fdc331b5d92c73cf140f6da9cee6bbedb8 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 21 Sep 2021 05:48:08 -0400 Subject: [PATCH 19/20] Polish "Add criticalHeader() convenience methods" --- .../security/oauth2/jwt/JwsHeader.java | 18 ++++++++++++----- .../security/oauth2/jwt/JwsHeaderTests.java | 20 ++++++++++++++++++- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java index 955a6e0bb90..c71cc18f0bd 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java @@ -16,8 +16,8 @@ package org.springframework.security.oauth2.jwt; +import java.util.HashMap; import java.util.Map; -import java.util.function.Consumer; import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; import org.springframework.util.Assert; @@ -72,10 +72,18 @@ private Builder(JwsAlgorithm jwsAlgorithm) { algorithm(jwsAlgorithm); } - private Builder(JwsHeader headers) { - Assert.notNull(headers, "headers cannot be null"); - Consumer> headersConsumer = (h) -> h.putAll(headers.getHeaders()); - headers(headersConsumer); + private Builder(JwsHeader jwsHeader) { + Assert.notNull(jwsHeader, "jwsHeader cannot be null"); + Map headers = new HashMap<>(jwsHeader.getHeaders()); + Map criticalHeaders = new HashMap<>(); + if (jwsHeader.getCritical() != null) { + jwsHeader.getCritical().forEach( + (criticalHeader) -> criticalHeaders.put(criticalHeader, jwsHeader.getHeader(criticalHeader))); + headers.keySet().removeAll(criticalHeaders.keySet()); + headers.remove(JoseHeaderNames.CRIT); + } + getHeaders().putAll(headers); + getCriticalHeaders().putAll(criticalHeaders); } /** diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java index eae43f54715..3c3db2572c0 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java @@ -22,6 +22,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.entry; /** * Tests for {@link JwsHeader}. @@ -39,7 +40,7 @@ public void withWhenNullThenThrowIllegalArgumentException() { @Test public void fromWhenNullThenThrowIllegalArgumentException() { assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JwsHeader.from(null)) - .withMessage("headers cannot be null"); + .withMessage("jwsHeader cannot be null"); } @Test @@ -49,6 +50,23 @@ public void fromWhenHeadersProvidedThenCopied() { assertThat(jwsHeader.getHeaders()).isEqualTo(expectedJwsHeader.getHeaders()); } + @Test + public void fromWhenHeadersProvidedThenCriticalHeadersCopied() { + JwsHeader expectedJwsHeader = TestJwsHeaders.jwsHeader() + .criticalHeader("critical-header1-name", "critical-header1-value") + .criticalHeaders((criticalHeaders) -> { + criticalHeaders.put("critical-header2-name", "critical-header2-value"); + criticalHeaders.put("critical-header3-name", "critical-header3-value"); + }).build(); + + JwsHeader.Builder jwsHeaderBuilder = JwsHeader.from(expectedJwsHeader); + assertThat(jwsHeaderBuilder.getHeaders()).doesNotContainKey(JoseHeaderNames.CRIT); + assertThat(jwsHeaderBuilder.getCriticalHeaders()).containsOnly( + entry("critical-header1-name", "critical-header1-value"), + entry("critical-header2-name", "critical-header2-value"), + entry("critical-header3-name", "critical-header3-value")); + } + @Test public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { JwsHeader expectedJwsHeader = TestJwsHeaders.jwsHeader().build(); From 132f833031bc8a6c3f24a5fa316a1f8d8b0e6bf0 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 21 Sep 2021 06:12:54 -0400 Subject: [PATCH 20/20] Simplify "Add criticalHeader() convenience methods" --- .../security/oauth2/jwt/JoseHeader.java | 36 +++---------------- .../security/oauth2/jwt/JwsHeader.java | 18 +++------- .../security/oauth2/jwt/JwsHeaderTests.java | 29 ++------------- .../oauth2/jwt/NimbusJweEncoderTests.java | 2 +- 4 files changed, 13 insertions(+), 72 deletions(-) diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java index 4be71daf6d0..f40fb09a558 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java @@ -19,6 +19,7 @@ import java.net.URL; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -186,8 +187,6 @@ abstract static class AbstractBuilder headers = new HashMap<>(); - private final Map criticalHeaders = new HashMap<>(); - protected AbstractBuilder() { } @@ -195,21 +194,6 @@ protected Map getHeaders() { return this.headers; } - protected Map getCriticalHeaders() { - return this.criticalHeaders; - } - - protected Map getMergedHeaders() { - if (getCriticalHeaders().isEmpty()) { - return getHeaders(); - } - Map mergedHeaders = new HashMap<>(getHeaders()); - Set crit = getCriticalHeaders().keySet(); - mergedHeaders.put(JoseHeaderNames.CRIT, crit); - mergedHeaders.putAll(getCriticalHeaders()); - return mergedHeaders; - } - @SuppressWarnings("unchecked") protected final B getThis() { return (B) this; // avoid unchecked casts in subclasses by using "getThis()" @@ -329,21 +313,11 @@ public B contentType(String contentType) { * @param value the critical header value * @return the {@link AbstractBuilder} */ + @SuppressWarnings("unchecked") public B criticalHeader(String name, Object value) { - Assert.hasText(name, "name cannot be empty"); - Assert.notNull(value, "value cannot be null"); - this.criticalHeaders.put(name, value); - return getThis(); - } - - /** - * A {@code Consumer} to be provided access to the critical headers allowing the - * ability to add, replace, or remove. - * @param headersConsumer a {@code Consumer} of the critical headers - * @return the {@link AbstractBuilder} - */ - public B criticalHeaders(Consumer> headersConsumer) { - headersConsumer.accept(this.criticalHeaders); + header(name, value); + getHeaders().computeIfAbsent(JoseHeaderNames.CRIT, (k) -> new HashSet()); + ((Set) getHeaders().get(JoseHeaderNames.CRIT)).add(name); return getThis(); } diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java index c71cc18f0bd..9b8ee4721a0 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwsHeader.java @@ -16,7 +16,6 @@ package org.springframework.security.oauth2.jwt; -import java.util.HashMap; import java.util.Map; import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; @@ -72,18 +71,9 @@ private Builder(JwsAlgorithm jwsAlgorithm) { algorithm(jwsAlgorithm); } - private Builder(JwsHeader jwsHeader) { - Assert.notNull(jwsHeader, "jwsHeader cannot be null"); - Map headers = new HashMap<>(jwsHeader.getHeaders()); - Map criticalHeaders = new HashMap<>(); - if (jwsHeader.getCritical() != null) { - jwsHeader.getCritical().forEach( - (criticalHeader) -> criticalHeaders.put(criticalHeader, jwsHeader.getHeader(criticalHeader))); - headers.keySet().removeAll(criticalHeaders.keySet()); - headers.remove(JoseHeaderNames.CRIT); - } - getHeaders().putAll(headers); - getCriticalHeaders().putAll(criticalHeaders); + private Builder(JwsHeader headers) { + Assert.notNull(headers, "headers cannot be null"); + getHeaders().putAll(headers.getHeaders()); } /** @@ -92,7 +82,7 @@ private Builder(JwsHeader jwsHeader) { */ @Override public JwsHeader build() { - return new JwsHeader(getMergedHeaders()); + return new JwsHeader(getHeaders()); } } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java index 3c3db2572c0..a1262a850f8 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwsHeaderTests.java @@ -22,7 +22,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.entry; /** * Tests for {@link JwsHeader}. @@ -40,7 +39,7 @@ public void withWhenNullThenThrowIllegalArgumentException() { @Test public void fromWhenNullThenThrowIllegalArgumentException() { assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JwsHeader.from(null)) - .withMessage("jwsHeader cannot be null"); + .withMessage("headers cannot be null"); } @Test @@ -50,23 +49,6 @@ public void fromWhenHeadersProvidedThenCopied() { assertThat(jwsHeader.getHeaders()).isEqualTo(expectedJwsHeader.getHeaders()); } - @Test - public void fromWhenHeadersProvidedThenCriticalHeadersCopied() { - JwsHeader expectedJwsHeader = TestJwsHeaders.jwsHeader() - .criticalHeader("critical-header1-name", "critical-header1-value") - .criticalHeaders((criticalHeaders) -> { - criticalHeaders.put("critical-header2-name", "critical-header2-value"); - criticalHeaders.put("critical-header3-name", "critical-header3-value"); - }).build(); - - JwsHeader.Builder jwsHeaderBuilder = JwsHeader.from(expectedJwsHeader); - assertThat(jwsHeaderBuilder.getHeaders()).doesNotContainKey(JoseHeaderNames.CRIT); - assertThat(jwsHeaderBuilder.getCriticalHeaders()).containsOnly( - entry("critical-header1-name", "critical-header1-value"), - entry("critical-header2-name", "critical-header2-value"), - entry("critical-header3-name", "critical-header3-value")); - } - @Test public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { JwsHeader expectedJwsHeader = TestJwsHeaders.jwsHeader().build(); @@ -83,10 +65,7 @@ public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { .type(expectedJwsHeader.getType()) .contentType(expectedJwsHeader.getContentType()) .criticalHeader("critical-header1-name", "critical-header1-value") - .criticalHeaders((criticalHeaders) -> { - criticalHeaders.put("critical-header2-name", "critical-header2-value"); - criticalHeaders.put("critical-header3-name", "critical-header3-value"); - }) + .criticalHeader("critical-header2-name", "critical-header2-value") .headers((headers) -> headers.put("custom-header-name", "custom-header-value")) .build(); // @formatter:on @@ -101,11 +80,9 @@ public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() { assertThat(jwsHeader.getX509SHA256Thumbprint()).isEqualTo(expectedJwsHeader.getX509SHA256Thumbprint()); assertThat(jwsHeader.getType()).isEqualTo(expectedJwsHeader.getType()); assertThat(jwsHeader.getContentType()).isEqualTo(expectedJwsHeader.getContentType()); - assertThat(jwsHeader.getCritical()).containsExactlyInAnyOrder("critical-header1-name", "critical-header2-name", - "critical-header3-name"); + assertThat(jwsHeader.getCritical()).containsExactlyInAnyOrder("critical-header1-name", "critical-header2-name"); assertThat(jwsHeader.getHeader("critical-header1-name")).isEqualTo("critical-header1-value"); assertThat(jwsHeader.getHeader("critical-header2-name")).isEqualTo("critical-header2-value"); - assertThat(jwsHeader.getHeader("critical-header3-name")).isEqualTo("critical-header3-value"); assertThat(jwsHeader.getHeader("custom-header-name")).isEqualTo("custom-header-value"); } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java index d3871989860..cfb1d7cd11c 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJweEncoderTests.java @@ -224,7 +224,7 @@ private Builder(JweHeader headers) { @Override public JweHeader build() { - return new JweHeader(getMergedHeaders()); + return new JweHeader(getHeaders()); } }