Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support regenerate for execute API #1709

Closed
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
import org.opensearch.ml.memory.action.conversation.GetInteractionAction;
import org.opensearch.ml.memory.action.conversation.GetInteractionRequest;

import com.google.gson.Gson;

Expand All @@ -62,6 +64,7 @@ public class MLAgentExecutor implements Executable {
public static final String MEMORY_ID = "memory_id";
public static final String QUESTION = "question";
public static final String PARENT_INTERACTION_ID = "parent_interaction_id";
public static final String REGENERATE_INTERACTION_ID = "regenerate_interaction_id";

private Client client;
private Settings settings;
Expand Down Expand Up @@ -113,9 +116,14 @@ public void execute(Input input, ActionListener<Output> listener) {
MLMemorySpec memorySpec = mlAgent.getMemory();
String memoryId = inputDataSet.getParameters().get(MEMORY_ID);
String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID);
String regenerateInteractionId = inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID);
String appType = mlAgent.getAppType();
String question = inputDataSet.getParameters().get(QUESTION);

if (memoryId == null && regenerateInteractionId != null) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if both are null? Do we still want to continue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if regenerateInteractionId is null, it's a normal flow, memory will automatically create if needed

throw new IllegalArgumentException("A memory ID must be provided to regenerate.");
}

if (memorySpec != null
&& memorySpec.getType() != null
&& memoryFactoryMap.containsKey(memorySpec.getType())
Expand All @@ -124,30 +132,29 @@ public void execute(Input input, ActionListener<Output> listener) {
(ConversationIndexMemory.Factory) memoryFactoryMap.get(memorySpec.getType());
conversationIndexMemoryFactory.create(question, memoryId, appType, ActionListener.wrap(memory -> {
inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId());
ConversationIndexMemory conversationIndexMemory = (ConversationIndexMemory) memory;
// Create root interaction ID
ConversationIndexMessage msg = ConversationIndexMessage
.conversationIndexMessageBuilder()
.type(appType)
.question(question)
.response("")
.finalAnswer(true)
.sessionId(memory.getConversationId())
.build();
conversationIndexMemory
.save(msg, null, null, null, ActionListener.<CreateInteractionResponse>wrap(interaction -> {
log.info("Created parent interaction ID: " + interaction.getId());
inputDataSet.getParameters().put(PARENT_INTERACTION_ID, interaction.getId());
ActionListener<Object> agentActionListener = createAgentActionListener(
listener,
outputs,
modelTensors

ActionListener<Object> agentActionListener = createAgentActionListener(listener, outputs, modelTensors);

// get question for regenerate
if (regenerateInteractionId != null) {
log.info("Regenerate for existing interaction {}", regenerateInteractionId);
client
.execute(
GetInteractionAction.INSTANCE,
new GetInteractionRequest(memoryId, regenerateInteractionId),
ActionListener.wrap(interactionRes -> {
inputDataSet
.getParameters()
.computeIfAbsent(QUESTION, (key) -> interactionRes.getInteraction().getInput());
saveRootInteractionAndExecute(agentActionListener, memory, inputDataSet, mlAgent);
}, e -> {
log.error("Failed to get existing interaction for regeneration", e);
listener.onFailure(e);
})
);
executeAgent(inputDataSet, mlAgent, agentActionListener);
}, ex -> {
log.error("Failed to create parent interaction", ex);
listener.onFailure(ex);
}));
} else {
saveRootInteractionAndExecute(agentActionListener, memory, inputDataSet, mlAgent);
}
}, ex -> {
log.error("Failed to read conversation memory", ex);
listener.onFailure(ex);
Expand All @@ -156,6 +163,9 @@ public void execute(Input input, ActionListener<Output> listener) {
ActionListener<Object> agentActionListener = createAgentActionListener(listener, outputs, modelTensors);
executeAgent(inputDataSet, mlAgent, agentActionListener);
}
} catch (Exception ex) {
log.error("Failed to execute agent", ex);
listener.onFailure(ex);
}
} else {
listener.onFailure(new ResourceNotFoundException("Agent not found"));
Expand All @@ -169,6 +179,40 @@ public void execute(Input input, ActionListener<Output> listener) {

}

/**
* save root interaction and start execute the agent
* @param listener callback listener
* @param memory memory instance
* @param inputDataSet input
* @param mlAgent agent to run
*/
private void saveRootInteractionAndExecute(
ActionListener<Object> listener,
ConversationIndexMemory memory,
RemoteInferenceInputDataSet inputDataSet,
MLAgent mlAgent
) {
String appType = mlAgent.getAppType();
String question = inputDataSet.getParameters().get(QUESTION);
// Create root interaction ID
ConversationIndexMessage msg = ConversationIndexMessage
.conversationIndexMessageBuilder()
.type(appType)
.question(question)
.response("")
.finalAnswer(true)
.sessionId(memory.getConversationId())
.build();
memory.save(msg, null, null, null, ActionListener.<CreateInteractionResponse>wrap(interaction -> {
log.info("Created parent interaction ID: " + interaction.getId());
inputDataSet.getParameters().put(PARENT_INTERACTION_ID, interaction.getId());
executeAgent(inputDataSet, mlAgent, listener);
}, ex -> {
log.error("Failed to create parent interaction", ex);
listener.onFailure(ex);
}));
}

private void executeAgent(RemoteInferenceInputDataSet inputDataSet, MLAgent mlAgent, ActionListener<Object> agentActionListener) {
if ("flow".equals(mlAgent.getType())) {
MLFlowAgentRunner flowAgentExecutor = new MLFlowAgentRunner(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION;
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.REGENERATE_INTERACTION_ID;

import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
Expand Down Expand Up @@ -109,8 +111,9 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
List<MLToolSpec> toolSpecs = mlAgent.getTools();
String memoryType = mlAgent.getMemory().getType();
String memoryId = params.get(MLAgentExecutor.MEMORY_ID);
String regenerateInteractionId = params.get(REGENERATE_INTERACTION_ID);
String appType = mlAgent.getAppType();
String title = params.get(MLAgentExecutor.QUESTION);
String title = params.get(QUESTION);

ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType);
conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.<ConversationIndexMemory>wrap(memory -> {
Expand All @@ -119,6 +122,12 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
Iterator<Interaction> iterator = r.iterator();
while (iterator.hasNext()) {
Interaction next = iterator.next();
// ignore regenerate interaction question/answer as chat history context to send to LLM
if (next.getId().equals(regenerateInteractionId)) {
// there may have other new interactions happened after this interaction,
// include them as well for chat history context
continue;
}
String question = next.getInput();
String response = next.getResponse();
// As we store the conversation with empty response first and then update when have final answer,
Expand Down Expand Up @@ -146,7 +155,14 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
params.put(CHAT_HISTORY, chatHistoryBuilder.toString());
}

runAgent(mlAgent, params, listener, toolSpecs, memory, memory.getConversationId());
ActionListener<Object> finalListener = regenerateInteractionId != null ? ActionListener.runBefore(listener, () -> {
memory.getMemoryManager().deleteInteraction(regenerateInteractionId, ActionListener.wrap(deleted -> {}, e -> {
log.error("Failed to regenerate for interaction {}", regenerateInteractionId, e);
listener.onFailure(e);
}));
}) : listener;

runAgent(mlAgent, params, finalListener, toolSpecs, memory, memory.getConversationId());
}, e -> {
log.error("Failed to get chat history", e);
listener.onFailure(e);
Expand Down Expand Up @@ -203,7 +219,7 @@ private void runReAct(
String sessionId,
ActionListener<Object> listener
) {
String question = parameters.get(MLAgentExecutor.QUESTION);
String question = parameters.get(QUESTION);
String parentInteractionId = parameters.get(MLAgentExecutor.PARENT_INTERACTION_ID);
boolean verbose = parameters.containsKey("verbose") ? Boolean.parseBoolean(parameters.get("verbose")) : false;
Map<String, String> tmpParameters = new HashMap<>();
Expand Down Expand Up @@ -506,7 +522,7 @@ private void runReAct(
llmToolTmpParameters.putAll(tmpParameters);
llmToolTmpParameters.putAll(toolSpecMap.get(action).getParameters());
// TODO: support tool parameter override : langauge_model_tool.prompt
llmToolTmpParameters.put(MLAgentExecutor.QUESTION, actionInput);
llmToolTmpParameters.put(QUESTION, actionInput);
tools.get(action).run(llmToolTmpParameters, nextStepListener); // run tool
} else {
tools.get(action).run(toolParams, nextStepListener); // run tool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.ExistsQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.DeleteByQueryAction;
import org.opensearch.index.reindex.DeleteByQueryRequest;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.common.conversation.Interaction;
Expand Down Expand Up @@ -219,22 +223,10 @@ public void getTraces(String parentInteractionId, ActionListener<List<Interactio
@VisibleForTesting
void innerGetTraces(String parentInteractionId, ActionListener<List<Interaction>> listener) {
SearchRequest searchRequest = Requests.searchRequest(indexName);

// Build the query
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();

// Add the ExistsQueryBuilder for checking null values
ExistsQueryBuilder existsQueryBuilder = QueryBuilders.existsQuery(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD);
boolQueryBuilder.must(existsQueryBuilder);

// Add the TermQueryBuilder for another field
TermQueryBuilder termQueryBuilder = QueryBuilders
.termQuery(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parentInteractionId);
boolQueryBuilder.must(termQueryBuilder);

QueryBuilder traceQueryBuilder = buildTraceQueryBuilder(parentInteractionId);
// Set the query to the search source
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(boolQueryBuilder);
searchSourceBuilder.query(traceQueryBuilder);
searchRequest.source(searchSourceBuilder);

searchRequest.source().sort(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, SortOrder.ASC);
Expand All @@ -258,6 +250,21 @@ void innerGetTraces(String parentInteractionId, ActionListener<List<Interaction>
}
}

private QueryBuilder buildTraceQueryBuilder(String parentInteractionId) {
// Build the query
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();

// Add the ExistsQueryBuilder for checking null values
ExistsQueryBuilder existsQueryBuilder = QueryBuilders.existsQuery(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD);
boolQueryBuilder.must(existsQueryBuilder);

// Add the TermQueryBuilder for another field
TermQueryBuilder termQueryBuilder = QueryBuilders
.termQuery(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parentInteractionId);
boolQueryBuilder.must(termQueryBuilder);
return boolQueryBuilder;
}

/**
* Get the interactions associate with this conversation, sorted by recency
* @param interactionId the parent interaction id whose traces to get
Expand Down Expand Up @@ -288,4 +295,47 @@ public void updateInteraction(String interactionId, Map<String, Object> updateCo
actionListener.onFailure(e);
}
}

/**
* Delete interaction with its trace data
* @param interactionId interaction id
* @param listener callback for delete result
*/
public void deleteInteraction(String interactionId, ActionListener<Boolean> listener) {
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
// interaction itself
boolQueryBuilder.should(QueryBuilders.idsQuery().addIds(interactionId));
// interaction trace
boolQueryBuilder.should(buildTraceQueryBuilder(interactionId));

DeleteByQueryRequest deleteByQueryRequest = new DeleteByQueryRequest(indexName);
deleteByQueryRequest.setQuery(boolQueryBuilder);
deleteByQueryRequest.setRefresh(true);

innerDeleteInteraction(deleteByQueryRequest, interactionId, listener);
}

@VisibleForTesting
void innerDeleteInteraction(DeleteByQueryRequest deleteByQueryRequest, String interactionId, ActionListener<Boolean> listener) {
try (ThreadContext.StoredContext ignored = client.threadPool().getThreadContext().stashContext()) {
ActionListener<BulkByScrollResponse> al = ActionListener.wrap(bulkResponse -> {
if (bulkResponse != null && (!bulkResponse.getBulkFailures().isEmpty() || !bulkResponse.getSearchFailures().isEmpty())) {
log.info("Failed to delete the interaction with ID: {}", interactionId);
listener.onResponse(false);
return;
}
log.info("Successfully delete the interaction with ID: {}", interactionId);
listener.onResponse(true);
}, exception -> {
log.error("Failed to delete interaction with ID {}. Details: {}", interactionId, exception);
listener.onFailure(exception);
});
// bulk delete interaction and its trace
client.execute(DeleteByQueryAction.INSTANCE, deleteByQueryRequest, al);
} catch (Exception e) {
log.error("Failed to delete interaction with ID {}. Details {}:", interactionId, e);
listener.onFailure(e);
}
}

}
Loading