Skip to content

Commit

Permalink
Merge pull request #1162 from b2ihealthcare/issue/nested-document-sco…
Browse files Browse the repository at this point in the history
…remode

fix(index): scoring issue when searching for nested analyzed fields
  • Loading branch information
cmark authored May 15, 2023
2 parents a4564b5 + f58743a commit ca689b3
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2011-2022 B2i Healthcare Pte Ltd, http://b2i.sg
* Copyright 2011-2023 B2i Healthcare Pte Ltd, http://b2i.sg
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -319,19 +319,30 @@ public int hashCode() {
}

@Doc
public static class MultipleNestedData {
public static class MultipleNestedData implements WithScore {

@ID
String id;
String field1 = "field1";
Collection<NestedData> nestedDatas = newHashSet();
float score;

@JsonCreator
public MultipleNestedData(@JsonProperty("id") String id, @JsonProperty("nestedDatas") Collection<NestedData> nestedDatas) {
this.id = id;
this.nestedDatas.addAll(nestedDatas);
}

@Override
public float getScore() {
return score;
}

@Override
public void setScore(float score) {
this.score = score;
}

@Override
public boolean equals(Object obj) {
if (this == obj) return true;
Expand All @@ -352,10 +363,18 @@ public int hashCode() {
public static class NestedData {

String field2;

@Field(aliases = @FieldAlias(name = "text", analyzer = Analyzers.TOKENIZED, type = FieldAliasType.TEXT))
String analyzedField;

public NestedData(String field2) {
this(field2, null);
}

@JsonCreator
public NestedData(@JsonProperty("field2") String field2) {
public NestedData(@JsonProperty("field2") String field2, @JsonProperty("analyzedField") String analyzedField) {
this.field2 = field2;
this.analyzedField = analyzedField;
}

@Override
Expand All @@ -364,12 +383,12 @@ public boolean equals(Object obj) {
if (obj == null) return false;
if (getClass() != obj.getClass()) return false;
NestedData other = (NestedData) obj;
return Objects.equals(field2, other.field2);
return Objects.equals(field2, other.field2) && Objects.equals(analyzedField, other.analyzedField);
}

@Override
public int hashCode() {
return Objects.hash(field2);
return Objects.hash(field2, analyzedField);
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2011-2022 B2i Healthcare Pte Ltd, http://b2i.sg
* Copyright 2011-2023 B2i Healthcare Pte Ltd, http://b2i.sg
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,21 +17,28 @@

import static com.google.common.collect.Lists.newArrayList;
import static com.google.common.collect.Sets.newTreeSet;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertArrayEquals;

import java.math.BigDecimal;
import java.util.*;
import java.util.function.Function;

import org.apache.commons.lang.RandomStringUtils;
import org.elasticsearch.common.UUIDs;
import org.junit.Test;

import com.b2international.index.Fixtures.Data;
import com.b2international.index.Fixtures.MultipleNestedData;
import com.b2international.index.Fixtures.NestedData;
import com.b2international.index.query.Expressions;
import com.b2international.index.query.Query;
import com.b2international.index.query.SortBy;
import com.b2international.index.query.SortBy.Order;
import com.google.common.collect.*;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

/**
* @since 5.4
Expand All @@ -42,7 +49,7 @@ public class SortIndexTest extends BaseIndexTest {

@Override
protected Collection<Class<?>> getTypes() {
return List.<Class<?>>of(Data.class);
return List.<Class<?>>of(Data.class, MultipleNestedData.class);
}

@Test
Expand Down Expand Up @@ -333,6 +340,26 @@ public void sortScore() throws Exception {

checkDocumentOrder(ascendingQuery, data -> data.getId(), Sets.newLinkedHashSet(Lists.reverse(orderedKeys)), String.class);
}

@Test
public void sortNestedFieldWithScores() throws Exception {
indexDocuments(
new MultipleNestedData(KEY1, List.of(new NestedData("unused", "abdominal knee pain"))),
new MultipleNestedData(KEY2, List.of(new NestedData("unused", "knee pain"))),
new MultipleNestedData(UUIDs.randomBase64UUID(), List.of(new NestedData("unused", "pain")))
);

Hits<MultipleNestedData> hits = search(Query.select(MultipleNestedData.class)
.where(Expressions.nestedMatch("nestedDatas", Expressions.matchTextAll("analyzedField.text", "knee pain")))
.build());

assertThat(hits)
.extracting(m -> m.id)
.containsOnly(KEY1, KEY2);
assertThat(hits)
.extracting(m -> m.score)
.allSatisfy(score -> assertThat(score).isGreaterThan(0.0f));
}

private <T> void checkDocumentOrder(Query<Data> query, Function<? super Data, T> hitFunction, Set<T> keySet, Class<T> clazz) {
final Hits<Data> hits = search(query);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ private void visit(NestedPredicate predicate) {
nestedQueryBuilder.visit(predicate.getExpression());
needsScoring = nestedQueryBuilder.needsScoring;
final QueryBuilder nestedQuery = nestedQueryBuilder.deque.pop();
deque.push(QueryBuilders.nestedQuery(nestedPath, nestedQuery, ScoreMode.None));
deque.push(QueryBuilders.nestedQuery(nestedPath, nestedQuery, nestedQueryBuilder.needsScoring ? ScoreMode.Max : ScoreMode.None));
}

private String toFieldPath(Predicate predicate) {
Expand Down

0 comments on commit ca689b3

Please sign in to comment.