Skip to content

Commit

Permalink
Use TrinoRequestUser to get the user for the query
Browse files Browse the repository at this point in the history
  • Loading branch information
vishalya committed Oct 17, 2024
1 parent a45d249 commit 2168a08
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ private Optional<String> extractUserFromBearerAuth(String header, String userFie

String token = header.substring(space + 1).trim();

if (header.split("\\.").length == 3) { //this is probably a JWS
if (token.split("\\.").length == 3) { //this is probably a JWS
log.debug("Trying to extract from JWS");
try {
DecodedJWT jwt = JWT.decode(header);
DecodedJWT jwt = JWT.decode(token);
if (jwt.getClaims().containsKey(userField)) {
return Optional.of(jwt.getClaim(userField).asString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.trino.gateway.ha.router.OAuth2GatewayCookie;
import io.trino.gateway.ha.router.QueryHistoryManager;
import io.trino.gateway.ha.router.RoutingManager;
import io.trino.gateway.ha.router.TrinoRequestUser;
import io.trino.gateway.proxyserver.ProxyResponseHandler.ProxyResponse;
import jakarta.annotation.PreDestroy;
import jakarta.servlet.http.HttpServletRequest;
Expand Down Expand Up @@ -58,11 +59,8 @@
import static io.airlift.http.client.Request.Builder.preparePost;
import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator;
import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse;
import static io.trino.gateway.ha.handler.ProxyUtils.AUTHORIZATION;
import static io.trino.gateway.ha.handler.ProxyUtils.QUERY_TEXT_LENGTH_FOR_HISTORY;
import static io.trino.gateway.ha.handler.ProxyUtils.SOURCE_HEADER;
import static io.trino.gateway.ha.handler.ProxyUtils.getQueryUser;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE;
import static jakarta.ws.rs.core.Response.Status.BAD_GATEWAY;
import static jakarta.ws.rs.core.Response.Status.OK;
Expand All @@ -86,6 +84,7 @@ public class ProxyRequestHandler
private final boolean addXForwardedHeaders;
private final List<String> statementPaths;
private final boolean includeClusterInfoInResponse;
private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider;

@Inject
public ProxyRequestHandler(
Expand All @@ -97,6 +96,7 @@ public ProxyRequestHandler(
this.httpClient = requireNonNull(httpClient, "httpClient is null");
this.routingManager = requireNonNull(routingManager, "routingManager is null");
this.queryHistoryManager = requireNonNull(queryHistoryManager, "queryHistoryManager is null");
trinoRequestUserProvider = new TrinoRequestUser.TrinoRequestUserProvider(haGatewayConfiguration.getRequestAnalyzerConfig());
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
asyncTimeout = haGatewayConfiguration.getRouting().getAsyncTimeout();
addXForwardedHeaders = haGatewayConfiguration.getRouting().isAddXForwardedHeaders();
Expand Down Expand Up @@ -173,7 +173,8 @@ private void performRequest(
FluentFuture<ProxyResponse> future = executeHttp(request);

if (statementPaths.stream().anyMatch(request.getUri().getPath()::startsWith) && request.getMethod().equals(HttpMethod.POST)) {
future = future.transform(response -> recordBackendForQueryId(request, response), executor);
String username = trinoRequestUserProvider.getInstance(servletRequest).getUser().orElse("Unknown");
future = future.transform(response -> recordBackendForQueryId(request, response, username), executor);
if (includeClusterInfoInResponse) {
cookieBuilder.add(new NewCookie.Builder("trinoClusterHost").value(remoteUri.getHost()).build());
}
Expand Down Expand Up @@ -250,11 +251,11 @@ private static WebApplicationException badRequest(String message)
.build());
}

private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse response)
private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse response, String username)
{
log.debug("For Request [%s] got Response [%s]", request.getUri(), response.body());

QueryHistoryManager.QueryDetail queryDetail = getQueryDetailsFromRequest(request);
QueryHistoryManager.QueryDetail queryDetail = getQueryDetailsFromRequest(request, username);

log.debug("Extracting proxy destination : [%s] for request : [%s]", queryDetail.getBackendUrl(), request.getUri());

Expand All @@ -276,12 +277,12 @@ private ProxyResponse recordBackendForQueryId(Request request, ProxyResponse res
return response;
}

public static QueryHistoryManager.QueryDetail getQueryDetailsFromRequest(Request request)
public static QueryHistoryManager.QueryDetail getQueryDetailsFromRequest(Request request, String username)
{
QueryHistoryManager.QueryDetail queryDetail = new QueryHistoryManager.QueryDetail();
queryDetail.setBackendUrl(getRemoteTarget(request.getUri()));
queryDetail.setCaptureTime(System.currentTimeMillis());
queryDetail.setUser(getQueryUser(request.getHeader(USER_HEADER), request.getHeader(AUTHORIZATION)));
queryDetail.setUser(username);
queryDetail.setSource(request.getHeader(SOURCE_HEADER));

String queryText = new String(((StaticBodyGenerator) request.getBodyGenerator()).getBody(), UTF_8);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,22 @@
*/
package io.trino.gateway.ha.router;

import com.auth0.jwt.JWT;
import com.auth0.jwt.algorithms.Algorithm;
import io.airlift.json.JsonCodec;
import io.trino.gateway.ha.config.RequestAnalyzerConfig;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.core.HttpHeaders;
import org.junit.jupiter.api.Test;

import java.time.Instant;
import java.util.Date;
import java.util.Optional;

import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.USER_HEADER;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

final class TestTrinoRequestUser
{
Expand All @@ -42,4 +52,31 @@ void testJsonCreator()
assertThat(deserializedTrinoRequestUser.getUser()).isEqualTo(trinoRequestUser.getUser());
assertThat(deserializedTrinoRequestUser.getUserInfo()).isEqualTo(trinoRequestUser.getUserInfo());
}

@Test
void testUserFromJwtToken()
{
String claimUserName = "username";
String claimUserValue = "trino";

RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig();
requestAnalyzerConfig.setTokenUserField(claimUserName);

Algorithm algorithm = Algorithm.HMAC256("random");

Instant expiryTime = Instant.now().plusSeconds(60);
String token = JWT.create()
.withIssuer("gateway")
.withClaim(claimUserName, claimUserValue)
.withExpiresAt(Date.from(expiryTime))
.sign(algorithm);

HttpServletRequest mockRequest = mock(HttpServletRequest.class);
when(mockRequest.getHeader(USER_HEADER)).thenReturn(null);
when(mockRequest.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn("Bearer " + token);

TrinoRequestUser trinoRequestUser = new TrinoRequestUser.TrinoRequestUserProvider(requestAnalyzerConfig).getInstance(mockRequest);

assertThat(trinoRequestUser.getUser()).hasValue(claimUserValue);
}
}

0 comments on commit 2168a08

Please sign in to comment.