Skip to content

Commit

Permalink
Added spotless, removed unused code, added more comments.
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee committed Oct 2, 2023
1 parent e499ec5 commit 857f436
Show file tree
Hide file tree
Showing 30 changed files with 685 additions and 399 deletions.
10 changes: 10 additions & 0 deletions search-processors/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ plugins {
id 'java'
id 'jacoco'
id "io.freefair.lombok"
id 'com.diffplug.spotless' version '6.18.0'
}

repositories {
Expand Down Expand Up @@ -73,3 +74,12 @@ jacocoTestCoverageVerification {
}

check.dependsOn jacocoTestCoverageVerification

spotless {
java {
removeUnusedImports()
importOrder 'java', 'javax', 'org', 'com'

eclipse().configFile rootProject.file('.eclipseformat.xml')
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative;

import java.util.Map;
import java.util.function.BooleanSupplier;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;

import java.util.Map;
import java.util.function.BooleanSupplier;

/**
* Defines the request processor for generative QA search pipelines.
*/
Expand All @@ -35,7 +35,13 @@ public class GenerativeQARequestProcessor extends AbstractProcessor implements S
private String modelId;
private final BooleanSupplier featureFlagSupplier;

protected GenerativeQARequestProcessor(String tag, String description, boolean ignoreFailure, String modelId, BooleanSupplier supplier) {
protected GenerativeQARequestProcessor(
String tag,
String description,
boolean ignoreFailure,
String modelId,
BooleanSupplier supplier
) {
super(tag, description, ignoreFailure);
this.modelId = modelId;
this.featureFlagSupplier = supplier;
Expand Down Expand Up @@ -76,12 +82,17 @@ public SearchRequestProcessor create(
PipelineContext pipelineContext
) throws Exception {
if (featureFlagSupplier.getAsBoolean()) {
return new GenerativeQARequestProcessor(tag, description, ignoreFailure,
ConfigurationUtils.readStringProperty(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID
),
return new GenerativeQARequestProcessor(
tag,
description,
ignoreFailure,
ConfigurationUtils
.readStringProperty(
GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID
),
this.featureFlagSupplier
);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative;

import com.google.gson.JsonArray;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException;

import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BooleanSupplier;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
Expand All @@ -31,25 +37,20 @@
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters;
import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;
import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ModelLocator;
import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil;

import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BooleanSupplier;
import com.google.gson.JsonArray;

import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

/**
* Defines the response processor for generative QA search pipelines.
Expand All @@ -60,9 +61,9 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements

private static final int DEFAULT_CHAT_HISTORY_WINDOW = 10;

private static final int MAX_PROCESSOR_TIME_IN_SECONDS = 60;
private static final int DEFAULT_PROCESSOR_TIME_IN_SECONDS = 30;

// TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM.
// TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM.

private final String llmModel;
private final List<String> contextFields;
Expand All @@ -80,8 +81,18 @@ public class GenerativeQAResponseProcessor extends AbstractProcessor implements

private final BooleanSupplier featureFlagSupplier;

protected GenerativeQAResponseProcessor(Client client, String tag, String description, boolean ignoreFailure,
Llm llm, String llmModel, List<String> contextFields, String systemPrompt, String userInstructions, BooleanSupplier supplier) {
protected GenerativeQAResponseProcessor(
Client client,
String tag,
String description,
boolean ignoreFailure,
Llm llm,
String llmModel,
List<String> contextFields,
String systemPrompt,
String userInstructions,
BooleanSupplier supplier
) {
super(tag, description, ignoreFailure);
this.llmModel = llmModel;
this.contextFields = contextFields;
Expand All @@ -105,7 +116,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp

Integer timeout = params.getTimeout();
if (timeout == null || timeout == GenerativeQAParameters.SIZE_NULL_VALUE) {
timeout = MAX_PROCESSOR_TIME_IN_SECONDS;
timeout = DEFAULT_PROCESSOR_TIME_IN_SECONDS;
}
log.info("Timeout for this request: {} seconds.", timeout);

Expand All @@ -122,7 +133,9 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
interactionSize = DEFAULT_CHAT_HISTORY_WINDOW;
}
log.info("Using interaction size of {}", interactionSize);
List<Interaction> chatHistory = (conversationId == null) ? Collections.emptyList() : memoryClient.getInteractions(conversationId, interactionSize);
List<Interaction> chatHistory = (conversationId == null)
? Collections.emptyList()
: memoryClient.getInteractions(conversationId, interactionSize);
log.info("Retrieved chat history. ({})", getDuration(start));

Integer topN = params.getContextSize();
Expand All @@ -134,8 +147,11 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
log.info("system_prompt: {}", systemPrompt);
log.info("user_instructions: {}", userInstructions);
start = Instant.now();
ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(systemPrompt, userInstructions, llmModel,
llmQuestion, chatHistory, searchResults, timeout));
ChatCompletionOutput output = llm
.doChatCompletion(
LlmIOUtil
.createChatCompletionInput(systemPrompt, userInstructions, llmModel, llmQuestion, chatHistory, searchResults, timeout)
);
log.info("doChatCompletion complete. ({})", getDuration(start));

String answer = null;
Expand All @@ -148,13 +164,15 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp

if (conversationId != null) {
start = Instant.now();
interactionId = memoryClient.createInteraction(conversationId,
llmQuestion,
PromptUtil.getPromptTemplate(systemPrompt, userInstructions),
answer,
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
jsonArrayToString(searchResults)
);
interactionId = memoryClient
.createInteraction(
conversationId,
llmQuestion,
PromptUtil.getPromptTemplate(systemPrompt, userInstructions),
answer,
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
jsonArrayToString(searchResults)
);
log.info("Created a new interaction: {} ({})", interactionId, getDuration(start));
}
}
Expand All @@ -175,8 +193,19 @@ private SearchResponse insertAnswer(SearchResponse response, String answer, Stri

// TODO return the interaction id in the response.

return new GenerativeSearchResponse(answer, errorMessage, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(),
response.getSkippedShards(), response.getSuccessfulShards(), response.getShardFailures(), response.getClusters(), interactionId);
return new GenerativeSearchResponse(
answer,
errorMessage,
response.getInternalResponse(),
response.getScrollId(),
response.getTotalShards(),
response.getSuccessfulShards(),
response.getSkippedShards(),
response.getSuccessfulShards(),
response.getShardFailures(),
response.getClusters(),
interactionId
);
}

private List<String> getSearchResults(SearchResponse response, Integer topN) {
Expand Down Expand Up @@ -225,41 +254,60 @@ public SearchResponseProcessor create(
PipelineContext pipelineContext
) throws Exception {
if (this.featureFlagSupplier.getAsBoolean()) {
String modelId = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID
);
String llmModel = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_LLM_MODEL
);
List<String> contextFields = ConfigurationUtils.readList(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST
);
String modelId = ConfigurationUtils
.readOptionalStringProperty(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID
);
String llmModel = ConfigurationUtils
.readOptionalStringProperty(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_LLM_MODEL
);
List<String> contextFields = ConfigurationUtils
.readList(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST
);
if (contextFields.isEmpty()) {
throw newConfigurationException(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
throw newConfigurationException(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST,
"required property can't be empty."
);
}
String systemPrompt = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT
);
String userInstructions = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS
);
log.info("model_id {}, llm_model {}, context_field_list {}, system_prompt {}, user_instructions {}",
modelId, llmModel, contextFields, systemPrompt, userInstructions);
return new GenerativeQAResponseProcessor(client,
String systemPrompt = ConfigurationUtils
.readOptionalStringProperty(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT
);
String userInstructions = ConfigurationUtils
.readOptionalStringProperty(
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE,
tag,
config,
GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS
);
log
.info(
"model_id {}, llm_model {}, context_field_list {}, system_prompt {}, user_instructions {}",
modelId,
llmModel,
contextFields,
systemPrompt,
userInstructions
);
return new GenerativeQAResponseProcessor(
client,
tag,
description,
ignoreFailure,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative;

import java.io.IOException;
import java.util.Objects;

import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

/**
* This is an extension of SearchResponse that adds LLM-generated answers to search responses in a dedicated "ext" section.
*
Expand Down
Loading

0 comments on commit 857f436

Please sign in to comment.