Skip to content

Commit

Permalink
remove interaction and trace at last
Browse files Browse the repository at this point in the history
Signed-off-by: Hailong Cui <[email protected]>
  • Loading branch information
Hailong-am committed Dec 4, 2023
1 parent cb34b59 commit bbe79a7
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
String appType = mlAgent.getAppType();
String title = params.get(QUESTION);
if (null != regenerateInteractionId && memoryId == null) {
listener.onFailure(new MLValidationException("memory Id must provide for regenerate"));
listener.onFailure(new MLValidationException("memory id must provide for regenerate"));
return;
}

Expand All @@ -136,7 +136,7 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
interactionIdFound = true;
// if no new question provided, use original question
params.computeIfAbsent(QUESTION, key -> next.getInput());
// there may have other new interactions happened
// there may have other new interactions happened after this interaction
continue;
}
String question = next.getInput();
Expand Down Expand Up @@ -174,17 +174,14 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
return;
}

if (interactionIdFound) {
memory.getMemoryManager().deleteInteraction(regenerateInteractionId, ActionListener.wrap(deleteResult -> {
runAgent(mlAgent, params, listener, toolSpecs, memory, memory.getConversationId());
}, e -> {
ActionListener<Object> finalListener = interactionIdFound ? ActionListener.runBefore(listener, () -> {
memory.getMemoryManager().deleteInteraction(regenerateInteractionId, ActionListener.wrap(deleted -> {}, e -> {
log.error("Failed to regenerate for interaction {}", regenerateInteractionId, e);
listener.onFailure(e);
}));
}) : listener;

} else {
runAgent(mlAgent, params, listener, toolSpecs, memory, memory.getConversationId());
}
runAgent(mlAgent, params, finalListener, toolSpecs, memory, memory.getConversationId());
}, e -> {
log.error("Failed to get chat history", e);
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

import org.opensearch.OpenSearchSecurityException;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.update.UpdateRequest;
Expand Down Expand Up @@ -289,26 +290,45 @@ public void updateInteraction(String interactionId, Map<String, Object> updateCo
}
}

public void deleteInteraction(String interactionId, ActionListener<Boolean> actionListener) {
DeleteRequest deleteRequest = new DeleteRequest(indexName, interactionId);
/**
* Delete interaction with its trace data
* @param interactionId interaction id
* @param listener callback for delete result
*/
public void deleteInteraction(String interactionId, ActionListener<Boolean> listener) {
BulkRequest bulkRequest = new BulkRequest(indexName);
bulkRequest.add(new DeleteRequest(indexName, interactionId));

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<DeleteResponse> al = ActionListener.runBefore(ActionListener.wrap(deleteResponse -> {
if (deleteResponse != null && deleteResponse.getResult() != DocWriteResponse.Result.DELETED) {
innerGetTraces(interactionId, ActionListener.wrap(traces -> {
traces.forEach(trace-> bulkRequest.add(new DeleteRequest(indexName, trace.getId())));

innerDeleteInteraction(bulkRequest, interactionId, listener);
}, e -> {
// delete interaction only if we can't get trace
innerDeleteInteraction(bulkRequest, interactionId, listener);
}));
}

@VisibleForTesting
void innerDeleteInteraction(BulkRequest bulkRequest, String interactionId, ActionListener<Boolean> listener) {
try (ThreadContext.StoredContext ignored = client.threadPool().getThreadContext().stashContext()) {
ActionListener<BulkResponse> al = ActionListener.wrap(bulkResponse -> {
if (bulkResponse != null && bulkResponse.hasFailures()) {
log.info("Failed to delete the interaction with ID: {}", interactionId);
actionListener.onResponse(true);
listener.onResponse(false);
return;
}
log.info("Successfully delete the interaction with ID: {}", interactionId);
actionListener.onResponse(true);
listener.onResponse(true);
}, exception -> {
log.error("Failed to delete interaction with ID {}. Details: {}", interactionId, exception);
actionListener.onFailure(exception);
}), context::restore);
client.delete(deleteRequest, al);
listener.onFailure(exception);
});
// bulk delete interaction and its trace
client.bulk(bulkRequest, al);
} catch (Exception e) {
log.error("Failed to delete interaction for interaction id {}. Details {}:", interactionId, e);
actionListener.onFailure(e);
log.error("Failed to delete interaction with ID {}. Details {}:", interactionId, e);
listener.onFailure(e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ public void testRegenerateWithInvalidInput() {
ArgumentCaptor<MLValidationException> argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class);
Mockito.verify(agentActionListener).onFailure(argumentCaptor.capture());
MLValidationException ex = argumentCaptor.getValue();
Assert.assertEquals(ex.getMessage(), "memory Id must provide for regenerate");
Assert.assertEquals(ex.getMessage(), "memory id must provide for regenerate");
}

@Test
Expand Down

0 comments on commit bbe79a7

Please sign in to comment.