Skip to content

Commit

Permalink
Add proxy handler for PUT request
Browse files Browse the repository at this point in the history
  • Loading branch information
siminyou authored and ebyhr committed Nov 15, 2024
1 parent 62bbe94 commit a11e5a4
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import static io.airlift.http.client.Request.Builder.prepareDelete;
import static io.airlift.http.client.Request.Builder.prepareGet;
import static io.airlift.http.client.Request.Builder.preparePost;
import static io.airlift.http.client.Request.Builder.preparePut;
import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator;
import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse;
import static io.trino.gateway.ha.handler.ProxyUtils.QUERY_TEXT_LENGTH_FOR_HISTORY;
Expand Down Expand Up @@ -140,6 +141,17 @@ public void postRequest(
performRequest(remoteUri, servletRequest, asyncResponse, request);
}

public void putRequest(
String statement,
HttpServletRequest servletRequest,
AsyncResponse asyncResponse,
URI remoteUri)
{
Request.Builder request = preparePut()
.setBodyGenerator(createStaticBodyGenerator(statement, UTF_8));
performRequest(remoteUri, servletRequest, asyncResponse, request);
}

private void performRequest(
URI remoteUri,
HttpServletRequest servletRequest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jakarta.ws.rs.DELETE;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.PUT;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.container.AsyncResponse;
import jakarta.ws.rs.container.Suspended;
Expand Down Expand Up @@ -85,4 +86,15 @@ public void deleteHandler(
String remoteUri = routingTargetHandler.getRoutingDestination(servletRequest);
proxyRequestHandler.deleteRequest(servletRequest, asyncResponse, URI.create(remoteUri));
}

@PUT
public void putHandler(
String body,
@Context HttpServletRequest servletRequest,
@Suspended AsyncResponse asyncResponse)
{
MultiReadHttpServletRequest multiReadHttpServletRequest = new MultiReadHttpServletRequest(servletRequest, body);
String remoteUri = routingTargetHandler.getRoutingDestination(multiReadHttpServletRequest);
proxyRequestHandler.putRequest(body, multiReadHttpServletRequest, asyncResponse, URI.create(remoteUri));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.gateway.proxyserver;

import io.trino.gateway.ha.HaGatewayLauncher;
import io.trino.gateway.ha.HaGatewayTestUtils;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.mockwebserver.Dispatcher;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
import static com.google.common.net.MediaType.JSON_UTF_8;
import static io.trino.gateway.ha.HaGatewayTestUtils.buildGatewayConfigAndSeedDb;
import static io.trino.gateway.ha.HaGatewayTestUtils.prepareMockBackend;
import static io.trino.gateway.ha.HaGatewayTestUtils.setUpBackend;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;

@TestInstance(PER_CLASS)
final class TestProxyRequestHandler
{
private final OkHttpClient httpClient = new OkHttpClient();
private final MockWebServer mockTrinoServer = new MockWebServer();

private final int routerPort = 21001 + (int) (Math.random() * 1000);
private final int customBackendPort = 21000 + (int) (Math.random() * 1000);

private static final String OK = "OK";
private static final int NOT_FOUND = 404;
private static final MediaType MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8");

private final String customPutEndpoint = "/v1/custom"; // this is enabled in test-config-template.yml
private final String healthCheckEndpoint = "/v1/info";

@BeforeAll
void setup()
throws Exception
{
prepareMockBackend(mockTrinoServer, customBackendPort, "default custom response");
mockTrinoServer.setDispatcher(new Dispatcher() {
@Override
public MockResponse dispatch(RecordedRequest request)
{
if (request.getPath().equals(healthCheckEndpoint)) {
return new MockResponse().setResponseCode(200)
.setHeader(CONTENT_TYPE, JSON_UTF_8)
.setBody("{\"starting\": false}");
}

if (request.getMethod().equals("PUT") && request.getPath().equals(customPutEndpoint)) {
return new MockResponse().setResponseCode(200)
.setHeader(CONTENT_TYPE, JSON_UTF_8)
.setBody(OK);
}

return new MockResponse().setResponseCode(NOT_FOUND);
}
});

HaGatewayTestUtils.TestConfig testConfig = buildGatewayConfigAndSeedDb(routerPort, "test-config-template.yml");

String[] args = {testConfig.configFilePath()};
HaGatewayLauncher.main(args);

setUpBackend("custom", "http://localhost:" + customBackendPort, "externalUrl", true, "adhoc", routerPort);
}

@AfterAll
void cleanup()
throws Exception
{
mockTrinoServer.shutdown();
}

@Test
void testPutRequestHandler()
throws Exception
{
String url = "http://localhost:" + routerPort + customPutEndpoint;
RequestBody requestBody = RequestBody.create("SELECT 1", MEDIA_TYPE);

Request putRequest = new Request.Builder().url(url).put(requestBody).build();
try (Response response = httpClient.newCall(putRequest).execute()) {
assertThat(response.body()).isNotNull();
assertThat(response.body().string()).isEqualTo(OK);
}

Request postRequest = new Request.Builder().url(url).post(requestBody).build();
try (Response response = httpClient.newCall(postRequest).execute()) {
assertThat(response.code()).isEqualTo(NOT_FOUND);
}
}
}

0 comments on commit a11e5a4

Please sign in to comment.