Skip to content

Commit

Permalink
Add unit tests for MachineLearningPlugin
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee committed Sep 5, 2023
1 parent 8cbac4c commit 739d96a
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.plugin;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;

import java.util.List;
import java.util.Map;

import org.junit.Test;
import org.opensearch.common.settings.Settings;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQARequestProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder;

public class MachineLearningPluginTests {

@Test
public void testGetSearchExtsFeatureDisabled() {
Settings settings = Settings.builder().build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
List<SearchPlugin.SearchExtSpec<?>> searchExts = plugin.getSearchExts();
assertEquals(0, searchExts.size());
}

@Test
public void testGetSearchExtsFeatureDisabledExplicit() {
Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "false").build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
List<SearchPlugin.SearchExtSpec<?>> searchExts = plugin.getSearchExts();
assertEquals(0, searchExts.size());
}

@Test
public void testGetRequestProcessorsFeatureDisabled() {
Settings settings = Settings.builder().build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, ?> requestProcessors = plugin.getRequestProcessors(parameters);
assertEquals(0, requestProcessors.size());
}

@Test
public void testGetResponseProcessorsFeatureDisabled() {
Settings settings = Settings.builder().build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, ?> responseProcessors = plugin.getResponseProcessors(parameters);
assertEquals(0, responseProcessors.size());
}

@Test
public void testGetSearchExts() {
Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
List<SearchPlugin.SearchExtSpec<?>> searchExts = plugin.getSearchExts();
assertEquals(1, searchExts.size());
SearchPlugin.SearchExtSpec<?> spec = searchExts.get(0);
assertEquals(GenerativeQAParamExtBuilder.PARAMETER_NAME, spec.getName().getPreferredName());
}

@Test
public void testGetRequestProcessors() {
Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, ?> requestProcessors = plugin.getRequestProcessors(parameters);
assertEquals(1, requestProcessors.size());
assertTrue(
requestProcessors.get(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE) instanceof GenerativeQARequestProcessor.Factory
);
}

@Test
public void testGetResponseProcessors() {
Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, ?> responseProcessors = plugin.getResponseProcessors(parameters);
assertEquals(1, responseProcessors.size());
assertTrue(
responseProcessors.get(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE) instanceof GenerativeQAResponseProcessor.Factory
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public static String getChatCompletionPrompt(String question, List<Interaction>
return buildMessageParameter(question, chatHistory, contexts);
}

enum Role {
enum ChatRole {
USER("user"),
ASSISTANT("assistant"),
SYSTEM("system");
Expand All @@ -69,7 +69,7 @@ enum Role {
@Getter
private String name;

Role(String name) {
ChatRole(String name) {
this.name = name;
}
}
Expand All @@ -80,15 +80,15 @@ static String buildMessageParameter(String question, List<Interaction> chatHisto
// TODO better prompt template management is needed here.

JsonArray messageArray = new JsonArray();
messageArray.add(new Message(Role.USER, DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE).toJson());
messageArray.add(new Message(ChatRole.USER, DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE).toJson());
for (String result : contexts) {
messageArray.add(new Message(Role.USER, "SEARCH RESULT: " + result).toJson());
messageArray.add(new Message(ChatRole.USER, "SEARCH RESULT: " + result).toJson());
}
if (!chatHistory.isEmpty()) {
Messages.fromInteractions(chatHistory).getMessages().forEach(m -> messageArray.add(m.toJson()));
}
messageArray.add(new Message(Role.USER, "QUESTION: " + question).toJson());
messageArray.add(new Message(Role.USER, "ANSWER:").toJson());
messageArray.add(new Message(ChatRole.USER, "QUESTION: " + question).toJson());
messageArray.add(new Message(ChatRole.USER, "ANSWER:").toJson());

return messageArray.toString();
}
Expand All @@ -114,8 +114,8 @@ public static Messages fromInteractions(final List<Interaction> interactions) {
List<Message> messages = new ArrayList<>();

for (Interaction interaction : interactions) {
messages.add(new Message(Role.USER, interaction.getInput()));
messages.add(new Message(Role.ASSISTANT, interaction.getResponse()));
messages.add(new Message(ChatRole.USER, interaction.getInput()));
messages.add(new Message(ChatRole.ASSISTANT, interaction.getResponse()));
}

return new Messages(messages);
Expand All @@ -128,7 +128,7 @@ static class Message {
private final static String MESSAGE_FIELD_CONTENT = "content";

@Getter
private Role role;
private ChatRole chatRole;
@Getter
private String content;

Expand All @@ -138,15 +138,15 @@ public Message() {
json = new JsonObject();
}

public Message(Role role, String content) {
public Message(ChatRole chatRole, String content) {
this();
setRole(role);
setChatRole(chatRole);
setContent(content);
}

public void setRole(Role role) {
public void setChatRole(ChatRole chatRole) {
json.remove(MESSAGE_FIELD_ROLE);
json.add(MESSAGE_FIELD_ROLE, new JsonPrimitive(role.getName()));
json.add(MESSAGE_FIELD_ROLE, new JsonPrimitive(chatRole.getName()));
}
public void setContent(String content) {
this.content = StringEscapeUtils.escapeJson(content);
Expand Down

0 comments on commit 739d96a

Please sign in to comment.