Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
Handle errors when calling Okta to get user groups (#1068)
Browse files Browse the repository at this point in the history
Co-authored-by: Shawn Sherwood <[email protected]>
  • Loading branch information
shawn-sher and shawn-sher authored Feb 2, 2023
1 parent 7db244b commit 9c61c0c
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
import com.okta.sdk.authc.credentials.TokenClientCredentials;
import com.okta.sdk.client.Client;
import com.okta.sdk.client.Clients;
import com.okta.sdk.resource.ResourceException;
import com.okta.sdk.resource.group.Group;
import com.okta.sdk.resource.group.GroupList;
import com.okta.sdk.resource.group.GroupProfile;
import com.okta.sdk.resource.user.User;
import java.util.HashSet;
import java.util.Map;
Expand Down Expand Up @@ -217,20 +220,60 @@ public AuthResponse mfaCheck(String stateToken, String deviceId, String otpToken
}
}

/**
* Get a valid user from the identity provider if possible
*
* @param userId
* @return User corresponding to the id
* @throws ApiException if user cannot be resolved
*/
protected User getUserFromIDP(String userId) {
try {
return sdkClient.getUser(userId);
} catch (IllegalStateException ise) {
throw ApiException.newBuilder()
.withExceptionCause(ise)
.withApiErrors(DefaultApiError.IDENTITY_PROVIDER_BAD_GATEWAY)
.withExceptionMessage("Could not communicate properly with identity provider")
.build();
} catch (ResourceException rexc) {
String msg =
String.format("Got invalid response from identity providers: %s", rexc.getMessage());
throw ApiException.newBuilder()
.withExceptionCause(rexc)
.withApiErrors(DefaultApiError.IDENTITY_PROVIDER_BAD_GATEWAY)
.withExceptionMessage(msg)
.build();
} catch (Exception exc) {
throw ApiException.newBuilder()
.withExceptionCause(exc)
.withApiErrors(DefaultApiError.INTERNAL_SERVER_ERROR)
.withExceptionMessage("Unknown error trying to getUser from identity provider")
.build();
}
}

