diff --git a/commons/com.b2international.index.tests/src/com/b2international/index/Fixtures.java b/commons/com.b2international.index.tests/src/com/b2international/index/Fixtures.java index 2ba9763c165..bda47bdf643 100644 --- a/commons/com.b2international.index.tests/src/com/b2international/index/Fixtures.java +++ b/commons/com.b2international.index.tests/src/com/b2international/index/Fixtures.java @@ -1,5 +1,5 @@ /* - * Copyright 2011-2022 B2i Healthcare Pte Ltd, http://b2i.sg + * Copyright 2011-2023 B2i Healthcare Pte Ltd, http://b2i.sg * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -319,12 +319,13 @@ public int hashCode() { } @Doc - public static class MultipleNestedData { + public static class MultipleNestedData implements WithScore { @ID String id; String field1 = "field1"; Collection nestedDatas = newHashSet(); + float score; @JsonCreator public MultipleNestedData(@JsonProperty("id") String id, @JsonProperty("nestedDatas") Collection nestedDatas) { @@ -332,6 +333,16 @@ public MultipleNestedData(@JsonProperty("id") String id, @JsonProperty("nestedDa this.nestedDatas.addAll(nestedDatas); } + @Override + public float getScore() { + return score; + } + + @Override + public void setScore(float score) { + this.score = score; + } + @Override public boolean equals(Object obj) { if (this == obj) return true; @@ -352,10 +363,18 @@ public int hashCode() { public static class NestedData { String field2; + + @Field(aliases = @FieldAlias(name = "text", analyzer = Analyzers.TOKENIZED, type = FieldAliasType.TEXT)) + String analyzedField; + public NestedData(String field2) { + this(field2, null); + } + @JsonCreator - public NestedData(@JsonProperty("field2") String field2) { + public NestedData(@JsonProperty("field2") String field2, @JsonProperty("analyzedField") String analyzedField) { this.field2 = field2; + this.analyzedField = analyzedField; } @Override @@ -364,12 +383,12 @@ public boolean equals(Object obj) { if (obj == null) return false; if (getClass() != obj.getClass()) return false; NestedData other = (NestedData) obj; - return Objects.equals(field2, other.field2); + return Objects.equals(field2, other.field2) && Objects.equals(analyzedField, other.analyzedField); } @Override public int hashCode() { - return Objects.hash(field2); + return Objects.hash(field2, analyzedField); } } diff --git a/commons/com.b2international.index.tests/src/com/b2international/index/SortIndexTest.java b/commons/com.b2international.index.tests/src/com/b2international/index/SortIndexTest.java index a8aea5f0982..567b700bde6 100644 --- a/commons/com.b2international.index.tests/src/com/b2international/index/SortIndexTest.java +++ b/commons/com.b2international.index.tests/src/com/b2international/index/SortIndexTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2011-2022 B2i Healthcare Pte Ltd, http://b2i.sg + * Copyright 2011-2023 B2i Healthcare Pte Ltd, http://b2i.sg * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ import static com.google.common.collect.Lists.newArrayList; import static com.google.common.collect.Sets.newTreeSet; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertArrayEquals; import java.math.BigDecimal; @@ -24,14 +25,20 @@ import java.util.function.Function; import org.apache.commons.lang.RandomStringUtils; +import org.elasticsearch.common.UUIDs; import org.junit.Test; import com.b2international.index.Fixtures.Data; +import com.b2international.index.Fixtures.MultipleNestedData; +import com.b2international.index.Fixtures.NestedData; import com.b2international.index.query.Expressions; import com.b2international.index.query.Query; import com.b2international.index.query.SortBy; import com.b2international.index.query.SortBy.Order; -import com.google.common.collect.*; +import com.google.common.collect.FluentIterable; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; /** * @since 5.4 @@ -42,7 +49,7 @@ public class SortIndexTest extends BaseIndexTest { @Override protected Collection> getTypes() { - return List.>of(Data.class); + return List.>of(Data.class, MultipleNestedData.class); } @Test @@ -333,6 +340,26 @@ public void sortScore() throws Exception { checkDocumentOrder(ascendingQuery, data -> data.getId(), Sets.newLinkedHashSet(Lists.reverse(orderedKeys)), String.class); } + + @Test + public void sortNestedFieldWithScores() throws Exception { + indexDocuments( + new MultipleNestedData(KEY1, List.of(new NestedData("unused", "abdominal knee pain"))), + new MultipleNestedData(KEY2, List.of(new NestedData("unused", "knee pain"))), + new MultipleNestedData(UUIDs.randomBase64UUID(), List.of(new NestedData("unused", "pain"))) + ); + + Hits hits = search(Query.select(MultipleNestedData.class) + .where(Expressions.nestedMatch("nestedDatas", Expressions.matchTextAll("analyzedField.text", "knee pain"))) + .build()); + + assertThat(hits) + .extracting(m -> m.id) + .containsOnly(KEY1, KEY2); + assertThat(hits) + .extracting(m -> m.score) + .allSatisfy(score -> assertThat(score).isGreaterThan(0.0f)); + } private void checkDocumentOrder(Query query, Function hitFunction, Set keySet, Class clazz) { final Hits hits = search(query); diff --git a/commons/com.b2international.index/src/com/b2international/index/es/query/EsQueryBuilder.java b/commons/com.b2international.index/src/com/b2international/index/es/query/EsQueryBuilder.java index a05a8d4c930..0eb9be6ec7b 100644 --- a/commons/com.b2international.index/src/com/b2international/index/es/query/EsQueryBuilder.java +++ b/commons/com.b2international.index/src/com/b2international/index/es/query/EsQueryBuilder.java @@ -261,7 +261,7 @@ private void visit(NestedPredicate predicate) { nestedQueryBuilder.visit(predicate.getExpression()); needsScoring = nestedQueryBuilder.needsScoring; final QueryBuilder nestedQuery = nestedQueryBuilder.deque.pop(); - deque.push(QueryBuilders.nestedQuery(nestedPath, nestedQuery, ScoreMode.None)); + deque.push(QueryBuilders.nestedQuery(nestedPath, nestedQuery, nestedQueryBuilder.needsScoring ? ScoreMode.Max : ScoreMode.None)); } private String toFieldPath(Predicate predicate) {