Skip to content

Commit

Permalink
add searchInteractionsITTests
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Nov 29, 2023
1 parent 95bb9f4 commit 16bd75b
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,12 @@ public void searchInteractions(String conversationId, SearchRequest request, Act
newQuery.must(originalQuery);
newQuery.must(new TermQueryBuilder(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId));
request.source().query(newQuery);
client.search(request, internalListener);
client.admin().indices().refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> {
client.search(request, internalListener);
}, e -> {
log.error("Failed to refresh interactions index during search interactions ", e);
internalListener.onFailure(e);
}));
} catch (Exception e) {
listener.onFailure(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,23 @@

import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;

import org.junit.Before;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.StepListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.test.OpenSearchIntegTestCase;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
Expand Down Expand Up @@ -348,4 +355,82 @@ public void testDeleteConversation() {
log.error(e);
}
}

public void testSearchInteractions() {
final String conversation1 = "conversation1";
final String conversation2 = "conversation2";
CountDownLatch cdl = new CountDownLatch(1);
StepListener<String> iid1 = new StepListener<>();
index.createInteraction(conversation1, "input about fish", "pt", "response about fish", "origin1", "lots of information about fish", iid1);

StepListener<String> iid2 = new StepListener<>();
iid1
.whenComplete(
r -> { index.createInteraction(conversation1, "input about squash", "pt", "response about squash", "origin1", "lots of information about squash", iid2); },
e -> {
cdl.countDown();
log.error(e);
assert (false);
}
);

StepListener<String> iid3 = new StepListener<>();
iid2
.whenComplete(
r -> { index.createInteraction(conversation2, "input about fish", "pt2", "response about fish", "origin1", "lots of information about fish", iid3); },
e -> {
cdl.countDown();
log.error(e);
assert (false);
}
);

StepListener<String> iid4 = new StepListener<>();
iid3
.whenComplete(
r -> { index.createInteraction(conversation1, "input about france", "pt", "response about france", "origin1", "lots of information about france", iid4); },
e -> {
cdl.countDown();
log.error(e);
assert (false);
}
);

StepListener<SearchResponse> searchListener = new StepListener<>();
iid4.whenComplete(
r -> {
SearchRequest request = new SearchRequest();
request.source(new SearchSourceBuilder());
request.source().query(new MatchQueryBuilder(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "fish input"));
index.searchInteractions(conversation1, request, searchListener);
}, e -> {
cdl.countDown();
log.error(e);
assert (false);
});

searchListener.whenComplete(
response -> {
cdl.countDown();
assert (response.getHits().getHits().length == 3);
// BM25 was being a little unpredictable here so I don't assert ordering
List<String> ids = new ArrayList<>(3);
for (SearchHit hit : response.getHits()) {
ids.add(hit.getId());
}
assert (ids.contains(iid1.result()));
assert (ids.contains(iid2.result()));
assert (ids.contains(iid4.result()));
}, e -> {
cdl.countDown();
log.error(e);
assert (false);
});

try {
cdl.await();
} catch (InterruptedException e) {
log.error(e);
}
}
}

0 comments on commit 16bd75b

Please sign in to comment.