Skip to content

Commit

Permalink
Fix test failure after lucene version upgraded to 10
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Jan 23, 2025
1 parent 686350b commit 88b0e08
Show file tree
Hide file tree
Showing 16 changed files with 32 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ public void searchTask() {

verify(client).execute(eq(MLTaskSearchAction.INSTANCE), isA(SearchRequest.class), any());
verify(searchTaskActionListener).onResponse(argumentCaptor.capture());
assertEquals(1, argumentCaptor.getValue().getHits().getTotalHits().value);
assertEquals(1, argumentCaptor.getValue().getHits().getTotalHits().value());
Map<String, Object> source = argumentCaptor.getValue().getHits().getAt(0).getSourceAsMap();
assertEquals(taskId, source.get(MLTask.TASK_ID_FIELD));
assertEquals(modelId, source.get(MLTask.MODEL_ID_FIELD));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
context = client.threadPool().getThreadContext().stashContext();
ThreadContext.StoredContext finalContext = context;
client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) {
hitStopWords.set(true);
}
}, e -> {
Expand All @@ -244,7 +244,7 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
}), latch), () -> finalContext.restore()));
} else {
client.search(searchRequest, new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) {
hitStopWords.set(true);
}
}, e -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public void parseSearchQueryInput(MLInputDataset mlInputDataset, ActionListener<
searchRequest.indices(indices);

client.search(searchRequest, ActionListener.wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) {
listener.onFailure(new IllegalArgumentException("No document found"));
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public void onResponse(SearchResponse searchResponse) {
SearchHits hits = searchResponse.getHits();
StringBuilder visBuilder = new StringBuilder();
visBuilder.append("Title,Id\n");
if (hits.getTotalHits().value > 0) {
if (hits.getTotalHits().value() > 0) {
Arrays.stream(hits.getHits()).forEach(h -> {
String id = trimIdPrefix(h.getId());
Map<String, String> visMap = (Map<String, String>) h.getSourceAsMap().get(SAVED_OBJECT_TYPE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ private void updateModelGroup(
ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
&& modelGroups.getHits().getTotalHits().value() != 0) {
for (SearchHit documentFields : modelGroups.getHits()) {
String id = documentFields.getId();
listener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegi
ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
&& modelGroups.getHits().getTotalHits().value() != 0) {
String modelGroupIdOfTheNameProvided = modelGroups.getHits().getAt(0).getId();
registerModelInput.setModelGroupId(modelGroupIdOfTheNameProvided);
checkUserAccess(registerModelInput, listener, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUnde
searchHiddenModels(modelIds, ActionListener.wrap(hiddenModels -> {
if (hiddenModels != null
&& hiddenModels.getHits().getTotalHits() != null
&& hiddenModels.getHits().getTotalHits().value != 0
&& hiddenModels.getHits().getTotalHits().value() != 0
&& !isSuperAdminUserWrapper(clusterService, client)) {
List<String> hiddenModelIds = Arrays
.stream(hiddenModels.getHits().getHits())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegi
mlModelGroupManager.validateUniqueModelGroupName(mlUploadInput.getName(), null, ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
&& modelGroups.getHits().getTotalHits().value() != 0) {
String modelGroupIdOfTheNameProvided = modelGroups.getHits().getAt(0).getId();
mlUploadInput.setModelGroupId(modelGroupIdOfTheNameProvided);
checkUserAccess(mlUploadInput, listener, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ private void triggerAutoDeployModels(List<String> addedNodes) {
private void triggerUndeployModelsOnDataNodes(List<String> dataNodeIds) {
List<String> modelIds = new ArrayList<>();
ActionListener<SearchResponse> listener = ActionListener.wrap(res -> {
if (res != null && res.getHits() != null && res.getHits().getTotalHits() != null && res.getHits().getTotalHits().value > 0) {
if (res != null && res.getHits() != null && res.getHits().getTotalHits() != null && res.getHits().getTotalHits().value() > 0) {
Arrays.stream(res.getHits().getHits()).forEach(x -> modelIds.add(x.getId()));
if (!modelIds.isEmpty()) {
ActionListener<MLUndeployModelNodesResponse> undeployModelListener = ActionListener.wrap(r -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str
validateUniqueModelGroupName(input.getName(), input.getTenantId(), ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
&& modelGroups.getHits().getTotalHits().value() != 0) {
for (SearchHit documentFields : modelGroups.getHits()) {
String id = documentFields.getId();
wrappedListener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public void getNumberOfDocumentsInIndex(
searchRequest.source(builder).indices(indexName);

client.search(searchRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
long count = r.getHits().getTotalHits().value;
long count = r.getHits().getTotalHits().value();
listener.onResponse(count);
}, e -> { listener.onFailure(e); }), () -> context.restore()));
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ protected void testTextEmbeddingModel(Set<String> modelWorkerNodes) throws Inter
SearchResponse response = searchModelChunks(modelId.get());
AtomicBoolean modelChunksReady = new AtomicBoolean(false);
if (response != null) {
long totalHits = response.getHits().getTotalHits().value;
long totalHits = response.getHits().getTotalHits().value();
if (totalHits == 9) {
modelChunksReady.set(true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public void test_empty_body_search() {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchRequest.source(searchSourceBuilder);
SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet();
assertEquals(1, response.getHits().getTotalHits().value);
assertEquals(1, response.getHits().getTotalHits().value());
assertEquals(modelGroupId, response.getHits().getHits()[0].getId());
}

Expand All @@ -62,7 +62,7 @@ public void test_matchAll_search() {
searchRequest.source(searchSourceBuilder);
searchRequest.source().query(QueryBuilders.matchAllQuery());
SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet();
assertEquals(1, response.getHits().getTotalHits().value);
assertEquals(1, response.getHits().getTotalHits().value());
assertEquals(modelGroupId, response.getHits().getHits()[0].getId());
}

Expand All @@ -72,7 +72,7 @@ public void test_bool_search() {
searchRequest.source(searchSourceBuilder);
searchRequest.source().query(QueryBuilders.boolQuery().must(QueryBuilders.termQuery("name.keyword", "mock_model_group_name")));
SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet();
assertEquals(1, response.getHits().getTotalHits().value);
assertEquals(1, response.getHits().getTotalHits().value());
assertEquals(modelGroupId, response.getHits().getHits()[0].getId());
}

Expand All @@ -82,7 +82,7 @@ public void test_term_search() {
searchRequest.source(searchSourceBuilder);
searchRequest.source().query(QueryBuilders.termQuery("name.keyword", "mock_model_group_name"));
SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet();
assertEquals(1, response.getHits().getTotalHits().value);
assertEquals(1, response.getHits().getTotalHits().value());
assertEquals(modelGroupId, response.getHits().getHits()[0].getId());
}

Expand All @@ -92,7 +92,7 @@ public void test_terms_search() {
searchRequest.source(searchSourceBuilder);
searchRequest.source().query(QueryBuilders.termsQuery("name.keyword", "mock_model_group_name", "test_model_group_name"));
SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet();
assertEquals(1, response.getHits().getTotalHits().value);
assertEquals(1, response.getHits().getTotalHits().value());
assertEquals(modelGroupId, response.getHits().getHits()[0].getId());
}

Expand All @@ -102,7 +102,7 @@ public void test_range_search() {
searchRequest.source(searchSourceBuilder);
searchRequest.source().query(QueryBuilders.rangeQuery("created_time").gte("now-1d"));
SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet();
assertEquals(1, response.getHits().getTotalHits().value);
assertEquals(1, response.getHits().getTotalHits().value());
assertEquals(modelGroupId, response.getHits().getHits()[0].getId());
}

Expand All @@ -112,7 +112,7 @@ public void test_matchPhrase_search() {
searchRequest.source(searchSourceBuilder);
searchRequest.source().query(QueryBuilders.matchPhraseQuery("description", "desc"));
SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet();
assertEquals(1, response.getHits().getTotalHits().value);
assertEquals(1, response.getHits().getTotalHits().value());
assertEquals(modelGroupId, response.getHits().getHits()[0].getId());
}

Expand All @@ -122,7 +122,7 @@ public void test_queryString_search() {
searchRequest.source(searchSourceBuilder);
searchRequest.source().query(QueryBuilders.queryStringQuery("name: mock_model_group_*"));
SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet();
assertEquals(1, response.getHits().getTotalHits().value);
assertEquals(1, response.getHits().getTotalHits().value());
assertEquals(modelGroupId, response.getHits().getHits()[0].getId());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ private void test_empty_body_search() {
searchRequest.source(searchSourceBuilder);
searchRequest.source().query(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(CHUNK_NUMBER)));
SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet();
assertEquals(1, response.getHits().getTotalHits().value);
assertEquals(1, response.getHits().getTotalHits().value());
}

private void test_matchAll_search() {
Expand All @@ -122,7 +122,7 @@ private void test_matchAll_search() {
.source()
.query(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(CHUNK_NUMBER)).must(QueryBuilders.matchAllQuery()));
SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet();
assertEquals(1, response.getHits().getTotalHits().value);
assertEquals(1, response.getHits().getTotalHits().value());
}

private void test_bool_search() {
Expand All @@ -142,7 +142,7 @@ private void test_bool_search() {
)
);
SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet();
assertEquals(1, response.getHits().getTotalHits().value);
assertEquals(1, response.getHits().getTotalHits().value());
}

private void test_term_search() {
Expand All @@ -155,7 +155,7 @@ private void test_term_search() {
.must(QueryBuilders.termQuery("name.keyword", "msmarco-distilbert-base-tas-b-pt"));
searchRequest.source().query(boolQueryBuilder);
SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet();
assertEquals(1, response.getHits().getTotalHits().value);
assertEquals(1, response.getHits().getTotalHits().value());
}

private void test_terms_search() {
Expand All @@ -168,7 +168,7 @@ private void test_terms_search() {
.must(QueryBuilders.termsQuery("name.keyword", "msmarco-distilbert-base-tas-b-pt", "test_model_group_name"));
searchRequest.source().query(boolQueryBuilder);
SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet();
assertEquals(1, response.getHits().getTotalHits().value);
assertEquals(1, response.getHits().getTotalHits().value());
}

private void test_range_search() {
Expand All @@ -181,7 +181,7 @@ private void test_range_search() {
.must(QueryBuilders.rangeQuery("created_time").gte("now-1d"));
searchRequest.source().query(boolQueryBuilder);
SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet();
assertEquals(1, response.getHits().getTotalHits().value);
assertEquals(1, response.getHits().getTotalHits().value());
}

private void test_matchPhrase_search() {
Expand All @@ -194,7 +194,7 @@ private void test_matchPhrase_search() {
.must(QueryBuilders.matchPhraseQuery("description", "desc"));
searchRequest.source().query(boolQueryBuilder);
SearchResponse response = client().execute(MLModelSearchAction.INSTANCE, searchRequest).actionGet();
assertEquals(1, response.getHits().getTotalHits().value);
assertEquals(1, response.getHits().getTotalHits().value());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public static SearchResponse waitModelAvailable1(String taskId) throws Interrupt
SearchRequest modelSearchRequest = new SearchRequest(new String[] { ML_MODEL_INDEX }, modelSearchSourceBuilder);
SearchResponse modelSearchResponse = null;
int i = 0;
while ((modelSearchResponse == null || modelSearchResponse.getHits().getTotalHits().value == 0) && i < 500) {
while ((modelSearchResponse == null || modelSearchResponse.getHits().getTotalHits().value() == 0) && i < 500) {
try {
ActionFuture<SearchResponse> searchFuture = client().execute(SearchAction.INSTANCE, modelSearchRequest);
modelSearchResponse = searchFuture.actionGet();
Expand All @@ -159,7 +159,7 @@ public static SearchResponse waitModelAvailable1(String taskId) throws Interrupt
i++;
}
assertNotNull(modelSearchResponse);
assertTrue(modelSearchResponse.getHits().getTotalHits().value > 0);
assertTrue(modelSearchResponse.getHits().getTotalHits().value() > 0);
return modelSearchResponse;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public void testReplaceHitsWithSearchHits() throws IOException {

assertNotNull(newResponse);
assertEquals(newHits.length, newResponse.getHits().getHits().length);
assertEquals(15, newResponse.getHits().getTotalHits().value);
assertEquals(15, newResponse.getHits().getTotalHits().value());
assertEquals(TotalHits.Relation.EQUAL_TO, newResponse.getHits().getTotalHits().relation);
assertEquals(0.7f, newResponse.getHits().getMaxScore(), 0.0001f);
}
Expand All @@ -131,7 +131,7 @@ public void testReplaceHitsWithNonWriteableAggregations() {

assertNotNull(newResponse);
assertEquals(newHits.length, newResponse.getHits().getHits().length);
assertEquals(15, newResponse.getHits().getTotalHits().value);
assertEquals(15, newResponse.getHits().getTotalHits().value());
assertEquals(TotalHits.Relation.EQUAL_TO, newResponse.getHits().getTotalHits().relation);
assertEquals(0.7f, newResponse.getHits().getMaxScore(), 0.0001f);
}
Expand Down

0 comments on commit 88b0e08

Please sign in to comment.