Skip to content

Commit

Permalink
Use TrinoRequestUser to get the user for the query
Browse files Browse the repository at this point in the history
  • Loading branch information
vishalya authored and ebyhr committed Oct 25, 2024
1 parent a45d249 commit 5d8a741
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,18 @@
*/
package io.trino.gateway.ha.handler;

import com.google.common.base.Splitter;
import com.google.common.io.CharStreams;
import io.airlift.log.Logger;
import jakarta.servlet.http.HttpServletRequest;

import java.io.InputStreamReader;
import java.util.Base64;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static com.google.common.base.Strings.isNullOrEmpty;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.TRINO_UI_PATH;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.V1_QUERY_PATH;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Locale.ENGLISH;
Expand Down Expand Up @@ -56,41 +53,6 @@ public final class ProxyUtils

private ProxyUtils() {}

public static String getQueryUser(String userHeader, String authorization)
{
if (!isNullOrEmpty(userHeader)) {
log.debug("User from header %s", USER_HEADER);
return userHeader;
}

log.debug("User from basic authentication");
String user = "";
if (authorization == null) {
log.debug("No basic auth header found.");
return user;
}

int space = authorization.indexOf(' ');
if ((space < 0) || !authorization.substring(0, space).equalsIgnoreCase("basic")) {
log.error("Basic auth format is invalid");
return user;
}

String headerInfo = authorization.substring(space + 1).trim();
if (isNullOrEmpty(headerInfo)) {
log.error("Encoded value of basic auth doesn't exist");
return user;
}

String info = new String(Base64.getDecoder().decode(headerInfo), UTF_8);
List<String> parts = Splitter.on(':').limit(2).splitToList(info);
if (parts.size() < 1) {
log.error("No user inside the basic auth text");
return user;
}
return parts.get(0);
}

