Skip to content

Commit

Permalink
Fix: register mulitple extensions. (opensearch-project#10256)
Browse files Browse the repository at this point in the history
* Fix: register mulitple extensions.

Signed-off-by: dblock <[email protected]>

* Updated CHANGELOG.

Signed-off-by: dblock <[email protected]>

* Added tests.

Signed-off-by: dblock <[email protected]>

---------

Signed-off-by: dblock <[email protected]>
Signed-off-by: Shivansh Arora <[email protected]>
  • Loading branch information
dblock authored and shiv0408 committed Apr 25, 2024
1 parent d81429b commit bdb5d5b
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 24 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix concurrent search NPE when track_total_hits, terminate_after and size=0 are used ([#10082](https://github.com/opensearch-project/OpenSearch/pull/10082))
- Fix remove ingest processor handing ignore_missing parameter not correctly ([10089](https://github.com/opensearch-project/OpenSearch/pull/10089))
- Fix circular dependency in Settings initialization ([10194](https://github.com/opensearch-project/OpenSearch/pull/10194))
- Fix registration and initialization of multiple extensions ([10256](https://github.com/opensearch-project/OpenSearch/pull/10256))

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ private void registerRequestHandler(DynamicActionRegistry dynamicActionRegistry)
* Loads a single extension
* @param extension The extension to be loaded
*/
public void loadExtension(Extension extension) throws IOException {
public DiscoveryExtensionNode loadExtension(Extension extension) throws IOException {
validateExtension(extension);
DiscoveryExtensionNode discoveryExtensionNode = new DiscoveryExtensionNode(
extension.getName(),
Expand All @@ -314,6 +314,12 @@ public void loadExtension(Extension extension) throws IOException {
extensionIdMap.put(extension.getUniqueId(), discoveryExtensionNode);
extensionSettingsMap.put(extension.getUniqueId(), extension);
logger.info("Loaded extension with uniqueId " + extension.getUniqueId() + ": " + extension);
return discoveryExtensionNode;
}

public void initializeExtension(Extension extension) throws IOException {
DiscoveryExtensionNode node = loadExtension(extension);
initializeExtensionNode(node);
}

private void validateField(String fieldName, String value) throws IOException {
Expand All @@ -340,11 +346,11 @@ private void validateExtension(Extension extension) throws IOException {
*/
public void initialize() {
for (DiscoveryExtensionNode extension : extensionIdMap.values()) {
initializeExtension(extension);
initializeExtensionNode(extension);
}
}

private void initializeExtension(DiscoveryExtensionNode extension) {
public void initializeExtensionNode(DiscoveryExtensionNode extensionNode) {

final CompletableFuture<InitializeExtensionResponse> inProgressFuture = new CompletableFuture<>();
final TransportResponseHandler<InitializeExtensionResponse> initializeExtensionResponseHandler = new TransportResponseHandler<
Expand Down Expand Up @@ -384,7 +390,8 @@ public String executor() {
transportService.getThreadPool().generic().execute(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
extensionIdMap.remove(extension.getId());
logger.warn("Error registering extension: " + extensionNode.getId(), e);
extensionIdMap.remove(extensionNode.getId());
if (e.getCause() instanceof ConnectTransportException) {
logger.info("No response from extension to request.", e);
throw (ConnectTransportException) e.getCause();
Expand All @@ -399,11 +406,11 @@ public void onFailure(Exception e) {

@Override
protected void doRun() throws Exception {
transportService.connectToExtensionNode(extension);
transportService.connectToExtensionNode(extensionNode);
transportService.sendRequest(
extension,
extensionNode,
REQUEST_EXTENSION_ACTION_NAME,
new InitializeExtensionRequest(transportService.getLocalNode(), extension, issueServiceAccount(extension)),
new InitializeExtensionRequest(transportService.getLocalNode(), extensionNode, issueServiceAccount(extensionNode)),
initializeExtensionResponseHandler
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ public TransportResponse handleRegisterRestActionsRequest(
DynamicActionRegistry dynamicActionRegistry
) throws Exception {
DiscoveryExtensionNode discoveryExtensionNode = extensionIdMap.get(restActionsRequest.getUniqueId());
if (discoveryExtensionNode == null) {
throw new IllegalStateException("Missing extension node for " + restActionsRequest.getUniqueId());
}
RestHandler handler = new RestSendToExtensionAction(
restActionsRequest,
discoveryExtensionNode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
extAdditionalSettings
);
try {
extensionsManager.loadExtension(extension);
extensionsManager.initialize();
extensionsManager.initializeExtension(extension);
} catch (CompletionException e) {
Throwable cause = e.getCause();
if (cause instanceof TimeoutException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ public RestSendToExtensionAction(

@Override
public String getName() {
return SEND_TO_EXTENSION_ACTION;
return this.discoveryExtensionNode.getId() + ":" + SEND_TO_EXTENSION_ACTION;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.indices.breaker.NoneCircuitBreakerService;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.discovery.InitializeExtensionRequest;
import org.opensearch.env.Environment;
import org.opensearch.env.EnvironmentSettingsResponse;
import org.opensearch.extensions.ExtensionsSettings.Extension;
Expand Down Expand Up @@ -77,6 +78,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
Expand Down Expand Up @@ -409,19 +411,94 @@ public void testInitialize() throws Exception {
)
);

// Test needs to be changed to mock the connection between the local node and an extension. Assert statment is commented out for
// now.
// Test needs to be changed to mock the connection between the local node and an extension.
// Link to issue: https://github.com/opensearch-project/OpenSearch/issues/4045
// mockLogAppender.assertAllExpectationsMatched();
}
}

public void testInitializeExtension() throws Exception {
ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);

TransportService mockTransportService = spy(
new TransportService(
Settings.EMPTY,
mock(Transport.class),
threadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> null,
null,
Collections.emptySet(),
NoopTracer.INSTANCE
)
);

doNothing().when(mockTransportService).connectToExtensionNode(any(DiscoveryExtensionNode.class));

doNothing().when(mockTransportService)
.sendRequest(any(DiscoveryExtensionNode.class), anyString(), any(InitializeExtensionRequest.class), any());

extensionsManager.initializeServicesAndRestHandler(
actionModule,
settingsModule,
mockTransportService,
clusterService,
settings,
client,
identityService
);

Extension firstExtension = new Extension(
"firstExtension",
"uniqueid1",
"127.0.0.0",
"9301",
"0.0.7",
"2.0.0",
"2.0.0",
List.of(),
null
);

extensionsManager.initializeExtension(firstExtension);

Extension secondExtension = new Extension(
"secondExtension",
"uniqueid2",
"127.0.0.0",
"9301",
"0.0.7",
"2.0.0",
"2.0.0",
List.of(),
null
);

extensionsManager.initializeExtension(secondExtension);

ThreadPool.terminate(threadPool, 3, TimeUnit.SECONDS);

verify(mockTransportService, times(2)).connectToExtensionNode(any(DiscoveryExtensionNode.class));

verify(mockTransportService, times(2)).sendRequest(
any(DiscoveryExtensionNode.class),
anyString(),
any(InitializeExtensionRequest.class),
any()
);
}

public void testHandleRegisterRestActionsRequest() throws Exception {

ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
initialize(extensionsManager);

String uniqueIdStr = "uniqueid1";

extensionsManager.loadExtension(
new Extension("firstExtension", uniqueIdStr, "127.0.0.0", "9300", "0.0.7", "3.0.0", "3.0.0", List.of(), null)
);

List<String> actionsList = List.of("GET /foo foo", "PUT /bar bar", "POST /baz baz");
List<String> deprecatedActionsList = List.of("GET /deprecated/foo foo_deprecated", "It's deprecated!");
RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, actionsList, deprecatedActionsList);
Expand All @@ -431,6 +508,58 @@ public void testHandleRegisterRestActionsRequest() throws Exception {
assertTrue(((AcknowledgedResponse) response).getStatus());
}

public void testHandleRegisterRestActionsRequestRequiresDiscoveryNode() throws Exception {

ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
initialize(extensionsManager);

RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest("uniqueId1", List.of(), List.of());

expectThrows(
IllegalStateException.class,
() -> extensionsManager.getRestActionsRequestHandler()
.handleRegisterRestActionsRequest(registerActionsRequest, actionModule.getDynamicActionRegistry())
);
}

public void testHandleRegisterRestActionsRequestMultiple() throws Exception {

ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
initialize(extensionsManager);

List<String> actionsList = List.of("GET /foo foo", "PUT /bar bar", "POST /baz baz");
List<String> deprecatedActionsList = List.of("GET /deprecated/foo foo_deprecated", "It's deprecated!");
for (int i = 0; i < 2; i++) {
String uniqueIdStr = "uniqueid-%d" + i;

Set<Setting<?>> additionalSettings = extAwarePlugin.getExtensionSettings().stream().collect(Collectors.toSet());
ExtensionScopedSettings extensionScopedSettings = new ExtensionScopedSettings(additionalSettings);
Extension firstExtension = new Extension(
"Extension %s" + i,
uniqueIdStr,
"127.0.0.0",
"9300",
"0.0.7",
"3.0.0",
"3.0.0",
List.of(),
extensionScopedSettings
);

extensionsManager.loadExtension(firstExtension);

RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(
uniqueIdStr,
actionsList,
deprecatedActionsList
);
TransportResponse response = extensionsManager.getRestActionsRequestHandler()
.handleRegisterRestActionsRequest(registerActionsRequest, actionModule.getDynamicActionRegistry());
assertEquals(AcknowledgedResponse.class, response.getClass());
assertTrue(((AcknowledgedResponse) response).getStatus());
}
}

public void testHandleRegisterSettingsRequest() throws Exception {
ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
initialize(extensionsManager);
Expand All @@ -452,6 +581,9 @@ public void testHandleRegisterRestActionsRequestWithInvalidMethod() throws Excep
initialize(extensionsManager);

String uniqueIdStr = "uniqueid1";
extensionsManager.loadExtension(
new Extension("firstExtension", uniqueIdStr, "127.0.0.0", "9300", "0.0.7", "3.0.0", "3.0.0", List.of(), null)
);
List<String> actionsList = List.of("FOO /foo", "PUT /bar", "POST /baz");
List<String> deprecatedActionsList = List.of("GET /deprecated/foo", "It's deprecated!");
RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, actionsList, deprecatedActionsList);
Expand All @@ -467,6 +599,9 @@ public void testHandleRegisterRestActionsRequestWithInvalidDeprecatedMethod() th
initialize(extensionsManager);

String uniqueIdStr = "uniqueid1";
extensionsManager.loadExtension(
new Extension("firstExtension", uniqueIdStr, "127.0.0.0", "9300", "0.0.7", "3.0.0", "3.0.0", List.of(), null)
);
List<String> actionsList = List.of("GET /foo", "PUT /bar", "POST /baz");
List<String> deprecatedActionsList = List.of("FOO /deprecated/foo", "It's deprecated!");
RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, actionsList, deprecatedActionsList);
Expand All @@ -481,6 +616,9 @@ public void testHandleRegisterRestActionsRequestWithInvalidUri() throws Exceptio
ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
initialize(extensionsManager);
String uniqueIdStr = "uniqueid1";
extensionsManager.loadExtension(
new Extension("firstExtension", uniqueIdStr, "127.0.0.0", "9300", "0.0.7", "3.0.0", "3.0.0", List.of(), null)
);
List<String> actionsList = List.of("GET", "PUT /bar", "POST /baz");
List<String> deprecatedActionsList = List.of("GET /deprecated/foo", "It's deprecated!");
RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, actionsList, deprecatedActionsList);
Expand All @@ -495,6 +633,9 @@ public void testHandleRegisterRestActionsRequestWithInvalidDeprecatedUri() throw
ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
initialize(extensionsManager);
String uniqueIdStr = "uniqueid1";
extensionsManager.loadExtension(
new Extension("firstExtension", uniqueIdStr, "127.0.0.0", "9300", "0.0.7", "3.0.0", "3.0.0", List.of(), null)
);
List<String> actionsList = List.of("GET /foo", "PUT /bar", "POST /baz");
List<String> deprecatedActionsList = List.of("GET", "It's deprecated!");
RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, actionsList, deprecatedActionsList);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
import org.opensearch.core.indices.breaker.NoneCircuitBreakerService;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.extensions.DiscoveryExtensionNode;
import org.opensearch.extensions.ExtensionsManager;
import org.opensearch.extensions.ExtensionsSettings;
import org.opensearch.extensions.ExtensionsSettings.Extension;
import org.opensearch.identity.IdentityService;
import org.opensearch.rest.RestRequest;
import org.opensearch.telemetry.tracing.noop.NoopTracer;
Expand Down Expand Up @@ -160,8 +161,8 @@ public void testRestInitializeExtensionActionResponseWithAdditionalSettings() th

// optionally, you can stub out some methods:
when(spy.getAdditionalSettings()).thenCallRealMethod();
Mockito.doCallRealMethod().when(spy).loadExtension(any(ExtensionsSettings.Extension.class));
Mockito.doNothing().when(spy).initialize();
Mockito.doCallRealMethod().when(spy).loadExtension(any(Extension.class));
Mockito.doNothing().when(spy).initializeExtensionNode(any(DiscoveryExtensionNode.class));
RestInitializeExtensionAction restInitializeExtensionAction = new RestInitializeExtensionAction(spy);
final String content = "{\"name\":\"ad-extension\",\"uniqueId\":\"ad-extension\",\"hostAddress\":\"127.0.0.1\","
+ "\"port\":\"4532\",\"version\":\"1.0\",\"opensearchVersion\":\""
Expand All @@ -177,10 +178,10 @@ public void testRestInitializeExtensionActionResponseWithAdditionalSettings() th
FakeRestChannel channel = new FakeRestChannel(request, false, 0);
restInitializeExtensionAction.handleRequest(request, channel, null);

assertEquals(channel.capturedResponse().status(), RestStatus.ACCEPTED);
assertEquals(RestStatus.ACCEPTED, channel.capturedResponse().status());
assertTrue(channel.capturedResponse().content().utf8ToString().contains("A request to initialize an extension has been sent."));

Optional<ExtensionsSettings.Extension> extension = spy.lookupExtensionSettingsById("ad-extension");
Optional<Extension> extension = spy.lookupExtensionSettingsById("ad-extension");
assertTrue(extension.isPresent());
assertEquals(true, extension.get().getAdditionalSettings().get(boolSetting));
assertEquals("customSetting", extension.get().getAdditionalSettings().get(stringSetting));
Expand Down Expand Up @@ -210,8 +211,8 @@ public void testRestInitializeExtensionActionResponseWithAdditionalSettingsUsing

// optionally, you can stub out some methods:
when(spy.getAdditionalSettings()).thenCallRealMethod();
Mockito.doCallRealMethod().when(spy).loadExtension(any(ExtensionsSettings.Extension.class));
Mockito.doNothing().when(spy).initialize();
Mockito.doCallRealMethod().when(spy).loadExtension(any(Extension.class));
Mockito.doNothing().when(spy).initializeExtensionNode(any(DiscoveryExtensionNode.class));
RestInitializeExtensionAction restInitializeExtensionAction = new RestInitializeExtensionAction(spy);
final String content = "{\"name\":\"ad-extension\",\"uniqueId\":\"ad-extension\",\"hostAddress\":\"127.0.0.1\","
+ "\"port\":\"4532\",\"version\":\"1.0\",\"opensearchVersion\":\""
Expand All @@ -227,10 +228,10 @@ public void testRestInitializeExtensionActionResponseWithAdditionalSettingsUsing
FakeRestChannel channel = new FakeRestChannel(request, false, 0);
restInitializeExtensionAction.handleRequest(request, channel, null);

assertEquals(channel.capturedResponse().status(), RestStatus.ACCEPTED);
assertEquals(RestStatus.ACCEPTED, channel.capturedResponse().status());
assertTrue(channel.capturedResponse().content().utf8ToString().contains("A request to initialize an extension has been sent."));

Optional<ExtensionsSettings.Extension> extension = spy.lookupExtensionSettingsById("ad-extension");
Optional<Extension> extension = spy.lookupExtensionSettingsById("ad-extension");
assertTrue(extension.isPresent());
assertEquals(false, extension.get().getAdditionalSettings().get(boolSetting));
assertEquals("default", extension.get().getAdditionalSettings().get(stringSetting));
Expand Down
Loading

0 comments on commit bdb5d5b

Please sign in to comment.