diff --git a/server/src/main/java/org/opensearch/action/ActionModule.java b/server/src/main/java/org/opensearch/action/ActionModule.java index 89f185e302f69..1cb96a08999da 100644 --- a/server/src/main/java/org/opensearch/action/ActionModule.java +++ b/server/src/main/java/org/opensearch/action/ActionModule.java @@ -282,6 +282,7 @@ import org.opensearch.common.inject.TypeLiteral; import org.opensearch.common.inject.multibindings.MapBinder; import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.extensions.ExtensionsManager; import org.opensearch.extensions.action.ExtensionAction; import org.opensearch.extensions.action.ExtensionProxyAction; import org.opensearch.extensions.action.ExtensionTransportAction; @@ -442,6 +443,7 @@ import org.opensearch.rest.action.search.RestSearchScrollAction; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; import org.opensearch.usage.UsageService; import java.util.ArrayList; @@ -476,8 +478,8 @@ public class ActionModule extends AbstractModule { private final List actionPlugins; // An unmodifiable map containing OpenSearch and Plugin actions private final Map> actions; - // A dynamic map containing Extension actions - private final DynamicActionRegistry extensionActions; + // A dynamic map combining the above immutable actions with dynamic Extension actions + private final DynamicActionRegistry dynamicActionRegistry; private final ActionFilters actionFilters; private final AutoCreateIndex autoCreateIndex; private final DestructiveOperations destructiveOperations; @@ -507,8 +509,8 @@ public ActionModule( this.actionPlugins = actionPlugins; this.threadPool = threadPool; actions = setupActions(actionPlugins); - extensionActions = FeatureFlags.isEnabled(FeatureFlags.EXTENSIONS) ? new DynamicActionRegistry() : null; actionFilters = setupActionFilters(actionPlugins); + dynamicActionRegistry = new DynamicActionRegistry(); autoCreateIndex = new AutoCreateIndex(settings, clusterSettings, indexNameExpressionResolver, systemIndices); destructiveOperations = new DestructiveOperations(settings, clusterSettings); Set headers = Stream.concat( @@ -731,13 +733,6 @@ public void reg return unmodifiableMap(actions.getRegistry()); } - public DynamicActionRegistry getExtensionActions() { - if (!FeatureFlags.isEnabled(FeatureFlags.EXTENSIONS)) { - throw new UnsupportedOperationException("This method requires enabling the feature flag [" + FeatureFlags.EXTENSIONS + "]."); - } - return extensionActions; - } - private ActionFilters setupActionFilters(List actionPlugins) { return new ActionFilters( Collections.unmodifiableSet(actionPlugins.stream().flatMap(p -> p.getActionFilters().stream()).collect(Collectors.toSet())) @@ -973,7 +968,7 @@ protected void configure() { // register dynamic ActionType -> transportAction Map used by NodeClient if (FeatureFlags.isEnabled(FeatureFlags.EXTENSIONS)) { - bind(DynamicActionRegistry.class).toInstance(extensionActions); + bind(DynamicActionRegistry.class).toInstance(dynamicActionRegistry); } } @@ -981,32 +976,93 @@ public ActionFilters getActionFilters() { return actionFilters; } + @SuppressWarnings("unchecked") + public DynamicActionRegistry getDynamicActionRegistry() { + return dynamicActionRegistry; + } + public RestController getRestController() { return restController; } - public static class DynamicActionRegistry { - private final Map registry = new ConcurrentHashMap<>(); + /** + * The DynamicActionRegistry maintains a registry mapping {@link ExtensionAction} instances to {@link TransportAction} instances. + *

+ * This class is modeled after {@link NamedRegistry} but provides both register and unregister capabilities. + * + * @opensearch.internal + */ + public static class DynamicActionRegistry { + // the immutable map of injected transport actions + private Map actions = Collections.emptyMap(); + // the dynamic map which can be updated over time + private final Map registry = new ConcurrentHashMap<>(); + + private ActionFilters actionFilters; + private ExtensionsManager extensionsManager; + private TransportService transportService; + + /** + * Initialize the immutable actions in the registry. + * + * @param actions The injected map of {@link ActionType} to {@link TransportAction} + * @param actionFilters The action filters + * @param transportService The node's {@link TransportService} + * @param extensionsManager The instance of the {@link ExtensionsManager} + */ + public void initialize( + Map actions, + ActionFilters actionFilters, + TransportService transportService, + ExtensionsManager extensionsManager + ) { + this.actions = actions; + this.actionFilters = actionFilters; + this.transportService = transportService; + this.extensionsManager = extensionsManager; + } + /** + * Add an {@link ExtensionAction} to the registry. + * + * @param extensionAction The action to add + */ public void registerExtensionAction(ExtensionAction extensionAction) { requireNonNull(extensionAction, "extension action is required"); String name = extensionAction.name(); - requireNonNull(name, "name is required"); - if (registry.putIfAbsent(name, extensionAction) != null) { - throw new IllegalArgumentException("extension action for name [" + name + "] already registered"); + String uniqueId = extensionAction.uniqueId(); + if (registry.containsKey(extensionAction)) { + throw new IllegalArgumentException("extension [" + uniqueId + "] action for [" + name + "] already registered"); } + registry.put(extensionAction, new ExtensionTransportAction(name, actionFilters, transportService, extensionsManager)); } - public void unregisterExtensionAction(String name) { - requireNonNull(name, "name is required"); - if (registry.remove(name) == null) { - throw new IllegalArgumentException("extension action for name [" + name + "] was not registered"); + /** + * Remove an {@link ExtensionAction} from the registry. + * + * @param extensionAction The action to remove + */ + public void unregisterExtensionAction(ExtensionAction extensionAction) { + requireNonNull(extensionAction, "extension action is required"); + String name = extensionAction.name(); + String uniqueId = extensionAction.uniqueId(); + if (registry.remove(extensionAction) == null) { + throw new IllegalArgumentException("extension [" + uniqueId + "] action for [" + name + "] was not registered"); } } - public ExtensionAction get(String name) { - requireNonNull(name, "name is required"); - return registry.get(name); + /** + * Gets the {@link TransportAction} instance corresponding to the {@link ActionType} instance. + * + * @param action The {@link ActionType}. May be an {@link ExtensionAction}. + * @return the corresponding {@link TransportAction} if it is registered, null otherwise. + */ + @SuppressWarnings("unchecked") + public TransportAction get(ActionType action) { + if (action instanceof ExtensionAction) { + return (TransportAction) registry.get((ExtensionAction) action); + } + return actions.get(action); } } } diff --git a/server/src/main/java/org/opensearch/client/node/NodeClient.java b/server/src/main/java/org/opensearch/client/node/NodeClient.java index a3be70f5ba5ec..68d17e1af65d5 100644 --- a/server/src/main/java/org/opensearch/client/node/NodeClient.java +++ b/server/src/main/java/org/opensearch/client/node/NodeClient.java @@ -34,7 +34,6 @@ import org.opensearch.action.ActionType; import org.opensearch.action.ActionListener; -import org.opensearch.action.ActionModule; import org.opensearch.action.ActionModule.DynamicActionRegistry; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionResponse; @@ -46,15 +45,12 @@ import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.settings.Settings; import org.opensearch.extensions.ExtensionsManager; -import org.opensearch.extensions.action.ExtensionAction; -import org.opensearch.extensions.action.ExtensionTransportAction; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskListener; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.RemoteClusterService; import org.opensearch.transport.TransportService; -import java.util.Map; import java.util.function.Supplier; /** @@ -64,8 +60,7 @@ */ public class NodeClient extends AbstractClient { - private Map actions; - private DynamicActionRegistry extensionActions; + private DynamicActionRegistry actionRegistry; private ActionFilters actionFilters; /** * The id of the local {@link DiscoveryNode}. Useful for generating task ids from tasks returned by @@ -82,32 +77,17 @@ public NodeClient(Settings settings, ThreadPool threadPool) { } public void initialize( - Map actions, + DynamicActionRegistry actionRegistry, Supplier localNodeId, RemoteClusterService remoteClusterService, NamedWriteableRegistry namedWriteableRegistry ) { - this.actions = actions; + this.actionRegistry = actionRegistry; this.localNodeId = localNodeId; this.remoteClusterService = remoteClusterService; this.namedWriteableRegistry = namedWriteableRegistry; } - public void initialize( - Map actions, - ActionModule actionModule, - TransportService transportService, - ExtensionsManager extensionsManager, - Supplier localNodeId, - NamedWriteableRegistry namedWriteableRegistry - ) { - initialize(actions, localNodeId, transportService.getRemoteClusterService(), namedWriteableRegistry); - this.extensionActions = actionModule.getExtensionActions(); - this.actionFilters = actionModule.getActionFilters(); - this.transportService = transportService; - this.extensionsManager = extensionsManager; - } - @Override public void close() { // nothing really to do @@ -163,23 +143,11 @@ public String getLocalNodeId() { private TransportAction transportAction( ActionType action ) { - if (actions == null) { + if (actionRegistry == null) { throw new IllegalStateException("NodeClient has not been initialized"); } // Get from action map if it exists - TransportAction transportAction = actions.get(action); - // Fallback to dynamic extension action map - if (transportAction == null && extensionActions != null && action instanceof ExtensionAction) { - ExtensionAction extensionAction = extensionActions.get(action.name()); - if (extensionAction != null) { - transportAction = (TransportAction) new ExtensionTransportAction( - action.name(), - transportService, - actionFilters, - extensionsManager - ); - } - } + TransportAction transportAction = actionRegistry.get(action); if (transportAction == null) { throw new IllegalStateException("failed to find action [" + action + "] to execute"); } diff --git a/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java b/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java index dda19718e49d8..df2bf2765840d 100644 --- a/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java +++ b/server/src/main/java/org/opensearch/extensions/ExtensionsManager.java @@ -195,7 +195,6 @@ public void initializeServicesAndRestHandler( ); this.client = client; this.extensionTransportActionsHandler = new ExtensionTransportActionsHandler( - this, extensionIdMap, transportService, client, diff --git a/server/src/main/java/org/opensearch/extensions/action/ExtensionAction.java b/server/src/main/java/org/opensearch/extensions/action/ExtensionAction.java index 023a8d50bcd55..3a6a4eb46d6f2 100644 --- a/server/src/main/java/org/opensearch/extensions/action/ExtensionAction.java +++ b/server/src/main/java/org/opensearch/extensions/action/ExtensionAction.java @@ -10,8 +10,10 @@ import org.opensearch.action.ActionType; +import java.util.Objects; + /** - * An {@link ActionType} to be used in extension action transport handlers. + * An {@link ActionType} to be used in extension action transport handling. * * @opensearch.internal */ @@ -35,7 +37,24 @@ public class ExtensionAction extends ActionType { * * @return the uniqueId */ - public String getUniqueId() { + public String uniqueId() { return this.uniqueId; } + + @Override + public int hashCode() { + final int prime = 31; + int result = super.hashCode(); + result = prime * result + Objects.hash(uniqueId); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (!super.equals(obj)) return false; + if (getClass() != obj.getClass()) return false; + ExtensionAction other = (ExtensionAction) obj; + return Objects.equals(uniqueId, other.uniqueId); + } } diff --git a/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportAction.java b/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportAction.java index 389edf769f39c..1d89aef69c9dc 100644 --- a/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportAction.java +++ b/server/src/main/java/org/opensearch/extensions/action/ExtensionTransportAction.java @@ -26,8 +26,8 @@ public class ExtensionTransportAction extends TransportAction extensionIdMap; private final TransportService transportService; private final NodeClient client; - private final DynamicActionRegistry dynamicActionRegistry; + private final DynamicActionRegistry dynamicActionRegistry; public ExtensionTransportActionsHandler( Map extensionIdMap, @@ -56,7 +58,7 @@ public ExtensionTransportActionsHandler( this.extensionIdMap = extensionIdMap; this.transportService = transportService; this.client = client; - this.dynamicActionRegistry = actionModule.getExtensionActions(); + this.dynamicActionRegistry = actionModule.getDynamicActionRegistry(); } /** @@ -115,10 +117,10 @@ public TransportResponse handleRegisterTransportActionsRequest(RegisterTransport public TransportResponse handleTransportActionRequestFromExtension(TransportActionRequestFromExtension request) throws Exception { String actionName = request.getAction(); String uniqueId = request.getUniqueId(); + ExtensionAction extensionAction = new ExtensionAction(actionName, uniqueId); final TransportActionResponseToExtension response = new TransportActionResponseToExtension(new byte[0]); // Validate that this action has been registered - ExtensionAction extensionAction = dynamicActionRegistry.get(actionName); - if (extensionAction == null) { + if (dynamicActionRegistry.get(extensionAction) == null) { byte[] responseBytes = ("Request failed: action [" + actionName + "] is not registered for extension [" + uniqueId + "].") .getBytes(StandardCharsets.UTF_8); response.setResponseBytes(responseBytes); diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index e94516012760f..6363b4acf63a9 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -58,6 +58,9 @@ import org.opensearch.OpenSearchTimeoutException; import org.opensearch.Version; import org.opensearch.action.ActionModule; +import org.opensearch.action.ActionModule.DynamicActionRegistry; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionResponse; import org.opensearch.action.ActionType; import org.opensearch.action.admin.cluster.snapshots.status.TransportNodesSnapshotsStatus; import org.opensearch.action.search.SearchExecutionStatsCollector; @@ -1112,13 +1115,16 @@ protected Node( resourcesToClose.addAll(pluginLifecycleComponents); resourcesToClose.add(injector.getInstance(PeerRecoverySourceService.class)); this.pluginLifecycleComponents = Collections.unmodifiableList(pluginLifecycleComponents); - if (FeatureFlags.isEnabled(FeatureFlags.EXTENSIONS)) { - client.initialize(injector.getInstance(new Key>() { - }), actionModule, transportService, extensionsManager, () -> clusterService.localNode().getId(), namedWriteableRegistry); - } else { - client.initialize(injector.getInstance(new Key>() { - }), () -> clusterService.localNode().getId(), transportService.getRemoteClusterService(), namedWriteableRegistry); - } + DynamicActionRegistry dynamicActionRegistry = actionModule + .getDynamicActionRegistry(); + dynamicActionRegistry.initialize(injector.getInstance(new Key>() { + }), actionModule.getActionFilters(), transportService, extensionsManager); + client.initialize( + dynamicActionRegistry, + () -> clusterService.localNode().getId(), + transportService.getRemoteClusterService(), + namedWriteableRegistry + ); this.namedWriteableRegistry = namedWriteableRegistry; logger.debug("initializing HTTP handlers ...");