Skip to content

Commit

Permalink
Fix oauth2 and jwt authenticator interference
Browse files Browse the repository at this point in the history
Before this fix, when oauth2 with refresh token is enabled
along with jwt authenticator, user couldn't log in
by using standard jwt token.

It was occuring due to incorect handling of tokens
that are in different format than the one issued
by the OAuth2 implementation to store refresh tokens.
After the fix in such case OAuth2 will just issue another
challenge to the client.
  • Loading branch information
s2lomon authored and Praveen2112 committed Sep 21, 2022
1 parent e17dc3f commit ee753bf
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.server.security.oauth2;

import com.google.common.collect.ImmutableSet;
import io.airlift.log.Logger;
import io.trino.server.security.AbstractBearerAuthenticator;
import io.trino.server.security.AuthenticationException;
import io.trino.server.security.UserMapping;
Expand Down Expand Up @@ -42,6 +43,7 @@
public class OAuth2Authenticator
extends AbstractBearerAuthenticator
{
private static final Logger log = Logger.get(OAuth2Authenticator.class);
private final OAuth2Client client;
private final String principalField;
private final Optional<String> groupsField;
Expand All @@ -64,7 +66,12 @@ public OAuth2Authenticator(OAuth2Client client, OAuth2Config config, TokenRefres
protected Optional<Identity> createIdentity(String token)
throws UserMappingException
{
TokenPair tokenPair = tokenPairSerializer.deserialize(token);
Optional<TokenPair> deserializeToken = deserializeToken(token);
if (deserializeToken.isEmpty()) {
return Optional.empty();
}

TokenPair tokenPair = deserializeToken.get();
if (tokenPair.getExpiration().before(Date.from(Instant.now()))) {
return Optional.empty();
}
Expand All @@ -80,11 +87,22 @@ protected Optional<Identity> createIdentity(String token)
return Optional.of(builder.build());
}

private Optional<TokenPair> deserializeToken(String token)
{
try {
return Optional.of(tokenPairSerializer.deserialize(token));
}
catch (RuntimeException ex) {
log.debug(ex, "Failed to deserialize token");
return Optional.empty();
}
}

@Override
protected AuthenticationException needAuthentication(ContainerRequestContext request, Optional<String> currentToken, String message)
{
return currentToken
.map(tokenPairSerializer::deserialize)
.flatMap(this::deserializeToken)
.flatMap(tokenRefresher::refreshToken)
.map(refreshId -> request.getUriInfo().getBaseUri().resolve(getTokenUri(refreshId)))
.map(tokenUri -> new AuthenticationException(message, format("Bearer x_token_server=\"%s\"", tokenUri)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,48 @@ public void testJwtAndOAuth2AuthenticatorsSeparation()
}
}

@Test
public void testJwtWithRefreshTokensForOAuth2Enabled()
throws Exception
{
TestingHttpServer jwkServer = createTestingJwkServer();
jwkServer.start();
try (TokenServer tokenServer = new TokenServer(Optional.empty());
TestingTrinoServer server = TestingTrinoServer.builder()
.setProperties(
ImmutableMap.<String, String>builder()
.putAll(SECURE_PROPERTIES)
.put("http-server.authentication.type", "oauth2,jwt")
.put("http-server.authentication.jwt.key-file", jwkServer.getBaseUrl().toString())
.putAll(ImmutableMap.<String, String>builder()
.putAll(getOAuth2Properties(tokenServer))
.put("http-server.authentication.oauth2.refresh-tokens", "true")
.buildOrThrow())
.put("web-ui.enabled", "true")
.buildOrThrow())
.setAdditionalModule(oauth2Module(tokenServer))
.build()) {
server.getInstance(Key.get(AccessControlManager.class)).addSystemAccessControl(TestSystemAccessControl.NO_IMPERSONATION);
HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class));

assertAuthenticationDisabled(httpServerInfo.getHttpUri());

String token = newJwtBuilder()
.signWith(JWK_PRIVATE_KEY)
.setHeaderParam(JwsHeader.KEY_ID, JWK_KEY_ID)
.setSubject("test-user")
.setExpiration(Date.from(ZonedDateTime.now().plusMinutes(5).toInstant()))
.compact();

OkHttpClient clientWithJwt = client.newBuilder()
.authenticator((route, response) -> response.request().newBuilder()
.header(AUTHORIZATION, "Bearer " + token)
.build())
.build();
assertAuthenticationAutomatic(httpServerInfo.getHttpsUri(), clientWithJwt);
}
}

private static Module oauth2Module(TokenServer tokenServer)
{
return binder -> {
Expand Down

0 comments on commit ee753bf

Please sign in to comment.