Skip to content

Commit

Permalink
Enable Serverless API protections dynamically
Browse files Browse the repository at this point in the history
This commit changes the method for enabling Serverless API protections
(see elastic#93607) from a system property to a runtime property.

Within this project there is no external way to configure this
property - it must be controlled by a plugin - but the RestController
has been changed to dynamically adapt to a change in that property.
  • Loading branch information
tvernum committed Jun 26, 2023
1 parent afbf1f5 commit 0834fe3
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ public ActionModule(
actionPlugins.stream().flatMap(p -> p.indicesAliasesRequestValidators().stream()).toList()
);
headersToCopy = headers;
restController = new RestController(restInterceptor, nodeClient, circuitBreakerService, usageService, tracer, serverlessEnabled);
restController = new RestController(restInterceptor, nodeClient, circuitBreakerService, usageService, tracer);
reservedClusterStateService = new ReservedClusterStateService(clusterService, reservedStateHandlers);
}

Expand Down
13 changes: 8 additions & 5 deletions server/src/main/java/org/elasticsearch/rest/RestController.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,14 @@ public class RestController implements HttpServerTransport.Dispatcher {
private final UsageService usageService;
private final Tracer tracer;
// If true, the ServerlessScope annotations will be enforced
private final boolean serverlessEnabled;
private final ServerlessApiProtections apiProtections;

public RestController(
UnaryOperator<RestHandler> handlerWrapper,
NodeClient client,
CircuitBreakerService circuitBreakerService,
UsageService usageService,
Tracer tracer,
boolean serverlessEnabled
Tracer tracer
) {
this.usageService = usageService;
this.tracer = tracer;
Expand All @@ -115,7 +114,11 @@ public RestController(
this.client = client;
this.circuitBreakerService = circuitBreakerService;
registerHandlerNoWrap(RestRequest.Method.GET, "/favicon.ico", RestApiVersion.current(), new RestFavIconHandler());
this.serverlessEnabled = serverlessEnabled;
this.apiProtections = new ServerlessApiProtections(false);
}

public ServerlessApiProtections getApiProtections() {
return apiProtections;
}

/**
Expand Down Expand Up @@ -374,7 +377,7 @@ private void dispatchRequest(RestRequest request, RestChannel channel, RestHandl
}
}
RestChannel responseChannel = channel;
if (serverlessEnabled) {
if (apiProtections.isEnabled()) {
Scope scope = handler.getServerlessScope();
if (scope == null) {
handleServerlessRequestToProtectedResource(request.uri(), request.method(), responseChannel);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.rest;

public class ServerlessApiProtections {

private volatile boolean enabled;

public ServerlessApiProtections(boolean enabled) {
this.enabled = enabled;
}

public boolean isEnabled() {
return enabled;
}

public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public void setup() {
HttpServerTransport httpServerTransport = new TestHttpServerTransport();
client = new NoOpNodeClient(this.getTestName());
tracer = mock(Tracer.class);
restController = new RestController(null, client, circuitBreakerService, usageService, tracer, false);
restController = new RestController(null, client, circuitBreakerService, usageService, tracer);
restController.registerHandler(
new Route(GET, "/"),
(request, channel, client) -> channel.sendResponse(
Expand All @@ -129,7 +129,7 @@ public void teardown() throws IOException {

public void testApplyProductSpecificResponseHeaders() {
final ThreadContext threadContext = client.threadPool().getThreadContext();
final RestController restController = new RestController(null, null, circuitBreakerService, usageService, tracer, false);
final RestController restController = new RestController(null, null, circuitBreakerService, usageService, tracer);
RestRequest fakeRequest = new FakeRestRequest.Builder(xContentRegistry()).build();
AssertingChannel channel = new AssertingChannel(fakeRequest, false, RestStatus.BAD_REQUEST);
restController.dispatchRequest(fakeRequest, channel, threadContext);
Expand All @@ -145,7 +145,7 @@ public void testRequestWithDisallowedMultiValuedHeader() {
Set<RestHeaderDefinition> headers = new HashSet<>(
Arrays.asList(new RestHeaderDefinition("header.1", true), new RestHeaderDefinition("header.2", false))
);
final RestController restController = new RestController(null, null, circuitBreakerService, usageService, tracer, false);
final RestController restController = new RestController(null, null, circuitBreakerService, usageService, tracer);
Map<String, List<String>> restHeaders = new HashMap<>();
restHeaders.put("header.1", Collections.singletonList("boo"));
restHeaders.put("header.2", List.of("foo", "bar"));
Expand All @@ -160,7 +160,7 @@ public void testRequestWithDisallowedMultiValuedHeader() {
*/
public void testDispatchStartsTrace() {
final ThreadContext threadContext = client.threadPool().getThreadContext();
final RestController restController = new RestController(null, null, circuitBreakerService, usageService, tracer, false);
final RestController restController = new RestController(null, null, circuitBreakerService, usageService, tracer);
RestRequest fakeRequest = new FakeRestRequest.Builder(xContentRegistry()).build();
final RestController spyRestController = spy(restController);
when(spyRestController.getAllHandlers(null, fakeRequest.rawPath())).thenReturn(new Iterator<>() {
Expand Down Expand Up @@ -189,7 +189,7 @@ public void testRequestWithDisallowedMultiValuedHeaderButSameValues() {
Set<RestHeaderDefinition> headers = new HashSet<>(
Arrays.asList(new RestHeaderDefinition("header.1", true), new RestHeaderDefinition("header.2", false))
);
final RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer, false);
final RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer);
Map<String, List<String>> restHeaders = new HashMap<>();
restHeaders.put("header.1", Collections.singletonList("boo"));
restHeaders.put("header.2", List.of("foo", "foo"));
Expand Down Expand Up @@ -260,7 +260,7 @@ public void testRegisterAsReplacedHandler() {
}

public void testRegisterSecondMethodWithDifferentNamedWildcard() {
final RestController restController = new RestController(null, null, circuitBreakerService, usageService, tracer, false);
final RestController restController = new RestController(null, null, circuitBreakerService, usageService, tracer);

RestRequest.Method firstMethod = randomFrom(methodList);
RestRequest.Method secondMethod = randomFrom(methodList.stream().filter(m -> m != firstMethod).toList());
Expand All @@ -287,7 +287,7 @@ public void testRestHandlerWrapper() throws Exception {
final RestController restController = new RestController(h -> {
assertSame(handler, h);
return (RestRequest request, RestChannel channel, NodeClient client) -> wrapperCalled.set(true);
}, client, circuitBreakerService, usageService, tracer, false);
}, client, circuitBreakerService, usageService, tracer);
restController.registerHandler(new Route(GET, "/wrapped"), handler);
RestRequest request = testRestRequest("/wrapped", "{}", XContentType.JSON);
AssertingChannel channel = new AssertingChannel(request, true, RestStatus.BAD_REQUEST);
Expand Down Expand Up @@ -374,7 +374,7 @@ public void testDispatchRequiresContentTypeForRequestsWithContent() {
String content = randomAlphaOfLength((int) Math.round(BREAKER_LIMIT.getBytes() / inFlightRequestsBreaker.getOverhead()));
RestRequest request = testRestRequest("/", content, null);
AssertingChannel channel = new AssertingChannel(request, true, RestStatus.NOT_ACCEPTABLE);
restController = new RestController(null, null, circuitBreakerService, usageService, tracer, false);
restController = new RestController(null, null, circuitBreakerService, usageService, tracer);
restController.registerHandler(
new Route(GET, "/"),
(r, c, client) -> c.sendResponse(new RestResponse(RestStatus.OK, RestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY))
Expand Down Expand Up @@ -761,7 +761,7 @@ public Method method() {

public void testDispatchCompatibleHandler() {

RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer, false);
RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer);

final RestApiVersion version = RestApiVersion.minimumSupported();

Expand All @@ -785,7 +785,7 @@ public void testDispatchCompatibleHandler() {

public void testDispatchCompatibleRequestToNewlyAddedHandler() {

RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer, false);
RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer);

final RestApiVersion version = RestApiVersion.minimumSupported();

Expand Down Expand Up @@ -820,7 +820,7 @@ private FakeRestRequest requestWithContent(String mediaType) {
}

public void testCurrentVersionVNDMediaTypeIsNotUsingCompatibility() {
RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer, false);
RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer);

final RestApiVersion version = RestApiVersion.current();

Expand All @@ -845,7 +845,7 @@ public void testCurrentVersionVNDMediaTypeIsNotUsingCompatibility() {
}

public void testCustomMediaTypeValidation() {
RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer, false);
RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer);

final String mediaType = "application/x-protobuf";
FakeRestRequest fakeRestRequest = requestWithContent(mediaType);
Expand All @@ -871,7 +871,7 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c
}

public void testBrowserSafelistedContentTypesAreRejected() {
RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer, false);
RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer);

final String mediaType = randomFrom(RestController.SAFELISTED_MEDIA_TYPES);
FakeRestRequest fakeRestRequest = requestWithContent(mediaType);
Expand All @@ -892,7 +892,7 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c
}

public void testRegisterWithReservedPath() {
final RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer, false);
final RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer);
for (String path : RestController.RESERVED_PATHS) {
IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> {
restController.registerHandler(
Expand All @@ -910,7 +910,7 @@ public void testRegisterWithReservedPath() {
* Test that when serverless is disabled, all endpoints are available regardless of ServerlessScope annotations.
*/
public void testApiProtectionWithServerlessDisabled() {
final RestController restController = new RestController(null, client, circuitBreakerService, new UsageService(), tracer, false);
final RestController restController = new RestController(null, client, circuitBreakerService, new UsageService(), tracer);
restController.registerHandler(new PublicRestHandler());
restController.registerHandler(new InternalRestHandler());
restController.registerHandler(new HiddenRestHandler());
Expand All @@ -926,7 +926,7 @@ public void testApiProtectionWithServerlessDisabled() {
* Test that when serverless is enabled, a normal user can not access endpoints without a ServerlessScope annotation.
*/
public void testApiProtectionWithServerlessEnabledAsEndUser() {
final RestController restController = new RestController(null, client, circuitBreakerService, new UsageService(), tracer, true);
final RestController restController = new RestController(null, client, circuitBreakerService, new UsageService(), tracer);
restController.registerHandler(new PublicRestHandler());
restController.registerHandler(new InternalRestHandler());
restController.registerHandler(new HiddenRestHandler());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public void testUnsupportedMethodResponseHttpHeader() throws Exception {
);

UsageService usageService = new UsageService();
RestController restController = new RestController(null, null, circuitBreakerService, usageService, Tracer.NOOP, false);
RestController restController = new RestController(null, null, circuitBreakerService, usageService, Tracer.NOOP);

// A basic RestHandler handles requests to the endpoint
RestHandler restHandler = (request, channel, client) -> channel.sendResponse(new RestResponse(RestStatus.OK, ""));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public class RestValidateQueryActionTests extends AbstractSearchTestCase {
private NodeClient client = new NodeClient(Settings.EMPTY, threadPool);

private UsageService usageService = new UsageService();
private RestController controller = new RestController(null, client, new NoneCircuitBreakerService(), usageService, Tracer.NOOP, false);
private RestController controller = new RestController(null, client, new NoneCircuitBreakerService(), usageService, Tracer.NOOP);
private RestValidateQueryAction action = new RestValidateQueryAction();

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public abstract class RestActionTestCase extends ESTestCase {
@Before
public void setUpController() {
verifyingClient = new VerifyingClient(this.getTestName());
controller = new RestController(null, verifyingClient, new NoneCircuitBreakerService(), new UsageService(), Tracer.NOOP, false);
controller = new RestController(null, verifyingClient, new NoneCircuitBreakerService(), new UsageService(), Tracer.NOOP);
}

@After
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ public class RestTermsEnumActionTests extends ESTestCase {
client,
new NoneCircuitBreakerService(),
usageService,
Tracer.NOOP,
false
Tracer.NOOP
);
private static RestTermsEnumAction action = new RestTermsEnumAction();

Expand Down

0 comments on commit 0834fe3

Please sign in to comment.