From 8cbac4cd86d795d8ad3e0ea1bcab0cf258d8f55e Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Tue, 5 Sep 2023 14:21:53 -0700 Subject: [PATCH] Address comments. Signed-off-by: Austin Lee --- .../ml/plugin/MachineLearningPlugin.java | 41 ++++++++++++------- .../client/ConversationalMemoryClient.java | 6 +-- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index a7c6cf1d65..5587099797 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -9,8 +9,9 @@ import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; import java.nio.file.Path; +import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -680,32 +681,42 @@ public List> getSettings() { @Override public List> getSearchExts() { - return ragSearchPipelineEnabled - ? List - .of( + List> searchExts = new ArrayList<>(); + + if (ragSearchPipelineEnabled) { + searchExts + .add( new SearchPlugin.SearchExtSpec<>( GenerativeQAParamExtBuilder.PARAMETER_NAME, input -> new GenerativeQAParamExtBuilder(input), parser -> GenerativeQAParamExtBuilder.parse(parser) ) - ) - // Feature not enabled - : Collections.emptyList(); + ); + } + + return searchExts; } @Override public Map> getRequestProcessors(Parameters parameters) { - return ragSearchPipelineEnabled - ? Map.of(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, new GenerativeQARequestProcessor.Factory()) - // Feature not enabled - : Collections.emptyMap(); + Map> requestProcessors = new HashMap<>(); + + if (ragSearchPipelineEnabled) { + requestProcessors.put(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, new GenerativeQARequestProcessor.Factory()); + } + + return requestProcessors; } @Override public Map> getResponseProcessors(Parameters parameters) { - return ragSearchPipelineEnabled - ? Map.of(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, new GenerativeQAResponseProcessor.Factory(this.client)) - // Feature not enabled - : Collections.emptyMap(); + Map> responseProcessors = new HashMap<>(); + + if (ragSearchPipelineEnabled) { + responseProcessors + .put(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, new GenerativeQAResponseProcessor.Factory(this.client)); + } + + return responseProcessors; } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java index d95461bcd2..84a32b2368 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java @@ -74,7 +74,7 @@ public List getInteractions(String conversationId, int lastN) { List interactions = new ArrayList<>(); int from = 0; - boolean done = false; + boolean allInteractionsFetched = false; int maxResults = lastN; do { GetInteractionsResponse response = @@ -92,8 +92,8 @@ public List getInteractions(String conversationId, int lastN) { break; } log.info("Interactions: {}, from: {}, maxResults: {}", interactions, from, maxResults); - done = !response.hasMorePages(); - } while (from < lastN && !done); + allInteractionsFetched = !response.hasMorePages(); + } while (from < lastN && !allInteractionsFetched); return interactions; }