Skip to content

Commit

Permalink
OIDC UserInfo Endpoint
Browse files Browse the repository at this point in the history
Signed-off-by: Stephen Crawford <[email protected]>
  • Loading branch information
stephen-crawford committed Aug 27, 2024
1 parent ec73fd2 commit ca57141
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ public abstract class AbstractHTTPJwtAuthenticator implements HTTPAuthenticator

private KeyProvider keyProvider;
protected JwtVerifier jwtVerifier;
private final String jwtHeaderName;
private final boolean isDefaultAuthHeader;
private final String jwtUrlParameter;
protected final String jwtHeaderName;
protected final boolean isDefaultAuthHeader;
protected final String jwtUrlParameter;
private final String subjectKey;
private final String rolesKey;
private final List<String> requiredAudience;
Expand Down Expand Up @@ -117,7 +117,7 @@ public AuthCredentials run() {

private AuthCredentials extractCredentials0(final SecurityRequest request) throws OpenSearchSecurityException {

String jwtString = getJwtTokenString(request);
String jwtString = getJwtTokenString(request, jwtHeaderName, jwtUrlParameter, isDefaultAuthHeader);

if (Strings.isNullOrEmpty(jwtString)) {
return null;
Expand Down Expand Up @@ -155,34 +155,6 @@ private AuthCredentials extractCredentials0(final SecurityRequest request) throw
return ac;
}

protected String getJwtTokenString(SecurityRequest request) {
String jwtToken = request.header(jwtHeaderName);
if (isDefaultAuthHeader && jwtToken != null && BASIC.matcher(jwtToken).matches()) {
jwtToken = null;
}

if (jwtUrlParameter != null) {
if (jwtToken == null || jwtToken.isEmpty()) {
jwtToken = request.params().get(jwtUrlParameter);
} else {
// just consume to avoid "contains unrecognized parameter"
request.params().get(jwtUrlParameter);
}
}

if (jwtToken == null) {
return null;
}

int index;

if ((index = jwtToken.toLowerCase().indexOf(BEARER)) > -1) { // detect Bearer
jwtToken = jwtToken.substring(index + BEARER.length());
}

return jwtToken;
}

@VisibleForTesting
public String extractSubject(JWTClaimsSet claims) {
String subject = claims.getSubject();
Expand Down Expand Up @@ -256,6 +228,39 @@ public Optional<SecurityResponse> reRequestAuthentication(final SecurityRequest
);
}

public static String getJwtTokenString(
SecurityRequest request,
String jwtHeaderName,
String jwtUrlParameter,
boolean isDefaultAuthHeader
) {
String jwtToken = request.header(jwtHeaderName);
if (isDefaultAuthHeader && jwtToken != null && BASIC.matcher(jwtToken).matches()) {
jwtToken = null;
}

if (jwtUrlParameter != null) {
if (jwtToken == null || jwtToken.isEmpty()) {
jwtToken = request.params().get(jwtUrlParameter);
} else {
// just consume to avoid "contains unrecognized parameter"
request.params().get(jwtUrlParameter);
}
}

if (jwtToken == null) {
return null;
}

int index;

if ((index = jwtToken.toLowerCase().indexOf(BEARER)) > -1) { // detect Bearer
jwtToken = jwtToken.substring(index + BEARER.length());
}

return jwtToken;
}

public List<String> getRequiredAudience() {
return requiredAudience;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import com.nimbusds.oauth2.sdk.token.AccessToken;
import com.nimbusds.oauth2.sdk.token.AccessTokenType;
import com.nimbusds.oauth2.sdk.util.StringUtils;
import com.nimbusds.oauth2.sdk.token.BearerAccessToken;
import com.nimbusds.openid.connect.sdk.UserInfoRequest;
import com.nimbusds.openid.connect.sdk.UserInfoResponse;
import com.nimbusds.openid.connect.sdk.UserInfoSuccessResponse;

import static org.apache.hc.core5.http.HttpHeaders.AUTHORIZATION;
import static com.amazon.dlic.auth.http.jwt.keybyoidc.OpenIdConstants.CLIENT_ID;
Expand Down Expand Up @@ -115,32 +114,9 @@ public AuthCredentials extractCredentials0(SecurityRequest request, ThreadContex

URI userInfoEndpointURI = new URI(this.userInfoEndpoint);

String bearerHeader = request.getHeaders().get(AUTHORIZATION).getFirst();
if (!StringUtils.isBlank(bearerHeader)) {
if (bearerHeader.contains("Bearer ")) {
bearerHeader = bearerHeader.substring(7);
}
}

String finalBearerHeader = bearerHeader;

AccessToken accessToken = new AccessToken(AccessTokenType.BEARER, finalBearerHeader) {
@Override
public String toAuthorizationHeader() {
return "Bearer " + finalBearerHeader;
}
};

UserInfoRequest userInfoRequest = new UserInfoRequest(userInfoEndpointURI, accessToken);
String bearerHeader = AbstractHTTPJwtAuthenticator.getJwtTokenString(request, AUTHORIZATION, null, false);

HTTPRequest httpRequest = userInfoRequest.toHTTPRequest();

HTTPResponse httpResponse = httpRequest.send();
if (httpResponse.getStatusCode() < 200 || httpResponse.getStatusCode() >= 300) {
throw new AuthenticatorUnavailableException(
"Error while getting " + this.userInfoEndpoint + ": " + httpResponse.getStatusMessage()
);
}
HTTPResponse httpResponse = getHttpResponse(bearerHeader, userInfoEndpointURI);

try {

Expand All @@ -152,17 +128,19 @@ public String toAuthorizationHeader() {
);
}

String contentType = String.valueOf(httpResponse.getHeaderValues("content-type"));
UserInfoSuccessResponse userInfoSuccessResponse = userInfoResponse.toSuccessResponse();

String contentType = userInfoSuccessResponse.getEntityContentType().getType();

JWTClaimsSet claims;
boolean isSigned = contentType.contains(ContentType.APPLICATION_JWT.toString());
boolean isSigned = contentType.contains(ContentType.APPLICATION_JWT.getType());
if (isSigned) { // We don't need the userinfo_encrypted_response_alg since the
// selfRefreshingKeyProvider has access to the keys
claims = openIdJwtAuthenticator.getJwtClaimsSetFromInfoContent(
userInfoResponse.toSuccessResponse().getUserInfoJWT().getParsedString()
userInfoSuccessResponse.getUserInfoJWT().getParsedString()
);
} else {
claims = JWTClaimsSet.parse(userInfoResponse.toSuccessResponse().getUserInfo().toString());
claims = JWTClaimsSet.parse(userInfoSuccessResponse.getUserInfo().toString());
}

String id = openIdJwtAuthenticator.getJwtClaimsSet(request).getSubject();
Expand Down Expand Up @@ -196,26 +174,42 @@ public String toAuthorizationHeader() {
}
}

private HTTPResponse getHttpResponse(String bearerHeader, URI userInfoEndpointURI) throws IOException {
BearerAccessToken accessToken = new BearerAccessToken(bearerHeader);

UserInfoRequest userInfoRequest = new UserInfoRequest(userInfoEndpointURI, accessToken);

HTTPRequest httpRequest = userInfoRequest.toHTTPRequest();

HTTPResponse httpResponse = httpRequest.send();
if (httpResponse.getStatusCode() < 200 || httpResponse.getStatusCode() >= 300) {
throw new AuthenticatorUnavailableException(
"Error while getting " + this.userInfoEndpoint + ": " + httpResponse.getStatusMessage()
);
}
return httpResponse;
}

private String validateResponseClaims(JWTClaimsSet claims, String id, boolean isSigned) {

String missing = "";
StringBuilder missing = new StringBuilder();

if (claims.getClaim(SUB_CLAIM) == null || claims.getClaim(SUB_CLAIM).toString().isBlank() || !claims.getClaim("sub").equals(id)) {
missing = missing.concat(SUB_CLAIM);
missing = missing.append(SUB_CLAIM);
}

if (isSigned) {
if (claims.getIssuer() == null || claims.getIssuer().isBlank() || !claims.getIssuer().equals(settings.get(ISSUER_ID_URL))) {
missing = missing.concat("iss");
missing = missing.append("iss");
}
if (claims.getAudience() == null
|| claims.getAudience().toString().isBlank()
|| !claims.getAudience().contains(settings.get(CLIENT_ID))) {
missing = missing.concat("aud");
missing = missing.append("aud");
}
}

return missing;
return missing.toString();
}

private final class HTTPJwtKeyByOpenIdConnectAuthenticator extends AbstractHTTPJwtAuthenticator {
Expand Down Expand Up @@ -260,7 +254,7 @@ protected KeyProvider initKeyProvider(Settings settings, Path configPath) throws
}

private JWTClaimsSet getJwtClaimsSet(SecurityRequest request) throws OpenSearchSecurityException {
String parsedToken = super.getJwtTokenString(request);
String parsedToken = getJwtTokenString(request, jwtHeaderName, jwtUrlParameter, isDefaultAuthHeader);
return getJwtClaimsSetFromInfoContent(parsedToken);
}

Expand Down

0 comments on commit ca57141

Please sign in to comment.