From a4376d84e9638c25cd1d24407e30345e1b3efe4f Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Fri, 24 Mar 2023 12:08:01 -0700 Subject: [PATCH] [Extensions] Add DynamicActionRegistry to ActionModule (#6734) * Add dynamic action registry to ActionModule Signed-off-by: Daniel Widdis * Update registration of transport actions Signed-off-by: Daniel Widdis * Generate transport actions dynamically Signed-off-by: Daniel Widdis * Refactor to combine registry internals Signed-off-by: Daniel Widdis * Finally figured out the generics (or lack thereof) Signed-off-by: Daniel Widdis * ExtensionProxyAction is dead! Long live ExtensionAction! Signed-off-by: Daniel Widdis * Simplify ExtensionTransportActionHandler, fix compile issues Signed-off-by: Daniel Widdis * Maybe tests will pass with this commit Signed-off-by: Daniel Widdis * I guess you can't use null as a key in a map Signed-off-by: Daniel Widdis * Lazy test setup, but this should finally work Signed-off-by: Daniel Widdis * Add Tests Signed-off-by: Daniel Widdis * Fix TransportActionRequestFromExtension inheritance Signed-off-by: Daniel Widdis * Fix return type for transport actions from extensions Signed-off-by: Daniel Widdis * Fix ParametersInWrongOrderError and add some preemptive null handling Signed-off-by: Daniel Widdis * NPE is not expected result if params are in correct order Signed-off-by: Daniel Widdis * Remove redundant class and string parsing, add success boolean Signed-off-by: Daniel Widdis * Last fix of params out of order. Working test case! Signed-off-by: Daniel Widdis * Code worked, tests didn't. This is finally done (I think) Signed-off-by: Daniel Widdis * Add more detail to comments on immutable vs. dynamic maps Signed-off-by: Daniel Widdis * Add StreamInput getter to ExtensionActionResponse Signed-off-by: Daniel Widdis * Generalize dynamic action registration Signed-off-by: Daniel Widdis * Comment and naming fixes Signed-off-by: Daniel Widdis * Register method renaming Signed-off-by: Daniel Widdis * Add generic type parameters Signed-off-by: Daniel Widdis * Improve/simplify which parameter types get passed Signed-off-by: Daniel Widdis * Revert removal of ProxyAction and changes to transport and requests Signed-off-by: Daniel Widdis * Wrap ExtensionTransportResponse in a class denoting success Signed-off-by: Daniel Widdis * Remove generic types as they are incompatible with Guice injection Signed-off-by: Daniel Widdis * Fix response handling, it works (again) Signed-off-by: Daniel Widdis * Fix up comments and remove debug logging Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis Signed-off-by: Valentin Mitrofanov --- .../org/opensearch/action/ActionModule.java | 92 +++++++++- .../opensearch/client/node/NodeClient.java | 12 +- .../extensions/ExtensionsManager.java | 28 ++- .../extensions/action/ExtensionAction.java | 60 +++++++ .../action/ExtensionProxyTransportAction.java | 50 ++++++ .../action/ExtensionTransportAction.java | 28 +-- .../ExtensionTransportActionsHandler.java | 163 ++++++++++++++---- .../RegisterTransportActionsRequest.java | 2 +- .../action/RemoteExtensionActionResponse.java | 116 +++++++++++++ .../TransportActionRequestFromExtension.java | 10 +- .../TransportActionResponseToExtension.java | 58 ------- .../main/java/org/opensearch/node/Node.java | 14 +- .../opensearch/action/ActionModuleTests.java | 75 ++++++++ .../client/node/NodeClientHeadersTests.java | 9 +- .../extensions/ExtensionsManagerTests.java | 9 +- .../RegisterTransportActionsRequestTests.java | 1 + ...ExtensionTransportActionsHandlerTests.java | 32 +++- ...> RemoteExtensionActionResponseTests.java} | 27 ++- .../indices/RestValidateQueryActionTests.java | 7 +- .../snapshots/SnapshotResiliencyTests.java | 5 +- .../test/client/NoOpNodeClient.java | 5 +- 21 files changed, 645 insertions(+), 158 deletions(-) create mode 100644 server/src/main/java/org/opensearch/extensions/action/ExtensionAction.java create mode 100644 server/src/main/java/org/opensearch/extensions/action/ExtensionProxyTransportAction.java rename server/src/main/java/org/opensearch/extensions/{ => action}/RegisterTransportActionsRequest.java (98%) create mode 100644 server/src/main/java/org/opensearch/extensions/action/RemoteExtensionActionResponse.java delete mode 100644 server/src/main/java/org/opensearch/extensions/action/TransportActionResponseToExtension.java rename server/src/test/java/org/opensearch/extensions/action/{TransportActionResponseToExtensionTests.java => RemoteExtensionActionResponseTests.java} (55%) diff --git a/server/src/main/java/org/opensearch/action/ActionModule.java b/server/src/main/java/org/opensearch/action/ActionModule.java index 2cb11a0586c98..202defad539c4 100644 --- a/server/src/main/java/org/opensearch/action/ActionModule.java +++ b/server/src/main/java/org/opensearch/action/ActionModule.java @@ -282,12 +282,12 @@ import org.opensearch.common.inject.TypeLiteral; import org.opensearch.common.inject.multibindings.MapBinder; import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.extensions.action.ExtensionProxyAction; -import org.opensearch.extensions.action.ExtensionTransportAction; import org.opensearch.common.settings.IndexScopedSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.settings.SettingsFilter; import org.opensearch.common.util.FeatureFlags; +import org.opensearch.extensions.action.ExtensionProxyAction; +import org.opensearch.extensions.action.ExtensionProxyTransportAction; import org.opensearch.index.seqno.RetentionLeaseActions; import org.opensearch.indices.SystemIndices; import org.opensearch.indices.breaker.CircuitBreakerService; @@ -448,6 +448,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; import java.util.function.Supplier; import java.util.function.UnaryOperator; @@ -455,6 +456,7 @@ import java.util.stream.Stream; import static java.util.Collections.unmodifiableMap; +import static java.util.Objects.requireNonNull; /** * Builds and binds the generic action map, all {@link TransportAction}s, and {@link ActionFilters}. @@ -471,7 +473,17 @@ public class ActionModule extends AbstractModule { private final ClusterSettings clusterSettings; private final SettingsFilter settingsFilter; private final List actionPlugins; + // The unmodifiable map containing OpenSearch and Plugin actions + // This is initialized at node bootstrap and contains same-JVM actions + // It will be wrapped in the Dynamic Action Registry but otherwise + // remains unchanged from its prior purpose, and registered actions + // will remain accessible. private final Map> actions; + // A dynamic action registry which includes the above immutable actions + // and also registers dynamic actions which may be unregistered. Usually + // associated with remote action execution on extensions, possibly in + // a different JVM and possibly on a different server. + private final DynamicActionRegistry dynamicActionRegistry; private final ActionFilters actionFilters; private final AutoCreateIndex autoCreateIndex; private final DestructiveOperations destructiveOperations; @@ -502,6 +514,7 @@ public ActionModule( this.threadPool = threadPool; actions = setupActions(actionPlugins); actionFilters = setupActionFilters(actionPlugins); + dynamicActionRegistry = new DynamicActionRegistry(); autoCreateIndex = new AutoCreateIndex(settings, clusterSettings, indexNameExpressionResolver, systemIndices); destructiveOperations = new DestructiveOperations(settings, clusterSettings); Set headers = Stream.concat( @@ -711,7 +724,7 @@ public void reg if (FeatureFlags.isEnabled(FeatureFlags.EXTENSIONS)) { // ExtensionProxyAction - actions.register(ExtensionProxyAction.INSTANCE, ExtensionTransportAction.class); + actions.register(ExtensionProxyAction.INSTANCE, ExtensionProxyTransportAction.class); } // Decommission actions @@ -954,13 +967,86 @@ protected void configure() { bind(supportAction).asEagerSingleton(); } } + + // register dynamic ActionType -> transportAction Map used by NodeClient + bind(DynamicActionRegistry.class).toInstance(dynamicActionRegistry); } public ActionFilters getActionFilters() { return actionFilters; } + public DynamicActionRegistry getDynamicActionRegistry() { + return dynamicActionRegistry; + } + public RestController getRestController() { return restController; } + + /** + * The DynamicActionRegistry maintains a registry mapping {@link ActionType} instances to {@link TransportAction} instances. + *

+ * This class is modeled after {@link NamedRegistry} but provides both register and unregister capabilities. + * + * @opensearch.internal + */ + public static class DynamicActionRegistry { + // This is the unmodifiable actions map created during node bootstrap, which + // will continue to link ActionType and TransportAction pairs from core and plugin + // action handler registration. + private Map actions = Collections.emptyMap(); + // A dynamic registry to add or remove ActionType / TransportAction pairs + // at times other than node bootstrap. + private final Map, TransportAction> registry = new ConcurrentHashMap<>(); + + /** + * Register the immutable actions in the registry. + * + * @param actions The injected map of {@link ActionType} to {@link TransportAction} + */ + public void registerUnmodifiableActionMap(Map actions) { + this.actions = actions; + } + + /** + * Add a dynamic action to the registry. + * + * @param action The action instance to add + * @param transportAction The corresponding instance of transportAction to execute + */ + public void registerDynamicAction(ActionType action, TransportAction transportAction) { + requireNonNull(action, "action is required"); + requireNonNull(transportAction, "transportAction is required"); + if (actions.containsKey(action) || registry.putIfAbsent(action, transportAction) != null) { + throw new IllegalArgumentException("action [" + action.name() + "] already registered"); + } + } + + /** + * Remove a dynamic action from the registry. + * + * @param action The action to remove + */ + public void unregisterDynamicAction(ActionType action) { + requireNonNull(action, "action is required"); + if (registry.remove(action) == null) { + throw new IllegalArgumentException("action [" + action.name() + "] was not registered"); + } + } + + /** + * Gets the {@link TransportAction} instance corresponding to the {@link ActionType} instance. + * + * @param action The {@link ActionType}. + * @return the corresponding {@link TransportAction} if it is registered, null otherwise. + */ + @SuppressWarnings("unchecked") + public TransportAction get(ActionType action) { + if (actions.containsKey(action)) { + return actions.get(action); + } + return registry.get(action); + } + } } diff --git a/server/src/main/java/org/opensearch/client/node/NodeClient.java b/server/src/main/java/org/opensearch/client/node/NodeClient.java index 56cb7c406744a..3341bfe326990 100644 --- a/server/src/main/java/org/opensearch/client/node/NodeClient.java +++ b/server/src/main/java/org/opensearch/client/node/NodeClient.java @@ -34,6 +34,7 @@ import org.opensearch.action.ActionType; import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionModule.DynamicActionRegistry; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionResponse; import org.opensearch.action.support.TransportAction; @@ -47,7 +48,6 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.RemoteClusterService; -import java.util.Map; import java.util.function.Supplier; /** @@ -57,7 +57,7 @@ */ public class NodeClient extends AbstractClient { - private Map actions; + private DynamicActionRegistry actionRegistry; /** * The id of the local {@link DiscoveryNode}. Useful for generating task ids from tasks returned by * {@link #executeLocally(ActionType, ActionRequest, TaskListener)}. @@ -71,12 +71,12 @@ public NodeClient(Settings settings, ThreadPool threadPool) { } public void initialize( - Map actions, + DynamicActionRegistry actionRegistry, Supplier localNodeId, RemoteClusterService remoteClusterService, NamedWriteableRegistry namedWriteableRegistry ) { - this.actions = actions; + this.actionRegistry = actionRegistry; this.localNodeId = localNodeId; this.remoteClusterService = remoteClusterService; this.namedWriteableRegistry = namedWriteableRegistry; @@ -137,10 +137,10 @@ public String getLocalNodeId() { private TransportAction transportAction( ActionType action ) { - if (actions == null) { + if (actionRegistry == null) { throw new IllegalStateException("NodeClient has not been initialized"); } - TransportAction transportAction = actions.get(action); + TransportAction transportAction = (TransportAction) actionRegistry.get(action); if (transportAction == null) { throw new IllegalStateException("failed to find action [" + action + "] to execute"); } diff --git a/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java b/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java index 4f852ca944966..ccc1bdb620f31 100644 --- a/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java +++ b/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java @@ -30,6 +30,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.OpenSearchException; import org.opensearch.Version; +import org.opensearch.action.ActionModule; import org.opensearch.action.admin.cluster.state.ClusterStateResponse; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.ClusterSettingsResponse; @@ -47,6 +48,8 @@ import org.opensearch.extensions.action.ExtensionActionRequest; import org.opensearch.extensions.action.ExtensionActionResponse; import org.opensearch.extensions.action.ExtensionTransportActionsHandler; +import org.opensearch.extensions.action.RegisterTransportActionsRequest; +import org.opensearch.extensions.action.RemoteExtensionActionResponse; import org.opensearch.extensions.action.TransportActionRequestFromExtension; import org.opensearch.extensions.rest.RegisterRestActionsRequest; import org.opensearch.extensions.rest.RestActionsRequestHandler; @@ -58,7 +61,6 @@ import org.opensearch.index.IndicesModuleResponse; import org.opensearch.index.shard.IndexEventListener; import org.opensearch.indices.cluster.IndicesClusterStateService; -import org.opensearch.rest.RestController; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.ConnectTransportException; import org.opensearch.transport.TransportException; @@ -89,6 +91,7 @@ public class ExtensionsManager { public static final String REQUEST_OPENSEARCH_PARSE_NAMED_WRITEABLE = "internal:discovery/parsenamedwriteable"; public static final String REQUEST_REST_EXECUTE_ON_EXTENSION_ACTION = "internal:extensions/restexecuteonextensiontaction"; public static final String REQUEST_EXTENSION_HANDLE_TRANSPORT_ACTION = "internal:extensions/handle-transportaction"; + public static final String REQUEST_EXTENSION_HANDLE_REMOTE_TRANSPORT_ACTION = "internal:extensions/handle-remote-transportaction"; public static final String TRANSPORT_ACTION_REQUEST_FROM_EXTENSION = "internal:extensions/request-transportaction-from-extension"; public static final int EXTENSION_REQUEST_WAIT_TIMEOUT = 10; @@ -166,7 +169,7 @@ public ExtensionsManager(Settings settings, Path extensionsPath) throws IOExcept * Initializes the {@link RestActionsRequestHandler}, {@link TransportService}, {@link ClusterService} and environment settings. This is called during Node bootstrap. * Lists/maps of extensions have already been initialized but not yet populated. * - * @param restController The RestController on which to register Rest Actions. + * @param actionModule The ActionModule with the RestController and DynamicActionModule * @param settingsModule The module that binds the provided settings to interface. * @param transportService The Node's transport service. * @param clusterService The Node's cluster service. @@ -174,14 +177,14 @@ public ExtensionsManager(Settings settings, Path extensionsPath) throws IOExcept * @param client The client used to make transport requests */ public void initializeServicesAndRestHandler( - RestController restController, + ActionModule actionModule, SettingsModule settingsModule, TransportService transportService, ClusterService clusterService, Settings initialEnvironmentSettings, NodeClient client ) { - this.restActionsRequestHandler = new RestActionsRequestHandler(restController, extensionIdMap, transportService); + this.restActionsRequestHandler = new RestActionsRequestHandler(actionModule.getRestController(), extensionIdMap, transportService); this.customSettingsRequestHandler = new CustomSettingsRequestHandler(settingsModule); this.transportService = transportService; this.clusterService = clusterService; @@ -192,7 +195,13 @@ public void initializeServicesAndRestHandler( REQUEST_EXTENSION_UPDATE_SETTINGS ); this.client = client; - this.extensionTransportActionsHandler = new ExtensionTransportActionsHandler(extensionIdMap, transportService, client); + this.extensionTransportActionsHandler = new ExtensionTransportActionsHandler( + extensionIdMap, + transportService, + client, + actionModule, + this + ); registerRequestHandler(); } @@ -201,6 +210,15 @@ public void initializeServicesAndRestHandler( * * @param request which was sent by an extension. */ + public RemoteExtensionActionResponse handleRemoteTransportRequest(ExtensionActionRequest request) throws Exception { + return extensionTransportActionsHandler.sendRemoteTransportRequestToExtension(request); + } + + /** + * Handles Transport Request from {@link org.opensearch.extensions.action.ExtensionTransportAction} which was invoked by OpenSearch or a plugin + * + * @param request which was sent by an extension. + */ public ExtensionActionResponse handleTransportRequest(ExtensionActionRequest request) throws Exception { return extensionTransportActionsHandler.sendTransportRequestToExtension(request); } diff --git a/server/src/main/java/org/opensearch/extensions/action/ExtensionAction.java b/server/src/main/java/org/opensearch/extensions/action/ExtensionAction.java new file mode 100644 index 0000000000000..658c114d73c1a --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/action/ExtensionAction.java @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.action; + +import org.opensearch.action.ActionType; + +import java.util.Objects; + +/** + * An {@link ActionType} to be used in extension action transport handling. + * + * @opensearch.internal + */ +public class ExtensionAction extends ActionType { + + private final String uniqueId; + + /** + * Create an instance of this action to register in the dynamic actions map. + * + * @param uniqueId The uniqueId of the extension which will run this action. + * @param name The fully qualified class name of the extension's action to execute. + */ + public ExtensionAction(String uniqueId, String name) { + super(name, RemoteExtensionActionResponse::new); + this.uniqueId = uniqueId; + } + + /** + * Gets the uniqueId of the extension which will run this action. + * + * @return the uniqueId + */ + public String uniqueId() { + return this.uniqueId; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = super.hashCode(); + result = prime * result + Objects.hash(uniqueId); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (!super.equals(obj)) return false; + if (getClass() != obj.getClass()) return false; + ExtensionAction other = (ExtensionAction) obj; + return Objects.equals(uniqueId, other.uniqueId); + } +} diff --git a/server/src/main/java/org/opensearch/extensions/action/ExtensionProxyTransportAction.java b/server/src/main/java/org/opensearch/extensions/action/ExtensionProxyTransportAction.java new file mode 100644 index 0000000000000..364965dc582e6 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/action/ExtensionProxyTransportAction.java @@ -0,0 +1,50 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.action; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.extensions.ExtensionsManager; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +/** + * A proxy transport action used to proxy a transport request from OpenSearch or a plugin to execute on an extension + * + * @opensearch.internal + */ +public class ExtensionProxyTransportAction extends HandledTransportAction { + + private final ExtensionsManager extensionsManager; + + @Inject + public ExtensionProxyTransportAction( + Settings settings, + TransportService transportService, + ActionFilters actionFilters, + ClusterService clusterService, + ExtensionsManager extensionsManager + ) { + super(ExtensionProxyAction.NAME, transportService, actionFilters, ExtensionActionRequest::new); + this.extensionsManager = extensionsManager; + } + + @Override + protected void doExecute(Task task, ExtensionActionRequest request, ActionListener listener) { + try { + listener.onResponse(extensionsManager.handleTransportRequest(request)); + } catch (Exception e) { + listener.onFailure(e); + } + } +} diff --git a/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportAction.java b/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportAction.java index 5976db78002eb..4b0b9725e50ae 100644 --- a/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportAction.java +++ b/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportAction.java @@ -10,44 +10,34 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.inject.Inject; -import org.opensearch.common.settings.Settings; +import org.opensearch.action.support.TransportAction; import org.opensearch.extensions.ExtensionsManager; -import org.opensearch.node.Node; import org.opensearch.tasks.Task; -import org.opensearch.transport.TransportService; +import org.opensearch.tasks.TaskManager; /** - * The main proxy transport action used to proxy a transport request from extension to another extension + * A proxy transport action used to proxy a transport request from an extension to execute on another extension * * @opensearch.internal */ -public class ExtensionTransportAction extends HandledTransportAction { +public class ExtensionTransportAction extends TransportAction { - private final String nodeName; - private final ClusterService clusterService; private final ExtensionsManager extensionsManager; - @Inject public ExtensionTransportAction( - Settings settings, - TransportService transportService, + String actionName, ActionFilters actionFilters, - ClusterService clusterService, + TaskManager taskManager, ExtensionsManager extensionsManager ) { - super(ExtensionProxyAction.NAME, transportService, actionFilters, ExtensionActionRequest::new); - this.nodeName = Node.NODE_NAME_SETTING.get(settings); - this.clusterService = clusterService; + super(actionName, actionFilters, taskManager); this.extensionsManager = extensionsManager; } @Override - protected void doExecute(Task task, ExtensionActionRequest request, ActionListener listener) { + protected void doExecute(Task task, ExtensionActionRequest request, ActionListener listener) { try { - listener.onResponse(extensionsManager.handleTransportRequest(request)); + listener.onResponse(extensionsManager.handleRemoteTransportRequest(request)); } catch (Exception e) { listener.onFailure(e); } diff --git a/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportActionsHandler.java b/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportActionsHandler.java index 1f2b58c2bd524..3fba76b7d3c59 100644 --- a/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportActionsHandler.java +++ b/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportActionsHandler.java @@ -11,12 +11,14 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionModule; +import org.opensearch.action.ActionModule.DynamicActionRegistry; +import org.opensearch.action.support.ActionFilters; import org.opensearch.client.node.NodeClient; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.extensions.DiscoveryExtensionNode; import org.opensearch.extensions.AcknowledgedResponse; import org.opensearch.extensions.ExtensionsManager; -import org.opensearch.extensions.RegisterTransportActionsRequest; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.ActionNotFoundTransportException; import org.opensearch.transport.TransportException; @@ -25,11 +27,10 @@ import org.opensearch.transport.TransportService; import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.HashMap; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -40,44 +41,62 @@ */ public class ExtensionTransportActionsHandler { private static final Logger logger = LogManager.getLogger(ExtensionTransportActionsHandler.class); - private Map actionsMap; + // Map of action name to Extension unique ID, populated locally + private final Map actionToIdMap = new ConcurrentHashMap<>(); + // Map of Extension unique ID to Extension Node, populated in Extensions Manager private final Map extensionIdMap; private final TransportService transportService; private final NodeClient client; + private final ActionFilters actionFilters; + private final DynamicActionRegistry dynamicActionRegistry; + private final ExtensionsManager extensionsManager; public ExtensionTransportActionsHandler( Map extensionIdMap, TransportService transportService, - NodeClient client + NodeClient client, + ActionModule actionModule, + ExtensionsManager extensionsManager ) { - this.actionsMap = new HashMap<>(); this.extensionIdMap = extensionIdMap; this.transportService = transportService; this.client = client; + this.actionFilters = actionModule.getActionFilters(); + this.dynamicActionRegistry = actionModule.getDynamicActionRegistry(); + this.extensionsManager = extensionsManager; } /** * Method to register actions for extensions. * * @param action to be registered. - * @param extension for which action is being registered. + * @param uniqueId id of extension for which action is being registered. * @throws IllegalArgumentException when action being registered already is registered. */ - void registerAction(String action, DiscoveryExtensionNode extension) throws IllegalArgumentException { - if (actionsMap.containsKey(action)) { - throw new IllegalArgumentException("The " + action + " you are trying to register is already registered"); + void registerAction(String action, String uniqueId) throws IllegalArgumentException { + // Register the action in this handler so it knows which extension owns it + if (actionToIdMap.putIfAbsent(action, uniqueId) != null) { + throw new IllegalArgumentException("The action [" + action + "] you are trying to register is already registered"); } - actionsMap.putIfAbsent(action, extension); + // Register the action in the action module's dynamic actions map + dynamicActionRegistry.registerDynamicAction( + new ExtensionAction(uniqueId, action), + new ExtensionTransportAction(action, actionFilters, transportService.getTaskManager(), extensionsManager) + ); } /** * Method to get extension for a given action. * * @param action for which to get the registered extension. - * @return the extension. + * @return the extension or null if not found */ public DiscoveryExtensionNode getExtension(String action) { - return actionsMap.get(action); + String uniqueId = actionToIdMap.get(action); + if (uniqueId == null) { + throw new ActionNotFoundTransportException(action); + } + return extensionIdMap.get(uniqueId); } /** @@ -87,17 +106,12 @@ public DiscoveryExtensionNode getExtension(String action) { * @return A {@link AcknowledgedResponse} indicating success. */ public TransportResponse handleRegisterTransportActionsRequest(RegisterTransportActionsRequest transportActionsRequest) { - /* - * We are proxying the transport Actions through ExtensionProxyAction, so we really dont need to register dynamic actions for now. - */ - logger.debug("Register Transport Actions request recieved {}", transportActionsRequest); - DiscoveryExtensionNode extension = extensionIdMap.get(transportActionsRequest.getUniqueId()); try { for (String action : transportActionsRequest.getTransportActions()) { - registerAction(action, extension); + registerAction(action, transportActionsRequest.getUniqueId()); } } catch (Exception e) { - logger.error("Could not register Transport Action " + e); + logger.error("Could not register Transport Action: " + e.getMessage()); return new AcknowledgedResponse(false); } return new AcknowledgedResponse(true); @@ -110,16 +124,37 @@ public TransportResponse handleRegisterTransportActionsRequest(RegisterTransport * @return {@link TransportResponse} which is sent back to the transport action invoker. * @throws InterruptedException when message transport fails. */ - public TransportResponse handleTransportActionRequestFromExtension(TransportActionRequestFromExtension request) throws Exception { - DiscoveryExtensionNode extension = extensionIdMap.get(request.getUniqueId()); - final CompletableFuture inProgressFuture = new CompletableFuture<>(); - final TransportActionResponseToExtension response = new TransportActionResponseToExtension(new byte[0]); + public RemoteExtensionActionResponse handleTransportActionRequestFromExtension(TransportActionRequestFromExtension request) + throws Exception { + String actionName = request.getAction(); + String uniqueId = actionToIdMap.get(actionName); + final RemoteExtensionActionResponse response = new RemoteExtensionActionResponse(false, new byte[0]); + // Fail fast if uniqueId is null + if (uniqueId == null) { + response.setResponseBytesAsString("Request failed: action [" + actionName + "] is not registered for any extension."); + return response; + } + ExtensionAction extensionAction = new ExtensionAction(uniqueId, actionName); + // Validate that this action has been registered + if (dynamicActionRegistry.get(extensionAction) == null) { + response.setResponseBytesAsString( + "Request failed: action [" + actionName + "] is not registered for extension [" + uniqueId + "]." + ); + return response; + } + DiscoveryExtensionNode extension = extensionIdMap.get(uniqueId); + if (extension == null) { + response.setResponseBytesAsString("Request failed: extension [" + uniqueId + "] can not be reached."); + return response; + } + final CompletableFuture inProgressFuture = new CompletableFuture<>(); client.execute( - ExtensionProxyAction.INSTANCE, + extensionAction, new ExtensionActionRequest(request.getAction(), request.getRequestBytes()), - new ActionListener() { + new ActionListener() { @Override - public void onResponse(ExtensionActionResponse actionResponse) { + public void onResponse(RemoteExtensionActionResponse actionResponse) { + response.setSuccess(actionResponse.isSuccess()); response.setResponseBytes(actionResponse.getResponseBytes()); inProgressFuture.complete(actionResponse); } @@ -127,8 +162,7 @@ public void onResponse(ExtensionActionResponse actionResponse) { @Override public void onFailure(Exception exp) { logger.debug("Transport request failed", exp); - byte[] responseBytes = ("Request failed: " + exp.getMessage()).getBytes(StandardCharsets.UTF_8); - response.setResponseBytes(responseBytes); + response.setResponseBytesAsString("Request failed: " + exp.getMessage()); inProgressFuture.completeExceptionally(exp); } } @@ -158,10 +192,7 @@ public void onFailure(Exception exp) { * @throws InterruptedException when message transport fails. */ public ExtensionActionResponse sendTransportRequestToExtension(ExtensionActionRequest request) throws Exception { - DiscoveryExtensionNode extension = actionsMap.get(request.getAction()); - if (extension == null) { - throw new ActionNotFoundTransportException(request.getAction()); - } + DiscoveryExtensionNode extension = getExtension(request.getAction()); final CompletableFuture inProgressFuture = new CompletableFuture<>(); final ExtensionActionResponse extensionActionResponse = new ExtensionActionResponse(new byte[0]); final TransportResponseHandler extensionActionResponseTransportResponseHandler = @@ -181,8 +212,6 @@ public void handleResponse(ExtensionActionResponse response) { @Override public void handleException(TransportException exp) { logger.debug("Transport request failed", exp); - byte[] responseBytes = ("Request failed: " + exp.getMessage()).getBytes(StandardCharsets.UTF_8); - extensionActionResponse.setResponseBytes(responseBytes); inProgressFuture.completeExceptionally(exp); } @@ -217,4 +246,68 @@ public String executor() { } return extensionActionResponse; } + + /** + * Method to send transport action request from a remote extension to another extension to handle. + * + * @param request to extension to handle transport request. + * @return {@link RemoteExtensionActionResponse} which encapsulates the transport response from the extension and its success. + */ + public RemoteExtensionActionResponse sendRemoteTransportRequestToExtension(ExtensionActionRequest request) { + DiscoveryExtensionNode extension = getExtension(request.getAction()); + final CompletableFuture inProgressFuture = new CompletableFuture<>(); + final RemoteExtensionActionResponse extensionActionResponse = new RemoteExtensionActionResponse(false, new byte[0]); + final TransportResponseHandler extensionActionResponseTransportResponseHandler = + new TransportResponseHandler() { + + @Override + public RemoteExtensionActionResponse read(StreamInput in) throws IOException { + return new RemoteExtensionActionResponse(in); + } + + @Override + public void handleResponse(RemoteExtensionActionResponse response) { + extensionActionResponse.setSuccess(response.isSuccess()); + extensionActionResponse.setResponseBytes(response.getResponseBytes()); + inProgressFuture.complete(response); + } + + @Override + public void handleException(TransportException exp) { + logger.debug("Transport request failed", exp); + extensionActionResponse.setResponseBytesAsString("Request failed: " + exp.getMessage()); + inProgressFuture.completeExceptionally(exp); + } + + @Override + public String executor() { + return ThreadPool.Names.GENERIC; + } + }; + try { + transportService.sendRequest( + extension, + ExtensionsManager.REQUEST_EXTENSION_HANDLE_REMOTE_TRANSPORT_ACTION, + new ExtensionHandleTransportRequest(request.getAction(), request.getRequestBytes()), + extensionActionResponseTransportResponseHandler + ); + } catch (Exception e) { + logger.info("Failed to send transport action to extension " + extension.getName(), e); + } + try { + inProgressFuture.orTimeout(ExtensionsManager.EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS).join(); + } catch (CompletionException e) { + if (e.getCause() instanceof TimeoutException) { + logger.info("No response from extension to request."); + } + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } else if (e.getCause() instanceof Error) { + throw (Error) e.getCause(); + } else { + throw new RuntimeException(e.getCause()); + } + } + return extensionActionResponse; + } } diff --git a/server/src/main/java/org/opensearch/extensions/RegisterTransportActionsRequest.java b/server/src/main/java/org/opensearch/extensions/action/RegisterTransportActionsRequest.java similarity index 98% rename from server/src/main/java/org/opensearch/extensions/RegisterTransportActionsRequest.java rename to server/src/main/java/org/opensearch/extensions/action/RegisterTransportActionsRequest.java index 94b15e2192722..be711ee69dea6 100644 --- a/server/src/main/java/org/opensearch/extensions/RegisterTransportActionsRequest.java +++ b/server/src/main/java/org/opensearch/extensions/action/RegisterTransportActionsRequest.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.extensions; +package org.opensearch.extensions.action; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; diff --git a/server/src/main/java/org/opensearch/extensions/action/RemoteExtensionActionResponse.java b/server/src/main/java/org/opensearch/extensions/action/RemoteExtensionActionResponse.java new file mode 100644 index 0000000000000..adc269a037231 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/action/RemoteExtensionActionResponse.java @@ -0,0 +1,116 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.action; + +import org.opensearch.action.ActionResponse; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +/** + * This class encapsulates the {@link ExtensionActionResponse} from an extension, adding a field denoting success + * + * @opensearch.internal + */ +public class RemoteExtensionActionResponse extends ActionResponse { + /** + * Indicates whether the response was successful. If false, responseBytes will include an error message. + */ + private boolean success; + /** + * responseBytes is the raw bytes being transported between extensions. + */ + private ExtensionActionResponse response; + + /** + * RemoteExtensionActionResponse constructor. + * + * @param success Whether the response was successful. + * @param responseBytes is the raw bytes being transported between extensions. + */ + public RemoteExtensionActionResponse(boolean success, byte[] responseBytes) { + this.success = success; + this.response = new ExtensionActionResponse(responseBytes); + } + + /** + * RemoteExtensionActionResponse constructor from an {@link ExtensionActionResponse}. + * + * @param response an ExtensionActionResponse in which the first byte denotes success or failure + */ + public RemoteExtensionActionResponse(ExtensionActionResponse response) { + byte[] combinedBytes = response.getResponseBytes(); + this.success = combinedBytes[0] != 0; + byte[] responseBytes = new byte[combinedBytes.length - 1]; + System.arraycopy(combinedBytes, 1, responseBytes, 0, responseBytes.length); + this.response = new ExtensionActionResponse(responseBytes); + } + + /** + * RemoteExtensionActionResponse constructor from {@link StreamInput}. + * + * @param in bytes stream input used to de-serialize the message. + * @throws IOException when message de-serialization fails. + */ + public RemoteExtensionActionResponse(StreamInput in) throws IOException { + this.success = in.readBoolean(); + this.response = new ExtensionActionResponse(in); + } + + public boolean isSuccess() { + return success; + } + + public void setSuccess(boolean success) { + this.success = success; + } + + public byte[] getResponseBytes() { + return response.getResponseBytes(); + } + + public void setResponseBytes(byte[] responseBytes) { + this.response = new ExtensionActionResponse(responseBytes); + } + + /** + * Gets the Response bytes as a {@link StreamInput} + * + * @return A StreamInput representation of the response bytes + */ + public StreamInput getResponseBytesAsStream() { + return StreamInput.wrap(response.getResponseBytes()); + } + + /** + * Gets the Response bytes as a UTF-8 string + * + * @return A string representation of the response bytes + */ + public String getResponseBytesAsString() { + return new String(response.getResponseBytes(), StandardCharsets.UTF_8); + } + + /** + * Sets the Response bytes from a UTF-8 string + * + * @param response The response to convert to bytes + */ + public void setResponseBytesAsString(String response) { + this.response = new ExtensionActionResponse(response.getBytes(StandardCharsets.UTF_8)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(success); + response.writeTo(out); + } +} diff --git a/server/src/main/java/org/opensearch/extensions/action/TransportActionRequestFromExtension.java b/server/src/main/java/org/opensearch/extensions/action/TransportActionRequestFromExtension.java index df494297559b3..1f90d3224bb82 100644 --- a/server/src/main/java/org/opensearch/extensions/action/TransportActionRequestFromExtension.java +++ b/server/src/main/java/org/opensearch/extensions/action/TransportActionRequestFromExtension.java @@ -8,9 +8,10 @@ package org.opensearch.extensions.action; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.transport.TransportRequest; import java.io.IOException; import java.util.Objects; @@ -20,7 +21,7 @@ * * @opensearch.api */ -public class TransportActionRequestFromExtension extends TransportRequest { +public class TransportActionRequestFromExtension extends ActionRequest { /** * action is the transport action intended to be invoked which is registered by an extension via {@link ExtensionTransportActionsHandler}. */ @@ -80,6 +81,11 @@ public String getUniqueId() { return this.uniqueId; } + @Override + public ActionRequestValidationException validate() { + return null; + } + @Override public String toString() { return "TransportActionRequestFromExtension{action=" + action + ", requestBytes=" + requestBytes + ", uniqueId=" + uniqueId + "}"; diff --git a/server/src/main/java/org/opensearch/extensions/action/TransportActionResponseToExtension.java b/server/src/main/java/org/opensearch/extensions/action/TransportActionResponseToExtension.java deleted file mode 100644 index 2913402bcd5e1..0000000000000 --- a/server/src/main/java/org/opensearch/extensions/action/TransportActionResponseToExtension.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.extensions.action; - -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.transport.TransportResponse; - -import java.io.IOException; - -/** - * This class encapsulates transport response to extension. - * - * @opensearch.api - */ -public class TransportActionResponseToExtension extends TransportResponse { - /** - * responseBytes is the raw bytes being transported between extensions. - */ - private byte[] responseBytes; - - /** - * TransportActionResponseToExtension constructor. - * - * @param responseBytes is the raw bytes being transported between extensions. - */ - public TransportActionResponseToExtension(byte[] responseBytes) { - this.responseBytes = responseBytes; - } - - /** - * TransportActionResponseToExtension constructor from {@link StreamInput} - * @param in bytes stream input used to de-serialize the message. - * @throws IOException when message de-serialization fails. - */ - public TransportActionResponseToExtension(StreamInput in) throws IOException { - this.responseBytes = in.readByteArray(); - } - - public void setResponseBytes(byte[] responseBytes) { - this.responseBytes = responseBytes; - } - - public byte[] getResponseBytes() { - return responseBytes; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeByteArray(responseBytes); - } -} diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 6b7d810e9f0d9..7fc4412e53d3e 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -58,6 +58,7 @@ import org.opensearch.OpenSearchTimeoutException; import org.opensearch.Version; import org.opensearch.action.ActionModule; +import org.opensearch.action.ActionModule.DynamicActionRegistry; import org.opensearch.action.ActionType; import org.opensearch.action.admin.cluster.snapshots.status.TransportNodesSnapshotsStatus; import org.opensearch.action.search.SearchExecutionStatsCollector; @@ -842,7 +843,7 @@ protected Node( ); if (FeatureFlags.isEnabled(FeatureFlags.EXTENSIONS)) { this.extensionsManager.initializeServicesAndRestHandler( - restController, + actionModule, settingsModule, transportService, clusterService, @@ -1112,8 +1113,15 @@ protected Node( resourcesToClose.addAll(pluginLifecycleComponents); resourcesToClose.add(injector.getInstance(PeerRecoverySourceService.class)); this.pluginLifecycleComponents = Collections.unmodifiableList(pluginLifecycleComponents); - client.initialize(injector.getInstance(new Key>() { - }), () -> clusterService.localNode().getId(), transportService.getRemoteClusterService(), namedWriteableRegistry); + DynamicActionRegistry dynamicActionRegistry = actionModule.getDynamicActionRegistry(); + dynamicActionRegistry.registerUnmodifiableActionMap(injector.getInstance(new Key>() { + })); + client.initialize( + dynamicActionRegistry, + () -> clusterService.localNode().getId(), + transportService.getRemoteClusterService(), + namedWriteableRegistry + ); this.namedWriteableRegistry = namedWriteableRegistry; logger.debug("initializing HTTP handlers ..."); diff --git a/server/src/test/java/org/opensearch/action/ActionModuleTests.java b/server/src/test/java/org/opensearch/action/ActionModuleTests.java index 3193a8d953763..94ebf0fcf8816 100644 --- a/server/src/test/java/org/opensearch/action/ActionModuleTests.java +++ b/server/src/test/java/org/opensearch/action/ActionModuleTests.java @@ -32,6 +32,7 @@ package org.opensearch.action; +import org.opensearch.action.ActionModule.DynamicActionRegistry; import org.opensearch.action.main.MainAction; import org.opensearch.action.main.TransportMainAction; import org.opensearch.action.support.ActionFilters; @@ -39,12 +40,16 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.Writeable.Reader; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.IndexScopedSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.settings.SettingsFilter; import org.opensearch.common.settings.SettingsModule; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.extensions.action.ExtensionAction; +import org.opensearch.extensions.action.ExtensionTransportAction; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ActionPlugin.ActionHandler; @@ -62,7 +67,9 @@ import org.opensearch.usage.UsageService; import java.io.IOException; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.function.Supplier; import static java.util.Collections.emptyList; @@ -262,4 +269,72 @@ public List routes() { threadPool.shutdown(); } } + + public void testDynamicActionRegistry() { + ActionFilters emptyFilters = new ActionFilters(Collections.emptySet()); + Map testMap = Map.of(TestAction.INSTANCE, new TestTransportAction("test-action", emptyFilters, null)); + + DynamicActionRegistry dynamicActionRegistry = new DynamicActionRegistry(); + dynamicActionRegistry.registerUnmodifiableActionMap(testMap); + + // Should contain the immutable map entry + assertNotNull(dynamicActionRegistry.get(TestAction.INSTANCE)); + // Should not contain anything not added + assertNull(dynamicActionRegistry.get(MainAction.INSTANCE)); + + // ExtensionsAction not yet registered + ExtensionAction testExtensionAction = new ExtensionAction("extensionId", "actionName"); + ExtensionTransportAction testExtensionTransportAction = new ExtensionTransportAction("test-action", emptyFilters, null, null); + assertNull(dynamicActionRegistry.get(testExtensionAction)); + + // Register an extension action + // Should insert without problem + try { + dynamicActionRegistry.registerDynamicAction(testExtensionAction, testExtensionTransportAction); + } catch (Exception e) { + fail("Should not have thrown exception registering action: " + e); + } + assertEquals(testExtensionTransportAction, dynamicActionRegistry.get(testExtensionAction)); + + // Should fail inserting twice + IllegalArgumentException ex = assertThrows( + IllegalArgumentException.class, + () -> dynamicActionRegistry.registerDynamicAction(testExtensionAction, testExtensionTransportAction) + ); + assertEquals("action [actionName] already registered", ex.getMessage()); + // Should remove without problem + try { + dynamicActionRegistry.unregisterDynamicAction(testExtensionAction); + } catch (Exception e) { + fail("Should not have thrown exception unregistering action: " + e); + } + // Should have been removed + assertNull(dynamicActionRegistry.get(testExtensionAction)); + + // Should fail removing twice + ex = assertThrows(IllegalArgumentException.class, () -> dynamicActionRegistry.unregisterDynamicAction(testExtensionAction)); + assertEquals("action [actionName] was not registered", ex.getMessage()); + } + + private static final class TestAction extends ActionType { + public static final TestAction INSTANCE = new TestAction(); + + private TestAction() { + super("test-action", new Reader() { + @Override + public ActionResponse read(StreamInput in) throws IOException { + return null; + } + }); + } + }; + + private static final class TestTransportAction extends TransportAction { + protected TestTransportAction(String actionName, ActionFilters actionFilters, TaskManager taskManager) { + super(actionName, actionFilters, taskManager); + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) {} + } } diff --git a/server/src/test/java/org/opensearch/client/node/NodeClientHeadersTests.java b/server/src/test/java/org/opensearch/client/node/NodeClientHeadersTests.java index cb9e3a6a19388..1f63dba4457a9 100644 --- a/server/src/test/java/org/opensearch/client/node/NodeClientHeadersTests.java +++ b/server/src/test/java/org/opensearch/client/node/NodeClientHeadersTests.java @@ -33,6 +33,7 @@ package org.opensearch.client.node; import org.opensearch.action.ActionType; +import org.opensearch.action.ActionModule.DynamicActionRegistry; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; @@ -57,14 +58,16 @@ protected Client buildClient(Settings headersSettings, ActionType[] testedAction Settings settings = HEADER_SETTINGS; Actions actions = new Actions(settings, threadPool, testedActions); NodeClient client = new NodeClient(settings, threadPool); - client.initialize(actions, () -> "test", null, new NamedWriteableRegistry(Collections.emptyList())); + DynamicActionRegistry dynamicActionRegistry = new DynamicActionRegistry(); + dynamicActionRegistry.registerUnmodifiableActionMap(actions); + client.initialize(dynamicActionRegistry, () -> "test", null, new NamedWriteableRegistry(Collections.emptyList())); return client; } private static class Actions extends HashMap { - private Actions(Settings settings, ThreadPool threadPool, ActionType[] actions) { - for (ActionType action : actions) { + private Actions(Settings settings, ThreadPool threadPool, ActionType[] actions) { + for (ActionType action : actions) { put(action, new InternalTransportAction(settings, action.name(), threadPool)); } } diff --git a/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java b/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java index a027b09a42da9..bbc115307e26f 100644 --- a/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java +++ b/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java @@ -15,6 +15,7 @@ import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.mockito.Mockito.times; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.mock; @@ -40,6 +41,7 @@ import org.junit.After; import org.junit.Before; import org.opensearch.Version; +import org.opensearch.action.ActionModule; import org.opensearch.action.admin.cluster.state.ClusterStateResponse; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.ClusterSettingsResponse; @@ -94,6 +96,7 @@ public class ExtensionsManagerTests extends OpenSearchTestCase { private FeatureFlagSetter featureFlagSetter; private TransportService transportService; + private ActionModule actionModule; private RestController restController; private SettingsModule settingsModule; private ClusterService clusterService; @@ -157,6 +160,7 @@ public void setup() throws Exception { null, Collections.emptySet() ); + actionModule = mock(ActionModule.class); restController = new RestController( emptySet(), null, @@ -164,6 +168,7 @@ public void setup() throws Exception { new NoneCircuitBreakerService(), new UsageService() ); + when(actionModule.getRestController()).thenReturn(restController); settingsModule = new SettingsModule(Settings.EMPTY, emptyList(), emptyList(), emptySet()); clusterService = createClusterService(threadPool); @@ -732,7 +737,7 @@ public void testRegisterHandler() throws Exception { ) ); extensionsManager.initializeServicesAndRestHandler( - restController, + actionModule, settingsModule, mockTransportService, clusterService, @@ -812,7 +817,7 @@ private void initialize(ExtensionsManager extensionsManager) { transportService.start(); transportService.acceptIncomingRequests(); extensionsManager.initializeServicesAndRestHandler( - restController, + actionModule, settingsModule, transportService, clusterService, diff --git a/server/src/test/java/org/opensearch/extensions/RegisterTransportActionsRequestTests.java b/server/src/test/java/org/opensearch/extensions/RegisterTransportActionsRequestTests.java index 27f1597e5779f..e819fcd893367 100644 --- a/server/src/test/java/org/opensearch/extensions/RegisterTransportActionsRequestTests.java +++ b/server/src/test/java/org/opensearch/extensions/RegisterTransportActionsRequestTests.java @@ -11,6 +11,7 @@ import org.junit.Before; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.extensions.action.RegisterTransportActionsRequest; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; diff --git a/server/src/test/java/org/opensearch/extensions/action/ExtensionTransportActionsHandlerTests.java b/server/src/test/java/org/opensearch/extensions/action/ExtensionTransportActionsHandlerTests.java index bb6df2f521e71..3fea207cbb700 100644 --- a/server/src/test/java/org/opensearch/extensions/action/ExtensionTransportActionsHandlerTests.java +++ b/server/src/test/java/org/opensearch/extensions/action/ExtensionTransportActionsHandlerTests.java @@ -11,6 +11,9 @@ import org.junit.After; import org.junit.Before; import org.opensearch.Version; +import org.opensearch.action.ActionModule; +import org.opensearch.action.ActionModule.DynamicActionRegistry; +import org.opensearch.action.support.ActionFilters; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.NamedWriteableRegistry; @@ -20,7 +23,6 @@ import org.opensearch.common.util.PageCacheRecycler; import org.opensearch.extensions.DiscoveryExtensionNode; import org.opensearch.extensions.AcknowledgedResponse; -import org.opensearch.extensions.RegisterTransportActionsRequest; import org.opensearch.extensions.rest.RestSendToExtensionActionTests; import org.opensearch.indices.breaker.NoneCircuitBreakerService; import org.opensearch.test.OpenSearchTestCase; @@ -41,10 +43,14 @@ import java.util.Set; import java.util.concurrent.TimeUnit; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; public class ExtensionTransportActionsHandlerTests extends OpenSearchTestCase { + private static final ActionFilters EMPTY_FILTERS = new ActionFilters(Collections.emptySet()); private TransportService transportService; private MockNioTransport transport; private DiscoveryExtensionNode discoveryExtensionNode; @@ -90,10 +96,17 @@ public void setup() throws Exception { Collections.emptyList() ); client = new NoOpNodeClient(this.getTestName()); + ActionModule mockActionModule = mock(ActionModule.class); + DynamicActionRegistry dynamicActionRegistry = new DynamicActionRegistry(); + dynamicActionRegistry.registerUnmodifiableActionMap(Collections.emptyMap()); + when(mockActionModule.getDynamicActionRegistry()).thenReturn(dynamicActionRegistry); + when(mockActionModule.getActionFilters()).thenReturn(EMPTY_FILTERS); extensionTransportActionsHandler = new ExtensionTransportActionsHandler( Map.of("uniqueid1", discoveryExtensionNode), transportService, - client + client, + mockActionModule, + null ); } @@ -108,11 +121,14 @@ public void tearDown() throws Exception { public void testRegisterAction() { String action = "test-action"; - extensionTransportActionsHandler.registerAction(action, discoveryExtensionNode); + extensionTransportActionsHandler.registerAction(action, discoveryExtensionNode.getId()); assertEquals(discoveryExtensionNode, extensionTransportActionsHandler.getExtension(action)); // Test duplicate action registration - expectThrows(IllegalArgumentException.class, () -> extensionTransportActionsHandler.registerAction(action, discoveryExtensionNode)); + expectThrows( + IllegalArgumentException.class, + () -> extensionTransportActionsHandler.registerAction(action, discoveryExtensionNode.getId()) + ); assertEquals(discoveryExtensionNode, extensionTransportActionsHandler.getExtension(action)); } @@ -130,12 +146,14 @@ public void testRegisterTransportActionsRequest() { assertFalse(response.getStatus()); } - public void testTransportActionRequestFromExtension() throws InterruptedException { + public void testTransportActionRequestFromExtension() throws Exception { String action = "test-action"; byte[] requestBytes = "requestBytes".getBytes(StandardCharsets.UTF_8); TransportActionRequestFromExtension request = new TransportActionRequestFromExtension(action, requestBytes, "uniqueid1"); - // NoOpNodeClient returns null as response - expectThrows(NullPointerException.class, () -> extensionTransportActionsHandler.handleTransportActionRequestFromExtension(request)); + RemoteExtensionActionResponse response = extensionTransportActionsHandler.handleTransportActionRequestFromExtension(request); + assertFalse(response.isSuccess()); + String responseString = response.getResponseBytesAsString(); + assertEquals("Request failed: action [test-action] is not registered for any extension.", responseString); } public void testSendTransportRequestToExtension() throws InterruptedException { diff --git a/server/src/test/java/org/opensearch/extensions/action/TransportActionResponseToExtensionTests.java b/server/src/test/java/org/opensearch/extensions/action/RemoteExtensionActionResponseTests.java similarity index 55% rename from server/src/test/java/org/opensearch/extensions/action/TransportActionResponseToExtensionTests.java rename to server/src/test/java/org/opensearch/extensions/action/RemoteExtensionActionResponseTests.java index 070feaa240d98..4ce42450bd577 100644 --- a/server/src/test/java/org/opensearch/extensions/action/TransportActionResponseToExtensionTests.java +++ b/server/src/test/java/org/opensearch/extensions/action/RemoteExtensionActionResponseTests.java @@ -13,31 +13,42 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.test.OpenSearchTestCase; -import java.io.IOException; import java.nio.charset.StandardCharsets; -public class TransportActionResponseToExtensionTests extends OpenSearchTestCase { - public void testTransportActionRequestToExtension() throws IOException { +public class RemoteExtensionActionResponseTests extends OpenSearchTestCase { + + public void testExtensionActionResponse() throws Exception { byte[] expectedResponseBytes = "response-bytes".getBytes(StandardCharsets.UTF_8); - TransportActionResponseToExtension response = new TransportActionResponseToExtension(expectedResponseBytes); + RemoteExtensionActionResponse response = new RemoteExtensionActionResponse(true, expectedResponseBytes); + assertTrue(response.isSuccess()); assertEquals(expectedResponseBytes, response.getResponseBytes()); BytesStreamOutput out = new BytesStreamOutput(); response.writeTo(out); BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes())); - response = new TransportActionResponseToExtension(in); + response = new RemoteExtensionActionResponse(in); + assertTrue(response.isSuccess()); assertArrayEquals(expectedResponseBytes, response.getResponseBytes()); } - public void testSetBytes() { - byte[] expectedResponseBytes = "response-bytes".getBytes(StandardCharsets.UTF_8); + public void testSetters() { + String expectedResponse = "response-bytes"; + byte[] expectedResponseBytes = expectedResponse.getBytes(StandardCharsets.UTF_8); byte[] expectedEmptyBytes = new byte[0]; - TransportActionResponseToExtension response = new TransportActionResponseToExtension(expectedEmptyBytes); + RemoteExtensionActionResponse response = new RemoteExtensionActionResponse(false, expectedEmptyBytes); assertArrayEquals(expectedEmptyBytes, response.getResponseBytes()); + assertFalse(response.isSuccess()); + + response.setResponseBytesAsString(expectedResponse); + assertArrayEquals(expectedResponseBytes, response.getResponseBytes()); response.setResponseBytes(expectedResponseBytes); assertArrayEquals(expectedResponseBytes, response.getResponseBytes()); + assertEquals(expectedResponse, response.getResponseBytesAsString()); + + response.setSuccess(true); + assertTrue(response.isSuccess()); } } diff --git a/server/src/test/java/org/opensearch/rest/action/admin/indices/RestValidateQueryActionTests.java b/server/src/test/java/org/opensearch/rest/action/admin/indices/RestValidateQueryActionTests.java index cc1a9d4fd2e40..6971aa866ccb1 100644 --- a/server/src/test/java/org/opensearch/rest/action/admin/indices/RestValidateQueryActionTests.java +++ b/server/src/test/java/org/opensearch/rest/action/admin/indices/RestValidateQueryActionTests.java @@ -34,6 +34,7 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionType; +import org.opensearch.action.ActionModule.DynamicActionRegistry; import org.opensearch.action.admin.indices.validate.query.ValidateQueryAction; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.TransportAction; @@ -84,7 +85,7 @@ public class RestValidateQueryActionTests extends AbstractSearchTestCase { public static void stubValidateQueryAction() { final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); - final TransportAction transportAction = new TransportAction( + final TransportAction transportAction = new TransportAction<>( ValidateQueryAction.NAME, new ActionFilters(Collections.emptySet()), taskManager @@ -96,7 +97,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listen final Map actions = new HashMap<>(); actions.put(ValidateQueryAction.INSTANCE, transportAction); - client.initialize(actions, () -> "local", null, new NamedWriteableRegistry(Collections.emptyList())); + DynamicActionRegistry dynamicActionRegistry = new DynamicActionRegistry(); + dynamicActionRegistry.registerUnmodifiableActionMap(actions); + client.initialize(dynamicActionRegistry, () -> "local", null, new NamedWriteableRegistry(Collections.emptyList())); controller.registerHandler(action); } diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java index 45d50ea98bc78..efaab9e11d644 100644 --- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java @@ -37,6 +37,7 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.Version; import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionModule.DynamicActionRegistry; import org.opensearch.action.ActionType; import org.opensearch.action.RequestValidators; import org.opensearch.action.StepListener; @@ -2191,8 +2192,10 @@ public void onFailure(final Exception e) { indexNameExpressionResolver ) ); + DynamicActionRegistry dynamicActionRegistry = new DynamicActionRegistry(); + dynamicActionRegistry.registerUnmodifiableActionMap(actions); client.initialize( - actions, + dynamicActionRegistry, () -> clusterService.localNode().getId(), transportService.getRemoteClusterService(), new NamedWriteableRegistry(Collections.emptyList()) diff --git a/test/framework/src/main/java/org/opensearch/test/client/NoOpNodeClient.java b/test/framework/src/main/java/org/opensearch/test/client/NoOpNodeClient.java index 413c78ba37026..1ff7a287b9b30 100644 --- a/test/framework/src/main/java/org/opensearch/test/client/NoOpNodeClient.java +++ b/test/framework/src/main/java/org/opensearch/test/client/NoOpNodeClient.java @@ -34,10 +34,10 @@ import org.opensearch.OpenSearchException; import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionModule.DynamicActionRegistry; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionResponse; import org.opensearch.action.ActionType; -import org.opensearch.action.support.TransportAction; import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; import org.opensearch.common.io.stream.NamedWriteableRegistry; @@ -48,7 +48,6 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.RemoteClusterService; -import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; @@ -86,7 +85,7 @@ public void doE @Override public void initialize( - Map actions, + DynamicActionRegistry dynamicActionRegistry, Supplier localNodeId, RemoteClusterService remoteClusterService, NamedWriteableRegistry namedWriteableRegistry