diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 2117b220b..98c32b189 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -131,7 +131,7 @@ public void inferenceSentences( @NonNull final Map inputObjects, @NonNull final ActionListener> listener ) { - inferenceSentencesWithRetry(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener); + retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener); } private void retryableInferenceSentencesWithMapResult( @@ -140,7 +140,7 @@ private void retryableInferenceSentencesWithMapResult( final int retryTime, final ActionListener>> listener ) { - MLInput mlInput = createMLInput(null, inputText); + MLInput mlInput = createMLTextInput(null, inputText); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { final List> result = buildMapResultFromResponse(mlOutput); listener.onResponse(result); @@ -181,12 +181,6 @@ private MLInput createMLTextInput(final List targetResponseFilters, List return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); } - private MLInput createMLInput(final List targetResponseFilters, List inputText) { - final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); - final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); - return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); - } - private List> buildVectorFromResponse(MLOutput mlOutput) { final List> vector = new ArrayList<>(); final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; @@ -223,8 +217,8 @@ private List buildSingleVectorFromResponse(MLOutput mlOutput) { return vector.isEmpty() ? new ArrayList<>() : vector.get(0); } - private void inferenceSentencesWithRetry( - @NonNull final List targetResponseFilters, + private void retryableInferenceSentencesWithSingleVectorResult( + final List targetResponseFilters, final String modelId, final Map inputObjects, final int retryTime, @@ -238,7 +232,7 @@ private void inferenceSentencesWithRetry( }, e -> { if (RetryUtil.shouldRetry(e, retryTime)) { final int retryTimeAdd = retryTime + 1; - inferenceSentencesWithRetry(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener); + retryableInferenceSentencesWithSingleVectorResult(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener); } else { listener.onFailure(e); }