-
Notifications
You must be signed in to change notification settings - Fork 138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support regenerate for execute API #1709
Changes from 15 commits
4ebed68
e5db137
e33b519
2ac5d61
59d1488
35759a2
cb34b59
bbe79a7
bc35a62
fc3227e
7ce8c8b
d4b2135
cfbee71
27a6f1b
6d3492b
29364aa
b0de5a9
6a26d05
4801775
fd52347
5c19d90
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
|
||
package org.opensearch.ml.engine.algorithms.agent; | ||
|
||
import static org.opensearch.core.action.ActionListener.wrap; | ||
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; | ||
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; | ||
|
||
|
@@ -14,10 +15,12 @@ | |
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Optional; | ||
|
||
import org.opensearch.ResourceNotFoundException; | ||
import org.opensearch.action.get.GetRequest; | ||
import org.opensearch.action.get.GetResponse; | ||
import org.opensearch.action.support.PlainActionFuture; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.cluster.service.ClusterService; | ||
import org.opensearch.common.settings.Settings; | ||
|
@@ -33,6 +36,8 @@ | |
import org.opensearch.ml.common.agent.MLAgent; | ||
import org.opensearch.ml.common.agent.MLMemorySpec; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.exception.MLException; | ||
import org.opensearch.ml.common.exception.MLValidationException; | ||
import org.opensearch.ml.common.input.Input; | ||
import org.opensearch.ml.common.input.execute.agent.AgentMLInput; | ||
import org.opensearch.ml.common.output.Output; | ||
|
@@ -46,6 +51,9 @@ | |
import org.opensearch.ml.engine.memory.ConversationIndexMemory; | ||
import org.opensearch.ml.engine.memory.ConversationIndexMessage; | ||
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; | ||
import org.opensearch.ml.memory.action.conversation.GetInteractionAction; | ||
import org.opensearch.ml.memory.action.conversation.GetInteractionRequest; | ||
import org.opensearch.ml.memory.action.conversation.GetInteractionResponse; | ||
|
||
import com.google.gson.Gson; | ||
|
||
|
@@ -62,6 +70,7 @@ public class MLAgentExecutor implements Executable { | |
public static final String MEMORY_ID = "memory_id"; | ||
public static final String QUESTION = "question"; | ||
public static final String PARENT_INTERACTION_ID = "parent_interaction_id"; | ||
public static final String REGENERATE_INTERACTION_ID = "regenerate_interaction_id"; | ||
|
||
private Client client; | ||
private Settings settings; | ||
|
@@ -113,41 +122,51 @@ public void execute(Input input, ActionListener<Output> listener) { | |
MLMemorySpec memorySpec = mlAgent.getMemory(); | ||
String memoryId = inputDataSet.getParameters().get(MEMORY_ID); | ||
String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); | ||
String regenerateInteractionId = inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID); | ||
String appType = mlAgent.getAppType(); | ||
String question = inputDataSet.getParameters().get(QUESTION); | ||
|
||
if (memoryId == null && regenerateInteractionId != null) { | ||
listener.onFailure(new MLValidationException("Memory id must provide for regenerate")); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
btw, what do we mean by regenerate here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missed a return express here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, will add return here. |
||
|
||
if (memorySpec != null | ||
&& memorySpec.getType() != null | ||
&& memoryFactoryMap.containsKey(memorySpec.getType()) | ||
&& (memoryId == null || parentInteractionId == null)) { | ||
ConversationIndexMemory.Factory conversationIndexMemoryFactory = | ||
(ConversationIndexMemory.Factory) memoryFactoryMap.get(memorySpec.getType()); | ||
conversationIndexMemoryFactory.create(question, memoryId, appType, ActionListener.wrap(memory -> { | ||
conversationIndexMemoryFactory.create(question, memoryId, appType, wrap(memory -> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keep the existing style of using ActionListener.wrap()? You only changed this wrap, but there are still some ActionListener.wrap() and ActionListener.* somewhere else in this file. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is unintentional change, change it back. |
||
inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId()); | ||
ConversationIndexMemory conversationIndexMemory = (ConversationIndexMemory) memory; | ||
|
||
// get regenerate interaction question | ||
Optional.ofNullable(regenerateInteractionId).ifPresent(interactionId -> { | ||
log.info("Regenerate for existing interaction {}", regenerateInteractionId); | ||
Hailong-am marked this conversation as resolved.
Show resolved
Hide resolved
|
||
getQuestionFromInteraction(memoryId, interactionId, inputDataSet); | ||
}); | ||
|
||
// Create root interaction ID | ||
ConversationIndexMessage msg = ConversationIndexMessage | ||
.conversationIndexMessageBuilder() | ||
.type(appType) | ||
.question(question) | ||
.question(inputDataSet.getParameters().get(QUESTION)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not use question? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question will be override at line 204 for regenerate case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add comment for this specific regenerate case? It takes a while to figure out why we can't use the question in request directly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks, comments added. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I was confused too. I think you can also set question to the original question meanwhile updating inputDataSet. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. refactor to listener way and it's more clear now. |
||
.response("") | ||
.finalAnswer(true) | ||
.sessionId(memory.getConversationId()) | ||
.build(); | ||
conversationIndexMemory | ||
.save(msg, null, null, null, ActionListener.<CreateInteractionResponse>wrap(interaction -> { | ||
log.info("Created parent interaction ID: " + interaction.getId()); | ||
inputDataSet.getParameters().put(PARENT_INTERACTION_ID, interaction.getId()); | ||
ActionListener<Object> agentActionListener = createAgentActionListener( | ||
listener, | ||
outputs, | ||
modelTensors | ||
); | ||
executeAgent(inputDataSet, mlAgent, agentActionListener); | ||
}, ex -> { | ||
log.error("Failed to create parent interaction", ex); | ||
listener.onFailure(ex); | ||
})); | ||
memory.save(msg, null, null, null, ActionListener.<CreateInteractionResponse>wrap(interaction -> { | ||
log.info("Created parent interaction ID: " + interaction.getId()); | ||
inputDataSet.getParameters().put(PARENT_INTERACTION_ID, interaction.getId()); | ||
ActionListener<Object> agentActionListener = createAgentActionListener( | ||
listener, | ||
outputs, | ||
modelTensors | ||
); | ||
executeAgent(inputDataSet, mlAgent, agentActionListener); | ||
}, ex -> { | ||
log.error("Failed to create parent interaction", ex); | ||
listener.onFailure(ex); | ||
})); | ||
}, ex -> { | ||
log.error("Failed to read conversation memory", ex); | ||
listener.onFailure(ex); | ||
|
@@ -156,6 +175,8 @@ public void execute(Input input, ActionListener<Output> listener) { | |
ActionListener<Object> agentActionListener = createAgentActionListener(listener, outputs, modelTensors); | ||
executeAgent(inputDataSet, mlAgent, agentActionListener); | ||
} | ||
} catch (Exception ex) { | ||
listener.onFailure(ex); | ||
Hailong-am marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} else { | ||
listener.onFailure(new ResourceNotFoundException("Agent not found")); | ||
|
@@ -169,6 +190,24 @@ public void execute(Input input, ActionListener<Output> listener) { | |
|
||
} | ||
|
||
/** | ||
* Get question from existing interaction | ||
* @param memoryId conversation id | ||
* @param interactionId interaction id | ||
* @param inputDataSet input parameters | ||
*/ | ||
private void getQuestionFromInteraction(String memoryId, String interactionId, RemoteInferenceInputDataSet inputDataSet) { | ||
PlainActionFuture<GetInteractionResponse> future = PlainActionFuture.newFuture(); | ||
client.execute(GetInteractionAction.INSTANCE, new GetInteractionRequest(memoryId, interactionId), future); | ||
try { | ||
GetInteractionResponse interactionResponse = future.get(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we should have any blocking invocation in an opensearch thread. Can you change to non-blocking way like listener? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, changed it to non-blocking way. |
||
inputDataSet.getParameters().computeIfAbsent(QUESTION, key -> interactionResponse.getInteraction().getInput()); | ||
} catch (Exception ex) { | ||
log.error("Can't get regenerate interaction {}", interactionId, ex); | ||
throw new MLException(ex); | ||
} | ||
} | ||
|
||
private void executeAgent(RemoteInferenceInputDataSet inputDataSet, MLAgent mlAgent, ActionListener<Object> agentActionListener) { | ||
if ("flow".equals(mlAgent.getType())) { | ||
MLFlowAgentRunner flowAgentExecutor = new MLFlowAgentRunner( | ||
|
@@ -208,7 +247,7 @@ private ActionListener<Object> createAgentActionListener( | |
List<ModelTensors> outputs, | ||
List<ModelTensor> modelTensors | ||
) { | ||
return ActionListener.wrap(output -> { | ||
return wrap(output -> { | ||
if (output != null) { | ||
Gson gson = new Gson(); | ||
if (output instanceof ModelTensorOutput) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if both are null? Do we still want to continue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if
regenerateInteractionId
is null, it's a normal flow, memory will automatically create if needed