Skip to content

Commit

Permalink
Merging latest changes from main
Browse files Browse the repository at this point in the history
Signed-off-by: Vacha Shah <[email protected]>
  • Loading branch information
VachaShah committed Apr 9, 2024
1 parent 46eeaf6 commit ed00f89
Show file tree
Hide file tree
Showing 15 changed files with 133 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ public FetchSearchResult(StreamInput in) throws IOException {

public FetchSearchResult(InputStream in) throws IOException {
super(in);
assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled";
this.fetchSearchResultProto = FetchSearchResultProto.FetchSearchResult.parseFrom(in);
contextId = new ShardSearchContextId(
this.fetchSearchResultProto.getContextId().getSessionId(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
import org.opensearch.search.internal.ShardSearchContextId;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.server.proto.QueryFetchSearchResultProto;
import org.opensearch.transport.BaseInboundMessage;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;
import org.opensearch.transport.protobufprotocol.ProtobufInboundMessage;

import java.io.IOException;
import java.io.InputStream;
Expand All @@ -65,7 +66,6 @@ public QueryFetchSearchResult(StreamInput in) throws IOException {

public QueryFetchSearchResult(InputStream in) throws IOException {
super(in);
assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled";
this.queryFetchSearchResultProto = QueryFetchSearchResultProto.QueryFetchSearchResult.parseFrom(in);
queryResult = new QuerySearchResult(in);
fetchResult = new FetchSearchResult(in);
Expand Down Expand Up @@ -125,9 +125,9 @@ public void writeTo(StreamOutput out) throws IOException {
@Override
public String getProtocol() {
if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) {
return BaseInboundMessage.PROTOBUF_PROTOCOL;
return ProtobufInboundMessage.PROTOBUF_PROTOCOL;
}
return BaseInboundMessage.NATIVE_PROTOCOL;
return NativeInboundMessage.NATIVE_PROTOCOL;
}

public QueryFetchSearchResultProto.QueryFetchSearchResult response() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ public ShardSearchRequest(StreamInput in) throws IOException {
}

public ShardSearchRequest(byte[] in) throws IOException {
assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled";
ShardSearchRequestProto.ShardSearchRequest searchRequestProto = ShardSearchRequestProto.ShardSearchRequest.parseFrom(in);
this.clusterAlias = searchRequestProto.getClusterAlias();
shardId = new ShardId(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@
import org.opensearch.server.proto.ShardSearchRequestProto;
import org.opensearch.server.proto.ShardSearchRequestProto.AliasFilter;
import org.opensearch.server.proto.ShardSearchRequestProto.ShardSearchRequest.SearchType;
import org.opensearch.transport.BaseInboundMessage;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;
import org.opensearch.transport.protobufprotocol.ProtobufInboundMessage;

import java.io.IOException;
import java.io.InputStream;
Expand Down Expand Up @@ -124,7 +125,6 @@ public QuerySearchResult(StreamInput in) throws IOException {

public QuerySearchResult(InputStream in) throws IOException {
super(in);
assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled";
this.querySearchResultProto = QuerySearchResultProto.QuerySearchResult.parseFrom(in);
isNull = this.querySearchResultProto.getIsNull();
if (!isNull) {
Expand Down Expand Up @@ -628,8 +628,8 @@ public QuerySearchResult(QuerySearchResultProto.QuerySearchResult querySearchRes
@Override
public String getProtocol() {
if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) {
return BaseInboundMessage.PROTOBUF_PROTOCOL;
return ProtobufInboundMessage.PROTOBUF_PROTOCOL;
}
return BaseInboundMessage.NATIVE_PROTOCOL;
return NativeInboundMessage.NATIVE_PROTOCOL;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@
package org.opensearch.transport;

import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;
import org.opensearch.transport.protobufprotocol.ProtobufInboundMessage;
import org.opensearch.transport.protobufprotocol.ProtobufMessageHandler;

import java.io.IOException;
import java.util.Map;
Expand Down Expand Up @@ -80,6 +83,12 @@ public class InboundHandler {
keepAlive
)
);
if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) {
this.protocolMessageHandlers.put(
ProtobufInboundMessage.PROTOBUF_PROTOCOL,
new ProtobufMessageHandler(threadPool, responseHandlers)
);
}
}

void setMessageListener(TransportMessageListener listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@
import org.opensearch.common.bytes.ReleasableBytesReference;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.common.util.PageCacheRecycler;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.bytes.CompositeBytesReference;
import org.opensearch.transport.nativeprotocol.NativeInboundBytesHandler;
import org.opensearch.transport.protobufprotocol.ProtobufInboundBytesHandler;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.List;
Expand Down Expand Up @@ -99,6 +99,9 @@ public InboundPipeline(
this.aggregator = aggregator;
this.protocolBytesHandlers = List.of(new NativeInboundBytesHandler(pending, decoder, aggregator, statsTracker));
this.messageHandler = messageHandler;
if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) {
protocolBytesHandlers.add(new ProtobufInboundBytesHandler());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.opensearch.search.fetch.QueryFetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.protobufprotocol.ProtobufInboundMessage;

import java.io.IOException;
import java.util.Set;
Expand Down Expand Up @@ -149,13 +150,13 @@ void sendResponse(
) throws IOException {
Version version = Version.min(this.version, nodeVersion);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response));
if ((response.getProtocol()).equals(BaseInboundMessage.PROTOBUF_PROTOCOL) && version.onOrAfter(Version.V_3_0_0)) {
if ((response.getProtocol()).equals(ProtobufInboundMessage.PROTOBUF_PROTOCOL) && version.onOrAfter(Version.V_3_0_0)) {
if (response instanceof QueryFetchSearchResult) {
QueryFetchSearchResult queryFetchSearchResult = (QueryFetchSearchResult) response;
if (queryFetchSearchResult.response() != null) {
byte[] bytes = new byte[1];
bytes[0] = 1;
NodeToNodeMessage protobufMessage = new NodeToNodeMessage(
ProtobufInboundMessage protobufMessage = new ProtobufInboundMessage(
requestId,
bytes,
Version.CURRENT,
Expand All @@ -171,7 +172,7 @@ void sendResponse(
if (querySearchResult.response() != null) {
byte[] bytes = new byte[1];
bytes[0] = 1;
NodeToNodeMessage protobufMessage = new NodeToNodeMessage(
ProtobufInboundMessage protobufMessage = new ProtobufInboundMessage(
requestId,
bytes,
Version.CURRENT,
Expand Down Expand Up @@ -231,7 +232,7 @@ private void sendMessage(TcpChannel channel, OutboundMessage networkMessage, Act
internalSend(channel, sendContext);
}

private void sendProtobufMessage(TcpChannel channel, NodeToNodeMessage message, ActionListener<Void> listener) throws IOException {
private void sendProtobufMessage(TcpChannel channel, ProtobufInboundMessage message, ActionListener<Void> listener) throws IOException {
ProtobufMessageSerializer serializer = new ProtobufMessageSerializer(message, bigArrays);
SendContext sendContext = new SendContext(channel, serializer, listener, serializer);
internalSend(channel, sendContext);
Expand Down Expand Up @@ -288,11 +289,11 @@ public void close() {

private static class ProtobufMessageSerializer implements CheckedSupplier<BytesReference, IOException>, Releasable {

private final NodeToNodeMessage message;
private final ProtobufInboundMessage message;
private final BigArrays bigArrays;
private volatile ReleasableBytesStreamOutput bytesStreamOutput;

private ProtobufMessageSerializer(NodeToNodeMessage message, BigArrays bigArrays) {
private ProtobufMessageSerializer(ProtobufInboundMessage message, BigArrays bigArrays) {
this.message = message;
this.bigArrays = bigArrays;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -806,14 +806,6 @@ public static int readMessageLength(BytesReference networkBytes) throws IOExcept
}
}

public static String determineTransportProtocol(BytesReference headerBuffer) {
if (headerBuffer.get(0) == 'O' && headerBuffer.get(1) == 'S' && headerBuffer.get(2) == 'P') {
return BaseInboundMessage.PROTOBUF_PROTOCOL;
} else {
return BaseInboundMessage.NATIVE_PROTOCOL;
}
}

private static int readHeaderBuffer(BytesReference headerBuffer) throws IOException {
if (headerBuffer.get(0) != 'E' || headerBuffer.get(1) != 'S') {
if (appearsToBeHTTPRequest(headerBuffer)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.transport.protobufprotocol;

import org.opensearch.common.bytes.ReleasableBytesReference;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.transport.InboundBytesHandler;
import org.opensearch.transport.ProtocolInboundMessage;
import org.opensearch.transport.TcpChannel;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.function.BiConsumer;

/**
* Handler for inbound bytes for the protobuf protocol.
*/
public class ProtobufInboundBytesHandler implements InboundBytesHandler {

public void ProtobufInboundBytesHandler() {}

@Override
public void doHandleBytes(
TcpChannel channel,
ReleasableBytesReference reference,
BiConsumer<TcpChannel, ProtocolInboundMessage> messageHandler
) throws IOException {
// removing the first byte we added for protobuf message
byte[] incomingBytes = BytesReference.toBytes(reference.slice(3, reference.length() - 3));
ProtobufInboundMessage protobufMessage = new ProtobufInboundMessage(new ByteArrayInputStream(incomingBytes));
messageHandler.accept(channel, protobufMessage);
}

@Override
public boolean canHandleBytes(ReleasableBytesReference reference) {
if (reference.get(0) == 'O' && reference.get(1) == 'S' && reference.get(2) == 'P') {
return true;
}
return false;
}

@Override
public void close() {
// no-op
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.transport;
package org.opensearch.transport.protobufprotocol;

import com.google.protobuf.ByteString;
import org.opensearch.Version;
Expand All @@ -18,6 +18,8 @@
import org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.ResponseHandlersList;
import org.opensearch.server.proto.QueryFetchSearchResultProto.QueryFetchSearchResult;
import org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult;
import org.opensearch.transport.ProtocolInboundMessage;
import org.opensearch.transport.TcpHeader;

import java.io.IOException;
import java.io.InputStream;
Expand All @@ -34,13 +36,17 @@
*
* @opensearch.internal
*/
public class NodeToNodeMessage implements BaseInboundMessage {
public class ProtobufInboundMessage implements ProtocolInboundMessage {

/**
* The protocol used to encode this message
*/
public static String PROTOBUF_PROTOCOL = "protobuf";

private final NodeToNodeMessageProto.NodeToNodeMessage message;
private static final byte[] PREFIX = { (byte) 'E', (byte) 'S' };
private String protocol;

public NodeToNodeMessage(
public ProtobufInboundMessage(
long requestId,
byte[] status,
Version version,
Expand Down Expand Up @@ -77,7 +83,7 @@ public NodeToNodeMessage(
.build();
}

public NodeToNodeMessage(
public ProtobufInboundMessage(
long requestId,
byte[] status,
Version version,
Expand Down Expand Up @@ -114,15 +120,15 @@ public NodeToNodeMessage(
.build();
}

public NodeToNodeMessage(InputStream in) throws IOException {
public ProtobufInboundMessage(InputStream in) throws IOException {
this.message = NodeToNodeMessageProto.NodeToNodeMessage.parseFrom(in);
}

public void writeTo(OutputStream out) throws IOException {
this.message.writeTo(out);
}

BytesReference serialize(BytesStreamOutput bytesStream) throws IOException {
public BytesReference serialize(BytesStreamOutput bytesStream) throws IOException {
NodeToNodeMessageProto.NodeToNodeMessage message = getMessage();
TcpHeader.writeHeaderForProtobuf(bytesStream);
message.writeTo(bytesStream);
Expand All @@ -135,7 +141,7 @@ public NodeToNodeMessageProto.NodeToNodeMessage getMessage() {

@Override
public String toString() {
return "NodeToNodeMessage [message=" + message + "]";
return "ProtobufInboundMessage [message=" + message + "]";
}

public org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.Header getHeader() {
Expand Down Expand Up @@ -163,8 +169,4 @@ public String getProtocol() {
return PROTOBUF_PROTOCOL;
}

@Override
public void setProtocol() {
this.protocol = PROTOBUF_PROTOCOL;
}
}
Loading

0 comments on commit ed00f89

Please sign in to comment.