diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index e6b6be2c62..55313bc986 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -119,8 +119,6 @@ public void dispatchTask( ActionListener listener ) { String modelId = request.getModelId(); - MLInput input = request.getMlInput(); - FunctionName algorithm = input.getAlgorithm(); try { ActionListener actionListener = ActionListener.wrap(node -> { if (clusterService.localNode().getId().equals(node.getId())) { @@ -133,9 +131,9 @@ public void dispatchTask( transportService.sendRequest(node, getTransportActionName(), request, getResponseHandler(listener)); } }, e -> { listener.onFailure(e); }); - String[] workerNodes = mlModelManager.getWorkerNodes(modelId, algorithm, true); + String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName, true); if (workerNodes == null || workerNodes.length == 0) { - if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) { + if (functionName == FunctionName.TEXT_EMBEDDING || functionName == FunctionName.REMOTE) { listener .onFailure( new IllegalArgumentException( @@ -144,7 +142,7 @@ public void dispatchTask( ); return; } else { - workerNodes = nodeHelper.getEligibleNodeIds(algorithm); + workerNodes = nodeHelper.getEligibleNodeIds(functionName); } } mlTaskDispatcher.dispatchPredictTask(workerNodes, actionListener); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 5f18d974a7..0d0c594458 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -212,7 +212,7 @@ public void setup() throws IOException { public void testExecuteTask_OnLocalNode() { setupMocks(true, false, false, false); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); @@ -220,10 +220,22 @@ public void testExecuteTask_OnLocalNode() { verify(mlTaskManager).remove(anyString()); } + public void testExecuteTask_OnLocalNode_RemoteModel() { + setupMocks(true, false, false, false); + + taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertTrue(argumentCaptor.getValue().getMessage().contains("Model not ready yet.")); + verify(mlTaskManager, never()).add(any(MLTask.class)); + verify(client, never()).get(any(), any()); + } + public void testExecuteTask_OnLocalNode_QueryInput() { setupMocks(true, false, false, false); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithQuery, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithQuery, transportService, listener); verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); @@ -234,7 +246,7 @@ public void testExecuteTask_OnLocalNode_QueryInput() { public void testExecuteTask_OnLocalNode_QueryInput_Failure() { setupMocks(true, true, false, false); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithQuery, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithQuery, transportService, listener); verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager, never()).add(any(MLTask.class)); @@ -245,7 +257,7 @@ public void testExecuteTask_NoPermission() { setupMocks(true, true, false, false); threadContext.stashContext(); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "test_user|test_role|test_tenant"); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlTaskManager).add(any(MLTask.class)); verify(mlTaskManager).remove(anyString()); verify(client).get(any(), any()); @@ -256,14 +268,14 @@ public void testExecuteTask_NoPermission() { public void testExecuteTask_OnRemoteNode() { setupMocks(false, false, false, false); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(transportService).sendRequest(eq(remoteNode), eq(MLPredictionTaskAction.NAME), eq(requestWithDataFrame), any()); } public void testExecuteTask_OnLocalNode_GetModelFail() { setupMocks(true, false, true, false); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); @@ -277,7 +289,7 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() { setupMocks(true, false, false, false); requestWithDataFrame = MLPredictionTaskRequest.builder().mlInput(mlInputWithDataFrame).build(); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); @@ -291,7 +303,7 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() { public void testExecuteTask_OnLocalNode_NullGetResponse() { setupMocks(true, false, false, true); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class));