diff --git a/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java b/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java index b46b2c5ca12d8..681398d36e07b 100644 --- a/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java +++ b/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java @@ -75,7 +75,6 @@ public FetchSearchResult(StreamInput in) throws IOException { public FetchSearchResult(InputStream in) throws IOException { super(in); - assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled"; this.fetchSearchResultProto = FetchSearchResultProto.FetchSearchResult.parseFrom(in); contextId = new ShardSearchContextId( this.fetchSearchResultProto.getContextId().getSessionId(), diff --git a/server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java b/server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java index b5e0a820da9af..8531fe027abd4 100644 --- a/server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java +++ b/server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java @@ -40,7 +40,8 @@ import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.server.proto.QueryFetchSearchResultProto; -import org.opensearch.transport.BaseInboundMessage; +import org.opensearch.transport.nativeprotocol.NativeInboundMessage; +import org.opensearch.transport.protobufprotocol.ProtobufInboundMessage; import java.io.IOException; import java.io.InputStream; @@ -65,7 +66,6 @@ public QueryFetchSearchResult(StreamInput in) throws IOException { public QueryFetchSearchResult(InputStream in) throws IOException { super(in); - assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled"; this.queryFetchSearchResultProto = QueryFetchSearchResultProto.QueryFetchSearchResult.parseFrom(in); queryResult = new QuerySearchResult(in); fetchResult = new FetchSearchResult(in); @@ -125,9 +125,9 @@ public void writeTo(StreamOutput out) throws IOException { @Override public String getProtocol() { if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { - return BaseInboundMessage.PROTOBUF_PROTOCOL; + return ProtobufInboundMessage.PROTOBUF_PROTOCOL; } - return BaseInboundMessage.NATIVE_PROTOCOL; + return NativeInboundMessage.NATIVE_PROTOCOL; } public QueryFetchSearchResultProto.QueryFetchSearchResult response() { diff --git a/server/src/main/java/org/opensearch/search/internal/ShardSearchRequest.java b/server/src/main/java/org/opensearch/search/internal/ShardSearchRequest.java index dde9f7130afa5..a42224b5d94de 100644 --- a/server/src/main/java/org/opensearch/search/internal/ShardSearchRequest.java +++ b/server/src/main/java/org/opensearch/search/internal/ShardSearchRequest.java @@ -272,7 +272,6 @@ public ShardSearchRequest(StreamInput in) throws IOException { } public ShardSearchRequest(byte[] in) throws IOException { - assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled"; ShardSearchRequestProto.ShardSearchRequest searchRequestProto = ShardSearchRequestProto.ShardSearchRequest.parseFrom(in); this.clusterAlias = searchRequestProto.getClusterAlias(); shardId = new ShardId( diff --git a/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java b/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java index 7ae21e6667caf..0cc151766084c 100644 --- a/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java +++ b/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java @@ -60,7 +60,8 @@ import org.opensearch.server.proto.ShardSearchRequestProto; import org.opensearch.server.proto.ShardSearchRequestProto.AliasFilter; import org.opensearch.server.proto.ShardSearchRequestProto.ShardSearchRequest.SearchType; -import org.opensearch.transport.BaseInboundMessage; +import org.opensearch.transport.nativeprotocol.NativeInboundMessage; +import org.opensearch.transport.protobufprotocol.ProtobufInboundMessage; import java.io.IOException; import java.io.InputStream; @@ -124,7 +125,6 @@ public QuerySearchResult(StreamInput in) throws IOException { public QuerySearchResult(InputStream in) throws IOException { super(in); - assert FeatureFlags.isEnabled(FeatureFlags.PROTOBUF) : "protobuf feature flag is not enabled"; this.querySearchResultProto = QuerySearchResultProto.QuerySearchResult.parseFrom(in); isNull = this.querySearchResultProto.getIsNull(); if (!isNull) { @@ -628,8 +628,8 @@ public QuerySearchResult(QuerySearchResultProto.QuerySearchResult querySearchRes @Override public String getProtocol() { if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { - return BaseInboundMessage.PROTOBUF_PROTOCOL; + return ProtobufInboundMessage.PROTOBUF_PROTOCOL; } - return BaseInboundMessage.NATIVE_PROTOCOL; + return NativeInboundMessage.NATIVE_PROTOCOL; } } diff --git a/server/src/main/java/org/opensearch/transport/BaseInboundMessage.java b/server/src/main/java/org/opensearch/transport/BaseInboundMessage.java deleted file mode 100644 index db0ddfc3f0cd0..0000000000000 --- a/server/src/main/java/org/opensearch/transport/BaseInboundMessage.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.transport; - -import org.opensearch.common.annotation.ExperimentalApi; - -/** - * Base class for inbound data as a message. - * Different implementations are used for different protocols. - * - * @opensearch.internal - */ -@ExperimentalApi -public interface BaseInboundMessage { - - /** - * The protocol used to encode this message - */ - static String NATIVE_PROTOCOL = "native"; - static String PROTOBUF_PROTOCOL = "protobuf"; - - /** - * @return the protocol used to encode this message - */ - public String getProtocol(); - - /** - * Set the protocol used to encode this message - */ - public void setProtocol(); -} diff --git a/server/src/main/java/org/opensearch/transport/InboundHandler.java b/server/src/main/java/org/opensearch/transport/InboundHandler.java index 6492900c49a0e..d5d7f614c37e6 100644 --- a/server/src/main/java/org/opensearch/transport/InboundHandler.java +++ b/server/src/main/java/org/opensearch/transport/InboundHandler.java @@ -33,10 +33,13 @@ package org.opensearch.transport; import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.nativeprotocol.NativeInboundMessage; +import org.opensearch.transport.protobufprotocol.ProtobufInboundMessage; +import org.opensearch.transport.protobufprotocol.ProtobufMessageHandler; import java.io.IOException; import java.util.Map; @@ -80,6 +83,12 @@ public class InboundHandler { keepAlive ) ); + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { + this.protocolMessageHandlers.put( + ProtobufInboundMessage.PROTOBUF_PROTOCOL, + new ProtobufMessageHandler(threadPool, responseHandlers) + ); + } } void setMessageListener(TransportMessageListener listener) { diff --git a/server/src/main/java/org/opensearch/transport/InboundPipeline.java b/server/src/main/java/org/opensearch/transport/InboundPipeline.java index f4c410671b08a..7f23c9e3db1a9 100644 --- a/server/src/main/java/org/opensearch/transport/InboundPipeline.java +++ b/server/src/main/java/org/opensearch/transport/InboundPipeline.java @@ -36,12 +36,12 @@ import org.opensearch.common.bytes.ReleasableBytesReference; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.common.util.PageCacheRecycler; import org.opensearch.core.common.breaker.CircuitBreaker; -import org.opensearch.core.common.bytes.CompositeBytesReference; import org.opensearch.transport.nativeprotocol.NativeInboundBytesHandler; +import org.opensearch.transport.protobufprotocol.ProtobufInboundBytesHandler; -import java.io.ByteArrayInputStream; import java.io.IOException; import java.util.ArrayDeque; import java.util.List; @@ -99,6 +99,9 @@ public InboundPipeline( this.aggregator = aggregator; this.protocolBytesHandlers = List.of(new NativeInboundBytesHandler(pending, decoder, aggregator, statsTracker)); this.messageHandler = messageHandler; + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { + protocolBytesHandlers.add(new ProtobufInboundBytesHandler()); + } } @Override diff --git a/server/src/main/java/org/opensearch/transport/OutboundHandler.java b/server/src/main/java/org/opensearch/transport/OutboundHandler.java index ec3a0efd3c834..e5bb7764c70dd 100644 --- a/server/src/main/java/org/opensearch/transport/OutboundHandler.java +++ b/server/src/main/java/org/opensearch/transport/OutboundHandler.java @@ -54,6 +54,7 @@ import org.opensearch.search.fetch.QueryFetchSearchResult; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.protobufprotocol.ProtobufInboundMessage; import java.io.IOException; import java.util.Set; @@ -149,13 +150,13 @@ void sendResponse( ) throws IOException { Version version = Version.min(this.version, nodeVersion); ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response)); - if ((response.getProtocol()).equals(BaseInboundMessage.PROTOBUF_PROTOCOL) && version.onOrAfter(Version.V_3_0_0)) { + if ((response.getProtocol()).equals(ProtobufInboundMessage.PROTOBUF_PROTOCOL) && version.onOrAfter(Version.V_3_0_0)) { if (response instanceof QueryFetchSearchResult) { QueryFetchSearchResult queryFetchSearchResult = (QueryFetchSearchResult) response; if (queryFetchSearchResult.response() != null) { byte[] bytes = new byte[1]; bytes[0] = 1; - NodeToNodeMessage protobufMessage = new NodeToNodeMessage( + ProtobufInboundMessage protobufMessage = new ProtobufInboundMessage( requestId, bytes, Version.CURRENT, @@ -171,7 +172,7 @@ void sendResponse( if (querySearchResult.response() != null) { byte[] bytes = new byte[1]; bytes[0] = 1; - NodeToNodeMessage protobufMessage = new NodeToNodeMessage( + ProtobufInboundMessage protobufMessage = new ProtobufInboundMessage( requestId, bytes, Version.CURRENT, @@ -231,7 +232,7 @@ private void sendMessage(TcpChannel channel, OutboundMessage networkMessage, Act internalSend(channel, sendContext); } - private void sendProtobufMessage(TcpChannel channel, NodeToNodeMessage message, ActionListener listener) throws IOException { + private void sendProtobufMessage(TcpChannel channel, ProtobufInboundMessage message, ActionListener listener) throws IOException { ProtobufMessageSerializer serializer = new ProtobufMessageSerializer(message, bigArrays); SendContext sendContext = new SendContext(channel, serializer, listener, serializer); internalSend(channel, sendContext); @@ -288,11 +289,11 @@ public void close() { private static class ProtobufMessageSerializer implements CheckedSupplier, Releasable { - private final NodeToNodeMessage message; + private final ProtobufInboundMessage message; private final BigArrays bigArrays; private volatile ReleasableBytesStreamOutput bytesStreamOutput; - private ProtobufMessageSerializer(NodeToNodeMessage message, BigArrays bigArrays) { + private ProtobufMessageSerializer(ProtobufInboundMessage message, BigArrays bigArrays) { this.message = message; this.bigArrays = bigArrays; } diff --git a/server/src/main/java/org/opensearch/transport/TcpTransport.java b/server/src/main/java/org/opensearch/transport/TcpTransport.java index 7fb600a484035..e32bba5e836d3 100644 --- a/server/src/main/java/org/opensearch/transport/TcpTransport.java +++ b/server/src/main/java/org/opensearch/transport/TcpTransport.java @@ -806,14 +806,6 @@ public static int readMessageLength(BytesReference networkBytes) throws IOExcept } } - public static String determineTransportProtocol(BytesReference headerBuffer) { - if (headerBuffer.get(0) == 'O' && headerBuffer.get(1) == 'S' && headerBuffer.get(2) == 'P') { - return BaseInboundMessage.PROTOBUF_PROTOCOL; - } else { - return BaseInboundMessage.NATIVE_PROTOCOL; - } - } - private static int readHeaderBuffer(BytesReference headerBuffer) throws IOException { if (headerBuffer.get(0) != 'E' || headerBuffer.get(1) != 'S') { if (appearsToBeHTTPRequest(headerBuffer)) { diff --git a/server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufInboundBytesHandler.java b/server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufInboundBytesHandler.java new file mode 100644 index 0000000000000..a0f54d645a378 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufInboundBytesHandler.java @@ -0,0 +1,53 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport.protobufprotocol; + +import org.opensearch.common.bytes.ReleasableBytesReference; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.transport.InboundBytesHandler; +import org.opensearch.transport.ProtocolInboundMessage; +import org.opensearch.transport.TcpChannel; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.function.BiConsumer; + +/** + * Handler for inbound bytes for the protobuf protocol. + */ +public class ProtobufInboundBytesHandler implements InboundBytesHandler { + + public void ProtobufInboundBytesHandler() {} + + @Override + public void doHandleBytes( + TcpChannel channel, + ReleasableBytesReference reference, + BiConsumer messageHandler + ) throws IOException { + // removing the first byte we added for protobuf message + byte[] incomingBytes = BytesReference.toBytes(reference.slice(3, reference.length() - 3)); + ProtobufInboundMessage protobufMessage = new ProtobufInboundMessage(new ByteArrayInputStream(incomingBytes)); + messageHandler.accept(channel, protobufMessage); + } + + @Override + public boolean canHandleBytes(ReleasableBytesReference reference) { + if (reference.get(0) == 'O' && reference.get(1) == 'S' && reference.get(2) == 'P') { + return true; + } + return false; + } + + @Override + public void close() { + // no-op + } + +} diff --git a/server/src/main/java/org/opensearch/transport/NodeToNodeMessage.java b/server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufInboundMessage.java similarity index 90% rename from server/src/main/java/org/opensearch/transport/NodeToNodeMessage.java rename to server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufInboundMessage.java index 943007f7913c6..3f650ff61d3ab 100644 --- a/server/src/main/java/org/opensearch/transport/NodeToNodeMessage.java +++ b/server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufInboundMessage.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.transport; +package org.opensearch.transport.protobufprotocol; import com.google.protobuf.ByteString; import org.opensearch.Version; @@ -18,6 +18,8 @@ import org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.ResponseHandlersList; import org.opensearch.server.proto.QueryFetchSearchResultProto.QueryFetchSearchResult; import org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult; +import org.opensearch.transport.ProtocolInboundMessage; +import org.opensearch.transport.TcpHeader; import java.io.IOException; import java.io.InputStream; @@ -34,13 +36,17 @@ * * @opensearch.internal */ -public class NodeToNodeMessage implements BaseInboundMessage { +public class ProtobufInboundMessage implements ProtocolInboundMessage { + + /** + * The protocol used to encode this message + */ + public static String PROTOBUF_PROTOCOL = "protobuf"; private final NodeToNodeMessageProto.NodeToNodeMessage message; private static final byte[] PREFIX = { (byte) 'E', (byte) 'S' }; - private String protocol; - public NodeToNodeMessage( + public ProtobufInboundMessage( long requestId, byte[] status, Version version, @@ -77,7 +83,7 @@ public NodeToNodeMessage( .build(); } - public NodeToNodeMessage( + public ProtobufInboundMessage( long requestId, byte[] status, Version version, @@ -114,7 +120,7 @@ public NodeToNodeMessage( .build(); } - public NodeToNodeMessage(InputStream in) throws IOException { + public ProtobufInboundMessage(InputStream in) throws IOException { this.message = NodeToNodeMessageProto.NodeToNodeMessage.parseFrom(in); } @@ -122,7 +128,7 @@ public void writeTo(OutputStream out) throws IOException { this.message.writeTo(out); } - BytesReference serialize(BytesStreamOutput bytesStream) throws IOException { + public BytesReference serialize(BytesStreamOutput bytesStream) throws IOException { NodeToNodeMessageProto.NodeToNodeMessage message = getMessage(); TcpHeader.writeHeaderForProtobuf(bytesStream); message.writeTo(bytesStream); @@ -135,7 +141,7 @@ public NodeToNodeMessageProto.NodeToNodeMessage getMessage() { @Override public String toString() { - return "NodeToNodeMessage [message=" + message + "]"; + return "ProtobufInboundMessage [message=" + message + "]"; } public org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.Header getHeader() { @@ -163,8 +169,4 @@ public String getProtocol() { return PROTOBUF_PROTOCOL; } - @Override - public void setProtocol() { - this.protocol = PROTOBUF_PROTOCOL; - } } diff --git a/server/src/main/java/org/opensearch/transport/ProtobufMessageHandler.java b/server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufMessageHandler.java similarity index 83% rename from server/src/main/java/org/opensearch/transport/ProtobufMessageHandler.java rename to server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufMessageHandler.java index 3aa8be4d094e2..a945bab7a345a 100644 --- a/server/src/main/java/org/opensearch/transport/ProtobufMessageHandler.java +++ b/server/src/main/java/org/opensearch/transport/protobufprotocol/ProtobufMessageHandler.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.transport; +package org.opensearch.transport.protobufprotocol; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -19,6 +19,15 @@ import org.opensearch.search.query.QuerySearchResult; import org.opensearch.server.proto.QueryFetchSearchResultProto.QueryFetchSearchResult; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.ProtocolInboundMessage; +import org.opensearch.transport.ProtocolMessageHandler; +import org.opensearch.transport.RemoteTransportException; +import org.opensearch.transport.ResponseHandlerFailureTransportException; +import org.opensearch.transport.TcpChannel; +import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportMessageListener; +import org.opensearch.transport.TransportResponseHandler; +import org.opensearch.transport.TransportSerializationException; import java.io.IOException; import java.net.InetSocketAddress; @@ -30,7 +39,7 @@ * * @opensearch.internal */ -public class ProtobufMessageHandler { +public class ProtobufMessageHandler implements ProtocolMessageHandler { private static final Logger logger = LogManager.getLogger(ProtobufMessageHandler.class); @@ -41,7 +50,7 @@ public class ProtobufMessageHandler { private volatile long slowLogThresholdMs = Long.MAX_VALUE; - ProtobufMessageHandler(ThreadPool threadPool, Transport.ResponseHandlers responseHandlers) { + public ProtobufMessageHandler(ThreadPool threadPool, Transport.ResponseHandlers responseHandlers) { this.threadPool = threadPool; this.responseHandlers = responseHandlers; } @@ -58,16 +67,24 @@ void setSlowLogThreshold(TimeValue slowLogThreshold) { this.slowLogThresholdMs = slowLogThreshold.getMillis(); } - public void messageReceivedProtobuf(TcpChannel channel, NodeToNodeMessage message, long startTime) throws IOException { + @Override + public void messageReceived( + TcpChannel channel, + ProtocolInboundMessage message, + long startTime, + long slowLogThresholdMs, + TransportMessageListener messageListener + ) throws IOException { + ProtobufInboundMessage nodeToNodeMessage = (ProtobufInboundMessage) message; final InetSocketAddress remoteAddress = channel.getRemoteAddress(); - final org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.Header header = message.getHeader(); + final org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.Header header = nodeToNodeMessage.getHeader(); ThreadContext threadContext = threadPool.getThreadContext(); try (ThreadContext.StoredContext existing = threadContext.stashContext()) { // Place the context with the headers from the message final Tuple, Map>> headers = new Tuple, Map>>( - message.getRequestHeaders(), - message.getResponseHandlers() + nodeToNodeMessage.getRequestHeaders(), + nodeToNodeMessage.getResponseHandlers() ); threadContext.setHeaders(headers); threadContext.putTransient("_remote_address", remoteAddress); @@ -75,7 +92,7 @@ public void messageReceivedProtobuf(TcpChannel channel, NodeToNodeMessage messag long requestId = header.getRequestId(); TransportResponseHandler handler = responseHandlers.onResponseReceived(requestId, messageListener); if (handler != null) { - handleProtobufResponse(requestId, remoteAddress, message, handler); + handleProtobufResponse(requestId, remoteAddress, nodeToNodeMessage, handler); } } finally { final long took = threadPool.relativeTimeInMillis() - startTime; @@ -94,7 +111,7 @@ public void messageReceivedProtobuf(TcpChannel channel, NodeToNodeMessage messag private void handleProtobufResponse( final long requestId, InetSocketAddress remoteAddress, - final NodeToNodeMessage message, + final ProtobufInboundMessage message, final TransportResponseHandler handler ) throws IOException { try { diff --git a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java index 892909b094eeb..d78de1d85648b 100644 --- a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java @@ -61,6 +61,7 @@ import org.opensearch.test.VersionUtils; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.protobufprotocol.ProtobufInboundMessage; import org.junit.After; import org.junit.Before; @@ -328,7 +329,7 @@ public QueryFetchSearchResult read(InputStream in) throws IOException { BytesReference fullResponseBytes = channel.getMessageCaptor().get(); byte[] incomingBytes = BytesReference.toBytes(fullResponseBytes.slice(3, fullResponseBytes.length() - 3)); - NodeToNodeMessage nodeToNodeMessage = new NodeToNodeMessage(new ByteArrayInputStream(incomingBytes)); + ProtobufInboundMessage nodeToNodeMessage = new ProtobufInboundMessage(new ByteArrayInputStream(incomingBytes)); handler.inboundMessage(channel, nodeToNodeMessage); QueryFetchSearchResult result = responseCaptor.get(); assertNotNull(result); diff --git a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java index 8e0880ea85326..03f10bb702144 100644 --- a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java @@ -104,13 +104,7 @@ public void testPipelineHandling() throws IOException { final TestCircuitBreaker circuitBreaker = new TestCircuitBreaker(); circuitBreaker.startBreaking(); final InboundAggregator aggregator = new InboundAggregator(() -> circuitBreaker, canTripBreaker); - final InboundPipeline pipeline = new InboundPipeline( - statsTracker, - millisSupplier, - decoder, - aggregator, - messageHandler - ); + final InboundPipeline pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, messageHandler); final FakeTcpChannel channel = new FakeTcpChannel(); final int iterations = randomIntBetween(5, 10); @@ -226,13 +220,7 @@ public void testDecodeExceptionIsPropagated() throws IOException { final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); final Supplier breaker = () -> new NoopCircuitBreaker("test"); final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); - final InboundPipeline pipeline = new InboundPipeline( - statsTracker, - millisSupplier, - decoder, - aggregator, - messageHandler - ); + final InboundPipeline pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, messageHandler); try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { String actionName = "actionName"; @@ -286,13 +274,7 @@ public void testEnsureBodyIsNotPrematurelyReleased() throws IOException { final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); final Supplier breaker = () -> new NoopCircuitBreaker("test"); final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); - final InboundPipeline pipeline = new InboundPipeline( - statsTracker, - millisSupplier, - decoder, - aggregator, - messageHandler - ); + final InboundPipeline pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, messageHandler); try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { String actionName = "actionName"; diff --git a/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java index ec40e95fe45c1..a9d8d3c45b9f9 100644 --- a/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java @@ -68,6 +68,7 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.protobufprotocol.ProtobufInboundMessage; import org.junit.After; import org.junit.Before; @@ -121,7 +122,7 @@ public void setUp() throws Exception { } catch (IOException e) { throw new AssertionError(e); } - }, Version.CURRENT); + }); } @After @@ -294,7 +295,7 @@ public void testSendProtobufResponse() throws IOException { FetchSearchResult fetchResult = createFetchSearchResult(); QueryFetchSearchResult response = new QueryFetchSearchResult(queryResult, fetchResult); System.setProperty(FeatureFlags.PROTOBUF, "true"); - assertTrue((response.getProtocol()).equals(BaseInboundMessage.PROTOBUF_PROTOCOL)); + assertTrue((response.getProtocol()).equals(ProtobufInboundMessage.PROTOBUF_PROTOCOL)); AtomicLong requestIdRef = new AtomicLong(); AtomicReference actionRef = new AtomicReference<>(); @@ -315,9 +316,9 @@ public void onResponseSent(long requestId, String action, TransportResponse resp final Supplier breaker = () -> new NoopCircuitBreaker("test"); final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) requestCanTripBreaker -> true); InboundPipeline inboundPipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, (c, m) -> { - NodeToNodeMessage m1 = (NodeToNodeMessage) m; + ProtobufInboundMessage m1 = (ProtobufInboundMessage) m; protobufMessage.set(BytesReference.fromByteBuffer(ByteBuffer.wrap(m1.getMessage().toByteArray()))); - }, Version.CURRENT); + }); BytesReference reference = channel.getMessageCaptor().get(); ActionListener sendListener = channel.getListenerCaptor().get(); if (randomBoolean()) { @@ -331,7 +332,7 @@ public void onResponseSent(long requestId, String action, TransportResponse resp inboundPipeline.handleBytes(channel, new ReleasableBytesReference(reference, () -> {})); final BytesReference responseBytes = protobufMessage.get(); - final NodeToNodeMessage message = new NodeToNodeMessage(new ByteArrayInputStream(responseBytes.toBytesRef().bytes)); + final ProtobufInboundMessage message = new ProtobufInboundMessage(new ByteArrayInputStream(responseBytes.toBytesRef().bytes)); assertEquals(version.toString(), message.getMessage().getVersion()); assertEquals(requestId, message.getHeader().getRequestId()); assertNotNull(message.getRequestHeaders());