Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support regenerate for execute API #1709

Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.engine.algorithms.agent;

import static org.opensearch.core.action.ActionListener.wrap;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;

Expand All @@ -14,10 +15,12 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
Expand All @@ -33,6 +36,8 @@
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLMemorySpec;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.output.Output;
Expand All @@ -46,6 +51,9 @@
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
import org.opensearch.ml.memory.action.conversation.GetInteractionAction;
import org.opensearch.ml.memory.action.conversation.GetInteractionRequest;
import org.opensearch.ml.memory.action.conversation.GetInteractionResponse;

import com.google.gson.Gson;

Expand All @@ -62,6 +70,7 @@ public class MLAgentExecutor implements Executable {
public static final String MEMORY_ID = "memory_id";
public static final String QUESTION = "question";
public static final String PARENT_INTERACTION_ID = "parent_interaction_id";
public static final String REGENERATE_INTERACTION_ID = "regenerate_interaction_id";

private Client client;
private Settings settings;
Expand Down Expand Up @@ -113,41 +122,51 @@ public void execute(Input input, ActionListener<Output> listener) {
MLMemorySpec memorySpec = mlAgent.getMemory();
String memoryId = inputDataSet.getParameters().get(MEMORY_ID);
String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID);
String regenerateInteractionId = inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID);
String appType = mlAgent.getAppType();
String question = inputDataSet.getParameters().get(QUESTION);

if (memoryId == null && regenerateInteractionId != null) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if both are null? Do we still want to continue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if regenerateInteractionId is null, it's a normal flow, memory will automatically create if needed

listener.onFailure(new MLValidationException("Memory id must provide for regenerate"));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A memory ID must be provided to regenerate.?

btw, what do we mean by regenerate here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regenerate means user want LLM to reprocess the chat question and generate answer, a picture may more clear

image

}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed a return express here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, will add return here.


if (memorySpec != null
&& memorySpec.getType() != null
&& memoryFactoryMap.containsKey(memorySpec.getType())
&& (memoryId == null || parentInteractionId == null)) {
ConversationIndexMemory.Factory conversationIndexMemoryFactory =
(ConversationIndexMemory.Factory) memoryFactoryMap.get(memorySpec.getType());
conversationIndexMemoryFactory.create(question, memoryId, appType, ActionListener.wrap(memory -> {
conversationIndexMemoryFactory.create(question, memoryId, appType, wrap(memory -> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep the existing style of using ActionListener.wrap()? You only changed this wrap, but there are still some ActionListener.wrap() and ActionListener.* somewhere else in this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is unintentional change, change it back.

inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId());
ConversationIndexMemory conversationIndexMemory = (ConversationIndexMemory) memory;

// get regenerate interaction question
Optional.ofNullable(regenerateInteractionId).ifPresent(interactionId -> {
log.info("Regenerate for existing interaction {}", regenerateInteractionId);
Hailong-am marked this conversation as resolved.
Show resolved Hide resolved
getQuestionFromInteraction(memoryId, interactionId, inputDataSet);
});

// Create root interaction ID
ConversationIndexMessage msg = ConversationIndexMessage
.conversationIndexMessageBuilder()
.type(appType)
.question(question)
.question(inputDataSet.getParameters().get(QUESTION))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use question?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question will be override at line 204 for regenerate case

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add comment for this specific regenerate case? It takes a while to figure out why we can't use the question in request directly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, comments added.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was confused too. I think you can also set question to the original question meanwhile updating inputDataSet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactor to listener way and it's more clear now.

.response("")
.finalAnswer(true)
.sessionId(memory.getConversationId())
.build();
conversationIndexMemory
.save(msg, null, null, null, ActionListener.<CreateInteractionResponse>wrap(interaction -> {
log.info("Created parent interaction ID: " + interaction.getId());
inputDataSet.getParameters().put(PARENT_INTERACTION_ID, interaction.getId());
ActionListener<Object> agentActionListener = createAgentActionListener(
listener,
outputs,
modelTensors
);
executeAgent(inputDataSet, mlAgent, agentActionListener);
}, ex -> {
log.error("Failed to create parent interaction", ex);
listener.onFailure(ex);
}));
memory.save(msg, null, null, null, ActionListener.<CreateInteractionResponse>wrap(interaction -> {
log.info("Created parent interaction ID: " + interaction.getId());
inputDataSet.getParameters().put(PARENT_INTERACTION_ID, interaction.getId());
ActionListener<Object> agentActionListener = createAgentActionListener(
listener,
outputs,
modelTensors
);
executeAgent(inputDataSet, mlAgent, agentActionListener);
}, ex -> {
log.error("Failed to create parent interaction", ex);
listener.onFailure(ex);
}));
}, ex -> {
log.error("Failed to read conversation memory", ex);
listener.onFailure(ex);
Expand All @@ -156,6 +175,8 @@ public void execute(Input input, ActionListener<Output> listener) {
ActionListener<Object> agentActionListener = createAgentActionListener(listener, outputs, modelTensors);
executeAgent(inputDataSet, mlAgent, agentActionListener);
}
} catch (Exception ex) {
listener.onFailure(ex);
Hailong-am marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
listener.onFailure(new ResourceNotFoundException("Agent not found"));
Expand All @@ -169,6 +190,24 @@ public void execute(Input input, ActionListener<Output> listener) {

}

/**
* Get question from existing interaction
* @param memoryId conversation id
* @param interactionId interaction id
* @param inputDataSet input parameters
*/
private void getQuestionFromInteraction(String memoryId, String interactionId, RemoteInferenceInputDataSet inputDataSet) {
PlainActionFuture<GetInteractionResponse> future = PlainActionFuture.newFuture();
client.execute(GetInteractionAction.INSTANCE, new GetInteractionRequest(memoryId, interactionId), future);
try {
GetInteractionResponse interactionResponse = future.get();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should have any blocking invocation in an opensearch thread. Can you change to non-blocking way like listener?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, changed it to non-blocking way.

inputDataSet.getParameters().computeIfAbsent(QUESTION, key -> interactionResponse.getInteraction().getInput());
} catch (Exception ex) {
log.error("Can't get regenerate interaction {}", interactionId, ex);
throw new MLException(ex);
}
}

private void executeAgent(RemoteInferenceInputDataSet inputDataSet, MLAgent mlAgent, ActionListener<Object> agentActionListener) {
if ("flow".equals(mlAgent.getType())) {
MLFlowAgentRunner flowAgentExecutor = new MLFlowAgentRunner(
Expand Down Expand Up @@ -208,7 +247,7 @@ private ActionListener<Object> createAgentActionListener(
List<ModelTensors> outputs,
List<ModelTensor> modelTensors
) {
return ActionListener.wrap(output -> {
return wrap(output -> {
if (output != null) {
Gson gson = new Gson();
if (output instanceof ModelTensorOutput) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION;
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.REGENERATE_INTERACTION_ID;

import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
Expand Down Expand Up @@ -67,6 +69,7 @@
public class MLChatAgentRunner {

public static final String SESSION_ID = "session_id";
public static final String MEMORY_ID = "memory_id";
public static final String PROMPT_PREFIX = "prompt_prefix";
public static final String LLM_TOOL_PROMPT_PREFIX = "LanguageModelTool.prompt_prefix";
public static final String LLM_TOOL_PROMPT_SUFFIX = "LanguageModelTool.prompt_suffix";
Expand Down Expand Up @@ -109,8 +112,9 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
List<MLToolSpec> toolSpecs = mlAgent.getTools();
String memoryType = mlAgent.getMemory().getType();
String memoryId = params.get(MLAgentExecutor.MEMORY_ID);
String regenerateInteractionId = params.get(REGENERATE_INTERACTION_ID);
String appType = mlAgent.getAppType();
String title = params.get(MLAgentExecutor.QUESTION);
String title = params.get(QUESTION);

ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType);
conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.<ConversationIndexMemory>wrap(memory -> {
Expand All @@ -119,6 +123,10 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
Iterator<Interaction> iterator = r.iterator();
while (iterator.hasNext()) {
Interaction next = iterator.next();
if (next.getId().equals(regenerateInteractionId)) {
// there may have other new interactions happened after this interaction
Hailong-am marked this conversation as resolved.
Show resolved Hide resolved
continue;
}
String question = next.getInput();
String response = next.getResponse();
// As we store the conversation with empty response first and then update when have final answer,
Expand Down Expand Up @@ -146,7 +154,14 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
params.put(CHAT_HISTORY, chatHistoryBuilder.toString());
}

runAgent(mlAgent, params, listener, toolSpecs, memory, memory.getConversationId());
ActionListener<Object> finalListener = regenerateInteractionId != null ? ActionListener.runBefore(listener, () -> {
memory.getMemoryManager().deleteInteraction(regenerateInteractionId, ActionListener.wrap(deleted -> {}, e -> {
log.error("Failed to regenerate for interaction {}", regenerateInteractionId, e);
listener.onFailure(e);
}));
}) : listener;

runAgent(mlAgent, params, finalListener, toolSpecs, memory, memory.getConversationId());
}, e -> {
log.error("Failed to get chat history", e);
listener.onFailure(e);
Expand Down Expand Up @@ -203,7 +218,7 @@ private void runReAct(
String sessionId,
ActionListener<Object> listener
) {
String question = parameters.get(MLAgentExecutor.QUESTION);
String question = parameters.get(QUESTION);
String parentInteractionId = parameters.get(MLAgentExecutor.PARENT_INTERACTION_ID);
boolean verbose = parameters.containsKey("verbose") ? Boolean.parseBoolean(parameters.get("verbose")) : false;
Map<String, String> tmpParameters = new HashMap<>();
Expand Down Expand Up @@ -506,7 +521,7 @@ private void runReAct(
llmToolTmpParameters.putAll(tmpParameters);
llmToolTmpParameters.putAll(toolSpecMap.get(action).getParameters());
// TODO: support tool parameter override : langauge_model_tool.prompt
llmToolTmpParameters.put(MLAgentExecutor.QUESTION, actionInput);
llmToolTmpParameters.put(QUESTION, actionInput);
tools.get(action).run(llmToolTmpParameters, nextStepListener); // run tool
} else {
tools.get(action).run(toolParams, nextStepListener); // run tool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.ExistsQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.DeleteByQueryAction;
import org.opensearch.index.reindex.DeleteByQueryRequest;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.common.conversation.Interaction;
Expand Down Expand Up @@ -219,22 +223,10 @@ public void getTraces(String parentInteractionId, ActionListener<List<Interactio
@VisibleForTesting
void innerGetTraces(String parentInteractionId, ActionListener<List<Interaction>> listener) {
SearchRequest searchRequest = Requests.searchRequest(indexName);

// Build the query
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();

// Add the ExistsQueryBuilder for checking null values
ExistsQueryBuilder existsQueryBuilder = QueryBuilders.existsQuery(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD);
boolQueryBuilder.must(existsQueryBuilder);

// Add the TermQueryBuilder for another field
TermQueryBuilder termQueryBuilder = QueryBuilders
.termQuery(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parentInteractionId);
boolQueryBuilder.must(termQueryBuilder);

QueryBuilder traceQueryBuilder = buildTraceQueryBuilder(parentInteractionId);
// Set the query to the search source
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(boolQueryBuilder);
searchSourceBuilder.query(traceQueryBuilder);
searchRequest.source(searchSourceBuilder);

searchRequest.source().sort(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, SortOrder.ASC);
Expand All @@ -258,6 +250,21 @@ void innerGetTraces(String parentInteractionId, ActionListener<List<Interaction>
}
}

private QueryBuilder buildTraceQueryBuilder(String parentInteractionId) {
// Build the query
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();

// Add the ExistsQueryBuilder for checking null values
ExistsQueryBuilder existsQueryBuilder = QueryBuilders.existsQuery(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD);
boolQueryBuilder.must(existsQueryBuilder);

// Add the TermQueryBuilder for another field
TermQueryBuilder termQueryBuilder = QueryBuilders
.termQuery(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parentInteractionId);
boolQueryBuilder.must(termQueryBuilder);
return boolQueryBuilder;
}

/**
* Get the interactions associate with this conversation, sorted by recency
* @param interactionId the parent interaction id whose traces to get
Expand Down Expand Up @@ -288,4 +295,47 @@ public void updateInteraction(String interactionId, Map<String, Object> updateCo
actionListener.onFailure(e);
}
}

/**
* Delete interaction with its trace data
* @param interactionId interaction id
* @param listener callback for delete result
*/
public void deleteInteraction(String interactionId, ActionListener<Boolean> listener) {
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
// interaction itself
boolQueryBuilder.should(QueryBuilders.idsQuery().addIds(interactionId));
// interaction trace
boolQueryBuilder.should(buildTraceQueryBuilder(interactionId));

DeleteByQueryRequest deleteByQueryRequest = new DeleteByQueryRequest(indexName);
deleteByQueryRequest.setQuery(boolQueryBuilder);
deleteByQueryRequest.setRefresh(true);

innerDeleteInteraction(deleteByQueryRequest, interactionId, listener);
}

@VisibleForTesting
void innerDeleteInteraction(DeleteByQueryRequest deleteByQueryRequest, String interactionId, ActionListener<Boolean> listener) {
try (ThreadContext.StoredContext ignored = client.threadPool().getThreadContext().stashContext()) {
ActionListener<BulkByScrollResponse> al = ActionListener.wrap(bulkResponse -> {
if (bulkResponse != null && (!bulkResponse.getBulkFailures().isEmpty() || !bulkResponse.getSearchFailures().isEmpty())) {
log.info("Failed to delete the interaction with ID: {}", interactionId);
listener.onResponse(false);
return;
}
log.info("Successfully delete the interaction with ID: {}", interactionId);
listener.onResponse(true);
}, exception -> {
log.error("Failed to delete interaction with ID {}. Details: {}", interactionId, exception);
listener.onFailure(exception);
});
// bulk delete interaction and its trace
client.execute(DeleteByQueryAction.INSTANCE, deleteByQueryRequest, al);
} catch (Exception e) {
log.error("Failed to delete interaction with ID {}. Details {}:", interactionId, e);
listener.onFailure(e);
}
}

}
Loading
Loading