Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch JWT library implementations from cxf to nimbus #3421

Merged
merged 37 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
1c75b4a
Replace JWT library with Nimbus Jose + JWT
peternied Aug 25, 2023
e9e5457
Merge remote-tracking branch 'peternied/nimbus-jose-jwt'
MaciejMierzwa Sep 26, 2023
68fb56c
test jwt content
MaciejMierzwa Sep 26, 2023
48658dd
swap cxf jwt to nimbus jwt
MaciejMierzwa Sep 28, 2023
93c1bce
remove all usages of cxf.rs.security.jose
MaciejMierzwa Sep 29, 2023
5fc3b8b
tests, encoding fixes
MaciejMierzwa Oct 2, 2023
1eb397c
naming, add padding to JwtVendor secret
MaciejMierzwa Oct 3, 2023
9a2ef33
small refactor, spotless, tests, use raw settings to create jwk
MaciejMierzwa Oct 3, 2023
0ee2de6
Merge remote-tracking branch 'origin/main' into nimbus-jose-jwt
MaciejMierzwa Oct 3, 2023
f52ca23
Merge remote-tracking branch 'origin/main' into nimbus-jose-jwt
MaciejMierzwa Oct 3, 2023
46eb723
test build after merge
MaciejMierzwa Oct 3, 2023
10fe305
revert misc changes
MaciejMierzwa Oct 4, 2023
da51bec
correct HMAC padding, escape chars in tests
MaciejMierzwa Oct 4, 2023
c151696
PR changes, style
MaciejMierzwa Oct 4, 2023
a597cf5
remove org.apache.cxf:cxf-rt-rs-security-jose import, add rule forbid…
MaciejMierzwa Oct 4, 2023
7e2c6ca
PR suggestions, null checks, java.util.Date
MaciejMierzwa Oct 4, 2023
88de2cc
Merge branch 'main_origin' into nimbus-jose-jwt
MaciejMierzwa Oct 5, 2023
48c3b5a
PR suggestions, spotless
MaciejMierzwa Oct 5, 2023
a245f58
Merge branch 'main_origin' into nimbus-jose-jwt
MaciejMierzwa Oct 5, 2023
3002f11
Exception -> IllegalArgumentException
MaciejMierzwa Oct 5, 2023
f0e19bd
Merge remote-tracking branch 'origin/main' into nimbus-jose-jwt
MaciejMierzwa Oct 6, 2023
0d21bd1
Merge remote-tracking branch 'origin/main' into nimbus-jose-jwt
MaciejMierzwa Oct 16, 2023
8d210a1
Class raw use fix
MaciejMierzwa Oct 17, 2023
480ba8a
Fix the seconds into milli seconds in jwt vendor
RyanL1997 Oct 20, 2023
f74edd9
Fixed obo integ test
RyanL1997 Oct 20, 2023
c899710
Refactor the matcher library
RyanL1997 Oct 20, 2023
1d5fcb4
Fix saml authenticator test
RyanL1997 Oct 20, 2023
2848560
Add padding back but not for obo
RyanL1997 Oct 23, 2023
c9fd75f
test cxf, nimbus compability
MaciejMierzwa Oct 24, 2023
4402b14
default encoding fix
MaciejMierzwa Oct 24, 2023
de31e00
default encoding fix
MaciejMierzwa Oct 24, 2023
2bdd1de
Add tests and relocate KeyPaddingUtil
RyanL1997 Oct 24, 2023
85d7eaa
Add comment for cxf code generation
RyanL1997 Oct 24, 2023
2c2560b
Reloacate the comment for cxf lib
RyanL1997 Oct 24, 2023
78e94fa
Fix testParsePrevGeneratedJwt()
RyanL1997 Oct 24, 2023
1004595
Fix the comment description
RyanL1997 Oct 24, 2023
9c667e4
Revert the changes for padding test config
RyanL1997 Oct 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,6 @@ test {
if (JavaVersion.current() > JavaVersion.VERSION_1_8) {
jvmArgs += "--add-opens=java.base/java.io=ALL-UNNAMED"
}
retry {
failOnPassedAfterRetry = false
maxRetries = 5
}
jacoco {
excludes = [
"com.sun.jndi.dns.*",
Expand Down Expand Up @@ -245,10 +241,10 @@ def setCommonTestConfig(Test task) {
if (JavaVersion.current() > JavaVersion.VERSION_1_8) {
task.jvmArgs += "--add-opens=java.base/java.io=ALL-UNNAMED"
}
task.retry {
failOnPassedAfterRetry = false
maxRetries = 5
}
// task.retry {
// failOnPassedAfterRetry = false
// maxRetries = 5
// }
task.jacoco {
excludes = [
"com.sun.jndi.dns.*",
Expand Down Expand Up @@ -464,11 +460,11 @@ task integrationTest(type: Test) {
systemProperty "java.util.logging.manager", "org.apache.logging.log4j.jul.LogManager"
testClassesDirs = sourceSets.integrationTest.output.classesDirs
classpath = sourceSets.integrationTest.runtimeClasspath
retry {
failOnPassedAfterRetry = false
maxRetries = 2
maxFailures = 10
}
// retry {
// failOnPassedAfterRetry = false
// maxRetries = 2
// maxFailures = 10
// }
MaciejMierzwa marked this conversation as resolved.
Show resolved Hide resolved
//run the integrationTest task after the test task
shouldRunAfter test
}
Expand All @@ -488,6 +484,7 @@ dependencies {
implementation 'commons-cli:commons-cli:1.5.0'
implementation "org.bouncycastle:bcprov-jdk15to18:${versions.bouncycastle}"
implementation 'org.ldaptive:ldaptive:1.2.3'
implementation 'com.nimbusds:nimbus-jose-jwt:9.31'
MaciejMierzwa marked this conversation as resolved.
Show resolved Hide resolved

//JWT
implementation "io.jsonwebtoken:jjwt-api:${jjwt_version}"
Expand Down
7 changes: 7 additions & 0 deletions checkstyle/checkstyle.xml
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@
<property name="severity" value="error"/>
</module>

<module name="RegexpSingleline">
<property name="format" value="println"/>
<property name="ignoreCase" value="true"/>
<property name="message" value="SYSTEM.OUT.PRINTLN NONONO!" />
<property name="severity" value="error"/>
</module>
MaciejMierzwa marked this conversation as resolved.
Show resolved Hide resolved

<module name="RegexpSingleline">
<property name="format" value="extension"/>
<property name="ignoreCase" value="true"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.text.ParseException;
import java.util.Collection;
import java.util.Map.Entry;
import java.util.regex.Pattern;

import com.google.common.annotations.VisibleForTesting;
import org.apache.cxf.rs.security.jose.jwt.JwtClaims;
import org.apache.cxf.rs.security.jose.jwt.JwtToken;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand Down Expand Up @@ -108,7 +109,7 @@ private AuthCredentials extractCredentials0(final RestRequest request) throws Op
return null;
}

JwtToken jwt;
SignedJWT jwt;

try {
jwt = jwtVerifier.getVerifiedJwtToken(jwtString);
Expand All @@ -120,7 +121,13 @@ private AuthCredentials extractCredentials0(final RestRequest request) throws Op
return null;
}

JwtClaims claims = jwt.getClaims();
JWTClaimsSet claims = null;
MaciejMierzwa marked this conversation as resolved.
Show resolved Hide resolved
try {
claims = jwt.getJWTClaimsSet();
} catch (ParseException e) {
log.info("Extracting JWT token from {} failed", jwtString, e);
return null;
}

final String subject = extractSubject(claims);

Expand All @@ -133,7 +140,7 @@ private AuthCredentials extractCredentials0(final RestRequest request) throws Op

final AuthCredentials ac = new AuthCredentials(subject, roles).markComplete();

for (Entry<String, Object> claim : claims.asMap().entrySet()) {
for (Entry<String, Object> claim : claims.getClaims().entrySet()) {
ac.addAttribute("attr.jwt." + claim.getKey(), String.valueOf(claim.getValue()));
}

Expand Down Expand Up @@ -170,7 +177,7 @@ protected String getJwtTokenString(RestRequest request) {
}

@VisibleForTesting
public String extractSubject(JwtClaims claims) {
public String extractSubject(JWTClaimsSet claims) {
String subject = claims.getSubject();

if (subjectKey != null) {
Expand Down Expand Up @@ -200,7 +207,7 @@ public String extractSubject(JwtClaims claims) {

@SuppressWarnings("unchecked")
@VisibleForTesting
public String[] extractRoles(JwtClaims claims) {
public String[] extractRoles(JWTClaimsSet claims) {
if (rolesKey == null) {
return new String[0];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,24 @@
package com.amazon.dlic.auth.http.jwt.keybyoidc;

import com.google.common.base.Strings;
import com.nimbusds.jose.Algorithm;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.crypto.Ed25519Verifier;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.KeyType;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.OctetKeyPair;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.BadJWTException;
import org.apache.commons.lang3.StringEscapeUtils;
import org.apache.cxf.rs.security.jose.jwa.SignatureAlgorithm;
import org.apache.cxf.rs.security.jose.jwk.JsonWebKey;
import org.apache.cxf.rs.security.jose.jwk.KeyType;
import org.apache.cxf.rs.security.jose.jwk.PublicKeyUse;
import org.apache.cxf.rs.security.jose.jws.JwsJwtCompactConsumer;
import org.apache.cxf.rs.security.jose.jws.JwsSignatureVerifier;
import org.apache.cxf.rs.security.jose.jws.JwsUtils;
import org.apache.cxf.rs.security.jose.jwt.JwtClaims;
import org.apache.cxf.rs.security.jose.jwt.JwtException;
import org.apache.cxf.rs.security.jose.jwt.JwtToken;
import org.apache.cxf.rs.security.jose.jwt.JwtUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.text.ParseException;
import java.util.List;

public class JwtVerifier {

private final static Logger log = LogManager.getLogger(JwtVerifier.class);
Expand All @@ -43,31 +46,29 @@ public JwtVerifier(KeyProvider keyProvider, int clockSkewToleranceSeconds, Strin
this.requiredAudience = requiredAudience;
}

public JwtToken getVerifiedJwtToken(String encodedJwt) throws BadCredentialsException {
public SignedJWT getVerifiedJwtToken(String encodedJwt) throws BadCredentialsException {
try {
JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(encodedJwt);
JwtToken jwt = jwtConsumer.getJwtToken();
SignedJWT jwt = SignedJWT.parse(encodedJwt);

String escapedKid = jwt.getJwsHeaders().getKeyId();
String escapedKid = jwt.getHeader().getKeyID();
String kid = escapedKid;
if (!Strings.isNullOrEmpty(kid)) {
kid = StringEscapeUtils.unescapeJava(escapedKid);
}
JsonWebKey key = keyProvider.getKey(kid);
JWK key = keyProvider.getKey(kid);

// Algorithm is not mandatory for the key material, so we set it to the same as the JWT
if (key.getAlgorithm() == null && key.getPublicKeyUse() == PublicKeyUse.SIGN && key.getKeyType() == KeyType.RSA) {
key.setAlgorithm(jwt.getJwsHeaders().getAlgorithm());
// TODO algorithm is final in jose implementation. Algorithm is not mandatory for the key material, so we set it to the same as the JWT, check if it's even necessary
if (key.getAlgorithm() == null && key.getKeyUse() == KeyUse.SIGNATURE && key.getKeyType() == KeyType.RSA) {
// key.setAlgorithm(jwt.getJwsHeaders().getAlgorithm());
MaciejMierzwa marked this conversation as resolved.
Show resolved Hide resolved
}

JwsSignatureVerifier signatureVerifier = getInitializedSignatureVerifier(key, jwt);

boolean signatureValid = jwtConsumer.verifySignatureWith(signatureVerifier);
JWSVerifier signatureVerifier = getInitializedSignatureVerifier(key, jwt);
boolean signatureValid = jwt.verify(signatureVerifier);

if (!signatureValid && Strings.isNullOrEmpty(kid)) {
key = keyProvider.getKeyAfterRefresh(null);
signatureVerifier = getInitializedSignatureVerifier(key, jwt);
signatureValid = jwtConsumer.verifySignatureWith(signatureVerifier);
signatureValid = jwt.verify(signatureVerifier);
peternied marked this conversation as resolved.
Show resolved Hide resolved
}

if (!signatureValid) {
Expand All @@ -77,18 +78,20 @@ public JwtToken getVerifiedJwtToken(String encodedJwt) throws BadCredentialsExce
validateClaims(jwt);

return jwt;
} catch (JwtException e) {
} catch (JOSEException | ParseException e) {
throw new BadCredentialsException(e.getMessage(), e);
} catch (BadJWTException e) {
throw new RuntimeException(e);
}
}

private void validateSignatureAlgorithm(JsonWebKey key, JwtToken jwt) throws BadCredentialsException {
if (Strings.isNullOrEmpty(key.getAlgorithm())) {
private void validateSignatureAlgorithm(JWK key, SignedJWT jwt) throws BadCredentialsException {
if (key.getAlgorithm() == null || jwt.getHeader().getAlgorithm() == null) {
return;
}

SignatureAlgorithm keyAlgorithm = SignatureAlgorithm.getAlgorithm(key.getAlgorithm());
SignatureAlgorithm tokenAlgorithm = SignatureAlgorithm.getAlgorithm(jwt.getJwsHeaders().getAlgorithm());
Algorithm keyAlgorithm = key.getAlgorithm();
Algorithm tokenAlgorithm = jwt.getHeader().getAlgorithm();

if (!keyAlgorithm.equals(tokenAlgorithm)) {
throw new BadCredentialsException(
Expand All @@ -97,38 +100,41 @@ private void validateSignatureAlgorithm(JsonWebKey key, JwtToken jwt) throws Bad
}
}

private JwsSignatureVerifier getInitializedSignatureVerifier(JsonWebKey key, JwtToken jwt) throws BadCredentialsException,
JwtException {
private JWSVerifier getInitializedSignatureVerifier(JWK key, SignedJWT jwt) throws BadCredentialsException, JOSEException {

validateSignatureAlgorithm(key, jwt);
JwsSignatureVerifier result = JwsUtils.getSignatureVerifier(key, jwt.getJwsHeaders().getSignatureAlgorithm());
if(key.getClass() != OctetKeyPair.class) {
throw new BadCredentialsException("Cannot verify JWT");
}
JWSVerifier result = new Ed25519Verifier((OctetKeyPair) key);
if (result == null) {
throw new BadCredentialsException("Cannot verify JWT");
} else {
return result;
}
}

private void validateClaims(JwtToken jwt) throws JwtException {
JwtClaims claims = jwt.getClaims();
private void validateClaims(SignedJWT jwt) throws ParseException, BadJWTException {
JWTClaimsSet claims = jwt.getJWTClaimsSet();

if (claims != null) {
JwtUtils.validateJwtExpiry(claims, clockSkewToleranceSeconds, false);
JwtUtils.validateJwtNotBefore(claims, clockSkewToleranceSeconds, false);
//TODO
// JwtUtils.validateJwtExpiry(claims, clockSkewToleranceSeconds, false);
// JwtUtils.validateJwtNotBefore(claims, clockSkewToleranceSeconds, false);
validateRequiredAudienceAndIssuer(claims);
}
}

private void validateRequiredAudienceAndIssuer(JwtClaims claims) {
String audience = claims.getAudience();
private void validateRequiredAudienceAndIssuer(JWTClaimsSet claims) throws BadJWTException {
List<String> audience = claims.getAudience();
String issuer = claims.getIssuer();

if (!Strings.isNullOrEmpty(requiredAudience) && !requiredAudience.equals(audience)) {
throw new JwtException("Invalid audience");
if (!Strings.isNullOrEmpty(requiredAudience) && !requiredAudience.equals(audience.stream().findFirst().orElse(""))) {
willyborankin marked this conversation as resolved.
Show resolved Hide resolved
throw new BadJWTException("Invalid audience");
}

if (!Strings.isNullOrEmpty(requiredIssuer) && !requiredIssuer.equals(issuer)) {
throw new JwtException("Invalid issuer");
throw new BadJWTException("Invalid issuer");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@

package com.amazon.dlic.auth.http.jwt.keybyoidc;

import org.apache.cxf.rs.security.jose.jwk.JsonWebKey;
import com.nimbusds.jose.jwk.JWK;

public interface KeyProvider {
public JsonWebKey getKey(String kid) throws AuthenticatorUnavailableException, BadCredentialsException;

public JsonWebKey getKeyAfterRefresh(String kid) throws AuthenticatorUnavailableException, BadCredentialsException;
JWK getKey(String kid) throws AuthenticatorUnavailableException, BadCredentialsException;
JWK getKeyAfterRefresh(String kid) throws AuthenticatorUnavailableException, BadCredentialsException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

package com.amazon.dlic.auth.http.jwt.keybyoidc;

import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys;
import com.nimbusds.jose.jwk.JWKSet;

@FunctionalInterface
public interface KeySetProvider {
JsonWebKeys get() throws AuthenticatorUnavailableException;
JWKSet get() throws AuthenticatorUnavailableException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
package com.amazon.dlic.auth.http.jwt.keybyoidc;

import java.io.IOException;
import java.text.ParseException;
import java.util.concurrent.TimeUnit;

import com.nimbusds.jose.jwk.JWKSet;
import joptsimple.internal.Strings;
import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys;
import org.apache.cxf.rs.security.jose.jwk.JwkUtils;
import org.apache.hc.client5.http.cache.HttpCacheContext;
import org.apache.hc.client5.http.cache.HttpCacheStorage;
import org.apache.hc.client5.http.classic.methods.HttpGet;
Expand Down Expand Up @@ -70,7 +70,7 @@ public class KeySetRetriever implements KeySetProvider {
configureCache(useCacheForOidConnectEndpoint);
}

public JsonWebKeys get() throws AuthenticatorUnavailableException {
public JWKSet get() throws AuthenticatorUnavailableException {
String uri = getJwksUri();

try (CloseableHttpClient httpClient = createHttpClient(null)) {
Expand All @@ -94,10 +94,13 @@ public JsonWebKeys get() throws AuthenticatorUnavailableException {
if (httpEntity == null) {
throw new AuthenticatorUnavailableException("Error while getting " + uri + ": Empty response entity");
}

JsonWebKeys keySet = JwkUtils.readJwkSet(httpEntity.getContent());
//TODO
JWKSet keySet = JWKSet.load(httpEntity.getContent());
// JWKSet keySet = JwkUtils.readJwkSet(httpEntity.getContent());

return keySet;
} catch (ParseException e) {
throw new RuntimeException(e);
}
} catch (IOException e) {
throw new AuthenticatorUnavailableException("Error while getting " + uri + ": " + e, e);
Expand Down
Loading