Skip to content

Commit

Permalink
Simplify ExtensionTransportActionHandler, fix compile issues
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Mar 17, 2023
1 parent 7c8061e commit b22dcad
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,8 @@ public ExtensionsManager(Settings settings, Path extensionsPath) throws IOExcept
* Initializes the {@link RestActionsRequestHandler}, {@link TransportService}, {@link ClusterService} and environment settings. This is called during Node bootstrap.
* Lists/maps of extensions have already been initialized but not yet populated.
*
* @param restController The RestController on which to register Rest Actions.
* @param actionModule The ActionModule with the RestController and DynamicActionModule
* @param settingsModule The module that binds the provided settings to interface.
* @param actionsModule The module that binds transport actions.
* @param transportService The Node's transport service.
* @param clusterService The Node's cluster service.
* @param initialEnvironmentSettings The finalized view of settings for the Environment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionModule;
import org.opensearch.action.ActionModule.DynamicActionRegistry;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionResponse;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.extensions.DiscoveryExtensionNode;
Expand Down Expand Up @@ -43,7 +41,9 @@
*/
public class ExtensionTransportActionsHandler {
private static final Logger logger = LogManager.getLogger(ExtensionTransportActionsHandler.class);
private final Map<String, DiscoveryExtensionNode> actionsMap = new ConcurrentHashMap<>();
// Map of action name to Extension unique ID, populated locally
private final Map<String, String> actionToIdMap = new ConcurrentHashMap<>();
// Map of Extension unique ID to Extension Node, populated in Extensions Manager
private final Map<String, DiscoveryExtensionNode> extensionIdMap;
private final TransportService transportService;
private final NodeClient client;
Expand All @@ -65,16 +65,16 @@ public ExtensionTransportActionsHandler(
* Method to register actions for extensions.
*
* @param action to be registered.
* @param extension for which action is being registered.
* @param uniqueId id of extension for which action is being registered.
* @throws IllegalArgumentException when action being registered already is registered.
*/
void registerAction(String action, DiscoveryExtensionNode extension) throws IllegalArgumentException {
void registerAction(String action, String uniqueId) throws IllegalArgumentException {
// Register the action in this handler so it knows which extension owns it
if (actionsMap.putIfAbsent(action, extension) != null) {
if (actionToIdMap.putIfAbsent(action, uniqueId) != null) {
throw new IllegalArgumentException("The action [" + action + "] you are trying to register is already registered");
}
// Register the action in the action module's extension actions map
dynamicActionRegistry.registerExtensionAction(new ExtensionAction(action, extension.getId()));
dynamicActionRegistry.registerExtensionAction(new ExtensionAction(action, uniqueId));
}

/**
Expand All @@ -84,7 +84,7 @@ void registerAction(String action, DiscoveryExtensionNode extension) throws Ille
* @return the extension.
*/
public DiscoveryExtensionNode getExtension(String action) {
return actionsMap.get(action);
return extensionIdMap.get(actionToIdMap.get(action));
}

/**
Expand All @@ -95,10 +95,9 @@ public DiscoveryExtensionNode getExtension(String action) {
*/
public TransportResponse handleRegisterTransportActionsRequest(RegisterTransportActionsRequest transportActionsRequest) {
logger.debug("Register Transport Actions request recieved {}", transportActionsRequest);
DiscoveryExtensionNode extension = extensionIdMap.get(transportActionsRequest.getUniqueId());
try {
for (String action : transportActionsRequest.getTransportActions()) {
registerAction(action, extension);
registerAction(action, transportActionsRequest.getUniqueId());
}
} catch (Exception e) {
logger.error("Could not register Transport Action " + e);
Expand All @@ -116,7 +115,7 @@ public TransportResponse handleRegisterTransportActionsRequest(RegisterTransport
*/
public TransportResponse handleTransportActionRequestFromExtension(TransportActionRequestFromExtension request) throws Exception {
String actionName = request.getAction();
String uniqueId = request.getUniqueId();
String uniqueId = actionToIdMap.get(actionName);
ExtensionAction extensionAction = new ExtensionAction(actionName, uniqueId);
final TransportActionResponseToExtension response = new TransportActionResponseToExtension(new byte[0]);
// Validate that this action has been registered
Expand Down Expand Up @@ -177,7 +176,7 @@ public void onFailure(Exception exp) {
* @throws InterruptedException when message transport fails.
*/
public ExtensionActionResponse sendTransportRequestToExtension(ExtensionActionRequest request) throws Exception {
DiscoveryExtensionNode extension = actionsMap.get(request.getAction());
DiscoveryExtensionNode extension = getExtension(request.getAction());
if (extension == null) {
throw new ActionNotFoundTransportException(request.getAction());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
package org.opensearch.client.node;

import org.opensearch.action.ActionType;
import org.opensearch.action.ActionModule.DynamicActionRegistry;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
Expand All @@ -57,7 +58,9 @@ protected Client buildClient(Settings headersSettings, ActionType[] testedAction
Settings settings = HEADER_SETTINGS;
Actions actions = new Actions(settings, threadPool, testedActions);
NodeClient client = new NodeClient(settings, threadPool);
client.initialize(actions, () -> "test", null, new NamedWriteableRegistry(Collections.emptyList()));
DynamicActionRegistry dynamicActionRegistry = new DynamicActionRegistry();
dynamicActionRegistry.initialize(actions, EMPTY_FILTERS, null, null);
client.initialize(dynamicActionRegistry, () -> "test", null, new NamedWriteableRegistry(Collections.emptyList()));
return client;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.mock;
Expand All @@ -40,6 +41,7 @@
import org.junit.After;
import org.junit.Before;
import org.opensearch.Version;
import org.opensearch.action.ActionModule;
import org.opensearch.action.admin.cluster.state.ClusterStateResponse;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.ClusterSettingsResponse;
Expand Down Expand Up @@ -94,6 +96,7 @@ public class ExtensionsManagerTests extends OpenSearchTestCase {

private FeatureFlagSetter featureFlagSetter;
private TransportService transportService;
private ActionModule actionModule;
private RestController restController;
private SettingsModule settingsModule;
private ClusterService clusterService;
Expand Down Expand Up @@ -157,13 +160,15 @@ public void setup() throws Exception {
null,
Collections.emptySet()
);
actionModule = mock(ActionModule.class);
restController = new RestController(
emptySet(),
null,
new NodeClient(Settings.EMPTY, threadPool),
new NoneCircuitBreakerService(),
new UsageService()
);
when(actionModule.getRestController()).thenReturn(restController);
settingsModule = new SettingsModule(Settings.EMPTY, emptyList(), emptyList(), emptySet());
clusterService = createClusterService(threadPool);

Expand Down Expand Up @@ -732,7 +737,7 @@ public void testRegisterHandler() throws Exception {
)
);
extensionsManager.initializeServicesAndRestHandler(
restController,
actionModule,
settingsModule,
mockTransportService,
clusterService,
Expand Down Expand Up @@ -812,7 +817,7 @@ private void initialize(ExtensionsManager extensionsManager) {
transportService.start();
transportService.acceptIncomingRequests();
extensionsManager.initializeServicesAndRestHandler(
restController,
actionModule,
settingsModule,
transportService,
clusterService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import org.junit.After;
import org.junit.Before;
import org.opensearch.Version;
import org.opensearch.action.ActionModule;
import org.opensearch.action.ActionModule.DynamicActionRegistry;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
Expand Down Expand Up @@ -40,6 +42,9 @@
import java.util.Set;
import java.util.concurrent.TimeUnit;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;

Expand Down Expand Up @@ -89,10 +94,14 @@ public void setup() throws Exception {
Collections.emptyList()
);
client = new NoOpNodeClient(this.getTestName());
ActionModule mockActionModule = mock(ActionModule.class);
DynamicActionRegistry dynamicActionRegistry = new DynamicActionRegistry();
when(mockActionModule.getDynamicActionRegistry()).thenReturn(dynamicActionRegistry);
extensionTransportActionsHandler = new ExtensionTransportActionsHandler(
Map.of("uniqueid1", discoveryExtensionNode),
transportService,
client
client,
mockActionModule
);
}

Expand All @@ -107,11 +116,14 @@ public void tearDown() throws Exception {

public void testRegisterAction() {
String action = "test-action";
extensionTransportActionsHandler.registerAction(action, discoveryExtensionNode);
extensionTransportActionsHandler.registerAction(action, discoveryExtensionNode.getId());
assertEquals(discoveryExtensionNode, extensionTransportActionsHandler.getExtension(action));

// Test duplicate action registration
expectThrows(IllegalArgumentException.class, () -> extensionTransportActionsHandler.registerAction(action, discoveryExtensionNode));
expectThrows(
IllegalArgumentException.class,
() -> extensionTransportActionsHandler.registerAction(action, discoveryExtensionNode.getId())
);
assertEquals(discoveryExtensionNode, extensionTransportActionsHandler.getExtension(action));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.ActionModule.DynamicActionRegistry;
import org.opensearch.action.admin.indices.validate.query.ValidateQueryAction;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.TransportAction;
Expand Down Expand Up @@ -96,7 +97,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listen
final Map<ActionType, TransportAction> actions = new HashMap<>();
actions.put(ValidateQueryAction.INSTANCE, transportAction);

client.initialize(actions, () -> "local", null, new NamedWriteableRegistry(Collections.emptyList()));
DynamicActionRegistry dynamicActionRegistry = new DynamicActionRegistry();
dynamicActionRegistry.initialize(actions, null, null, null);
client.initialize(dynamicActionRegistry, () -> "local", null, new NamedWriteableRegistry(Collections.emptyList()));
controller.registerHandler(action);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.opensearch.ExceptionsHelper;
import org.opensearch.Version;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionModule.DynamicActionRegistry;
import org.opensearch.action.ActionType;
import org.opensearch.action.RequestValidators;
import org.opensearch.action.StepListener;
Expand Down Expand Up @@ -2191,8 +2192,10 @@ public void onFailure(final Exception e) {
indexNameExpressionResolver
)
);
DynamicActionRegistry dynamicActionRegistry = new DynamicActionRegistry();
dynamicActionRegistry.initialize(actions, actionFilters, transportService, null);
client.initialize(
actions,
dynamicActionRegistry,
() -> clusterService.localNode().getId(),
transportService.getRemoteClusterService(),
new NamedWriteableRegistry(Collections.emptyList())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@

import org.opensearch.OpenSearchException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionModule.DynamicActionRegistry;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionResponse;
import org.opensearch.action.ActionType;
import org.opensearch.action.support.TransportAction;
import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
Expand All @@ -48,7 +48,6 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.RemoteClusterService;

import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

Expand Down Expand Up @@ -86,7 +85,7 @@ public <Request extends ActionRequest, Response extends ActionResponse> void doE

@Override
public void initialize(
Map<ActionType, TransportAction> actions,
DynamicActionRegistry dynamicActionRegistry,
Supplier<String> localNodeId,
RemoteClusterService remoteClusterService,
NamedWriteableRegistry namedWriteableRegistry
Expand Down

0 comments on commit b22dcad

Please sign in to comment.