diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java index 4a2134c8a..5f74a3ca4 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java @@ -41,6 +41,7 @@ public class HaGatewayConfiguration private OAuth2GatewayCookieConfiguration oauth2GatewayCookieConfiguration = new OAuth2GatewayCookieConfiguration(); private GatewayCookieConfiguration gatewayCookieConfiguration = new GatewayCookieConfiguration(); private List statementPaths = ImmutableList.of(V1_STATEMENT_PATH); + private boolean includeClusterHostInResponse; private RequestAnalyzerConfig requestAnalyzerConfig = new RequestAnalyzerConfig(); @@ -244,6 +245,16 @@ public void setAdditionalStatementPaths(List statementPaths) statementPaths.stream().peek(s -> validateStatementPath(s, statementPaths)).map(s -> s.replaceAll("/+$", ""))).toList(); } + public boolean isIncludeClusterHostInResponse() + { + return includeClusterHostInResponse; + } + + public void setIncludeClusterHostInResponse(boolean includeClusterHostInResponse) + { + this.includeClusterHostInResponse = includeClusterHostInResponse; + } + private void validateStatementPath(String statementPath, List statementPaths) { if (statementPath.startsWith(V1_STATEMENT_PATH) || diff --git a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java index 9a180eb1d..17a77e92b 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java +++ b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java @@ -85,6 +85,7 @@ public class ProxyRequestHandler private final boolean cookiesEnabled; private final boolean addXForwardedHeaders; private final List statementPaths; + private final boolean includeClusterInfoInResponse; @Inject public ProxyRequestHandler( @@ -100,6 +101,7 @@ public ProxyRequestHandler( asyncTimeout = haGatewayConfiguration.getRouting().getAsyncTimeout(); addXForwardedHeaders = haGatewayConfiguration.getRouting().isAddXForwardedHeaders(); statementPaths = haGatewayConfiguration.getStatementPaths(); + this.includeClusterInfoInResponse = haGatewayConfiguration.isIncludeClusterHostInResponse(); } @PreDestroy @@ -160,7 +162,8 @@ private void performRequest( addXForwardedHeaders(servletRequest, requestBuilder); } - ImmutableList oauth2GatewayCookie = getOAuth2GatewayCookie(remoteUri, servletRequest); + ImmutableList.Builder cookieBuilder = ImmutableList.builder(); + cookieBuilder.addAll(getOAuth2GatewayCookie(remoteUri, servletRequest)); Request request = requestBuilder .setPreserveAuthorizationOnRedirect(true) @@ -171,11 +174,14 @@ private void performRequest( if (statementPaths.stream().anyMatch(request.getUri().getPath()::startsWith) && request.getMethod().equals(HttpMethod.POST)) { future = future.transform(response -> recordBackendForQueryId(request, response), executor); + if (includeClusterInfoInResponse) { + cookieBuilder.add(new NewCookie.Builder("trinoClusterHost").value(remoteUri.getHost()).build()); + } } setupAsyncResponse( asyncResponse, - future.transform(response -> buildResponse(response, oauth2GatewayCookie), executor) + future.transform(response -> buildResponse(response, cookieBuilder.build()), executor) .catching(ProxyException.class, e -> handleProxyException(request, e), directExecutor())); } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/TestGatewayHaMultipleBackend.java b/gateway-ha/src/test/java/io/trino/gateway/ha/TestGatewayHaMultipleBackend.java index 1ffcc6dd5..7f7e5f0da 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/TestGatewayHaMultipleBackend.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/TestGatewayHaMultipleBackend.java @@ -40,6 +40,7 @@ import java.util.Optional; import java.util.concurrent.TimeUnit; +import static com.google.common.collect.MoreCollectors.onlyElement; import static org.assertj.core.api.Assertions.assertThat; import static org.testcontainers.utility.MountableFile.forClasspathResource; @@ -158,6 +159,42 @@ public void testQueryDeliveryToMultipleRoutingGroups() assertThat(response4.body().string()).contains("http://localhost:" + routerPort); } + @Test + public void testTrinoClusterHostCookie() + throws Exception + { + RequestBody requestBody = RequestBody.create("SELECT 1", MediaType.get("application/json; charset=utf-8")); + + // When X-Trino-Routing-Group is set in header, query should be routed to cluster under the routing group + Request requestOne = + new Request.Builder() + .url("http://localhost:" + routerPort + "/v1/statement") + .addHeader("X-Trino-User", "test") + .post(requestBody) + .addHeader("X-Trino-Routing-Group", "scheduled") + .build(); + Response responseOne = httpClient.newCall(requestOne).execute(); + assertThat(responseOne.body().string()).contains("http://localhost:" + routerPort); + List cookies = Cookie.parseAll(responseOne.request().url(), responseOne.headers()); + Cookie cookie = cookies.stream().filter(c -> c.name().equals("trinoClusterHost")).collect(onlyElement()); + assertThat(cookie.value()).isEqualTo("localhost"); + // test with sending the request which includes trinoClusterHost in the cookie + // when X-Trino-Routing-Group is set in header, query should be routed to cluster under the routing group + Request requestTwo = + new Request.Builder() + .url("http://localhost:" + routerPort + "/v1/statement") + .addHeader("X-Trino-User", "test") + .post(requestBody) + .addHeader("X-Trino-Routing-Group", "scheduled") + .addHeader("Cookie", "trinoClientHost=foo.example.com") + .build(); + Response responseTwo = httpClient.newCall(requestTwo).execute(); + assertThat(responseTwo.body().string()).contains("http://localhost:" + routerPort); + cookies = Cookie.parseAll(responseTwo.request().url(), responseTwo.headers()); + cookie = cookies.stream().filter(c -> c.name().equals("trinoClusterHost")).collect(onlyElement()); + assertThat(cookie.value()).isEqualTo("localhost"); + } + @Test public void testDeleteQueryId() throws IOException diff --git a/gateway-ha/src/test/resources/test-config-template.yml b/gateway-ha/src/test/resources/test-config-template.yml index b22e27fe9..beafe30c7 100644 --- a/gateway-ha/src/test/resources/test-config-template.yml +++ b/gateway-ha/src/test/resources/test-config-template.yml @@ -2,6 +2,7 @@ serverConfig: node.environment: test http-server.http.port: REQUEST_ROUTER_PORT +includeClusterHostInResponse: true dataStore: jdbcUrl: jdbc:h2:DB_FILE_PATH user: sa