Skip to content

Commit

Permalink
Refactor to combine registry internals
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Mar 17, 2023
1 parent eea1a21 commit 7893511
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 75 deletions.
102 changes: 79 additions & 23 deletions server/src/main/java/org/opensearch/action/ActionModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@
import org.opensearch.common.inject.TypeLiteral;
import org.opensearch.common.inject.multibindings.MapBinder;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.extensions.ExtensionsManager;
import org.opensearch.extensions.action.ExtensionAction;
import org.opensearch.extensions.action.ExtensionProxyAction;
import org.opensearch.extensions.action.ExtensionTransportAction;
Expand Down Expand Up @@ -442,6 +443,7 @@
import org.opensearch.rest.action.search.RestSearchScrollAction;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.opensearch.usage.UsageService;

import java.util.ArrayList;
Expand Down Expand Up @@ -476,8 +478,8 @@ public class ActionModule extends AbstractModule {
private final List<ActionPlugin> actionPlugins;
// An unmodifiable map containing OpenSearch and Plugin actions
private final Map<String, ActionHandler<?, ?>> actions;
// A dynamic map containing Extension actions
private final DynamicActionRegistry extensionActions;
// A dynamic map combining the above immutable actions with dynamic Extension actions
private final DynamicActionRegistry dynamicActionRegistry;
private final ActionFilters actionFilters;
private final AutoCreateIndex autoCreateIndex;
private final DestructiveOperations destructiveOperations;
Expand Down Expand Up @@ -507,8 +509,8 @@ public ActionModule(
this.actionPlugins = actionPlugins;
this.threadPool = threadPool;
actions = setupActions(actionPlugins);
extensionActions = FeatureFlags.isEnabled(FeatureFlags.EXTENSIONS) ? new DynamicActionRegistry() : null;
actionFilters = setupActionFilters(actionPlugins);
dynamicActionRegistry = new DynamicActionRegistry();
autoCreateIndex = new AutoCreateIndex(settings, clusterSettings, indexNameExpressionResolver, systemIndices);
destructiveOperations = new DestructiveOperations(settings, clusterSettings);
Set<RestHeaderDefinition> headers = Stream.concat(
Expand Down Expand Up @@ -731,13 +733,6 @@ public <Request extends ActionRequest, Response extends ActionResponse> void reg
return unmodifiableMap(actions.getRegistry());
}

public DynamicActionRegistry getExtensionActions() {
if (!FeatureFlags.isEnabled(FeatureFlags.EXTENSIONS)) {
throw new UnsupportedOperationException("This method requires enabling the feature flag [" + FeatureFlags.EXTENSIONS + "].");
}
return extensionActions;
}

private ActionFilters setupActionFilters(List<ActionPlugin> actionPlugins) {
return new ActionFilters(
Collections.unmodifiableSet(actionPlugins.stream().flatMap(p -> p.getActionFilters().stream()).collect(Collectors.toSet()))
Expand Down Expand Up @@ -973,40 +968,101 @@ protected void configure() {

// register dynamic ActionType -> transportAction Map used by NodeClient
if (FeatureFlags.isEnabled(FeatureFlags.EXTENSIONS)) {
bind(DynamicActionRegistry.class).toInstance(extensionActions);
bind(DynamicActionRegistry.class).toInstance(dynamicActionRegistry);
}
}

public ActionFilters getActionFilters() {
return actionFilters;
}

@SuppressWarnings("unchecked")
public DynamicActionRegistry<? extends ActionRequest, ? extends ActionResponse> getDynamicActionRegistry() {
return dynamicActionRegistry;
}

public RestController getRestController() {
return restController;
}

public static class DynamicActionRegistry {
private final Map<String, ExtensionAction> registry = new ConcurrentHashMap<>();
/**
* The DynamicActionRegistry maintains a registry mapping {@link ExtensionAction} instances to {@link TransportAction} instances.
* <p>
* This class is modeled after {@link NamedRegistry} but provides both register and unregister capabilities.
*
* @opensearch.internal
*/
public static class DynamicActionRegistry<Request extends ActionRequest, Response extends ActionResponse> {
// the immutable map of injected transport actions
private Map<ActionType, TransportAction> actions = Collections.emptyMap();
// the dynamic map which can be updated over time
private final Map<ExtensionAction, ExtensionTransportAction> registry = new ConcurrentHashMap<>();

private ActionFilters actionFilters;
private ExtensionsManager extensionsManager;
private TransportService transportService;

/**
* Initialize the immutable actions in the registry.
*
* @param actions The injected map of {@link ActionType} to {@link TransportAction}
* @param actionFilters The action filters
* @param transportService The node's {@link TransportService}
* @param extensionsManager The instance of the {@link ExtensionsManager}
*/
public void initialize(
Map<ActionType, TransportAction> actions,
ActionFilters actionFilters,
TransportService transportService,
ExtensionsManager extensionsManager
) {
this.actions = actions;
this.actionFilters = actionFilters;
this.transportService = transportService;
this.extensionsManager = extensionsManager;
}

/**
* Add an {@link ExtensionAction} to the registry.
*
* @param extensionAction The action to add
*/
public void registerExtensionAction(ExtensionAction extensionAction) {
requireNonNull(extensionAction, "extension action is required");
String name = extensionAction.name();
requireNonNull(name, "name is required");
if (registry.putIfAbsent(name, extensionAction) != null) {
throw new IllegalArgumentException("extension action for name [" + name + "] already registered");
String uniqueId = extensionAction.uniqueId();
if (registry.containsKey(extensionAction)) {
throw new IllegalArgumentException("extension [" + uniqueId + "] action for [" + name + "] already registered");
}
registry.put(extensionAction, new ExtensionTransportAction(name, actionFilters, transportService, extensionsManager));
}

public void unregisterExtensionAction(String name) {
requireNonNull(name, "name is required");
if (registry.remove(name) == null) {
throw new IllegalArgumentException("extension action for name [" + name + "] was not registered");
/**
* Remove an {@link ExtensionAction} from the registry.
*
* @param extensionAction The action to remove
*/
public void unregisterExtensionAction(ExtensionAction extensionAction) {
requireNonNull(extensionAction, "extension action is required");
String name = extensionAction.name();
String uniqueId = extensionAction.uniqueId();
if (registry.remove(extensionAction) == null) {
throw new IllegalArgumentException("extension [" + uniqueId + "] action for [" + name + "] was not registered");
}
}

public ExtensionAction get(String name) {
requireNonNull(name, "name is required");
return registry.get(name);
/**
* Gets the {@link TransportAction} instance corresponding to the {@link ActionType} instance.
*
* @param action The {@link ActionType}. May be an {@link ExtensionAction}.
* @return the corresponding {@link TransportAction} if it is registered, null otherwise.
*/
@SuppressWarnings("unchecked")
public TransportAction<Request, Response> get(ActionType<?> action) {
if (action instanceof ExtensionAction) {
return (TransportAction<Request, Response>) registry.get((ExtensionAction) action);
}
return actions.get(action);
}
}
}
42 changes: 5 additions & 37 deletions server/src/main/java/org/opensearch/client/node/NodeClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

import org.opensearch.action.ActionType;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionModule;
import org.opensearch.action.ActionModule.DynamicActionRegistry;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionResponse;
Expand All @@ -46,15 +45,12 @@
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.settings.Settings;
import org.opensearch.extensions.ExtensionsManager;
import org.opensearch.extensions.action.ExtensionAction;
import org.opensearch.extensions.action.ExtensionTransportAction;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskListener;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.RemoteClusterService;
import org.opensearch.transport.TransportService;

import java.util.Map;
import java.util.function.Supplier;

/**
Expand All @@ -64,8 +60,7 @@
*/
public class NodeClient extends AbstractClient {

private Map<ActionType, TransportAction> actions;
private DynamicActionRegistry extensionActions;
private DynamicActionRegistry actionRegistry;
private ActionFilters actionFilters;
/**
* The id of the local {@link DiscoveryNode}. Useful for generating task ids from tasks returned by
Expand All @@ -82,32 +77,17 @@ public NodeClient(Settings settings, ThreadPool threadPool) {
}

public void initialize(
Map<ActionType, TransportAction> actions,
DynamicActionRegistry actionRegistry,
Supplier<String> localNodeId,
RemoteClusterService remoteClusterService,
NamedWriteableRegistry namedWriteableRegistry
) {
this.actions = actions;
this.actionRegistry = actionRegistry;
this.localNodeId = localNodeId;
this.remoteClusterService = remoteClusterService;
this.namedWriteableRegistry = namedWriteableRegistry;
}

public void initialize(
Map<ActionType, TransportAction> actions,
ActionModule actionModule,
TransportService transportService,
ExtensionsManager extensionsManager,
Supplier<String> localNodeId,
NamedWriteableRegistry namedWriteableRegistry
) {
initialize(actions, localNodeId, transportService.getRemoteClusterService(), namedWriteableRegistry);
this.extensionActions = actionModule.getExtensionActions();
this.actionFilters = actionModule.getActionFilters();
this.transportService = transportService;
this.extensionsManager = extensionsManager;
}

@Override
public void close() {
// nothing really to do
Expand Down Expand Up @@ -163,23 +143,11 @@ public String getLocalNodeId() {
private <Request extends ActionRequest, Response extends ActionResponse> TransportAction<Request, Response> transportAction(
ActionType<Response> action
) {
if (actions == null) {
if (actionRegistry == null) {
throw new IllegalStateException("NodeClient has not been initialized");
}
// Get from action map if it exists
TransportAction<Request, Response> transportAction = actions.get(action);
// Fallback to dynamic extension action map
if (transportAction == null && extensionActions != null && action instanceof ExtensionAction) {
ExtensionAction extensionAction = extensionActions.get(action.name());
if (extensionAction != null) {
transportAction = (TransportAction<Request, Response>) new ExtensionTransportAction(
action.name(),
transportService,
actionFilters,
extensionsManager
);
}
}
TransportAction<Request, Response> transportAction = actionRegistry.get(action);
if (transportAction == null) {
throw new IllegalStateException("failed to find action [" + action + "] to execute");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ public void initializeServicesAndRestHandler(
);
this.client = client;
this.extensionTransportActionsHandler = new ExtensionTransportActionsHandler(
this,
extensionIdMap,
transportService,
client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

import org.opensearch.action.ActionType;

import java.util.Objects;

/**
* An {@link ActionType} to be used in extension action transport handlers.
* An {@link ActionType} to be used in extension action transport handling.
*
* @opensearch.internal
*/
Expand All @@ -35,7 +37,24 @@ public class ExtensionAction extends ActionType<ExtensionActionResponse> {
*
* @return the uniqueId
*/
public String getUniqueId() {
public String uniqueId() {
return this.uniqueId;
}

@Override
public int hashCode() {
final int prime = 31;
int result = super.hashCode();
result = prime * result + Objects.hash(uniqueId);
return result;
}

@Override
public boolean equals(Object obj) {
if (this == obj) return true;
if (!super.equals(obj)) return false;
if (getClass() != obj.getClass()) return false;
ExtensionAction other = (ExtensionAction) obj;
return Objects.equals(uniqueId, other.uniqueId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ public class ExtensionTransportAction extends TransportAction<ExtensionActionReq

public ExtensionTransportAction(
String actionName,
TransportService transportService,
ActionFilters actionFilters,
TransportService transportService,
ExtensionsManager extensionsManager
) {
super(actionName, actionFilters, transportService.getTaskManager());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionModule;
import org.opensearch.action.ActionModule.DynamicActionRegistry;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionResponse;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.extensions.DiscoveryExtensionNode;
Expand Down Expand Up @@ -45,7 +47,7 @@ public class ExtensionTransportActionsHandler {
private final Map<String, DiscoveryExtensionNode> extensionIdMap;
private final TransportService transportService;
private final NodeClient client;
private final DynamicActionRegistry dynamicActionRegistry;
private final DynamicActionRegistry<? extends ActionRequest, ? extends ActionResponse> dynamicActionRegistry;

public ExtensionTransportActionsHandler(
Map<String, DiscoveryExtensionNode> extensionIdMap,
Expand All @@ -56,7 +58,7 @@ public ExtensionTransportActionsHandler(
this.extensionIdMap = extensionIdMap;
this.transportService = transportService;
this.client = client;
this.dynamicActionRegistry = actionModule.getExtensionActions();
this.dynamicActionRegistry = actionModule.getDynamicActionRegistry();
}

/**
Expand Down Expand Up @@ -115,10 +117,10 @@ public TransportResponse handleRegisterTransportActionsRequest(RegisterTransport
public TransportResponse handleTransportActionRequestFromExtension(TransportActionRequestFromExtension request) throws Exception {
String actionName = request.getAction();
String uniqueId = request.getUniqueId();
ExtensionAction extensionAction = new ExtensionAction(actionName, uniqueId);
final TransportActionResponseToExtension response = new TransportActionResponseToExtension(new byte[0]);
// Validate that this action has been registered
ExtensionAction extensionAction = dynamicActionRegistry.get(actionName);
if (extensionAction == null) {
if (dynamicActionRegistry.get(extensionAction) == null) {
byte[] responseBytes = ("Request failed: action [" + actionName + "] is not registered for extension [" + uniqueId + "].")
.getBytes(StandardCharsets.UTF_8);
response.setResponseBytes(responseBytes);
Expand Down
20 changes: 13 additions & 7 deletions server/src/main/java/org/opensearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
import org.opensearch.OpenSearchTimeoutException;
import org.opensearch.Version;
import org.opensearch.action.ActionModule;
import org.opensearch.action.ActionModule.DynamicActionRegistry;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionResponse;
import org.opensearch.action.ActionType;
import org.opensearch.action.admin.cluster.snapshots.status.TransportNodesSnapshotsStatus;
import org.opensearch.action.search.SearchExecutionStatsCollector;
Expand Down Expand Up @@ -1112,13 +1115,16 @@ protected Node(
resourcesToClose.addAll(pluginLifecycleComponents);
resourcesToClose.add(injector.getInstance(PeerRecoverySourceService.class));
this.pluginLifecycleComponents = Collections.unmodifiableList(pluginLifecycleComponents);
if (FeatureFlags.isEnabled(FeatureFlags.EXTENSIONS)) {
client.initialize(injector.getInstance(new Key<Map<ActionType, TransportAction>>() {
}), actionModule, transportService, extensionsManager, () -> clusterService.localNode().getId(), namedWriteableRegistry);
} else {
client.initialize(injector.getInstance(new Key<Map<ActionType, TransportAction>>() {
}), () -> clusterService.localNode().getId(), transportService.getRemoteClusterService(), namedWriteableRegistry);
}
DynamicActionRegistry<? extends ActionRequest, ? extends ActionResponse> dynamicActionRegistry = actionModule
.getDynamicActionRegistry();
dynamicActionRegistry.initialize(injector.getInstance(new Key<Map<ActionType, TransportAction>>() {
}), actionModule.getActionFilters(), transportService, extensionsManager);
client.initialize(
dynamicActionRegistry,
() -> clusterService.localNode().getId(),
transportService.getRemoteClusterService(),
namedWriteableRegistry
);
this.namedWriteableRegistry = namedWriteableRegistry;

logger.debug("initializing HTTP handlers ...");
Expand Down

0 comments on commit 7893511

Please sign in to comment.