From 16bd75b14c5f81c4032da5aab197624f6dffcbfc Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 5 Oct 2023 09:39:26 -0700 Subject: [PATCH] add searchInteractionsITTests Signed-off-by: HenryL27 --- .../ml/memory/index/InteractionsIndex.java | 7 +- .../index/InteractionsIndexITTests.java | 85 +++++++++++++++++++ 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index 8d492c1207..1f273678f7 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -360,7 +360,12 @@ public void searchInteractions(String conversationId, SearchRequest request, Act newQuery.must(originalQuery); newQuery.must(new TermQueryBuilder(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId)); request.source().query(newQuery); - client.search(request, internalListener); + client.admin().indices().refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { + client.search(request, internalListener); + }, e -> { + log.error("Failed to refresh interactions index during search interactions ", e); + internalListener.onFailure(e); + })); } catch (Exception e) { listener.onFailure(e); } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java index c23177bc2f..0bbeda03d1 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java @@ -19,16 +19,23 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.CountDownLatch; import org.junit.Before; import org.opensearch.action.LatchedActionListener; import org.opensearch.action.StepListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchIntegTestCase; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; @@ -348,4 +355,82 @@ public void testDeleteConversation() { log.error(e); } } + + public void testSearchInteractions() { + final String conversation1 = "conversation1"; + final String conversation2 = "conversation2"; + CountDownLatch cdl = new CountDownLatch(1); + StepListener iid1 = new StepListener<>(); + index.createInteraction(conversation1, "input about fish", "pt", "response about fish", "origin1", "lots of information about fish", iid1); + + StepListener iid2 = new StepListener<>(); + iid1 + .whenComplete( + r -> { index.createInteraction(conversation1, "input about squash", "pt", "response about squash", "origin1", "lots of information about squash", iid2); }, + e -> { + cdl.countDown(); + log.error(e); + assert (false); + } + ); + + StepListener iid3 = new StepListener<>(); + iid2 + .whenComplete( + r -> { index.createInteraction(conversation2, "input about fish", "pt2", "response about fish", "origin1", "lots of information about fish", iid3); }, + e -> { + cdl.countDown(); + log.error(e); + assert (false); + } + ); + + StepListener iid4 = new StepListener<>(); + iid3 + .whenComplete( + r -> { index.createInteraction(conversation1, "input about france", "pt", "response about france", "origin1", "lots of information about france", iid4); }, + e -> { + cdl.countDown(); + log.error(e); + assert (false); + } + ); + + StepListener searchListener = new StepListener<>(); + iid4.whenComplete( + r -> { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.source().query(new MatchQueryBuilder(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "fish input")); + index.searchInteractions(conversation1, request, searchListener); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + searchListener.whenComplete( + response -> { + cdl.countDown(); + assert (response.getHits().getHits().length == 3); + // BM25 was being a little unpredictable here so I don't assert ordering + List ids = new ArrayList<>(3); + for (SearchHit hit : response.getHits()) { + ids.add(hit.getId()); + } + assert (ids.contains(iid1.result())); + assert (ids.contains(iid2.result())); + assert (ids.contains(iid4.result())); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } }