From 7db244b3f5c55947e2001ea337878e8a10c3fa72 Mon Sep 17 00:00:00 2001 From: shawn-sher <5396793+shawn-sher@users.noreply.github.com> Date: Wed, 25 Jan 2023 13:49:25 -0800 Subject: [PATCH] Feat/add jwt exchange (#1067) * added endpoint for access token * Refined and updated tests * Added a feature flag to exchange endpoint * Corrected SpEL Value string Co-authored-by: Shawn Sherwood --- cerberus-auth-connector-okta/build.gradle | 4 + .../connector/okta/OktaAuthConnector.java | 90 +++++++++++++++++- .../connector/okta/OktaAuthConnectorTest.java | 91 ++++++++++++++++++- .../onelogin/OneLoginAuthConnector.java | 7 ++ .../onelogin/OneLoginAuthConnectorTest.java | 12 +++ .../auth/connector/AuthConnector.java | 3 + .../nike/cerberus/error/DefaultApiError.java | 3 + cerberus-web/build.gradle | 4 + .../UserAuthenticationController.java | 40 +++++++- .../security/WebSecurityConfiguration.java | 1 + .../service/AuthenticationService.java | 26 ++++++ .../com/nike/cerberus/service/JwtService.java | 8 +- .../UserAuthenticationControllerTest.java | 88 +++++++++++++++++- .../service/AuthenticationServiceTest.java | 88 +++++++++++++++++- .../nike/cerberus/service/JwtServiceTest.java | 2 +- 15 files changed, 458 insertions(+), 9 deletions(-) diff --git a/cerberus-auth-connector-okta/build.gradle b/cerberus-auth-connector-okta/build.gradle index 1c26bf486..6a0ccca8e 100644 --- a/cerberus-auth-connector-okta/build.gradle +++ b/cerberus-auth-connector-okta/build.gradle @@ -26,4 +26,8 @@ dependencies { // The Okta SDKs pull in an outdated version of guava that the OWASP Dep checker doesn't like implementation group: 'com.google.guava', name: 'guava', version: "${versions.guava}" + + // Okta jwt verfier libraries + implementation 'com.okta.jwt:okta-jwt-verifier:0.5.7' + implementation 'com.okta.jwt:okta-jwt-verifier-impl:0.5.7' } diff --git a/cerberus-auth-connector-okta/src/main/java/com/nike/cerberus/auth/connector/okta/OktaAuthConnector.java b/cerberus-auth-connector-okta/src/main/java/com/nike/cerberus/auth/connector/okta/OktaAuthConnector.java index 02c0c2614..dbb07bee6 100644 --- a/cerberus-auth-connector-okta/src/main/java/com/nike/cerberus/auth/connector/okta/OktaAuthConnector.java +++ b/cerberus-auth-connector-okta/src/main/java/com/nike/cerberus/auth/connector/okta/OktaAuthConnector.java @@ -19,6 +19,7 @@ import static java.lang.Thread.sleep; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; import com.nike.backstopper.exception.ApiException; import com.nike.cerberus.auth.connector.AuthConnector; import com.nike.cerberus.auth.connector.AuthData; @@ -30,16 +31,22 @@ import com.okta.authn.sdk.FactorValidationException; import com.okta.authn.sdk.client.AuthenticationClient; import com.okta.authn.sdk.impl.resource.DefaultVerifyPassCodeFactorRequest; +import com.okta.jwt.AccessTokenVerifier; +import com.okta.jwt.Jwt; +import com.okta.jwt.JwtVerificationException; +import com.okta.jwt.JwtVerifiers; import com.okta.sdk.authc.credentials.TokenClientCredentials; import com.okta.sdk.client.Client; import com.okta.sdk.client.Clients; import com.okta.sdk.resource.group.GroupList; import com.okta.sdk.resource.user.User; import java.util.HashSet; +import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; /** Okta version 1 API implementation of the AuthConnector interface. */ @@ -50,18 +57,36 @@ public class OktaAuthConnector implements AuthConnector { private final Client sdkClient; + private final String jwtIssuer; + + private final String jwtAudience; + + protected AccessTokenVerifier jwtVerifier; + @Autowired public OktaAuthConnector( AuthenticationClient oktaAuthenticationClient, - OktaConfigurationProperties oktaConfigurationProperties) { + OktaConfigurationProperties oktaConfigurationProperties, + @Value("${cerberus.auth.jwt.issuer}") String jwtIssuer, + @Value("${cerberus.auth.jwt.audience}") String jwtAudience) { this.oktaAuthenticationClient = oktaAuthenticationClient; this.sdkClient = getSdkClient(oktaConfigurationProperties); + this.jwtIssuer = jwtIssuer; + this.jwtAudience = jwtAudience; } /** Alternate constructor to facilitate unit testing */ - public OktaAuthConnector(AuthenticationClient oktaAuthenticationClient, Client sdkClient) { + public OktaAuthConnector( + AuthenticationClient oktaAuthenticationClient, + Client sdkClient, + String jwtIssuer, + String jwtAudience, + AccessTokenVerifier jwtVerifier) { this.oktaAuthenticationClient = oktaAuthenticationClient; this.sdkClient = sdkClient; + this.jwtIssuer = jwtIssuer; + this.jwtAudience = jwtAudience; + this.jwtVerifier = jwtVerifier; } private Client getSdkClient(OktaConfigurationProperties oktaConfigurationProperties) { @@ -209,4 +234,65 @@ public Set getGroups(AuthData authData) { return groups; } + + /** + * Validates a JWT and retunrs the subject and userId in a map + * + * @param jwtString String jwt access token + * @return Map of username and userId + * @throws ApiException if JWT cannot be verified + */ + @Override + public Map getValidatedUserPrincipal(String jwtString) { + try { + Jwt jwt = this.getAccessTokenVerifier().decode(jwtString); + Map claims = jwt.getClaims(); + + String username = claims.getOrDefault("sub", "").toString(); + String userId = claims.getOrDefault("uid", "").toString(); + + if (username.isEmpty() || userId.isEmpty()) { + throw new JwtVerificationException("sub and uid claims are required"); + } + + Map principalInfoMap = + ImmutableMap.of("username", username, "userId", userId); + return principalInfoMap; + } catch (JwtVerificationException jve) { + throw this.buildJwtVerificationApiException(jve, "Failed to verify JWT access token"); + } + } + + /** + * Convert JwtVerificationException to ApiException + * + * @param jve JwtVerificationException + * @param msg Message + * @return ApiException + */ + private ApiException buildJwtVerificationApiException(JwtVerificationException jve, String msg) { + ApiException exc = + ApiException.Builder.newBuilder() + .withApiErrors(DefaultApiError.BEARER_TOKEN_INVALID) + .withExceptionMessage(msg) + .withExceptionCause(jve) + .build(); + return exc; + } + + /** + * Creates an access token verifier with the configured issuer and audience + * + * @return AccessTokenVerifier + */ + protected AccessTokenVerifier getAccessTokenVerifier() { + if (this.jwtVerifier == null) { + this.jwtVerifier = + JwtVerifiers.accessTokenVerifierBuilder() + .setIssuer(this.jwtIssuer) + .setAudience(this.jwtAudience) + .build(); + } + return this.jwtVerifier; + } } diff --git a/cerberus-auth-connector-okta/src/test/java/com/nike/cerberus/auth/connector/okta/OktaAuthConnectorTest.java b/cerberus-auth-connector-okta/src/test/java/com/nike/cerberus/auth/connector/okta/OktaAuthConnectorTest.java index 3aa2acf78..8ff70e93b 100644 --- a/cerberus-auth-connector-okta/src/test/java/com/nike/cerberus/auth/connector/okta/OktaAuthConnectorTest.java +++ b/cerberus-auth-connector-okta/src/test/java/com/nike/cerberus/auth/connector/okta/OktaAuthConnectorTest.java @@ -17,7 +17,9 @@ package com.nike.cerberus.auth.connector.okta; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.*; import static org.mockito.MockitoAnnotations.initMocks; @@ -29,7 +31,12 @@ import com.nike.cerberus.auth.connector.okta.statehandlers.MfaStateHandler; import com.okta.authn.sdk.client.AuthenticationClient; import com.okta.authn.sdk.impl.resource.DefaultVerifyPassCodeFactorRequest; +import com.okta.jwt.AccessTokenVerifier; +import com.okta.jwt.Jwt; +import com.okta.jwt.JwtVerificationException; import com.okta.sdk.client.Client; +import java.util.HashMap; +import java.util.Map; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -48,7 +55,9 @@ public void setup() { initMocks(this); - oktaAuthConnector = new OktaAuthConnector(client, sdkClient); + this.oktaAuthConnector = + new OktaAuthConnector( + client, sdkClient, "https://foo.bar", "dogs", mock(AccessTokenVerifier.class)); } ///////////////////////// @@ -238,4 +247,84 @@ public void mfaCheckFails() throws Exception { // verify results assertEquals(expectedResponse, actualResponse); } + + @Test + public void testGetValidatedOktaPrincipalOkay() { + try { + Map claims = new HashMap(); + claims.put("sub", "tester"); + claims.put("uid", "freeter"); + + Jwt mockJwt = mock(Jwt.class); + when(mockJwt.getClaims()).thenReturn(claims); + AccessTokenVerifier verifier = mock(AccessTokenVerifier.class); + when(verifier.decode(anyString())).thenReturn(mockJwt); + OktaAuthConnector connector = + new OktaAuthConnector( + client, sdkClient, "https://foo.bar/oauth2/skiddleydee", "dogs", verifier); + Map principal = connector.getValidatedUserPrincipal("us"); + + assertEquals(principal.get("username"), "tester"); + assertEquals(principal.get("userId"), "freeter"); + } catch (JwtVerificationException jve) { + assert false; + } + } + + @Test(expected = ApiException.class) + public void testGetValidatedOktaPrincipalMissingUserId() { + try { + Map claims = new HashMap(); + claims.put("sub", "tester"); + // claims.put("uid", "freeter"); + + Jwt mockJwt = mock(Jwt.class); + when(mockJwt.getClaims()).thenReturn(claims); + AccessTokenVerifier verifier = mock(AccessTokenVerifier.class); + when(verifier.decode(anyString())).thenReturn(mockJwt); + OktaAuthConnector connector = + new OktaAuthConnector( + client, sdkClient, "https://foo.bar/oauth2/skiddleydee", "dogs", verifier); + Map principal = connector.getValidatedUserPrincipal("us"); + + assertEquals(principal.get("username"), "tester"); + assertEquals(principal.get("userId"), "freeter"); + } catch (JwtVerificationException jve) { + assert false; + } + } + + @Test(expected = ApiException.class) + public void testGetValidatedOktaPrincipalBadClaims() { + try { + Map claims = new HashMap(); + + Jwt mockJwt = mock(Jwt.class); + when(mockJwt.getClaims()).thenReturn(claims); + AccessTokenVerifier verifier = mock(AccessTokenVerifier.class); + when(verifier.decode(anyString())).thenReturn(mockJwt); + OktaAuthConnector connector = + new OktaAuthConnector( + client, sdkClient, "https://foo.bar/oauth2/skiddleydee", "dogs", verifier); + + connector.getValidatedUserPrincipal("us"); + } catch (JwtVerificationException jve) { + assert false; + } + } + + @Test + public void testGetAccessTokenVerifierInitialNull() { + OktaAuthConnector connector = + new OktaAuthConnector( + client, sdkClient, "https://foo.bar/oauth2/skiddleydee", "dogs", null); + AccessTokenVerifier verifier = connector.getAccessTokenVerifier(); + assertNotNull(verifier); + } + + @Test + public void testGetAccessTokenVerifier() { + AccessTokenVerifier verifier = this.oktaAuthConnector.getAccessTokenVerifier(); + assertNotNull(verifier); + } } diff --git a/cerberus-auth-connector-onelogin/src/main/java/com/nike/cerberus/auth/connector/onelogin/OneLoginAuthConnector.java b/cerberus-auth-connector-onelogin/src/main/java/com/nike/cerberus/auth/connector/onelogin/OneLoginAuthConnector.java index 2a745ef7b..26ef2c312 100644 --- a/cerberus-auth-connector-onelogin/src/main/java/com/nike/cerberus/auth/connector/onelogin/OneLoginAuthConnector.java +++ b/cerberus-auth-connector-onelogin/src/main/java/com/nike/cerberus/auth/connector/onelogin/OneLoginAuthConnector.java @@ -22,7 +22,9 @@ import com.nike.cerberus.auth.connector.*; import com.nike.cerberus.error.DefaultApiError; import java.util.HashSet; +import java.util.Map; import java.util.Set; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.lang3.StringEscapeUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; @@ -255,4 +257,9 @@ protected SessionLoginTokenData createSessionLoginToken( return createSessionLoginTokenResponse.getData().get(0); } + + @Override + public Map getValidatedUserPrincipal(String jwtString) { + throw new NotImplementedException("Not implemented for OneLogin"); + } } diff --git a/cerberus-auth-connector-onelogin/src/test/java/com/nike/cerberus/auth/connector/onelogin/OneLoginAuthConnectorTest.java b/cerberus-auth-connector-onelogin/src/test/java/com/nike/cerberus/auth/connector/onelogin/OneLoginAuthConnectorTest.java index 2ea38e812..9b5779281 100644 --- a/cerberus-auth-connector-onelogin/src/test/java/com/nike/cerberus/auth/connector/onelogin/OneLoginAuthConnectorTest.java +++ b/cerberus-auth-connector-onelogin/src/test/java/com/nike/cerberus/auth/connector/onelogin/OneLoginAuthConnectorTest.java @@ -28,6 +28,7 @@ import com.nike.cerberus.auth.connector.AuthStatus; import com.nike.cerberus.error.DefaultApiError; import java.util.Set; +import org.apache.commons.lang3.NotImplementedException; import org.junit.Before; import org.junit.Test; @@ -369,4 +370,15 @@ public void test_createSessionLoginToken_fails_with_when_MFA_setup_is_required() MFA_SETUP_REQUIRED.getHttpStatusCode(), ae.getApiErrors().get(0).getHttpStatusCode()); } } + + @Test + public void testgetValidatedUserPrincipalNotImplemented() { + NotImplementedException nie = null; + try { + oneLoginAuthConnector.getValidatedUserPrincipal("this won't work"); + } catch (NotImplementedException exc) { + nie = exc; + } + assertNotNull(nie); + } } diff --git a/cerberus-core/src/main/java/com/nike/cerberus/auth/connector/AuthConnector.java b/cerberus-core/src/main/java/com/nike/cerberus/auth/connector/AuthConnector.java index 365a03550..444beb53f 100644 --- a/cerberus-core/src/main/java/com/nike/cerberus/auth/connector/AuthConnector.java +++ b/cerberus-core/src/main/java/com/nike/cerberus/auth/connector/AuthConnector.java @@ -16,6 +16,7 @@ package com.nike.cerberus.auth.connector; +import java.util.Map; import java.util.Set; public interface AuthConnector { @@ -29,4 +30,6 @@ public interface AuthConnector { AuthResponse mfaCheck(final String stateToken, final String deviceId, final String otpToken); Set getGroups(final AuthData data); + + Map getValidatedUserPrincipal(String jwtString); } diff --git a/cerberus-core/src/main/java/com/nike/cerberus/error/DefaultApiError.java b/cerberus-core/src/main/java/com/nike/cerberus/error/DefaultApiError.java index f7296c19c..98db2eda8 100644 --- a/cerberus-core/src/main/java/com/nike/cerberus/error/DefaultApiError.java +++ b/cerberus-core/src/main/java/com/nike/cerberus/error/DefaultApiError.java @@ -51,6 +51,9 @@ public enum DefaultApiError implements ApiError { AUTH_TOKEN_INVALID( 99105, "X-Vault-Token or X-Cerberus-Token header is malformed or invalid.", SC_UNAUTHORIZED), + /** Authorization Bearer header was blank or invalid. */ + BEARER_TOKEN_INVALID(99100, "Authorization Bearer header was blank or invalid.", SC_UNAUTHORIZED), + /** Supplied credentials are invalid. */ AUTH_BAD_CREDENTIALS(99106, "Invalid credentials", SC_UNAUTHORIZED), diff --git a/cerberus-web/build.gradle b/cerberus-web/build.gradle index 30b6ca031..f5660a5ca 100644 --- a/cerberus-web/build.gradle +++ b/cerberus-web/build.gradle @@ -73,6 +73,10 @@ dependencies { implementation "io.jsonwebtoken:jjwt-api:${versions.jjwt}" implementation "io.jsonwebtoken:jjwt-impl:${versions.jjwt}" implementation "io.jsonwebtoken:jjwt-jackson:${versions.jjwt}" + implementation 'com.okta.jwt:okta-jwt-verifier:0.5.7' + implementation 'com.okta.jwt:okta-jwt-verifier-impl:0.5.7' + + //dist tracing diff --git a/cerberus-web/src/main/java/com/nike/cerberus/controller/authentication/UserAuthenticationController.java b/cerberus-web/src/main/java/com/nike/cerberus/controller/authentication/UserAuthenticationController.java index fb3f53f71..4c1bb7f03 100644 --- a/cerberus-web/src/main/java/com/nike/cerberus/controller/authentication/UserAuthenticationController.java +++ b/cerberus-web/src/main/java/com/nike/cerberus/controller/authentication/UserAuthenticationController.java @@ -29,12 +29,14 @@ import com.nike.cerberus.security.CerberusPrincipal; import com.nike.cerberus.service.AuthenticationService; import java.nio.charset.Charset; +import java.util.Locale; import javax.validation.Valid; import lombok.extern.slf4j.Slf4j; import org.apache.commons.codec.binary.Base64; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; import org.springframework.http.HttpHeaders; import org.springframework.security.core.Authentication; import org.springframework.web.bind.annotation.RequestBody; @@ -47,15 +49,20 @@ @RequestMapping("/v2/auth") public class UserAuthenticationController { + private static final String BEARER_AUTH_PREFIX = "bearer"; private final AuthenticationService authenticationService; private final AuditLoggingFilterDetails auditLoggingFilterDetails; + private final boolean accessTokenExchangeEnabled; @Autowired public UserAuthenticationController( AuthenticationService authenticationService, - AuditLoggingFilterDetails auditLoggingFilterDetails) { + AuditLoggingFilterDetails auditLoggingFilterDetails, + @Value("${cerberus.auth.jwt.accessTokenExchangeEnabled:false}") + boolean accessTokenExchangeEnabled) { this.authenticationService = authenticationService; this.auditLoggingFilterDetails = auditLoggingFilterDetails; + this.accessTokenExchangeEnabled = accessTokenExchangeEnabled; } @RequestMapping(value = "/user", method = GET) @@ -76,6 +83,37 @@ public AuthResponse authenticate( return authResponse; } + @RequestMapping(value = "/exchange", method = GET) + public AuthResponse exchangeToken( + @RequestHeader(value = HttpHeaders.AUTHORIZATION) String authHeader) { + + if (!this.accessTokenExchangeEnabled) { + throw ApiException.Builder.newBuilder() + .withApiErrors(DefaultApiError.ENTITY_NOT_FOUND) + .build(); + } + + if (authHeader == null || !authHeader.toLowerCase(Locale.ROOT).startsWith(BEARER_AUTH_PREFIX)) { + final String msg = "Wrong authentication header"; + auditLoggingFilterDetails.setAction(msg); + throw ApiException.Builder.newBuilder() + .withApiErrors(DefaultApiError.BEARER_TOKEN_INVALID) + .withExceptionMessage(msg) + .build(); + } + + AuthResponse authResponse; + try { + final String jwtString = authHeader.replace(BEARER_AUTH_PREFIX, "").trim(); + authResponse = this.authenticationService.exchangeJwtAccessToken(jwtString); + auditLoggingFilterDetails.setAction("Authenticated"); + } catch (ApiException e) { + auditLoggingFilterDetails.setAction("Failed to authenticate"); + throw e; + } + return authResponse; + } + @RequestMapping(value = "/mfa_check", method = POST, consumes = APPLICATION_JSON_VALUE) public AuthResponse handleMfaCheck(@Valid @RequestBody MfaCheckRequest request) { if (request.isPush()) { diff --git a/cerberus-web/src/main/java/com/nike/cerberus/security/WebSecurityConfiguration.java b/cerberus-web/src/main/java/com/nike/cerberus/security/WebSecurityConfiguration.java index c5aee2c65..db738e494 100644 --- a/cerberus-web/src/main/java/com/nike/cerberus/security/WebSecurityConfiguration.java +++ b/cerberus-web/src/main/java/com/nike/cerberus/security/WebSecurityConfiguration.java @@ -52,6 +52,7 @@ public class WebSecurityConfiguration extends WebSecurityConfigurerAdapter { "/dashboard", "/dashboard/**", "/healthcheck", + "/v2/auth/exchange", "/v2/auth/sts-identity", "/v2/auth/iam-principal", "/v1/auth/iam-role", diff --git a/cerberus-web/src/main/java/com/nike/cerberus/service/AuthenticationService.java b/cerberus-web/src/main/java/com/nike/cerberus/service/AuthenticationService.java index 425619b60..2580b7851 100644 --- a/cerberus-web/src/main/java/com/nike/cerberus/service/AuthenticationService.java +++ b/cerberus-web/src/main/java/com/nike/cerberus/service/AuthenticationService.java @@ -172,6 +172,32 @@ public AuthResponse authenticate(final UserCredentials credentials) { return authResponse; } + /** + * Attempt to exchange an access token from an IdP for a Cerberus token based on their groups + * + * @param jwtString String jwt access token + * @return The auth response + */ + public AuthResponse exchangeJwtAccessToken(String jwtString) { + + final Map claims = + this.authServiceConnector.getValidatedUserPrincipal(jwtString); + + final String username = claims.get("username"); + final String userId = claims.get("userId"); + + final AuthData authData = + AuthData.builder().username(username).factorResult("SUCCESS").userId(userId).build(); + + final Set groups = this.authServiceConnector.getGroups(authData); + AuthTokenResponse token = this.generateToken(username, groups, 0); + authData.setClientToken(token); + + final AuthResponse authResponse = + AuthResponse.builder().data(authData).status(AuthStatus.SUCCESS).build(); + return authResponse; + } + /** * Enables a user to trigger a factor challenge. * diff --git a/cerberus-web/src/main/java/com/nike/cerberus/service/JwtService.java b/cerberus-web/src/main/java/com/nike/cerberus/service/JwtService.java index 5de65c34a..f98854dd0 100644 --- a/cerberus-web/src/main/java/com/nike/cerberus/service/JwtService.java +++ b/cerberus-web/src/main/java/com/nike/cerberus/service/JwtService.java @@ -59,6 +59,8 @@ public class JwtService { private final CerberusSigningKeyResolver signingKeyResolver; private final String environmentName; private final JwtBlocklistDao jwtBlocklistDao; + private final String jwtIssuer; + private final String jwtAudience; private HashSet blocklist; @@ -66,10 +68,14 @@ public class JwtService { public JwtService( CerberusSigningKeyResolver signingKeyResolver, @Value("${cerberus.environmentName}") String environmentName, - JwtBlocklistDao jwtBlocklistDao) { + JwtBlocklistDao jwtBlocklistDao, + @Value("${cerberus.auth.jwt.issuer}") String jwtIssuer, + @Value("${cerberus.auth.jwt.audience}") String jwtAudience) { this.signingKeyResolver = signingKeyResolver; this.environmentName = environmentName; this.jwtBlocklistDao = jwtBlocklistDao; + this.jwtIssuer = jwtIssuer; + this.jwtAudience = jwtAudience; refreshBlocklist(); } diff --git a/cerberus-web/src/test/java/com/nike/cerberus/controller/authentication/UserAuthenticationControllerTest.java b/cerberus-web/src/test/java/com/nike/cerberus/controller/authentication/UserAuthenticationControllerTest.java index f5b062630..a24bfaa68 100644 --- a/cerberus-web/src/test/java/com/nike/cerberus/controller/authentication/UserAuthenticationControllerTest.java +++ b/cerberus-web/src/test/java/com/nike/cerberus/controller/authentication/UserAuthenticationControllerTest.java @@ -1,7 +1,9 @@ package com.nike.cerberus.controller.authentication; import com.nike.backstopper.exception.ApiException; +import com.nike.cerberus.auth.connector.AuthData; import com.nike.cerberus.auth.connector.AuthResponse; +import com.nike.cerberus.auth.connector.AuthStatus; import com.nike.cerberus.domain.MfaCheckRequest; import com.nike.cerberus.domain.UserCredentials; import com.nike.cerberus.error.DefaultApiError; @@ -14,7 +16,7 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.mockito.InjectMocks; +// import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; @@ -23,12 +25,15 @@ public class UserAuthenticationControllerTest { @Mock private AuthenticationService authenticationService; @Mock private AuditLoggingFilterDetails auditLoggingFilterDetails; + private UserAuthenticationController userAuthenticationController; - @InjectMocks private UserAuthenticationController userAuthenticationController; + // @InjectMocks private UserAuthenticationController userAuthenticationController; @Before public void setup() { MockitoAnnotations.initMocks(this); + this.userAuthenticationController = + new UserAuthenticationController(authenticationService, auditLoggingFilterDetails, true); } @Test @@ -40,6 +45,7 @@ public void testAuthenticateWhenAuthHeaderIsNull() { } catch (ApiException e) { apiException = e; } + Assert.assertNotNull(apiException); Assert.assertEquals(DefaultApiError.AUTH_BAD_CREDENTIALS, apiException.getApiErrors().get(0)); } @@ -52,6 +58,7 @@ public void testAuthenticateWhenAuthHeaderIsEmpty() { } catch (ApiException e) { apiException = e; } + Assert.assertNotNull(apiException); Assert.assertEquals(DefaultApiError.AUTH_BAD_CREDENTIALS, apiException.getApiErrors().get(0)); } @@ -64,6 +71,7 @@ public void testAuthenticateWhenAuthHeaderIsDoesNotStartWithBasic() { } catch (ApiException e) { apiException = e; } + Assert.assertNotNull(apiException); Assert.assertEquals(DefaultApiError.AUTH_BAD_CREDENTIALS, apiException.getApiErrors().get(0)); } @@ -81,6 +89,7 @@ public void testAuthenticateWhenAuthenticationServiceAuthenticateThrowsException } catch (ApiException e) { apiException = e; } + Assert.assertNotNull(apiException); Assert.assertEquals(DefaultApiError.LOGIN_FAILED, apiException.getApiErrors().get(0)); Mockito.verify(auditLoggingFilterDetails).setAction("Failed to authenticate"); } @@ -131,4 +140,79 @@ public void testRefreshToken() { userAuthenticationController.refreshToken(cerberusPrincipal); Mockito.verify(authenticationService).refreshUserToken(cerberusPrincipal); } + + @Test + public void testExchangeTokenDisabled() { + ApiException apiException = null; + try { + UserAuthenticationController uac = + new UserAuthenticationController(authenticationService, auditLoggingFilterDetails, false); + uac.exchangeToken("bearer foobar"); + } catch (ApiException exc) { + apiException = exc; + } + Assert.assertNotNull(apiException); + Assert.assertEquals(DefaultApiError.ENTITY_NOT_FOUND, apiException.getApiErrors().get(0)); + } + + @Test + public void testExchangeTokenNullHeader() { + ApiException apiException = null; + try { + userAuthenticationController.exchangeToken(null); + } catch (ApiException e) { + apiException = e; + } + Assert.assertNotNull(apiException); + Assert.assertEquals(DefaultApiError.BEARER_TOKEN_INVALID, apiException.getApiErrors().get(0)); + } + + @Test + public void testExchangeTokenWrongHeader() { + ApiException apiException = null; + try { + userAuthenticationController.exchangeToken("bear dogs"); + } catch (ApiException e) { + apiException = e; + } + Assert.assertNotNull(apiException); + Assert.assertEquals(DefaultApiError.BEARER_TOKEN_INVALID, apiException.getApiErrors().get(0)); + } + + @Test + public void testExchangeTokenDecodeError() { + Mockito.when(authenticationService.exchangeJwtAccessToken(Mockito.anyString())) + .thenThrow(buildApiException("error")); + ApiException apiException = null; + try { + userAuthenticationController.exchangeToken("bearer dogs"); + } catch (ApiException e) { + apiException = e; + } + Assert.assertNotNull(apiException); + String expected = DefaultApiError.BEARER_TOKEN_INVALID.toString(); + String actual = apiException.getApiErrors().get(0).toString(); + Assert.assertEquals(expected, actual); + } + + @Test + public void testExchangeTokenHappy() { + final AuthData authData = AuthData.builder().username("tester").userId("aardvark").build(); + final AuthResponse authResponse = + AuthResponse.builder().data(authData).status(AuthStatus.SUCCESS).build(); + Mockito.when(authenticationService.exchangeJwtAccessToken(Mockito.anyString())) + .thenReturn(authResponse); + + AuthResponse response = userAuthenticationController.exchangeToken("bearer dogs"); + Assert.assertEquals(response.getData().getUserId(), "aardvark"); + } + + private ApiException buildApiException(String msg) { + ApiException exc = + ApiException.Builder.newBuilder() + .withApiErrors(DefaultApiError.BEARER_TOKEN_INVALID) + .withExceptionMessage(msg) + .build(); + return exc; + } } diff --git a/cerberus-web/src/test/java/com/nike/cerberus/service/AuthenticationServiceTest.java b/cerberus-web/src/test/java/com/nike/cerberus/service/AuthenticationServiceTest.java index cc55c06a7..ba18f853e 100644 --- a/cerberus-web/src/test/java/com/nike/cerberus/service/AuthenticationServiceTest.java +++ b/cerberus-web/src/test/java/com/nike/cerberus/service/AuthenticationServiceTest.java @@ -30,12 +30,14 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.nike.backstopper.exception.ApiException; import com.nike.cerberus.PrincipalType; import com.nike.cerberus.auth.connector.AuthConnector; import com.nike.cerberus.auth.connector.AuthData; import com.nike.cerberus.auth.connector.AuthResponse; +import com.nike.cerberus.auth.connector.AuthStatus; import com.nike.cerberus.aws.KmsClientFactory; import com.nike.cerberus.config.ApplicationConfiguration; import com.nike.cerberus.dao.AwsIamRoleDao; @@ -49,6 +51,8 @@ import com.nike.cerberus.security.CerberusPrincipal; import com.nike.cerberus.util.AwsIamRoleArnParser; import com.nike.cerberus.util.DateTimeSupplier; +import com.okta.jwt.JwtVerificationException; +import java.time.Duration; import java.time.OffsetDateTime; import java.time.ZoneOffset; import java.util.*; @@ -57,6 +61,7 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mock; +import org.mockito.Mockito; /** Tests the AuthenticationService class */ public class AuthenticationServiceTest { @@ -129,7 +134,6 @@ public void triggerChallengeSuccess() { assertEquals( expectedResponse.getData().getStateToken(), actualResponse.getData().getStateToken()); } - ; @Test public void tests_that_generateCommonVaultPrincipalAuthMetadata_contains_expected_fields() { @@ -490,4 +494,86 @@ public void tests_that_refreshUserToken_refreshes_token_when_count_is_less_than_ .getApiErrors() .contains(DefaultApiError.MAXIMUM_TOKEN_REFRESH_COUNT_REACHED)); } + + @Test + public void exchangeJwtAccessTokenOkay() { + + Map claims = ImmutableMap.of("username", "someone", "userId", "cataphract"); + when(authConnector.getGroups(Mockito.any(AuthData.class))).thenReturn(Set.of("cat", "dog")); + when(authConnector.getValidatedUserPrincipal(Mockito.anyString())).thenReturn(claims); + + CerberusAuthToken newToken = getNewCerberusAuthToken(); + when(authTokenService.generateToken( + anyString(), any(PrincipalType.class), anyBoolean(), anyObject(), anyInt(), anyInt())) + .thenReturn(newToken); + + AuthResponse response = this.authenticationService.exchangeJwtAccessToken("us"); + + AuthData data = response.getData(); + assertEquals(data.getUsername(), "someone"); + assertEquals(data.getUserId(), "cataphract"); + + Duration expectedDuration = Duration.between(newToken.getCreated(), newToken.getExpires()); + assertEquals(data.getClientToken().getLeaseDuration(), expectedDuration.getSeconds()); + assertEquals(response.getStatus().toString(), AuthStatus.SUCCESS.toString()); + } + + @Test + public void exchangeJwtAccessTokenBadJwt() { + + JwtVerificationException jve = new JwtVerificationException("oops"); + ApiException apiException = + ApiException.Builder.newBuilder() + .withApiErrors(DefaultApiError.BEARER_TOKEN_INVALID) + .withExceptionMessage(jve.getMessage()) + .withExceptionCause(jve) + .build(); + when(authConnector.getValidatedUserPrincipal(Mockito.anyString())).thenThrow(apiException); + + ApiException actualException = null; + try { + this.authenticationService.exchangeJwtAccessToken("us"); + } catch (ApiException caught) { + actualException = caught; + } + assertEquals(actualException, apiException); + } + + @Test + public void exchangeJwtAccessTokenDogs() { + + JwtVerificationException jve = new JwtVerificationException("oops"); + ApiException apiException = + ApiException.Builder.newBuilder() + .withApiErrors(DefaultApiError.BEARER_TOKEN_INVALID) + .withExceptionMessage(jve.getMessage()) + .withExceptionCause(jve) + .build(); + when(authConnector.getValidatedUserPrincipal(Mockito.anyString())).thenThrow(apiException); + + ApiException actualException = null; + try { + this.authenticationService.exchangeJwtAccessToken("us"); + } catch (ApiException caught) { + actualException = caught; + } + assertEquals(actualException, apiException); + } + + CerberusAuthToken getNewCerberusAuthToken() { + OffsetDateTime now = OffsetDateTime.now(); + OffsetDateTime later = now.plusHours(1); + return CerberusAuthToken.Builder.create().withCreated(now).withExpires(later).build(); + } + + AuthTokenResponse getAuthTokenReponse(CerberusAuthToken tokenResult) { + OffsetDateTime now = OffsetDateTime.now(); + OffsetDateTime later = now.plusHours(1); + AuthTokenResponse response = + new AuthTokenResponse() + .setClientToken(tokenResult.getToken()) + .setPolicies(Collections.emptySet()) + .setLeaseDuration(Duration.between(now, later).getSeconds()); + return response; + } } diff --git a/cerberus-web/src/test/java/com/nike/cerberus/service/JwtServiceTest.java b/cerberus-web/src/test/java/com/nike/cerberus/service/JwtServiceTest.java index 38271fb3a..f65080224 100644 --- a/cerberus-web/src/test/java/com/nike/cerberus/service/JwtServiceTest.java +++ b/cerberus-web/src/test/java/com/nike/cerberus/service/JwtServiceTest.java @@ -36,7 +36,7 @@ public class JwtServiceTest { @Before public void setUp() throws Exception { initMocks(this); - jwtService = new JwtService(signingKeyResolver, "local", jwtBlocklistDao); + jwtService = new JwtService(signingKeyResolver, "local", jwtBlocklistDao, "iss", "aud"); ReflectionTestUtils.setField(jwtService, "maxTokenLength", 1600); cerberusJwtKeySpec = new CerberusJwtKeySpec(new byte[64], "HmacSHA512", "key id"); cerberusJwtClaims = new CerberusJwtClaims();