Skip to content

Commit

Permalink
Protobuf support for node-to-node communication
Browse files Browse the repository at this point in the history
Signed-off-by: Vacha Shah <[email protected]>
  • Loading branch information
VachaShah committed Jan 17, 2024
1 parent e8bfd09 commit 5c92d5c
Show file tree
Hide file tree
Showing 21 changed files with 566 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ public abstract class TransportMessage implements Writeable, ProtobufWriteable {

private TransportAddress remoteAddress;

private boolean isProtobuf;

public void remoteAddress(TransportAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
Expand All @@ -54,6 +56,10 @@ public TransportAddress remoteAddress() {
return remoteAddress;
}

public boolean isMessageProtobuf() {
return isProtobuf;
}

/**
* Constructs a new empty transport message
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ final class Netty4MessageChannelHandler extends ChannelDuplexHandler {
threadPool::relativeTimeInMillis,
transport.getInflightBreaker(),
requestHandlers::getHandler,
transport::inboundMessage
transport::inboundMessage,
transport::inboundMessageProtobuf
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ public TcpReadWriteHandler(NioTcpChannel channel, PageCacheRecycler recycler, Tc
threadPool::relativeTimeInMillis,
breaker,
requestHandlers::getHandler,
transport::inboundMessage
transport::inboundMessage,
transport::inboundMessageProtobuf
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ public static BiFunction<Transport.Connection, SearchActionListener, ActionListe
@Override
public void onResponse(SearchPhaseResult response) {
if (response instanceof QueryFetchSearchResult) {
response.queryResult().getShardSearchRequest().setOutboundNetworkTime(0);
response.queryResult().getShardSearchRequest().setInboundNetworkTime(0);
if (response.queryResult().getShardSearchRequest() != null) {
response.queryResult().getShardSearchRequest().setOutboundNetworkTime(0);
response.queryResult().getShardSearchRequest().setInboundNetworkTime(0);
}
}
QuerySearchResult queryResult = response.queryResult();
if (response.getShardSearchRequest() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,8 @@ public void sendExecuteQuery(
// we optimize this and expect a QueryFetchSearchResult if we only have a single shard in the search request
// this used to be the QUERY_AND_FETCH which doesn't exist anymore.
final boolean fetchDocuments = request.numberOfShards() == 1;
// System.setProperty("opensearch.experimental.feature.search_with_protobuf.enabled", "true");
// System.setProperty(FeatureFlags.PROTOBUF, "true");
if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) {
// System.out.println("Feature flag enabled");
ProtobufWriteable.Reader<SearchPhaseResult> reader = fetchDocuments ? QueryFetchSearchResult::new : QuerySearchResult::new;

final ActionListener handler = responseWrapper.apply(connection, listener);
Expand Down
30 changes: 29 additions & 1 deletion server/src/main/java/org/opensearch/search/SearchHit.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.ProtobufWriteable;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand All @@ -66,9 +67,12 @@
import org.opensearch.index.seqno.SequenceNumbers;
import org.opensearch.search.fetch.subphase.highlight.HighlightField;
import org.opensearch.search.lookup.SourceLookup;
import org.opensearch.server.proto.FetchSearchResultProto;
import org.opensearch.transport.RemoteClusterAware;

import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -96,7 +100,7 @@
* @opensearch.api
*/
@PublicApi(since = "1.0.0")
public final class SearchHit implements Writeable, ToXContentObject, Iterable<DocumentField> {
public final class SearchHit implements Writeable, ToXContentObject, Iterable<DocumentField>, ProtobufWriteable {

private final transient int docId;

Expand Down Expand Up @@ -137,6 +141,8 @@ public final class SearchHit implements Writeable, ToXContentObject, Iterable<Do

private Map<String, SearchHits> innerHits;

private FetchSearchResultProto.SearchHit searchHitProto;

// used only in tests
public SearchHit(int docId) {
this(docId, null, null, null);
Expand Down Expand Up @@ -224,6 +230,23 @@ public SearchHit(StreamInput in) throws IOException {
}
}

public SearchHit(byte[] in) throws IOException {
this.searchHitProto = FetchSearchResultProto.SearchHit.parseFrom(in);
docId = -1;
score = this.searchHitProto.getScore();
id = new Text(this.searchHitProto.getId());
// Support for nestedIdentity to be added in the future
nestedIdentity = null;
version = this.searchHitProto.getVersion();
seqNo = this.searchHitProto.getSeqNo();
primaryTerm = this.searchHitProto.getPrimaryTerm();
source = BytesReference.fromByteBuffer(ByteBuffer.wrap(this.searchHitProto.getSource().toByteArray()));
if (source.length() == 0) {
source = null;
}
metaFields = new HashMap<>();
}

private Map<String, DocumentField> readFields(StreamInput in) throws IOException {
Map<String, DocumentField> fields;
int size = in.readVInt();
Expand Down Expand Up @@ -306,6 +329,11 @@ public void writeTo(StreamOutput out) throws IOException {
}
}

@Override
public void writeTo(OutputStream out) throws IOException {
out.write(this.searchHitProto.toByteArray());
}

public int docId() {
return this.docId;
}
Expand Down
28 changes: 27 additions & 1 deletion server/src/main/java/org/opensearch/search/SearchHits.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.common.Nullable;
import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.core.common.io.stream.ProtobufWriteable;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand All @@ -47,6 +48,7 @@
import org.opensearch.rest.action.search.RestSearchAction;

import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
Expand All @@ -61,7 +63,7 @@
* @opensearch.api
*/
@PublicApi(since = "1.0.0")
public final class SearchHits implements Writeable, ToXContentFragment, Iterable<SearchHit> {
public final class SearchHits implements Writeable, ToXContentFragment, Iterable<SearchHit>, ProtobufWriteable {
public static SearchHits empty() {
return empty(true);
}
Expand All @@ -82,6 +84,8 @@ public static SearchHits empty(boolean withTotalHits) {
@Nullable
private final Object[] collapseValues;

private org.opensearch.server.proto.FetchSearchResultProto.SearchHits searchHitsProto;

public SearchHits(SearchHit[] hits, @Nullable TotalHits totalHits, float maxScore) {
this(hits, totalHits, maxScore, null, null, null);
}
Expand Down Expand Up @@ -124,6 +128,23 @@ public SearchHits(StreamInput in) throws IOException {
collapseValues = in.readOptionalArray(Lucene::readSortValue, Object[]::new);
}

public SearchHits(byte[] in) throws IOException {
this.searchHitsProto = org.opensearch.server.proto.FetchSearchResultProto.SearchHits.parseFrom(in);
this.hits = new SearchHit[this.searchHitsProto.getHitsCount()];
for (int i = 0; i < this.searchHitsProto.getHitsCount(); i++) {
this.hits[i] = new SearchHit(this.searchHitsProto.getHits(i).toByteArray());
}
this.totalHits = new TotalHits(
this.searchHitsProto.getTotalHits().getValue(),
Relation.valueOf(this.searchHitsProto.getTotalHits().getRelation().toString())
);
this.maxScore = this.searchHitsProto.getMaxScore();
this.collapseField = this.searchHitsProto.getCollapseField();
// Below fields are set to null currently, support to be added in the future
this.collapseValues = null;
this.sortFields = null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
final boolean hasTotalHits = totalHits != null;
Expand Down Expand Up @@ -342,4 +363,9 @@ private static Relation parseRelation(String relation) {
throw new IllegalArgumentException("invalid total hits relation: " + relation);
}
}

@Override
public void writeTo(OutputStream out) throws IOException {
out.write(searchHitsProto.toByteArray());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@

package org.opensearch.search.fetch;

import com.google.protobuf.ByteString;
import org.apache.lucene.search.TotalHits.Relation;
import org.opensearch.common.annotation.PublicApi;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.search.SearchHit;
Expand Down Expand Up @@ -76,6 +79,7 @@ public FetchSearchResult(byte[] in) throws IOException {
this.fetchSearchResultProto.getContextId().getSessionId(),
this.fetchSearchResultProto.getContextId().getId()
);
hits = new SearchHits(this.fetchSearchResultProto.getHits().toByteArray());
}

public FetchSearchResult(ShardSearchContextId id, SearchShardTarget shardTarget) {
Expand All @@ -101,6 +105,30 @@ public FetchSearchResult fetchResult() {
public void hits(SearchHits hits) {
assert assertNoSearchTarget(hits);
this.hits = hits;
if (this.fetchSearchResultProto != null) {
QuerySearchResultProto.TotalHits.Builder totalHitsBuilder = QuerySearchResultProto.TotalHits.newBuilder();
totalHitsBuilder.setValue(hits.getTotalHits().value);
totalHitsBuilder.setRelation(
hits.getTotalHits().relation == Relation.EQUAL_TO
? QuerySearchResultProto.TotalHits.Relation.EQUAL_TO
: QuerySearchResultProto.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
);
FetchSearchResultProto.SearchHits.Builder searchHitsBuilder = FetchSearchResultProto.SearchHits.newBuilder();
searchHitsBuilder.setMaxScore(hits.getMaxScore());
searchHitsBuilder.setTotalHits(totalHitsBuilder.build());
for (SearchHit hit : hits.getHits()) {
FetchSearchResultProto.SearchHit.Builder searchHitBuilder = FetchSearchResultProto.SearchHit.newBuilder();
searchHitBuilder.setIndex(hit.getIndex());
searchHitBuilder.setId(hit.getId());
searchHitBuilder.setScore(hit.getScore());
searchHitBuilder.setSeqNo(hit.getSeqNo());
searchHitBuilder.setPrimaryTerm(hit.getPrimaryTerm());
searchHitBuilder.setVersion(hit.getVersion());
searchHitBuilder.setSource(ByteString.copyFrom(BytesReference.toBytes(hit.getSourceRef())));
searchHitsBuilder.addHits(searchHitBuilder.build());
}
this.fetchSearchResultProto = this.fetchSearchResultProto.toBuilder().setHits(searchHitsBuilder.build()).build();
}
}

private boolean assertNoSearchTarget(SearchHits hits) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

package org.opensearch.search.fetch;

import org.opensearch.common.util.FeatureFlags;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.search.SearchPhaseResult;
Expand Down Expand Up @@ -117,4 +118,23 @@ public void writeTo(StreamOutput out) throws IOException {
queryResult.writeTo(out);
fetchResult.writeTo(out);
}

@Override
public boolean isMessageProtobuf() {
// System.setProperty(FeatureFlags.PROTOBUF, "true");
if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) {
return true;
}
return false;
}

public QueryFetchSearchResultProto.QueryFetchSearchResult response() {
return this.queryFetchSearchResultProto;
}

public QueryFetchSearchResult(QueryFetchSearchResultProto.QueryFetchSearchResult queryFetchSearchResult) {
this.queryFetchSearchResultProto = queryFetchSearchResult;
this.queryResult = new QuerySearchResult(queryFetchSearchResult.getQueryResult());
this.fetchResult = new FetchSearchResult(queryFetchSearchResult.getFetchResult());
}
}
Loading

0 comments on commit 5c92d5c

Please sign in to comment.