Skip to content

Commit

Permalink
Fix search_after field values (#2679)
Browse files Browse the repository at this point in the history
Closes #2678
  • Loading branch information
sothawo authored Aug 28, 2023
1 parent 922c7dd commit 9adc4d2
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@
import static org.springframework.data.elasticsearch.client.elc.TypeUtils.*;
import static org.springframework.util.CollectionUtils.*;

import co.elastic.clients.elasticsearch._types.*;
import co.elastic.clients.elasticsearch._types.Conflicts;
import co.elastic.clients.elasticsearch._types.ExpandWildcard;
import co.elastic.clients.elasticsearch._types.InlineScript;
import co.elastic.clients.elasticsearch._types.NestedSortValue;
import co.elastic.clients.elasticsearch._types.OpType;
import co.elastic.clients.elasticsearch._types.SortOptions;
import co.elastic.clients.elasticsearch._types.SortOrder;
import co.elastic.clients.elasticsearch._types.VersionType;
import co.elastic.clients.elasticsearch._types.WaitForActiveShardOptions;
import co.elastic.clients.elasticsearch._types.mapping.FieldType;
import co.elastic.clients.elasticsearch._types.mapping.RuntimeField;
import co.elastic.clients.elasticsearch._types.mapping.RuntimeFieldType;
Expand Down Expand Up @@ -81,7 +89,6 @@
import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentProperty;
import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates;
import org.springframework.data.elasticsearch.core.query.*;
import org.springframework.data.elasticsearch.core.query.IndicesOptions;
import org.springframework.data.elasticsearch.core.reindex.ReindexRequest;
import org.springframework.data.elasticsearch.core.reindex.Remote;
import org.springframework.data.elasticsearch.core.script.Script;
Expand Down Expand Up @@ -1226,8 +1233,7 @@ public MsearchRequest searchMsearchRequest(
}

if (!isEmpty(query.getSearchAfter())) {
bb.searchAfter(query.getSearchAfter().stream().map(it -> FieldValue.of(it.toString()))
.collect(Collectors.toList()));
bb.searchAfter(query.getSearchAfter().stream().map(TypeUtils::toFieldValue).toList());
}

query.getRescorerQueries().forEach(rescorerQuery -> bb.rescore(getRescore(rescorerQuery)));
Expand Down Expand Up @@ -1391,8 +1397,7 @@ private <T> void prepareSearchRequest(Query query, @Nullable String routing, @Nu
}

if (!isEmpty(query.getSearchAfter())) {
builder.searchAfter(
query.getSearchAfter().stream().map(it -> FieldValue.of(it.toString())).collect(Collectors.toList()));
builder.searchAfter(query.getSearchAfter().stream().map(TypeUtils::toFieldValue).toList());
}

query.getRescorerQueries().forEach(rescorerQuery -> builder.rescore(getRescore(rescorerQuery)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,40 @@ static Object toObject(@Nullable FieldValue fieldValue) {
}
}

@Nullable
static FieldValue toFieldValue(@Nullable Object fieldValue) {

if (fieldValue == null) {
return FieldValue.NULL;
}

if (fieldValue instanceof Boolean b) {
return b ? FieldValue.TRUE : FieldValue.FALSE;
}

if (fieldValue instanceof String s) {
return FieldValue.of(s);
}

if (fieldValue instanceof Long l) {
return FieldValue.of(l);
}

if (fieldValue instanceof Integer i) {
return FieldValue.of((long) i);
}

if (fieldValue instanceof Double d) {
return FieldValue.of(d);
}

if (fieldValue instanceof Float f) {
return FieldValue.of((double) f);
}

return FieldValue.of(JsonData.of(fieldValue));
}

@Nullable
static GeoDistanceType geoDistanceType(GeoDistanceOrder.DistanceType distanceType) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public void before() {
@Test
@Order(java.lang.Integer.MAX_VALUE)
void cleanup() {
operations.indexOps(IndexCoordinates.of(indexNameProvider.getPrefix() + "*")).delete();
operations.indexOps(IndexCoordinates.of(indexNameProvider.getPrefix() + '*')).delete();
}

@Test // #1143
Expand All @@ -85,11 +85,11 @@ void shouldReadPagesWithSearchAfter() {
query.setSearchAfter(searchAfter);
SearchHits<Entity> searchHits = operations.search(query, Entity.class);

if (searchHits.getSearchHits().size() == 0) {
if (searchHits.getSearchHits().isEmpty()) {
break;
}
foundEntities.addAll(searchHits.stream().map(SearchHit::getContent).collect(Collectors.toList()));
searchAfter = searchHits.getSearchHit((int) (searchHits.getSearchHits().size() - 1)).getSortValues();
foundEntities.addAll(searchHits.stream().map(SearchHit::getContent).toList());
searchAfter = searchHits.getSearchHit(searchHits.getSearchHits().size() - 1).getSortValues();

if (++loop > 10) {
fail("loop not terminating");
Expand All @@ -99,16 +99,69 @@ void shouldReadPagesWithSearchAfter() {
assertThat(foundEntities).containsExactlyElementsOf(entities);
}

@Test // #2678
@DisplayName("should be able to handle different search after type values including null")
void shouldBeAbleToHandleDifferentSearchAfterTypeValuesIncludingNull() {

List<Entity> entities = IntStream.rangeClosed(1, 10)
.mapToObj(i -> {
var message = (i % 2 == 0) ? null : "message " + i;
var value = (i % 3 == 0) ? null : (long) i;
return new Entity((long) i, message, value);
})
.collect(Collectors.toList());
operations.save(entities);

Query query = Query.findAll();
query.setPageable(PageRequest.of(0, 3));
query.addSort(Sort.by(Sort.Direction.ASC, "id"));
query.addSort(Sort.by(Sort.Direction.ASC, "keyword"));
query.addSort(Sort.by(Sort.Direction.ASC, "value"));

List<Object> searchAfter = null;
List<Entity> foundEntities = new ArrayList<>();

int loop = 0;
do {
query.setSearchAfter(searchAfter);
SearchHits<Entity> searchHits = operations.search(query, Entity.class);

if (searchHits.getSearchHits().isEmpty()) {
break;
}
foundEntities.addAll(searchHits.stream().map(SearchHit::getContent).toList());
searchAfter = searchHits.getSearchHit(searchHits.getSearchHits().size() - 1).getSortValues();

if (++loop > 10) {
fail("loop not terminating");
}
} while (true);

assertThat(foundEntities).containsExactlyElementsOf(entities);
}

@SuppressWarnings("unused")
@Document(indexName = "#{@indexNameProvider.indexName()}")
private static class Entity {
@Nullable
@Id private Long id;
@Nullable
@Field(type = FieldType.Text) private String message;
@Field(type = FieldType.Keyword) private String keyword;

@Nullable
@Field(type = FieldType.Long) private Long value;

public Entity() {}

public Entity(@Nullable Long id, @Nullable String message) {
public Entity(@Nullable Long id, @Nullable String keyword) {
this.id = id;
this.message = message;
this.keyword = keyword;
}

public Entity(@Nullable Long id, @Nullable String keyword, @Nullable Long value) {
this.id = id;
this.keyword = keyword;
this.value = value;
}

@Nullable
Expand All @@ -121,30 +174,44 @@ public void setId(@Nullable Long id) {
}

@Nullable
public String getMessage() {
return message;
public String getKeyword() {
return keyword;
}

public void setKeyword(@Nullable String keyword) {
this.keyword = keyword;
}

@Nullable
public Long getValue() {
return value;
}

public void setMessage(@Nullable String message) {
this.message = message;
public void setValue(@Nullable Long value) {
this.value = value;
}

@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (!(o instanceof Entity entity))
if (o == null || getClass() != o.getClass())
return false;

Entity entity = (Entity) o;

if (!Objects.equals(id, entity.id))
return false;
return Objects.equals(message, entity.message);
if (!Objects.equals(keyword, entity.keyword))
return false;
return Objects.equals(value, entity.value);
}

@Override
public int hashCode() {
int result = id != null ? id.hashCode() : 0;
result = 31 * result + (message != null ? message.hashCode() : 0);
result = 31 * result + (keyword != null ? keyword.hashCode() : 0);
result = 31 * result + (value != null ? value.hashCode() : 0);
return result;
}
}
Expand Down

0 comments on commit 9adc4d2

Please sign in to comment.