diff --git a/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java b/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java index df2bf2765840d..85317fa4d40ca 100644 --- a/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java +++ b/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java @@ -167,9 +167,8 @@ public ExtensionsManager(Settings settings, Path extensionsPath) throws IOExcept * Initializes the {@link RestActionsRequestHandler}, {@link TransportService}, {@link ClusterService} and environment settings. This is called during Node bootstrap. * Lists/maps of extensions have already been initialized but not yet populated. * - * @param restController The RestController on which to register Rest Actions. + * @param actionModule The ActionModule with the RestController and DynamicActionModule * @param settingsModule The module that binds the provided settings to interface. - * @param actionsModule The module that binds transport actions. * @param transportService The Node's transport service. * @param clusterService The Node's cluster service. * @param initialEnvironmentSettings The finalized view of settings for the Environment diff --git a/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportActionsHandler.java b/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportActionsHandler.java index 4a6fe15d79328..0c0d874cd0f9e 100644 --- a/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportActionsHandler.java +++ b/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportActionsHandler.java @@ -13,8 +13,6 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.ActionModule; import org.opensearch.action.ActionModule.DynamicActionRegistry; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionResponse; import org.opensearch.client.node.NodeClient; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.extensions.DiscoveryExtensionNode; @@ -43,7 +41,9 @@ */ public class ExtensionTransportActionsHandler { private static final Logger logger = LogManager.getLogger(ExtensionTransportActionsHandler.class); - private final Map actionsMap = new ConcurrentHashMap<>(); + // Map of action name to Extension unique ID, populated locally + private final Map actionToIdMap = new ConcurrentHashMap<>(); + // Map of Extension unique ID to Extension Node, populated in Extensions Manager private final Map extensionIdMap; private final TransportService transportService; private final NodeClient client; @@ -65,16 +65,16 @@ public ExtensionTransportActionsHandler( * Method to register actions for extensions. * * @param action to be registered. - * @param extension for which action is being registered. + * @param uniqueId id of extension for which action is being registered. * @throws IllegalArgumentException when action being registered already is registered. */ - void registerAction(String action, DiscoveryExtensionNode extension) throws IllegalArgumentException { + void registerAction(String action, String uniqueId) throws IllegalArgumentException { // Register the action in this handler so it knows which extension owns it - if (actionsMap.putIfAbsent(action, extension) != null) { + if (actionToIdMap.putIfAbsent(action, uniqueId) != null) { throw new IllegalArgumentException("The action [" + action + "] you are trying to register is already registered"); } // Register the action in the action module's extension actions map - dynamicActionRegistry.registerExtensionAction(new ExtensionAction(action, extension.getId())); + dynamicActionRegistry.registerExtensionAction(new ExtensionAction(action, uniqueId)); } /** @@ -84,7 +84,7 @@ void registerAction(String action, DiscoveryExtensionNode extension) throws Ille * @return the extension. */ public DiscoveryExtensionNode getExtension(String action) { - return actionsMap.get(action); + return extensionIdMap.get(actionToIdMap.get(action)); } /** @@ -95,10 +95,9 @@ public DiscoveryExtensionNode getExtension(String action) { */ public TransportResponse handleRegisterTransportActionsRequest(RegisterTransportActionsRequest transportActionsRequest) { logger.debug("Register Transport Actions request recieved {}", transportActionsRequest); - DiscoveryExtensionNode extension = extensionIdMap.get(transportActionsRequest.getUniqueId()); try { for (String action : transportActionsRequest.getTransportActions()) { - registerAction(action, extension); + registerAction(action, transportActionsRequest.getUniqueId()); } } catch (Exception e) { logger.error("Could not register Transport Action " + e); @@ -116,7 +115,7 @@ public TransportResponse handleRegisterTransportActionsRequest(RegisterTransport */ public TransportResponse handleTransportActionRequestFromExtension(TransportActionRequestFromExtension request) throws Exception { String actionName = request.getAction(); - String uniqueId = request.getUniqueId(); + String uniqueId = actionToIdMap.get(actionName); ExtensionAction extensionAction = new ExtensionAction(actionName, uniqueId); final TransportActionResponseToExtension response = new TransportActionResponseToExtension(new byte[0]); // Validate that this action has been registered @@ -177,7 +176,7 @@ public void onFailure(Exception exp) { * @throws InterruptedException when message transport fails. */ public ExtensionActionResponse sendTransportRequestToExtension(ExtensionActionRequest request) throws Exception { - DiscoveryExtensionNode extension = actionsMap.get(request.getAction()); + DiscoveryExtensionNode extension = getExtension(request.getAction()); if (extension == null) { throw new ActionNotFoundTransportException(request.getAction()); } diff --git a/server/src/test/java/org/opensearch/client/node/NodeClientHeadersTests.java b/server/src/test/java/org/opensearch/client/node/NodeClientHeadersTests.java index cb9e3a6a19388..bb4aa6f2ab7a8 100644 --- a/server/src/test/java/org/opensearch/client/node/NodeClientHeadersTests.java +++ b/server/src/test/java/org/opensearch/client/node/NodeClientHeadersTests.java @@ -33,6 +33,7 @@ package org.opensearch.client.node; import org.opensearch.action.ActionType; +import org.opensearch.action.ActionModule.DynamicActionRegistry; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; @@ -57,7 +58,9 @@ protected Client buildClient(Settings headersSettings, ActionType[] testedAction Settings settings = HEADER_SETTINGS; Actions actions = new Actions(settings, threadPool, testedActions); NodeClient client = new NodeClient(settings, threadPool); - client.initialize(actions, () -> "test", null, new NamedWriteableRegistry(Collections.emptyList())); + DynamicActionRegistry dynamicActionRegistry = new DynamicActionRegistry(); + dynamicActionRegistry.initialize(actions, EMPTY_FILTERS, null, null); + client.initialize(dynamicActionRegistry, () -> "test", null, new NamedWriteableRegistry(Collections.emptyList())); return client; } diff --git a/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java b/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java index a027b09a42da9..bbc115307e26f 100644 --- a/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java +++ b/server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java @@ -15,6 +15,7 @@ import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.mockito.Mockito.times; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.mock; @@ -40,6 +41,7 @@ import org.junit.After; import org.junit.Before; import org.opensearch.Version; +import org.opensearch.action.ActionModule; import org.opensearch.action.admin.cluster.state.ClusterStateResponse; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.ClusterSettingsResponse; @@ -94,6 +96,7 @@ public class ExtensionsManagerTests extends OpenSearchTestCase { private FeatureFlagSetter featureFlagSetter; private TransportService transportService; + private ActionModule actionModule; private RestController restController; private SettingsModule settingsModule; private ClusterService clusterService; @@ -157,6 +160,7 @@ public void setup() throws Exception { null, Collections.emptySet() ); + actionModule = mock(ActionModule.class); restController = new RestController( emptySet(), null, @@ -164,6 +168,7 @@ public void setup() throws Exception { new NoneCircuitBreakerService(), new UsageService() ); + when(actionModule.getRestController()).thenReturn(restController); settingsModule = new SettingsModule(Settings.EMPTY, emptyList(), emptyList(), emptySet()); clusterService = createClusterService(threadPool); @@ -732,7 +737,7 @@ public void testRegisterHandler() throws Exception { ) ); extensionsManager.initializeServicesAndRestHandler( - restController, + actionModule, settingsModule, mockTransportService, clusterService, @@ -812,7 +817,7 @@ private void initialize(ExtensionsManager extensionsManager) { transportService.start(); transportService.acceptIncomingRequests(); extensionsManager.initializeServicesAndRestHandler( - restController, + actionModule, settingsModule, transportService, clusterService, diff --git a/server/src/test/java/org/opensearch/extensions/action/ExtensionTransportActionsHandlerTests.java b/server/src/test/java/org/opensearch/extensions/action/ExtensionTransportActionsHandlerTests.java index 2621753e52267..f13e4b569dcf0 100644 --- a/server/src/test/java/org/opensearch/extensions/action/ExtensionTransportActionsHandlerTests.java +++ b/server/src/test/java/org/opensearch/extensions/action/ExtensionTransportActionsHandlerTests.java @@ -11,6 +11,8 @@ import org.junit.After; import org.junit.Before; import org.opensearch.Version; +import org.opensearch.action.ActionModule; +import org.opensearch.action.ActionModule.DynamicActionRegistry; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.NamedWriteableRegistry; @@ -40,6 +42,9 @@ import java.util.Set; import java.util.concurrent.TimeUnit; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; @@ -89,10 +94,14 @@ public void setup() throws Exception { Collections.emptyList() ); client = new NoOpNodeClient(this.getTestName()); + ActionModule mockActionModule = mock(ActionModule.class); + DynamicActionRegistry dynamicActionRegistry = new DynamicActionRegistry(); + when(mockActionModule.getDynamicActionRegistry()).thenReturn(dynamicActionRegistry); extensionTransportActionsHandler = new ExtensionTransportActionsHandler( Map.of("uniqueid1", discoveryExtensionNode), transportService, - client + client, + mockActionModule ); } @@ -107,11 +116,14 @@ public void tearDown() throws Exception { public void testRegisterAction() { String action = "test-action"; - extensionTransportActionsHandler.registerAction(action, discoveryExtensionNode); + extensionTransportActionsHandler.registerAction(action, discoveryExtensionNode.getId()); assertEquals(discoveryExtensionNode, extensionTransportActionsHandler.getExtension(action)); // Test duplicate action registration - expectThrows(IllegalArgumentException.class, () -> extensionTransportActionsHandler.registerAction(action, discoveryExtensionNode)); + expectThrows( + IllegalArgumentException.class, + () -> extensionTransportActionsHandler.registerAction(action, discoveryExtensionNode.getId()) + ); assertEquals(discoveryExtensionNode, extensionTransportActionsHandler.getExtension(action)); } diff --git a/server/src/test/java/org/opensearch/rest/action/admin/indices/RestValidateQueryActionTests.java b/server/src/test/java/org/opensearch/rest/action/admin/indices/RestValidateQueryActionTests.java index cc1a9d4fd2e40..8e27e1860989a 100644 --- a/server/src/test/java/org/opensearch/rest/action/admin/indices/RestValidateQueryActionTests.java +++ b/server/src/test/java/org/opensearch/rest/action/admin/indices/RestValidateQueryActionTests.java @@ -34,6 +34,7 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionType; +import org.opensearch.action.ActionModule.DynamicActionRegistry; import org.opensearch.action.admin.indices.validate.query.ValidateQueryAction; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.TransportAction; @@ -96,7 +97,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listen final Map actions = new HashMap<>(); actions.put(ValidateQueryAction.INSTANCE, transportAction); - client.initialize(actions, () -> "local", null, new NamedWriteableRegistry(Collections.emptyList())); + DynamicActionRegistry dynamicActionRegistry = new DynamicActionRegistry(); + dynamicActionRegistry.initialize(actions, null, null, null); + client.initialize(dynamicActionRegistry, () -> "local", null, new NamedWriteableRegistry(Collections.emptyList())); controller.registerHandler(action); } diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java index 45d50ea98bc78..008d5127ba5a7 100644 --- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java @@ -37,6 +37,7 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.Version; import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionModule.DynamicActionRegistry; import org.opensearch.action.ActionType; import org.opensearch.action.RequestValidators; import org.opensearch.action.StepListener; @@ -2191,8 +2192,10 @@ public void onFailure(final Exception e) { indexNameExpressionResolver ) ); + DynamicActionRegistry dynamicActionRegistry = new DynamicActionRegistry(); + dynamicActionRegistry.initialize(actions, actionFilters, transportService, null); client.initialize( - actions, + dynamicActionRegistry, () -> clusterService.localNode().getId(), transportService.getRemoteClusterService(), new NamedWriteableRegistry(Collections.emptyList()) diff --git a/test/framework/src/main/java/org/opensearch/test/client/NoOpNodeClient.java b/test/framework/src/main/java/org/opensearch/test/client/NoOpNodeClient.java index 413c78ba37026..1ff7a287b9b30 100644 --- a/test/framework/src/main/java/org/opensearch/test/client/NoOpNodeClient.java +++ b/test/framework/src/main/java/org/opensearch/test/client/NoOpNodeClient.java @@ -34,10 +34,10 @@ import org.opensearch.OpenSearchException; import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionModule.DynamicActionRegistry; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionResponse; import org.opensearch.action.ActionType; -import org.opensearch.action.support.TransportAction; import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; import org.opensearch.common.io.stream.NamedWriteableRegistry; @@ -48,7 +48,6 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.RemoteClusterService; -import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; @@ -86,7 +85,7 @@ public void doE @Override public void initialize( - Map actions, + DynamicActionRegistry dynamicActionRegistry, Supplier localNodeId, RemoteClusterService remoteClusterService, NamedWriteableRegistry namedWriteableRegistry