diff --git a/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java new file mode 100644 index 0000000000..b810f7c439 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java @@ -0,0 +1,105 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.plugin; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +import java.util.List; +import java.util.Map; + +import org.junit.Test; +import org.opensearch.common.settings.Settings; +import org.opensearch.plugins.SearchPipelinePlugin; +import org.opensearch.plugins.SearchPlugin; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQARequestProcessor; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; + +public class MachineLearningPluginTests { + + @Test + public void testGetSearchExtsFeatureDisabled() { + Settings settings = Settings.builder().build(); + MachineLearningPlugin plugin = new MachineLearningPlugin(settings); + List> searchExts = plugin.getSearchExts(); + assertEquals(0, searchExts.size()); + } + + @Test + public void testGetSearchExtsFeatureDisabledExplicit() { + Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "false").build(); + MachineLearningPlugin plugin = new MachineLearningPlugin(settings); + List> searchExts = plugin.getSearchExts(); + assertEquals(0, searchExts.size()); + } + + @Test + public void testGetRequestProcessorsFeatureDisabled() { + Settings settings = Settings.builder().build(); + MachineLearningPlugin plugin = new MachineLearningPlugin(settings); + SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); + Map requestProcessors = plugin.getRequestProcessors(parameters); + assertEquals(0, requestProcessors.size()); + } + + @Test + public void testGetResponseProcessorsFeatureDisabled() { + Settings settings = Settings.builder().build(); + MachineLearningPlugin plugin = new MachineLearningPlugin(settings); + SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); + Map responseProcessors = plugin.getResponseProcessors(parameters); + assertEquals(0, responseProcessors.size()); + } + + @Test + public void testGetSearchExts() { + Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build(); + MachineLearningPlugin plugin = new MachineLearningPlugin(settings); + List> searchExts = plugin.getSearchExts(); + assertEquals(1, searchExts.size()); + SearchPlugin.SearchExtSpec spec = searchExts.get(0); + assertEquals(GenerativeQAParamExtBuilder.PARAMETER_NAME, spec.getName().getPreferredName()); + } + + @Test + public void testGetRequestProcessors() { + Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build(); + MachineLearningPlugin plugin = new MachineLearningPlugin(settings); + SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); + Map requestProcessors = plugin.getRequestProcessors(parameters); + assertEquals(1, requestProcessors.size()); + assertTrue( + requestProcessors.get(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE) instanceof GenerativeQARequestProcessor.Factory + ); + } + + @Test + public void testGetResponseProcessors() { + Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build(); + MachineLearningPlugin plugin = new MachineLearningPlugin(settings); + SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); + Map responseProcessors = plugin.getResponseProcessors(parameters); + assertEquals(1, responseProcessors.size()); + assertTrue( + responseProcessors.get(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE) instanceof GenerativeQAResponseProcessor.Factory + ); + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java index 45b72a41b8..10e5a924c6 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java @@ -59,7 +59,7 @@ public static String getChatCompletionPrompt(String question, List return buildMessageParameter(question, chatHistory, contexts); } - enum Role { + enum ChatRole { USER("user"), ASSISTANT("assistant"), SYSTEM("system"); @@ -69,7 +69,7 @@ enum Role { @Getter private String name; - Role(String name) { + ChatRole(String name) { this.name = name; } } @@ -80,15 +80,15 @@ static String buildMessageParameter(String question, List chatHisto // TODO better prompt template management is needed here. JsonArray messageArray = new JsonArray(); - messageArray.add(new Message(Role.USER, DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE).toJson()); + messageArray.add(new Message(ChatRole.USER, DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE).toJson()); for (String result : contexts) { - messageArray.add(new Message(Role.USER, "SEARCH RESULT: " + result).toJson()); + messageArray.add(new Message(ChatRole.USER, "SEARCH RESULT: " + result).toJson()); } if (!chatHistory.isEmpty()) { Messages.fromInteractions(chatHistory).getMessages().forEach(m -> messageArray.add(m.toJson())); } - messageArray.add(new Message(Role.USER, "QUESTION: " + question).toJson()); - messageArray.add(new Message(Role.USER, "ANSWER:").toJson()); + messageArray.add(new Message(ChatRole.USER, "QUESTION: " + question).toJson()); + messageArray.add(new Message(ChatRole.USER, "ANSWER:").toJson()); return messageArray.toString(); } @@ -114,8 +114,8 @@ public static Messages fromInteractions(final List interactions) { List messages = new ArrayList<>(); for (Interaction interaction : interactions) { - messages.add(new Message(Role.USER, interaction.getInput())); - messages.add(new Message(Role.ASSISTANT, interaction.getResponse())); + messages.add(new Message(ChatRole.USER, interaction.getInput())); + messages.add(new Message(ChatRole.ASSISTANT, interaction.getResponse())); } return new Messages(messages); @@ -128,7 +128,7 @@ static class Message { private final static String MESSAGE_FIELD_CONTENT = "content"; @Getter - private Role role; + private ChatRole chatRole; @Getter private String content; @@ -138,15 +138,15 @@ public Message() { json = new JsonObject(); } - public Message(Role role, String content) { + public Message(ChatRole chatRole, String content) { this(); - setRole(role); + setChatRole(chatRole); setContent(content); } - public void setRole(Role role) { + public void setChatRole(ChatRole chatRole) { json.remove(MESSAGE_FIELD_ROLE); - json.add(MESSAGE_FIELD_ROLE, new JsonPrimitive(role.getName())); + json.add(MESSAGE_FIELD_ROLE, new JsonPrimitive(chatRole.getName())); } public void setContent(String content) { this.content = StringEscapeUtils.escapeJson(content);