public static String extractQueryIdIfPresent(HttpServletRequest request, List<String> statementPaths)
{
String path = request.getRequestURI();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ private Optional<String> extractUserFromBearerAuth(String header, String userFie

String token = header.substring(space + 1).trim();

if (header.split("\\.").length == 3) { //this is probably a JWS
if (token.split("\\.").length == 3) { //this is probably a JWS
log.debug("Trying to extract from JWS");
try {
DecodedJWT jwt = JWT.decode(header);
DecodedJWT jwt = JWT.decode(token);
if (jwt.getClaims().containsKey(userField)) {
return Optional.of(jwt.getClaim(userField).asString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.trino.gateway.ha.router.OAuth2GatewayCookie;
import io.trino.gateway.ha.router.QueryHistoryManager;
import io.trino.gateway.ha.router.RoutingManager;
import io.trino.gateway.ha.router.TrinoRequestUser;
import io.trino.gateway.proxyserver.ProxyResponseHandler.ProxyResponse;
import jakarta.annotation.PreDestroy;
import jakarta.servlet.http.HttpServletRequest;
Expand All @@ -43,6 +44,7 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;

import static com.google.common.collect.ImmutableList.toImmutableList;
Expand All @@ -58,11 +60,8 @@
import static io.airlift.http.client.Request.Builder.preparePost;
import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator;
import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse;
import static io.trino.gateway.ha.handler.ProxyUtils.AUTHORIZATION;
import static io.trino.gateway.ha.handler.ProxyUtils.QUERY_TEXT_LENGTH_FOR_HISTORY;
import static io.trino.gateway.ha.handler.ProxyUtils.SOURCE_HEADER;
import static io.trino.gateway.ha.handler.ProxyUtils.getQueryUser;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE;
import static jakarta.ws.rs.core.Response.Status.BAD_GATEWAY;
import static jakarta.ws.rs.core.Response.Status.OK;
Expand All @@ -86,6 +85,7 @@ public class ProxyRequestHandler
private final boolean addXForwardedHeaders;
private final List<String> statementPaths;
private final boolean includeClusterInfoInResponse;
private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider;

@Inject
public ProxyRequestHandler(
Expand All @@ -97,6 +97,7 @@ public ProxyRequestHandler(
this.httpClient = requireNonNull(httpClient, "httpClient is null");
this.routingManager = requireNonNull(routingManager, "routingManager is null");
this.queryHistoryManager = requireNonNull(queryHistoryManager, "queryHistoryManager is null");
trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(haGatewayConfiguration.getRequestAnalyzerConfig());
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
asyncTimeout = haGatewayConfiguration.getRouting().getAsyncTimeout();
addXForwardedHeaders = haGatewayConfiguration.getRouting().isAddXForwardedHeaders();
Expand Down Expand Up @@ -173,7 +174,8 @@ private void performRequest(
FluentFuture<ProxyResponse> future = executeHttp(request);

if (statementPaths.stream().anyMatch(request.getUri().getPath()::startsWith) && request.getMethod().equals(HttpMethod.POST)) {
future = future.transform(response -> recordBackendForQueryId(request, response), executor);
Optional<String> username = trinoRequestUserProvider.getInstance(servletRequest).getUser();
future = future.transform(response -> recordBackendForQueryId(request, response, username), executor);
if (includeClusterInfoInResponse) {
cookieBuilder.add(new NewCookie.Builder("trinoClusterHost").value(remoteUri.getHost()).build());
}
Expand Down Expand Up @@ -250,11 +252,11 @@ private static WebApplicationException badRequest(String message)
.build());
}

private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse response)
private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse response, Optional<String> username)
{
log.debug("For Request [%s] got Response [%s]", request.getUri(), response.body());

QueryHistoryManager.QueryDetail queryDetail = getQueryDetailsFromRequest(request);
QueryHistoryManager.QueryDetail queryDetail = getQueryDetailsFromRequest(request, username);

log.debug("Extracting proxy destination : [%s] for request : [%s]", queryDetail.getBackendUrl(), request.getUri());

Expand All @@ -276,12 +278,12 @@ private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse res
return response;
}

public static QueryHistoryManager.QueryDetail getQueryDetailsFromRequest(Request request)
public static QueryHistoryManager.QueryDetail getQueryDetailsFromRequest(Request request, Optional<String> username)
{
QueryHistoryManager.QueryDetail queryDetail = new QueryHistoryManager.QueryDetail();
queryDetail.setBackendUrl(getRemoteTarget(request.getUri()));
queryDetail.setCaptureTime(System.currentTimeMillis());
queryDetail.setUser(getQueryUser(request.getHeader(USER_HEADER), request.getHeader(AUTHORIZATION)));
username.ifPresent(queryDetail::setUser);
queryDetail.setSource(request.getHeader(SOURCE_HEADER));

String queryText = new String(((StaticBodyGenerator) request.getBodyGenerator()).getBody(), UTF_8);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.List;

import static io.trino.gateway.ha.handler.ProxyUtils.extractQueryIdIfPresent;
import static io.trino.gateway.ha.handler.ProxyUtils.getQueryUser;
import static org.assertj.core.api.Assertions.assertThat;

@TestInstance(Lifecycle.PER_CLASS)
Expand Down Expand Up @@ -61,11 +60,4 @@ void testExtractQueryIdFromUrl()
assertThat(extractQueryIdIfPresent("/ui/", "lang=en&p=1&id=0_1_2_a", statementPaths))
.isNull();
}

@Test
void testGetQueryUser()
{
assertThat(getQueryUser(null, "Basic dGVzdDoxMjPCow==")).isEqualTo("test");
assertThat(getQueryUser("trino_user", "Basic dGVzdDoxMjPCow==")).isEqualTo("trino_user");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,25 @@
*/
package io.trino.gateway.ha.router;

import com.auth0.jwt.JWT;
import com.auth0.jwt.algorithms.Algorithm;
import io.airlift.json.JsonCodec;
import io.trino.gateway.ha.config.RequestAnalyzerConfig;
import jakarta.servlet.http.HttpServletRequest;
import org.junit.jupiter.api.Test;

import java.time.Instant;
import java.util.Base64;
import java.util.Date;
import java.util.Optional;

import static com.auth0.jwt.algorithms.Algorithm.HMAC256;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
import static jakarta.ws.rs.core.HttpHeaders.AUTHORIZATION;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

final class TestTrinoRequestUser
{
Expand All @@ -42,4 +55,49 @@ void testJsonCreator()
assertThat(deserializedTrinoRequestUser.getUser()).isEqualTo(trinoRequestUser.getUser());
assertThat(deserializedTrinoRequestUser.getUserInfo()).isEqualTo(trinoRequestUser.getUserInfo());
}

@Test
void testUserFromJwtToken()
{
String claimUserName = "username";
String claimUserValue = "trino";

RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig();
requestAnalyzerConfig.setTokenUserField(claimUserName);

Algorithm algorithm = HMAC256("random");

Instant expiryTime = Instant.now().plusSeconds(60);
String token = JWT.create()
.withIssuer("gateway")
.withClaim(claimUserName, claimUserValue)
.withExpiresAt(Date.from(expiryTime))
.sign(algorithm);

HttpServletRequest mockRequest = mock(HttpServletRequest.class);
when(mockRequest.getHeader(USER_HEADER)).thenReturn(null);
when(mockRequest.getHeader(AUTHORIZATION)).thenReturn("Bearer " + token);

TrinoRequestUser trinoRequestUser = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig).getInstance(mockRequest);

assertThat(trinoRequestUser.getUser()).hasValue(claimUserValue);
}

@Test
void testGetBasicAuthUser()
{
String username = "trino_user";
String password = "don't care";
String credentials = username + ":" + password;
String encodedCredentials = Base64.getEncoder().encodeToString(credentials.getBytes(UTF_8));

HttpServletRequest mockRequest = mock(HttpServletRequest.class);
when(mockRequest.getHeader(USER_HEADER)).thenReturn(null);
when(mockRequest.getHeader(AUTHORIZATION)).thenReturn("Basic " + encodedCredentials);

RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig();
TrinoRequestUser trinoRequestUser = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig).getInstance(mockRequest);

assertThat(trinoRequestUser.getUser()).hasValue(username);
}
}

0 comments on commit 5d8a741

Please sign in to comment.