diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 2718a1781..f7c9d2e00 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -113,6 +113,7 @@ public Collection createComponents( FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService, encryptorUtils); WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory( settings, + threadPool, clusterService, client, mlClient, diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 44650061e..293674f52 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -38,6 +38,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; @@ -213,12 +214,19 @@ private void executeWorkflow(List workflowSequence, String workflow ); } catch (Exception ex) { logger.error("Provisioning failed for workflow: {}", workflowId, ex); + String errorMessage; + if (ex instanceof CancellationException) { + errorMessage = "A step in the workflow was cancelled."; + } else if (ex.getCause() != null) { + errorMessage = ex.getCause().getMessage(); + } else { + errorMessage = ex.getMessage(); + } flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( workflowId, Map.ofEntries( Map.entry(STATE_FIELD, State.FAILED), - // TODO: potentially improve the error message here - Map.entry(ERROR_FIELD, ex.getMessage()), + Map.entry(ERROR_FIELD, errorMessage), Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.FAILED), Map.entry(PROVISION_END_TIME_FIELD, Instant.now().toEpochMilli()) ), diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java index 6096f8f29..b68374252 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java @@ -19,15 +19,16 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.MLTask; +import org.opensearch.threadpool.ThreadPool; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.stream.Stream; +import java.util.concurrent.atomic.AtomicInteger; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; -import static org.opensearch.flowframework.common.WorkflowResources.DEPLOY_MODEL; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; /** @@ -39,20 +40,24 @@ public abstract class AbstractRetryableWorkflowStep implements WorkflowStep { protected volatile Integer maxRetry; private final MachineLearningNodeClient mlClient; private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private ThreadPool threadPool; /** * Instantiates a new Retryable workflow step * @param settings Environment settings + * @param threadPool The OpenSearch thread pool * @param clusterService the cluster service * @param mlClient machine learning client * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ protected AbstractRetryableWorkflowStep( Settings settings, + ThreadPool threadPool, ClusterService clusterService, MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler ) { + this.threadPool = threadPool; this.maxRetry = MAX_GET_TASK_REQUEST_RETRY.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_GET_TASK_REQUEST_RETRY, it -> maxRetry = it); this.mlClient = mlClient; @@ -65,7 +70,6 @@ protected AbstractRetryableWorkflowStep( * @param nodeId the workflow node id * @param future the workflow step future * @param taskId the ml task id - * @param retries the current number of request retries * @param workflowStep the workflow step which requires a retry get ml task functionality */ protected void retryableGetMlTask( @@ -73,74 +77,86 @@ protected void retryableGetMlTask( String nodeId, CompletableFuture future, String taskId, - int retries, String workflowStep ) { - mlClient.getTask(taskId, ActionListener.wrap(response -> { - MLTaskState currentState = response.getState(); - if (currentState != MLTaskState.COMPLETED) { - if (Stream.of(MLTaskState.FAILED, MLTaskState.COMPLETED_WITH_ERROR).anyMatch(x -> x == currentState)) { - // Model registration failed or completed with errors - String errorMessage = workflowStep + " failed with error : " + response.getError(); + AtomicInteger retries = new AtomicInteger(); + CompletableFuture.runAsync(() -> { + while (retries.getAndIncrement() < this.maxRetry && !future.isDone()) { + mlClient.getTask(taskId, ActionListener.wrap(response -> { + switch (response.getState()) { + case COMPLETED: + try { + String resourceName = getResourceByWorkflowStep(getName()); + String id = getResourceId(response); + logger.info("{} successful for {} and {} {}", workflowStep, workflowId, resourceName, id); + flowFrameworkIndicesHandler.updateResourceInStateIndex( + workflowId, + nodeId, + getName(), + id, + ActionListener.wrap(updateResponse -> { + logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); + future.complete( + new WorkflowData( + Map.ofEntries( + Map.entry(resourceName, id), + Map.entry(REGISTER_MODEL_STATUS, response.getState().name()) + ), + workflowId, + nodeId + ) + ); + }, exception -> { + logger.error("Failed to update new created resource", exception); + future.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + future.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + break; + case FAILED: + case COMPLETED_WITH_ERROR: + String errorMessage = workflowStep + " failed with error : " + response.getError(); + logger.error(errorMessage); + future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); + break; + case CANCELLED: + errorMessage = workflowStep + " task was cancelled."; + logger.error(errorMessage); + future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT)); + break; + default: + // Task started or running, do nothing + } + }, exception -> { + String errorMessage = workflowStep + " failed with error : " + exception.getMessage(); logger.error(errorMessage); future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); - } else { - // Task still in progress, attempt retry - throw new IllegalStateException(workflowStep + " is not yet completed"); - } - } else { - try { - logger.info(workflowStep + " successful for {} and modelId {}", workflowId, response.getModelId()); - String resourceName = getResourceByWorkflowStep(getName()); - String id; - if (getName().equals(DEPLOY_MODEL.getWorkflowStep())) { - id = response.getModelId(); - } else { - id = response.getTaskId(); - } - flowFrameworkIndicesHandler.updateResourceInStateIndex( - workflowId, - nodeId, - getName(), - id, - ActionListener.wrap(updateResponse -> { - logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); - future.complete( - new WorkflowData( - Map.ofEntries( - Map.entry(resourceName, response.getModelId()), - Map.entry(REGISTER_MODEL_STATUS, response.getState().name()) - ), - workflowId, - nodeId - ) - ); - }, exception -> { - logger.error("Failed to update new created resource", exception); - future.completeExceptionally( - new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) - ); - }) - ); - } catch (Exception e) { - logger.error("Failed to parse and update new created resource", e); - future.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); - } - } - }, exception -> { - if (retries < maxRetry) { - // Sleep thread prior to retrying request + })); + // Wait long enough for future to possibly complete try { Thread.sleep(5000); - } catch (Exception e) { + } catch (InterruptedException e) { FutureUtils.cancel(future); + Thread.currentThread().interrupt(); } - retryableGetMlTask(workflowId, nodeId, future, taskId, retries + 1, workflowStep); - } else { - logger.error("Failed to retrieve" + workflowStep + ",maximum retries exceeded"); - future.completeExceptionally(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); } - })); + if (!future.isDone()) { + String errorMessage = workflowStep + " did not complete after " + maxRetry + " retries"; + logger.error(errorMessage); + future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.REQUEST_TIMEOUT)); + } + }, threadPool.executor(PROVISION_THREAD_POOL)); } + /** + * Returns the resourceId associated with the task + * @param response The Task response + * @return the resource ID, such as a model id + */ + protected abstract String getResourceId(MLTask response); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index b8447d5e0..21f1eaca5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -18,7 +18,9 @@ import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; +import org.opensearch.threadpool.ThreadPool; import java.util.Collections; import java.util.Map; @@ -42,17 +44,19 @@ public class DeployModelStep extends AbstractRetryableWorkflowStep { /** * Instantiate this class * @param settings The OpenSearch settings + * @param threadPool The OpenSearch thread pool * @param clusterService The cluster service * @param mlClient client to instantiate MLClient * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ public DeployModelStep( Settings settings, + ThreadPool threadPool, ClusterService clusterService, MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler ) { - super(settings, clusterService, mlClient, flowFrameworkIndicesHandler); + super(settings, threadPool, clusterService, mlClient, flowFrameworkIndicesHandler); this.mlClient = mlClient; this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @@ -74,7 +78,7 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) { String taskId = mlDeployModelResponse.getTaskId(); // Attempt to retrieve the model ID - retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeId, deployModelFuture, taskId, 0, "Deploy model"); + retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeId, deployModelFuture, taskId, "Deploy model"); } @Override @@ -105,6 +109,11 @@ public void onFailure(Exception e) { return deployModelFuture; } + @Override + protected String getResourceId(MLTask response) { + return response.getModelId(); + } + @Override public String getName() { return NAME; diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index cb0442cbc..103a6d643 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -18,6 +18,8 @@ import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; @@ -26,6 +28,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import org.opensearch.threadpool.ThreadPool; import java.util.Map; import java.util.Set; @@ -35,6 +38,7 @@ import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION; import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE; +import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME; import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE; import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE; @@ -60,17 +64,19 @@ public class RegisterLocalModelStep extends AbstractRetryableWorkflowStep { /** * Instantiate this class * @param settings The OpenSearch settings + * @param threadPool The OpenSearch thread pool * @param clusterService The cluster service * @param mlClient client to instantiate MLClient * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ public RegisterLocalModelStep( Settings settings, + ThreadPool threadPool, ClusterService clusterService, MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler ) { - super(settings, clusterService, mlClient, flowFrameworkIndicesHandler); + super(settings, threadPool, clusterService, mlClient, flowFrameworkIndicesHandler); this.mlClient = mlClient; this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @@ -98,7 +104,6 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { currentNodeId, registerLocalModelFuture, taskId, - 0, "Local model registration" ); } @@ -120,7 +125,7 @@ public void onFailure(Exception e) { MODEL_CONTENT_HASH_VALUE, URL ); - Set optionalKeys = Set.of(DESCRIPTION_FIELD, MODEL_GROUP_ID, ALL_CONFIG); + Set optionalKeys = Set.of(DESCRIPTION_FIELD, MODEL_GROUP_ID, ALL_CONFIG, FUNCTION_NAME); try { Map inputs = ParseUtils.getInputsFromPreviousSteps( @@ -142,6 +147,7 @@ public void onFailure(Exception e) { FrameworkType frameworkType = FrameworkType.from((String) inputs.get(FRAMEWORK_TYPE)); String allConfig = (String) inputs.get(ALL_CONFIG); String url = (String) inputs.get(URL); + String functionName = (String) inputs.get(FUNCTION_NAME); // Create Model configuration TextEmbeddingModelConfigBuilder modelConfigBuilder = TextEmbeddingModelConfig.builder() @@ -158,13 +164,18 @@ public void onFailure(Exception e) { .modelName(modelName) .version(modelVersion) .modelFormat(modelFormat) - .modelGroupId(modelGroupId) .hashValue(modelContentHashValue) .modelConfig(modelConfig) .url(url); if (description != null) { mlInputBuilder.description(description); } + if (modelGroupId != null) { + mlInputBuilder.modelGroupId(modelGroupId); + } + if (functionName != null) { + mlInputBuilder.functionName(FunctionName.from(functionName)); + } MLRegisterModelInput mlInput = mlInputBuilder.build(); @@ -175,6 +186,11 @@ public void onFailure(Exception e) { return registerLocalModelFuture; } + @Override + protected String getResourceId(MLTask response) { + return response.getModelId(); + } + @Override public String getName() { return NAME; diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index c2e55b100..14a9d5dbd 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -15,6 +15,7 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.threadpool.ThreadPool; import java.util.HashMap; import java.util.Map; @@ -31,6 +32,7 @@ public class WorkflowStepFactory { * Instantiate this class. * * @param settings The OpenSearch settings + * @param threadPool The OpenSearch thread pool * @param clusterService The OpenSearch cluster service * @param client The OpenSearch client steps can use * @param mlClient Machine Learning client to perform ml operations @@ -38,6 +40,7 @@ public class WorkflowStepFactory { */ public WorkflowStepFactory( Settings settings, + ThreadPool threadPool, ClusterService clusterService, Client client, MachineLearningNodeClient mlClient, @@ -48,11 +51,14 @@ public WorkflowStepFactory( stepMap.put(CreateIngestPipelineStep.NAME, () -> new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler)); stepMap.put( RegisterLocalModelStep.NAME, - () -> new RegisterLocalModelStep(settings, clusterService, mlClient, flowFrameworkIndicesHandler) + () -> new RegisterLocalModelStep(settings, threadPool, clusterService, mlClient, flowFrameworkIndicesHandler) ); stepMap.put(RegisterRemoteModelStep.NAME, () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeleteModelStep.NAME, () -> new DeleteModelStep(mlClient)); - stepMap.put(DeployModelStep.NAME, () -> new DeployModelStep(settings, clusterService, mlClient, flowFrameworkIndicesHandler)); + stepMap.put( + DeployModelStep.NAME, + () -> new DeployModelStep(settings, threadPool, clusterService, mlClient, flowFrameworkIndicesHandler) + ); stepMap.put(UndeployModelStep.NAME, () -> new UndeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeleteConnectorStep.NAME, () -> new DeleteConnectorStep(mlClient)); diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java index a6cbcf9bd..73ddf0349 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; import java.io.IOException; import java.util.ArrayList; @@ -71,6 +72,7 @@ public void testFailedParseWorkflowValidator() throws IOException { public void testWorkflowStepFactoryHasValidators() throws IOException { + ThreadPool threadPool = mock(ThreadPool.class); ClusterService clusterService = mock(ClusterService.class); ClusterAdminClient clusterAdminClient = mock(ClusterAdminClient.class); AdminClient adminClient = mock(AdminClient.class); @@ -89,6 +91,7 @@ public void testWorkflowStepFactoryHasValidators() throws IOException { WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory( Settings.EMPTY, + threadPool, clusterService, client, mlClient, diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index 0d1d16c74..10547d972 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -15,6 +15,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; @@ -26,6 +27,10 @@ import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.FixedExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.junit.AfterClass; import java.io.IOException; import java.util.Collections; @@ -33,6 +38,7 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -41,6 +47,8 @@ import org.mockito.MockitoAnnotations; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; @@ -57,6 +65,7 @@ @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class DeployModelStepTests extends OpenSearchTestCase { + private static TestThreadPool testThreadPool; private WorkflowData inputData = WorkflowData.EMPTY; @Mock @@ -77,14 +86,34 @@ public void setUp() throws Exception { Stream.of(MAX_GET_TASK_REQUEST_RETRY) ).collect(Collectors.toSet()); - // Set max request retry setting to 0 to avoid sleeping the thread during unit test failure cases - Settings testMaxRetrySetting = Settings.builder().put(MAX_GET_TASK_REQUEST_RETRY.getKey(), 0).build(); + // Set max request retry setting to 1 to limit sleeping the thread to one retry iteration + Settings testMaxRetrySetting = Settings.builder().put(MAX_GET_TASK_REQUEST_RETRY.getKey(), 1).build(); ClusterSettings clusterSettings = new ClusterSettings(testMaxRetrySetting, settingsSet); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - this.deployModel = new DeployModelStep(testMaxRetrySetting, clusterService, machineLearningNodeClient, flowFrameworkIndicesHandler); - this.inputData = new WorkflowData(Map.ofEntries(Map.entry(MODEL_ID, "modelId")), "test-id", "test-node-id"); + testThreadPool = new TestThreadPool( + DeployModelStepTests.class.getName(), + new FixedExecutorBuilder( + Settings.EMPTY, + PROVISION_THREAD_POOL, + OpenSearchExecutors.allocatedProcessors(Settings.EMPTY), + 100, + FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_THREAD_POOL + ) + ); + this.deployModel = new DeployModelStep( + testMaxRetrySetting, + testThreadPool, + clusterService, + machineLearningNodeClient, + flowFrameworkIndicesHandler + ); + this.inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")), "test-id", "test-node-id"); + } + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } public void testDeployModel() throws ExecutionException, InterruptedException, IOException { @@ -140,19 +169,17 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I Collections.emptyMap() ); + future.join(); + verify(machineLearningNodeClient, times(1)).deploy(any(String.class), any()); verify(machineLearningNodeClient, times(1)).getTask(any(), any()); - assertTrue(future.isDone()); - assertFalse(future.isCompletedExceptionally()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); } public void testDeployModelFailure() { - String modelId = "modelId"; - String taskId = "taskId"; @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -171,13 +198,12 @@ public void testDeployModelFailure() { verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); - assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Failed to deploy model", ex.getCause().getMessage()); } - public void testDeployModelTaskFailure() throws IOException { + public void testDeployModelTaskFailure() throws IOException, InterruptedException, ExecutionException { String modelId = "modelId"; String taskId = "taskId"; @@ -225,11 +251,8 @@ public void testDeployModelTaskFailure() throws IOException { Collections.emptyMap() ); - assertTrue(future.isDone()); - assertTrue(future.isCompletedExceptionally()); ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Deploy model failed with error : " + testErrorMessage, ex.getCause().getMessage()); - } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index 366d798db..518b904e3 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -15,6 +15,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -25,12 +26,17 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.FixedExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.junit.AfterClass; import java.util.Collections; import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -38,6 +44,8 @@ import org.mockito.MockitoAnnotations; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; @@ -54,6 +62,7 @@ @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class RegisterLocalModelStepTests extends OpenSearchTestCase { + private static TestThreadPool testThreadPool; private RegisterLocalModelStep registerLocalModelStep; private WorkflowData workflowData; private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; @@ -72,13 +81,24 @@ public void setUp() throws Exception { Stream.of(MAX_GET_TASK_REQUEST_RETRY) ).collect(Collectors.toSet()); - // Set max request retry setting to 0 to avoid sleeping the thread during unit test failure cases - Settings testMaxRetrySetting = Settings.builder().put(MAX_GET_TASK_REQUEST_RETRY.getKey(), 0).build(); + // Set max request retry setting to 1 to limit sleeping the thread to one retry iteration + Settings testMaxRetrySetting = Settings.builder().put(MAX_GET_TASK_REQUEST_RETRY.getKey(), 1).build(); ClusterSettings clusterSettings = new ClusterSettings(testMaxRetrySetting, settingsSet); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + testThreadPool = new TestThreadPool( + RegisterLocalModelStepTests.class.getName(), + new FixedExecutorBuilder( + Settings.EMPTY, + PROVISION_THREAD_POOL, + OpenSearchExecutors.allocatedProcessors(Settings.EMPTY), + 100, + FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_THREAD_POOL + ) + ); this.registerLocalModelStep = new RegisterLocalModelStep( testMaxRetrySetting, + testThreadPool, clusterService, machineLearningNodeClient, flowFrameworkIndicesHandler @@ -89,6 +109,7 @@ public void setUp() throws Exception { Map.entry("name", "xyz"), Map.entry("version", "1.0.0"), Map.entry("description", "description"), + Map.entry("function_name", "SPARSE_TOKENIZE"), Map.entry("model_format", "TORCH_SCRIPT"), Map.entry(MODEL_GROUP_ID, "abcdefg"), Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), @@ -103,6 +124,11 @@ public void setUp() throws Exception { } + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); + } + public void testRegisterLocalModelSuccess() throws Exception { String taskId = "abcd"; @@ -152,15 +178,14 @@ public void testRegisterLocalModelSuccess() throws Exception { Collections.emptyMap(), Collections.emptyMap() ); - ; + + future.join(); + verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); verify(machineLearningNodeClient, times(1)).getTask(any(), any()); - assertTrue(future.isDone()); - assertFalse(future.isCompletedExceptionally()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); - } public void testRegisterLocalModelFailure() { @@ -177,8 +202,7 @@ public void testRegisterLocalModelFailure() { Collections.emptyMap(), Collections.emptyMap() ); - assertTrue(future.isDone()); - assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("test", ex.getCause().getMessage()); @@ -227,12 +251,10 @@ public void testRegisterLocalModelTaskFailure() { Collections.emptyMap(), Collections.emptyMap() ); - assertTrue(future.isDone()); - assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Local model registration failed with error : " + testErrorMessage, ex.getCause().getMessage()); - } public void testMissingInputs() { diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index b2bf1a6e7..41a65141f 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -22,6 +22,7 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; @@ -35,6 +36,7 @@ import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.plugins.PluginInfo; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.FixedExecutorBuilder; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.junit.AfterClass; @@ -50,6 +52,8 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; @@ -111,9 +115,19 @@ public static void setup() throws IOException { when(client.admin()).thenReturn(adminClient); - testThreadPool = new TestThreadPool(WorkflowProcessSorterTests.class.getName()); + testThreadPool = new TestThreadPool( + WorkflowProcessSorterTests.class.getName(), + new FixedExecutorBuilder( + Settings.EMPTY, + PROVISION_THREAD_POOL, + OpenSearchExecutors.allocatedProcessors(Settings.EMPTY), + 100, + FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_THREAD_POOL + ) + ); WorkflowStepFactory factory = new WorkflowStepFactory( Settings.EMPTY, + testThreadPool, clusterService, client, mlClient,