Skip to content

Commit

Permalink
Always initialize SDKActionModule (#537) (#548)
Browse files Browse the repository at this point in the history
(cherry picked from commit 860c72e)

Signed-off-by: Daniel Widdis <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
1 parent 5dab9f9 commit 980dd82
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 84 deletions.
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/sdk/BaseExtension.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
/**
* An abstract class that simplifies extension initialization and provides an instance of the runner.
*/
public abstract class BaseExtension implements Extension, ActionExtension {
public abstract class BaseExtension implements Extension {
/**
* The {@link ExtensionsRunner} instance running this extension
*/
Expand Down
40 changes: 13 additions & 27 deletions src/main/java/org/opensearch/sdk/ExtensionsRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.opensearch.extensions.rest.ExtensionRestRequest;
import org.opensearch.extensions.rest.RegisterRestActionsRequest;
import org.opensearch.extensions.settings.RegisterCustomSettingsRequest;
import org.opensearch.common.Nullable;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.NamedXContentRegistry;
Expand Down Expand Up @@ -133,8 +132,7 @@ public class ExtensionsRunner {
private final SDKNamedXContentRegistry sdkNamedXContentRegistry;
private final SDKClient sdkClient = new SDKClient();
private final SDKClusterService sdkClusterService = new SDKClusterService(this);
@Nullable
private SDKActionModule sdkActionModule = null;
private final SDKActionModule sdkActionModule;

private ExtensionsInitRequestHandler extensionsInitRequestHandler = new ExtensionsInitRequestHandler(this);
private ExtensionsIndicesModuleRequestHandler extensionsIndicesModuleRequestHandler = new ExtensionsIndicesModuleRequestHandler();
Expand Down Expand Up @@ -191,17 +189,16 @@ protected ExtensionsRunner(Extension extension) throws IOException {
// Bind the return values from create components
modules.add(this::injectComponents);
// Bind actions from getActions
if (extension instanceof ActionExtension) {
this.sdkActionModule = new SDKActionModule((ActionExtension) extension);
modules.add(this.sdkActionModule);
}
this.sdkActionModule = new SDKActionModule(extension);
modules.add(this.sdkActionModule);
// Finally, perform the injection
this.injector = Guice.createInjector(modules);

// Perform other initialization. These should have access to injected classes.
// initialize SDKClient action map
initializeSdkClient();

if (extension instanceof ActionExtension) {
// initialize SDKClient action map
initializeSdkClient();
// store REST handlers in the registry
for (ExtensionRestHandler extensionRestHandler : ((ActionExtension) extension).getExtensionRestHandlers()) {
for (Route route : extensionRestHandler.routes()) {
Expand Down Expand Up @@ -398,25 +395,20 @@ public void startTransportService(TransportService transportService) {
* @return A list of strings matching the interface name.
*/
public List<String> getExtensionImplementedInterfaces() {
Extension extension = getExtension();

Set<Class<?>> interfaceSet = new HashSet<>();
Class<?> extensionClass = extension.getClass();
Class<?> extensionClass = getExtension().getClass();
do {
interfaceSet.addAll(Arrays.stream(extensionClass.getInterfaces()).collect(Collectors.toSet()));
extensionClass = extensionClass.getSuperclass();
} while (extensionClass != null);

List<String> interfacesOfOpenSearch = new ArrayList<String>();
// we are making an assumption here that all the other Interfaces will be in the same package ( or will be in subpackage ) in which
// ActionExtension Interface belongs.
String packageNameOfActionExtension = ActionExtension.class.getPackageName();
for (Class<?> anInterface : interfaceSet) {
if (anInterface.getPackageName().startsWith(packageNameOfActionExtension)) {
interfacesOfOpenSearch.add(anInterface.getSimpleName());
}
}
return interfacesOfOpenSearch;
// Extension Interface belongs.
String extensionInterfacePackageName = Extension.class.getPackageName();
return interfaceSet.stream()
.filter(i -> i.getPackageName().startsWith(extensionInterfacePackageName))
.map(Class::getSimpleName)
.collect(Collectors.toList());
}

/**
Expand Down Expand Up @@ -627,12 +619,6 @@ public SDKClusterService getSdkClusterService() {
return sdkClusterService;
}

/**
* Returns the {@link SDKActionModule} if actions have been registered
*
* @return The SDKActionModule instance if it exists, null otherwise.
*/
@Nullable
public SDKActionModule getSdkActionModule() {
return sdkActionModule;
}
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/org/opensearch/sdk/SDKClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import java.io.Closeable;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;

import com.fasterxml.jackson.annotation.JsonTypeInfo;
Expand Down Expand Up @@ -77,7 +78,7 @@ public class SDKClient implements Closeable {

// Used by client.execute, populated by initialize method
@SuppressWarnings("rawtypes")
private Map<ActionType, TransportAction> actions;
private Map<ActionType, TransportAction> actions = Collections.emptyMap();

/**
* Initialize this client.
Expand Down
42 changes: 25 additions & 17 deletions src/main/java/org/opensearch/sdk/action/SDKActionModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

package org.opensearch.sdk.action;

import java.util.Collections;
import java.util.Map;
import java.util.stream.Collectors;

Expand All @@ -22,6 +23,7 @@
import org.opensearch.extensions.ExtensionsManager;
import org.opensearch.extensions.RegisterTransportActionsRequest;
import org.opensearch.sdk.ActionExtension.ActionHandler;
import org.opensearch.sdk.Extension;
import org.opensearch.sdk.handlers.AcknowledgedResponseHandler;
import org.opensearch.transport.TransportService;

Expand All @@ -46,10 +48,9 @@ public class SDKActionModule extends AbstractModule {
*
* @param extension An instance of {@link ActionExtension}.
*/
public SDKActionModule(ActionExtension extension) {
public SDKActionModule(Extension extension) {
this.actions = setupActions(extension);
this.actionFilters = setupActionFilters(extension);
// TODO: consider moving Rest Handler registration here
}

public Map<String, ActionHandler<?, ?>> getActions() {
Expand All @@ -60,26 +61,33 @@ public ActionFilters getActionFilters() {
return actionFilters;
}

private static Map<String, ActionHandler<?, ?>> setupActions(ActionExtension extension) {
// Subclass NamedRegistry for easy registration
class ActionRegistry extends NamedRegistry<ActionHandler<?, ?>> {
ActionRegistry() {
super("action");
private static Map<String, ActionHandler<?, ?>> setupActions(Extension extension) {
if (extension instanceof ActionExtension) {
// Subclass NamedRegistry for easy registration
class ActionRegistry extends NamedRegistry<ActionHandler<?, ?>> {
ActionRegistry() {
super("action");
}

public void register(ActionHandler<?, ?> handler) {
register(handler.getAction().name(), handler);
}
}
ActionRegistry actions = new ActionRegistry();
// Register getActions in it
((ActionExtension) extension).getActions().stream().forEach(actions::register);

public void register(ActionHandler<?, ?> handler) {
register(handler.getAction().name(), handler);
}
return unmodifiableMap(actions.getRegistry());
}
ActionRegistry actions = new ActionRegistry();
// Register getActions in it
extension.getActions().stream().forEach(actions::register);

return unmodifiableMap(actions.getRegistry());
return Collections.emptyMap();
}

private static ActionFilters setupActionFilters(ActionExtension extension) {
return new ActionFilters(extension.getActionFilters().stream().collect(Collectors.toSet()));
private static ActionFilters setupActionFilters(Extension extension) {
return new ActionFilters(
extension instanceof ActionExtension
? ((ActionExtension) extension).getActionFilters().stream().collect(Collectors.toSet())
: Collections.emptySet()
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,12 @@ public InitializeExtensionResponse handleExtensionInitRequest(InitializeExtensio
extensionTransportService.connectToNode(extensionsRunner.opensearchNode);
extensionsRunner.sendRegisterRestActionsRequest(extensionTransportService);
extensionsRunner.sendRegisterCustomSettingsRequest(extensionTransportService);
if (extensionsRunner.getSdkActionModule() != null) {
extensionsRunner.getSdkActionModule()
.sendRegisterTransportActionsRequest(
extensionTransportService,
extensionsRunner.opensearchNode,
extensionsRunner.getUniqueId()
);
}
extensionsRunner.getSdkActionModule()
.sendRegisterTransportActionsRequest(
extensionTransportService,
extensionsRunner.opensearchNode,
extensionsRunner.getUniqueId()
);
// Get OpenSearch Settings and set values on ExtensionsRunner
Settings settings = extensionsRunner.sendEnvironmentSettingsRequest(extensionTransportService);
extensionsRunner.setEnvironmentSettings(settings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.sdk.ExtensionRestHandler;
import org.opensearch.sdk.ExtensionSettings;
import org.opensearch.sdk.ExtensionsRunner;
import org.opensearch.sdk.ActionExtension;
import org.opensearch.sdk.ActionExtension.ActionHandler;
import org.opensearch.sdk.sample.helloworld.rest.RestHelloAction;
import org.opensearch.sdk.sample.helloworld.transport.SampleAction;
Expand All @@ -34,7 +35,7 @@
* <p>
* To execute, pass an instatiated object of this class to {@link ExtensionsRunner#run(Extension)}.
*/
public class HelloWorldExtension extends BaseExtension {
public class HelloWorldExtension extends BaseExtension implements ActionExtension {

/**
* Optional classpath-relative path to a yml file containing extension settings.
Expand Down
3 changes: 1 addition & 2 deletions src/test/java/org/opensearch/sdk/TestExtensionsRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,8 @@ public void testGettersAndSetters() throws IOException {
@Test
public void testGetExtensionImplementedInterfaces() {
List<String> implementedInterfaces = extensionsRunner.getExtensionImplementedInterfaces();
assertTrue(!implementedInterfaces.isEmpty());
assertFalse(implementedInterfaces.isEmpty());
assertTrue(implementedInterfaces.contains("Extension"));
assertTrue(implementedInterfaces.contains("ActionExtension"));
}

}
13 changes: 11 additions & 2 deletions src/test/java/org/opensearch/sdk/action/TestSDKActionModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.opensearch.extensions.ExtensionsManager;
import org.opensearch.extensions.RegisterTransportActionsRequest;
import org.opensearch.sdk.ActionExtension;
import org.opensearch.sdk.Extension;
import org.opensearch.sdk.ExtensionSettings;
import org.opensearch.sdk.handlers.AcknowledgedResponseHandler;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.transport.Transport;
Expand Down Expand Up @@ -51,7 +53,12 @@ public class TestSDKActionModule extends OpenSearchTestCase {
private TransportService transportService;
private DiscoveryNode opensearchNode;

private SDKActionModule sdkActionModule = new SDKActionModule(new ActionExtension() {
private static class TestActionExtension implements Extension, ActionExtension {
@Override
public ExtensionSettings getExtensionSettings() {
return null;
}

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
@SuppressWarnings("unchecked")
Expand All @@ -60,7 +67,9 @@ public class TestSDKActionModule extends OpenSearchTestCase {

return Arrays.asList(new ActionHandler<ActionRequest, ActionResponse>(testAction, null));
}
});
}

private SDKActionModule sdkActionModule = new SDKActionModule(new TestActionExtension());

@Override
@BeforeEach
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,29 +237,4 @@ public void onFailure(Exception e) {

assertEquals("failed to find action [" + UnregisteredAction.INSTANCE + "] to execute", ex.getMessage());
}

@Test
public void testSdkClientNotInitialized() throws Exception {
String expectedName = "";
SampleRequest request = new SampleRequest(expectedName);
CompletableFuture<SampleResponse> responseFuture = new CompletableFuture<>();
SDKClient uninitializedSdkClient = new SDKClient();

IllegalStateException ex = assertThrows(
IllegalStateException.class,
() -> uninitializedSdkClient.execute(SampleAction.INSTANCE, request, new ActionListener<SampleResponse>() {
@Override
public void onResponse(SampleResponse response) {
responseFuture.complete(response);
}

@Override
public void onFailure(Exception e) {
responseFuture.completeExceptionally(e);
}
})
);

assertEquals("SDKClient was not initialized because the Extension does not implement ActionExtension.", ex.getMessage());
}
}

0 comments on commit 980dd82

Please sign in to comment.