/** Obtains groups user belongs to. */
@Override
public Set<String> getGroups(AuthData authData) {

Preconditions.checkNotNull(authData, "auth data cannot be null.");

User user = sdkClient.getUser(authData.getUserId());
String userId = authData.getUserId();
User user = getUserFromIDP(userId);
GroupList userGroups = user.listGroups();

final Set<String> groups = new HashSet<>();
if (userGroups == null) {
return groups;
}
userGroups.forEach(group -> groups.add(group.getProfile().getName()));

for (Group group : userGroups) {
GroupProfile profile = group.getProfile();
if (profile != null) {
groups.add(profile.getName());
}
}

return groups;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,34 @@
import static org.mockito.Mockito.*;
import static org.mockito.MockitoAnnotations.initMocks;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.nike.backstopper.apierror.ApiError;
import com.nike.backstopper.exception.ApiException;
import com.nike.cerberus.auth.connector.AuthData;
import com.nike.cerberus.auth.connector.AuthResponse;
import com.nike.cerberus.auth.connector.AuthStatus;
import com.nike.cerberus.auth.connector.okta.statehandlers.InitialLoginStateHandler;
import com.nike.cerberus.auth.connector.okta.statehandlers.MfaStateHandler;
import com.nike.cerberus.error.DefaultApiError;
import com.okta.authn.sdk.client.AuthenticationClient;
import com.okta.authn.sdk.impl.resource.DefaultVerifyPassCodeFactorRequest;
import com.okta.jwt.AccessTokenVerifier;
import com.okta.jwt.Jwt;
import com.okta.jwt.JwtVerificationException;
import com.okta.sdk.client.Client;
import com.okta.sdk.impl.client.DefaultClient;
import com.okta.sdk.impl.error.DefaultError;
import com.okta.sdk.resource.ResourceException;
import com.okta.sdk.resource.group.Group;
import com.okta.sdk.resource.group.GroupList;
import com.okta.sdk.resource.group.GroupProfile;
import com.okta.sdk.resource.user.User;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
Expand Down Expand Up @@ -327,4 +341,130 @@ public void testGetAccessTokenVerifier() {
AccessTokenVerifier verifier = this.oktaAuthConnector.getAccessTokenVerifier();
assertNotNull(verifier);
}

@Test
public void testGetGroups() {
AccessTokenVerifier verifier = mock(AccessTokenVerifier.class);

GroupProfile groupProfile = mock(GroupProfile.class);
when(groupProfile.getName()).thenReturn("testGroup");

Group fakeGroup = mock(Group.class);
when(fakeGroup.getProfile()).thenReturn(groupProfile);

List<Group> groupIteraterList = Lists.newArrayList(fakeGroup);
GroupList groupList = mock(GroupList.class);
when(groupList.iterator()).thenReturn(groupIteraterList.iterator());

User mockUser = mock(User.class);
when(mockUser.listGroups()).thenReturn(groupList);

DefaultClient mockClient = mock(DefaultClient.class);
when(mockClient.getUser(anyString())).thenReturn(mockUser);

OktaAuthConnector connector =
new OktaAuthConnector(
client, mockClient, "https://foo.bar/oauth2/skiddleydee", "dogs", verifier);
AuthData authData = AuthData.builder().userId("deadbeef").build();
Set<String> groups = connector.getGroups(authData);
assertEquals(groups, Set.of("testGroup"));
}

@Test
public void testGetGroupsMissingProfile() {
AccessTokenVerifier verifier = mock(AccessTokenVerifier.class);

Group fakeGroup = mock(Group.class);
when(fakeGroup.getProfile()).thenReturn(null);

List<Group> groupIteraterList = Lists.newArrayList(fakeGroup);
GroupList groupList = mock(GroupList.class);
when(groupList.iterator()).thenReturn(groupIteraterList.iterator());

User mockUser = mock(User.class);
when(mockUser.listGroups()).thenReturn(groupList);

DefaultClient mockClient = mock(DefaultClient.class);
when(mockClient.getUser(anyString())).thenReturn(mockUser);

OktaAuthConnector connector =
new OktaAuthConnector(
client, mockClient, "https://foo.bar/oauth2/skiddleydee", "dogs", verifier);
AuthData authData = AuthData.builder().userId("deadbeef").build();
Set<String> groups = connector.getGroups(authData);
assertEquals(groups, new HashSet<String>());
}

@Test
public void testGetGroupsNullGroups() {
AccessTokenVerifier verifier = mock(AccessTokenVerifier.class);

User mockUser = mock(User.class);
when(mockUser.listGroups()).thenReturn(null);

DefaultClient mockClient = mock(DefaultClient.class);
when(mockClient.getUser(anyString())).thenReturn(mockUser);

OktaAuthConnector connector =
new OktaAuthConnector(
client, mockClient, "https://foo.bar/oauth2/skiddleydee", "dogs", verifier);

AuthData authData = AuthData.builder().userId("deadbeef").build();
Set<String> groups = connector.getGroups(authData);
assertEquals(groups, new HashSet<>());
}

@Test(expected = ApiException.class)
public void testBadGetUser() {
AccessTokenVerifier verifier = mock(AccessTokenVerifier.class);
Client mockClient = mock(Client.class);
when(mockClient.getUser(anyString())).thenThrow(new RuntimeException("it's broke"));
OktaAuthConnector connector =
new OktaAuthConnector(
client, mockClient, "https://foo.bar/oauth2/skiddleydee", "dogs", verifier);
AuthData authData = AuthData.builder().userId("deadbeef").build();
connector.getGroups(authData);
}

@Test
public void testGetUserFromIdpCompletelyBrokenOkta() {
AccessTokenVerifier verifier = mock(AccessTokenVerifier.class);
Client mockClient = mock(Client.class);
String exceptionMessage = "who knows what broke?";
when(mockClient.getUser(anyString())).thenThrow(new IllegalStateException(exceptionMessage));
OktaAuthConnector connector =
new OktaAuthConnector(
client, mockClient, "https://foo.bar/oauth2/skiddleydee", "dogs", verifier);
try {
connector.getUserFromIDP("fooUser");
} catch (ApiException exc) {
String actualMessage = exc.getMessage();
assertEquals(actualMessage, "Could not communicate properly with identity provider");
ApiError apiError = exc.getApiErrors().get(0);
assertEquals(apiError, DefaultApiError.IDENTITY_PROVIDER_BAD_GATEWAY);
String causeMessage = exc.getCause().getMessage();
assertEquals(causeMessage, exceptionMessage);
}
}

@Test
public void testGetUserFromIdpOktaProblem() {
AccessTokenVerifier verifier = mock(AccessTokenVerifier.class);
Client mockClient = mock(Client.class);
String excMessage = "A specific thing had a problem";
String excpetionPrefix = "Got invalid response from identity providers";
ResourceException resourceException =
new ResourceException(new DefaultError(ImmutableMap.of("message", excMessage)));
when(mockClient.getUser(anyString())).thenThrow(resourceException);
OktaAuthConnector connector =
new OktaAuthConnector(
client, mockClient, "https://foo.bar/oauth2/skiddleydee", "dogs", verifier);
try {
connector.getUserFromIDP("fooUser");
} catch (ApiException exc) {
String actualMessage = exc.getMessage();
assert actualMessage.startsWith(excpetionPrefix);
assertEquals(exc.getApiErrors().get(0), DefaultApiError.IDENTITY_PROVIDER_BAD_GATEWAY);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,9 @@ public enum DefaultApiError implements ApiError {
/** Generic bad requests. This is useful because the blueprint error handling sucks. */
GENERIC_BAD_REQUEST(99999, "Request will not be completed.", SC_BAD_REQUEST),

/** Bad response from identity provider */
IDENTITY_PROVIDER_BAD_GATEWAY(99988, "Bad response from identity provider", SC_BAD_GATEWAY),

/**
* If we encounter an error where something expected is not setup correctly, meaning the service
* is not functional.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import static com.nike.backstopper.apierror.ApiErrorConstants.*;
import static com.nike.backstopper.apierror.projectspecificinfo.ProjectSpecificErrorCodeRange.ALLOW_ALL_ERROR_CODES;
import static javax.servlet.http.HttpServletResponse.SC_BAD_GATEWAY;
import static javax.servlet.http.HttpServletResponse.SC_NOT_IMPLEMENTED;

import com.nike.backstopper.apierror.ApiError;
Expand All @@ -34,6 +35,7 @@ public class DefaultApiErrorsImpl extends SampleProjectApiErrorsBase {
Arrays.asList(
HTTP_STATUS_CODE_FORBIDDEN,
HTTP_STATUS_CODE_UNAUTHORIZED,
SC_BAD_GATEWAY,
HTTP_STATUS_CODE_SERVICE_UNAVAILABLE,
HTTP_STATUS_CODE_TOO_MANY_REQUESTS,
HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR,
Expand Down

0 comments on commit 9c61c0c

Please sign in to comment.