From b13c3bad4b8f254b3a0bd74c1466d5b199747f9e Mon Sep 17 00:00:00 2001 From: Andre Kurait Date: Wed, 24 Apr 2024 20:32:05 +0000 Subject: [PATCH] Revert "Simplify Netty RefCounting and ByteBuf Consumption (#592)" This reverts commit 9df361480624c4eaed1c1da29b8e2db283c4ec51. --- TrafficCapture/build.gradle | 16 --- .../replay/HttpByteBufFormatter.java | 115 ++++++++++------- .../replay/HttpMessageAndTimestamp.java | 24 ++-- .../replay/ParsedHttpMessagesAsDicts.java | 68 +++++----- .../replay/SourceTargetCaptureTuple.java | 2 +- .../migrations/replay/util/NettyUtils.java | 22 ---- .../migrations/replay/util/RefSafeHolder.java | 27 ---- .../replay/util/RefSafeStreamUtils.java | 31 ----- .../replay/HttpByteBufFormatterTest.java | 11 +- .../replay/RequestSenderOrchestratorTest.java | 35 +++--- .../replay/util/RefSafeStreamUtilsTest.java | 119 ------------------ 11 files changed, 143 insertions(+), 327 deletions(-) delete mode 100644 TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/util/NettyUtils.java delete mode 100644 TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/util/RefSafeHolder.java delete mode 100644 TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/util/RefSafeStreamUtils.java delete mode 100644 TrafficCapture/trafficReplayer/src/test/java/org/opensearch/migrations/replay/util/RefSafeStreamUtilsTest.java diff --git a/TrafficCapture/build.gradle b/TrafficCapture/build.gradle index 6ae9a33722..7cbcc5f653 100644 --- a/TrafficCapture/build.gradle +++ b/TrafficCapture/build.gradle @@ -12,22 +12,6 @@ allprojects { subprojects { apply plugin: 'java' apply plugin: 'maven-publish' - - // TODO: Expand to do more static checking in more projects - if (project.name == "trafficReplayer" || project.name == "trafficCaptureProxyServer") { - dependencies { - annotationProcessor group: 'com.google.errorprone', name: 'error_prone_core', version: '2.26.1' - } - tasks.named('compileJava', JavaCompile) { - if (project.name == "trafficReplayer" || project.name == "trafficCaptureProxyServer") { - options.compilerArgs += [ - "-XDcompilePolicy=simple", - "-Xplugin:ErrorProne -XepDisableAllChecks -Xep:MustBeClosed:ERROR -XepDisableWarningsInGeneratedCode", - ] - } - } - } - task javadocJar(type: Jar, dependsOn: javadoc) { archiveClassifier.set('javadoc') from javadoc.destinationDir diff --git a/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/HttpByteBufFormatter.java b/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/HttpByteBufFormatter.java index 272ea38e6a..a258d033ba 100644 --- a/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/HttpByteBufFormatter.java +++ b/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/HttpByteBufFormatter.java @@ -1,6 +1,8 @@ package org.opensearch.migrations.replay; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; @@ -9,6 +11,9 @@ import io.netty.handler.codec.http.HttpMessage; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpServerCodec; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; + import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; @@ -19,10 +24,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; -import lombok.SneakyThrows; -import lombok.extern.slf4j.Slf4j; -import org.opensearch.migrations.replay.util.NettyUtils; -import org.opensearch.migrations.replay.util.RefSafeHolder; @Slf4j public class HttpByteBufFormatter { @@ -63,29 +64,38 @@ public static String httpPacketBytesToString(HttpMessageType msgType, List byteBufStream) { - return httpPacketBufsToString(msgType, byteBufStream, DEFAULT_LINE_DELIMITER); + public static String httpPacketBytesToString(HttpMessageType msgType, Stream byteArrStream) { + return httpPacketBytesToString(msgType, byteArrStream, DEFAULT_LINE_DELIMITER); + } + + public static String httpPacketBufsToString(HttpMessageType msgType, Stream byteBufStream, + boolean releaseByteBufs) { + return httpPacketBufsToString(msgType, byteBufStream, releaseByteBufs, DEFAULT_LINE_DELIMITER); + } + + public static String httpPacketBytesToString(HttpMessageType msgType, List byteArrStream, String lineDelimiter) { + return httpPacketBytesToString(msgType, + Optional.ofNullable(byteArrStream).map(p -> p.stream()).orElse(Stream.of()), lineDelimiter); } - public static String httpPacketBytesToString(HttpMessageType msgType, List byteArrs, String lineDelimiter) { + public static String httpPacketBytesToString(HttpMessageType msgType, Stream byteArrStream, String lineDelimiter) { // This isn't memory efficient, // but stringifying byte bufs through a full parse and reserializing them was already really slow! - try (var stream = NettyUtils.createRefCntNeutralCloseableByteBufStream(byteArrs)) { - return httpPacketBufsToString(msgType, stream, lineDelimiter); - } + return httpPacketBufsToString(msgType, byteArrStream.map(Unpooled::wrappedBuffer), true, lineDelimiter); } - public static String httpPacketBufsToString(HttpMessageType msgType, Stream byteBufStream, String lineDelimiter) { + public static String httpPacketBufsToString(HttpMessageType msgType, Stream byteBufStream, + boolean releaseByteBufs, String lineDelimiter) { switch (printStyle.get().orElse(PacketPrintFormat.TRUNCATED)) { case TRUNCATED: - return httpPacketBufsToString(byteBufStream, Utils.MAX_BYTES_SHOWN_FOR_TO_STRING); + return httpPacketBufsToString(byteBufStream, Utils.MAX_BYTES_SHOWN_FOR_TO_STRING, releaseByteBufs); case FULL_BYTES: - return httpPacketBufsToString(byteBufStream, Long.MAX_VALUE); + return httpPacketBufsToString(byteBufStream, Long.MAX_VALUE, releaseByteBufs); case PARSED_HTTP: - return httpPacketsToPrettyPrintedString(msgType, byteBufStream, false, + return httpPacketsToPrettyPrintedString(msgType, byteBufStream, false, releaseByteBufs, lineDelimiter); case PARSED_HTTP_SORTED_HEADERS: - return httpPacketsToPrettyPrintedString(msgType, byteBufStream, true, + return httpPacketsToPrettyPrintedString(msgType, byteBufStream, true, releaseByteBufs, lineDelimiter); default: throw new IllegalStateException("Unknown PacketPrintFormat: " + printStyle.get()); @@ -93,21 +103,22 @@ public static String httpPacketBufsToString(HttpMessageType msgType, Stream byteBufStream, - boolean sortHeaders, String lineDelimiter) { - try(var messageHolder = RefSafeHolder.create(parseHttpMessageFromBufs(msgType, byteBufStream))) { - final HttpMessage httpMessage = messageHolder.get(); - if (httpMessage != null) { - if (httpMessage instanceof FullHttpRequest) { - return prettyPrintNettyRequest((FullHttpRequest) httpMessage, sortHeaders, lineDelimiter); - } else if (httpMessage instanceof FullHttpResponse) { - return prettyPrintNettyResponse((FullHttpResponse) httpMessage, sortHeaders, lineDelimiter); - } else { - throw new IllegalStateException("Embedded channel with an HttpObjectAggregator returned an " + - "unexpected object of type " + httpMessage.getClass() + ": " + httpMessage); - } - } else { + boolean sortHeaders, boolean releaseByteBufs, String lineDelimiter) { + HttpMessage httpMessage = parseHttpMessageFromBufs(msgType, byteBufStream, releaseByteBufs); + var holderOp = Optional.ofNullable((httpMessage instanceof ByteBufHolder) ? (ByteBufHolder) httpMessage : null); + try { + if (httpMessage instanceof FullHttpRequest) { + return prettyPrintNettyRequest((FullHttpRequest) httpMessage, sortHeaders, lineDelimiter); + } else if (httpMessage instanceof FullHttpResponse) { + return prettyPrintNettyResponse((FullHttpResponse) httpMessage, sortHeaders, lineDelimiter); + } else if (httpMessage == null) { return "[NULL]"; + } else { + throw new IllegalStateException("Embedded channel with an HttpObjectAggregator returned an " + + "unexpected object of type " + httpMessage.getClass() + ": " + httpMessage); } + } finally { + holderOp.ifPresent(bbh->bbh.content().release()); } } @@ -142,40 +153,58 @@ private static String prettyPrintNettyMessage(StringJoiner sj, boolean sorted, H * @param byteBufStream * @return */ - public static HttpMessage parseHttpMessageFromBufs(HttpMessageType msgType, Stream byteBufStream) { + public static HttpMessage parseHttpMessageFromBufs(HttpMessageType msgType, Stream byteBufStream, + boolean releaseByteBufs) { EmbeddedChannel channel = new EmbeddedChannel( msgType == HttpMessageType.REQUEST ? new HttpServerCodec() : new HttpClientCodec(), new HttpContentDecompressor(), new HttpObjectAggregator(Utils.MAX_PAYLOAD_SIZE_TO_PRINT) // Set max content length if needed ); + + byteBufStream.forEach(b -> { + try { + channel.writeInbound(b.retainedDuplicate()); + } finally { + if (releaseByteBufs) { + b.release(); + } + } + }); + try { - byteBufStream.forEachOrdered(b -> channel.writeInbound(b.retainedDuplicate())); return channel.readInbound(); } finally { channel.finishAndReleaseAll(); } } - public static FullHttpRequest parseHttpRequestFromBufs(Stream byteBufStream) { - return (FullHttpRequest) parseHttpMessageFromBufs(HttpMessageType.REQUEST, byteBufStream); + public static FullHttpRequest parseHttpRequestFromBufs(Stream byteBufStream, boolean releaseByteBufs) { + return (FullHttpRequest) parseHttpMessageFromBufs(HttpMessageType.REQUEST, byteBufStream, releaseByteBufs); } - public static FullHttpResponse parseHttpResponseFromBufs(Stream byteBufStream) { - return (FullHttpResponse) parseHttpMessageFromBufs(HttpMessageType.RESPONSE, byteBufStream); + public static FullHttpResponse parseHttpResponseFromBufs(Stream byteBufStream, boolean releaseByteBufs) { + return (FullHttpResponse) parseHttpMessageFromBufs(HttpMessageType.RESPONSE, byteBufStream, releaseByteBufs); } - public static String httpPacketBufsToString(Stream byteBufStream, long maxBytesToShow) { + public static String httpPacketBufsToString(Stream byteBufStream, long maxBytesToShow, + boolean releaseByteBufs) { if (byteBufStream == null) { return "null"; } return byteBufStream.map(originalByteBuf -> { - var bb = originalByteBuf.duplicate(); - var length = bb.readableBytes(); - var str = IntStream.range(0, length).map(idx -> bb.readByte()) - .limit(maxBytesToShow) - .mapToObj(b -> "" + (char) b) - .collect(Collectors.joining()); - return "[" + (length > maxBytesToShow ? str + "..." : str) + "]"; - }).collect(Collectors.joining(",")); + try { + var bb = originalByteBuf.duplicate(); + var length = bb.readableBytes(); + var str = IntStream.range(0, length).map(idx -> bb.readByte()) + .limit(maxBytesToShow) + .mapToObj(b -> "" + (char) b) + .collect(Collectors.joining()); + return "[" + (length > maxBytesToShow ? str + "..." : str) + "]"; + } finally { + if (releaseByteBufs) { + originalByteBuf.release(); + } + }}) + .collect(Collectors.joining(",")); } } diff --git a/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/HttpMessageAndTimestamp.java b/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/HttpMessageAndTimestamp.java index 0529b2b273..cb9089c3df 100644 --- a/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/HttpMessageAndTimestamp.java +++ b/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/HttpMessageAndTimestamp.java @@ -1,5 +1,6 @@ package org.opensearch.migrations.replay; +import io.netty.buffer.Unpooled; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Lombok; @@ -12,7 +13,6 @@ import java.time.Instant; import java.util.Optional; import java.util.stream.Stream; -import org.opensearch.migrations.replay.util.NettyUtils; @Slf4j @EqualsAndHashCode(exclude = "currentSegmentBytes") @@ -65,17 +65,17 @@ public Stream stream() { } public String format(Optional messageTypeOp) { - try (var bufStream = NettyUtils.createRefCntNeutralCloseableByteBufStream(packetBytes)) { - var packetBytesAsStr = messageTypeOp.map(mt-> HttpByteBufFormatter.httpPacketBytesToString(mt, packetBytes, - HttpByteBufFormatter.LF_LINE_DELIMITER)) - .orElseGet(()-> HttpByteBufFormatter.httpPacketBufsToString(bufStream, Utils.MAX_PAYLOAD_SIZE_TO_PRINT)); - final StringBuilder sb = new StringBuilder("HttpMessageAndTimestamp{"); - sb.append("firstPacketTimestamp=").append(firstPacketTimestamp); - sb.append(", lastPacketTimestamp=").append(lastPacketTimestamp); - sb.append(", message=[").append(packetBytesAsStr); - sb.append("]}"); - return sb.toString(); - } + var packetBytesAsStr = messageTypeOp.map(mt-> HttpByteBufFormatter.httpPacketBytesToString(mt, packetBytes, + HttpByteBufFormatter.LF_LINE_DELIMITER)) + .orElseGet(()-> HttpByteBufFormatter.httpPacketBufsToString( + packetBytes.stream().map(Unpooled::wrappedBuffer), + Utils.MAX_PAYLOAD_SIZE_TO_PRINT, true)); + final StringBuilder sb = new StringBuilder("HttpMessageAndTimestamp{"); + sb.append("firstPacketTimestamp=").append(firstPacketTimestamp); + sb.append(", lastPacketTimestamp=").append(lastPacketTimestamp); + sb.append(", message=[").append(packetBytesAsStr); + sb.append("]}"); + return sb.toString(); } public void addSegment(byte[] data) { diff --git a/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/ParsedHttpMessagesAsDicts.java b/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/ParsedHttpMessagesAsDicts.java index 66567a7dca..0c78c5a07e 100644 --- a/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/ParsedHttpMessagesAsDicts.java +++ b/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/ParsedHttpMessagesAsDicts.java @@ -1,7 +1,14 @@ package org.opensearch.migrations.replay; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import io.netty.handler.codec.http.HttpHeaders; +import io.netty.util.ReferenceCounted; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.opensearch.migrations.replay.datatypes.TransformedPackets; +import org.opensearch.migrations.replay.tracing.IReplayContexts; + import java.time.Duration; import java.util.Base64; import java.util.LinkedHashMap; @@ -10,12 +17,7 @@ import java.util.Optional; import java.util.concurrent.Callable; import java.util.stream.Collectors; -import lombok.NonNull; -import lombok.extern.slf4j.Slf4j; -import org.opensearch.migrations.replay.datatypes.TransformedPackets; -import org.opensearch.migrations.replay.tracing.IReplayContexts; -import org.opensearch.migrations.replay.util.NettyUtils; -import org.opensearch.migrations.replay.util.RefSafeHolder; +import java.util.stream.Stream; /** * TODO - This class will pull all bodies in as a byte[], even if that byte[] isn't @@ -100,6 +102,11 @@ public static void fillStatusCodeMetrics(@NonNull IReplayContexts.ITupleHandling targetResponseOp.ifPresent(r -> context.setTargetStatus((Integer) r.get(STATUS_CODE_KEY))); } + + private static Stream byteToByteBufStream(List incoming) { + return incoming.stream().map(Unpooled::wrappedBuffer); + } + private static byte[] getBytesFromByteBuf(ByteBuf buf) { var bytes = new byte[buf.readableBytes()]; buf.getBytes(buf.readerIndex(), bytes); @@ -131,20 +138,17 @@ private static Map convertRequest(@NonNull IReplayContexts.ITupl @NonNull List data) { return makeSafeMap(context, () -> { var map = new LinkedHashMap(); - try (var bufStream = NettyUtils.createRefCntNeutralCloseableByteBufStream(data); - var messageHolder = RefSafeHolder.create(HttpByteBufFormatter.parseHttpRequestFromBufs(bufStream))) { - var message = messageHolder.get(); - if (message != null) { - map.put("Request-URI", message.uri()); - map.put("Method", message.method().toString()); - map.put("HTTP-Version", message.protocolVersion().toString()); - context.setMethod(message.method().toString()); - context.setEndpoint(message.uri()); - context.setHttpVersion(message.protocolVersion().toString()); - return fillMap(map, message.headers(), message.content()); - } else { - return Map.of("Exception", "Message couldn't be parsed as a full http message"); - } + var message = HttpByteBufFormatter.parseHttpRequestFromBufs(byteToByteBufStream(data), true); + try { + map.put("Request-URI", message.uri()); + map.put("Method", message.method().toString()); + map.put("HTTP-Version", message.protocolVersion().toString()); + context.setMethod(message.method().toString()); + context.setEndpoint(message.uri()); + context.setHttpVersion(message.protocolVersion().toString()); + return fillMap(map, message.headers(), message.content()); + } finally { + Optional.ofNullable(message).ifPresent(ReferenceCounted::release); } }); } @@ -153,18 +157,18 @@ private static Map convertResponse(@NonNull IReplayContexts.ITup @NonNull List data, Duration latency) { return makeSafeMap(context, () -> { var map = new LinkedHashMap(); - try (var bufStream = NettyUtils.createRefCntNeutralCloseableByteBufStream(data); - var messageHolder = RefSafeHolder.create(HttpByteBufFormatter.parseHttpResponseFromBufs(bufStream))) { - var message = messageHolder.get(); - if (message != null) { - map.put("HTTP-Version", message.protocolVersion()); - map.put(STATUS_CODE_KEY, message.status().code()); - map.put("Reason-Phrase", message.status().reasonPhrase()); - map.put(RESPONSE_TIME_MS_KEY, latency.toMillis()); - return fillMap(map, message.headers(), message.content()); - } else { - return Map.of("Exception", "Message couldn't be parsed as a full http message"); - } + var message = HttpByteBufFormatter.parseHttpResponseFromBufs(byteToByteBufStream(data), true); + if (message == null) { + return Map.of("Exception", "Message couldn't be parsed as a full http message"); + } + try { + map.put("HTTP-Version", message.protocolVersion()); + map.put(STATUS_CODE_KEY, message.status().code()); + map.put("Reason-Phrase", message.status().reasonPhrase()); + map.put(RESPONSE_TIME_MS_KEY, latency.toMillis()); + return fillMap(map, message.headers(), message.content()); + } finally { + Optional.ofNullable(message).ifPresent(ReferenceCounted::release); } }); } diff --git a/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/SourceTargetCaptureTuple.java b/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/SourceTargetCaptureTuple.java index 0067ed78c7..b63f1aa150 100644 --- a/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/SourceTargetCaptureTuple.java +++ b/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/SourceTargetCaptureTuple.java @@ -54,7 +54,7 @@ public String toString() { if (targetResponseDuration != null) { sj.add("targetResponseDuration=").add(targetResponseDuration+""); } Optional.ofNullable(targetRequestData).ifPresent(d-> sj.add("targetRequestData=") .add(d.isClosed() ? "CLOSED" : HttpByteBufFormatter.httpPacketBufsToString( - HttpByteBufFormatter.HttpMessageType.REQUEST, d.streamUnretained(), LF_LINE_DELIMITER))); + HttpByteBufFormatter.HttpMessageType.REQUEST, d.streamUnretained(), false, LF_LINE_DELIMITER))); Optional.ofNullable(targetResponseData).filter(d->!d.isEmpty()).ifPresent(d -> sj.add("targetResponseData=") .add(HttpByteBufFormatter.httpPacketBytesToString(HttpByteBufFormatter.HttpMessageType.RESPONSE, d, LF_LINE_DELIMITER))); sj.add("transformStatus=").add(transformationStatus+""); diff --git a/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/util/NettyUtils.java b/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/util/NettyUtils.java deleted file mode 100644 index 60b97df6f2..0000000000 --- a/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/util/NettyUtils.java +++ /dev/null @@ -1,22 +0,0 @@ -package org.opensearch.migrations.replay.util; - -import com.google.errorprone.annotations.MustBeClosed; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import java.util.Collection; - -import java.util.stream.Stream; - -public final class NettyUtils { - @MustBeClosed - public static Stream createRefCntNeutralCloseableByteBufStream(Stream byteArrStream) { - return RefSafeStreamUtils.refSafeMap(byteArrStream, Unpooled::wrappedBuffer); - } - - @MustBeClosed - public static Stream createRefCntNeutralCloseableByteBufStream(Collection byteArrCollection) { - return createRefCntNeutralCloseableByteBufStream(byteArrCollection.stream()); - } - - private NettyUtils() {} -} \ No newline at end of file diff --git a/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/util/RefSafeHolder.java b/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/util/RefSafeHolder.java deleted file mode 100644 index af943bfc51..0000000000 --- a/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/util/RefSafeHolder.java +++ /dev/null @@ -1,27 +0,0 @@ -package org.opensearch.migrations.replay.util; - -import com.google.errorprone.annotations.MustBeClosed; -import io.netty.util.ReferenceCountUtil; -import javax.annotation.Nullable; - -public class RefSafeHolder implements AutoCloseable { - private final T resource; - - private RefSafeHolder(@Nullable T resource) { - this.resource = resource; - } - - @MustBeClosed - static public RefSafeHolder create(@Nullable T resource) { - return new RefSafeHolder<>(resource); - } - - public @Nullable T get() { - return resource; - } - - @Override - public void close() { - ReferenceCountUtil.release(resource); - } -} diff --git a/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/util/RefSafeStreamUtils.java b/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/util/RefSafeStreamUtils.java deleted file mode 100644 index 2254f3c38e..0000000000 --- a/TrafficCapture/trafficReplayer/src/main/java/org/opensearch/migrations/replay/util/RefSafeStreamUtils.java +++ /dev/null @@ -1,31 +0,0 @@ -package org.opensearch.migrations.replay.util; - -import com.google.errorprone.annotations.MustBeClosed; -import io.netty.util.ReferenceCounted; -import java.util.ArrayDeque; -import java.util.Deque; -import java.util.function.Function; -import java.util.stream.Stream; - -public final class RefSafeStreamUtils { - @MustBeClosed - public static Stream refSafeMap(Stream inputStream, - Function referenceTrackedMappingFunction) { - final Deque refCountedTracker = new ArrayDeque<>(); - return inputStream.map(t -> { - var resource = referenceTrackedMappingFunction.apply(t); - refCountedTracker.add(resource); - return resource; - }).onClose(() -> refCountedTracker.forEach(ReferenceCounted::release)); - } - - public static U refSafeTransform(Stream inputStream, - Function transformCreatingReferenceTrackedObjects, - Function, U> streamApplication) { - try (var mappedStream = refSafeMap(inputStream, transformCreatingReferenceTrackedObjects)) { - return streamApplication.apply(mappedStream); - } - } - - private RefSafeStreamUtils() {} -} diff --git a/TrafficCapture/trafficReplayer/src/test/java/org/opensearch/migrations/replay/HttpByteBufFormatterTest.java b/TrafficCapture/trafficReplayer/src/test/java/org/opensearch/migrations/replay/HttpByteBufFormatterTest.java index 822c9d8eed..c895bc3676 100644 --- a/TrafficCapture/trafficReplayer/src/test/java/org/opensearch/migrations/replay/HttpByteBufFormatterTest.java +++ b/TrafficCapture/trafficReplayer/src/test/java/org/opensearch/migrations/replay/HttpByteBufFormatterTest.java @@ -1,13 +1,11 @@ package org.opensearch.migrations.replay; - import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.opensearch.migrations.replay.util.RefSafeStreamUtils; import org.opensearch.migrations.testutils.CountingNettyResourceLeakDetector; import org.opensearch.migrations.testutils.TestUtilities; import org.opensearch.migrations.testutils.WrapWithNettyLeakDetection; @@ -16,6 +14,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.stream.Collectors; import java.util.stream.Stream; @WrapWithNettyLeakDetection @@ -146,10 +145,10 @@ private static String prettyPrint(List byteArrays, private static String prettyPrintByteBufs(List byteArrays, HttpByteBufFormatter.HttpMessageType messageType, boolean usePooled) { - return RefSafeStreamUtils.refSafeTransform(byteArrays.stream(), - b->TestUtilities.getByteBuf(b,usePooled), - bbs -> HttpByteBufFormatter.httpPacketBufsToString(messageType, bbs)); - + var bbList = byteArrays.stream().map(b->TestUtilities.getByteBuf(b,usePooled)).collect(Collectors.toList()); + var formattedString = HttpByteBufFormatter.httpPacketBufsToString(messageType, bbList.stream(), false); + bbList.forEach(bb->bb.release()); + return formattedString; } static String getExpectedResult(HttpByteBufFormatter.PacketPrintFormat format, BufferContent content) { diff --git a/TrafficCapture/trafficReplayer/src/test/java/org/opensearch/migrations/replay/RequestSenderOrchestratorTest.java b/TrafficCapture/trafficReplayer/src/test/java/org/opensearch/migrations/replay/RequestSenderOrchestratorTest.java index 8921555272..661d76ba31 100644 --- a/TrafficCapture/trafficReplayer/src/test/java/org/opensearch/migrations/replay/RequestSenderOrchestratorTest.java +++ b/TrafficCapture/trafficReplayer/src/test/java/org/opensearch/migrations/replay/RequestSenderOrchestratorTest.java @@ -1,15 +1,9 @@ package org.opensearch.migrations.replay; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; import io.netty.buffer.Unpooled; import io.netty.handler.codec.http.FullHttpResponse; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.time.Instant; -import java.util.AbstractMap.SimpleEntry; -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Tag; @@ -17,12 +11,18 @@ import org.junit.jupiter.api.parallel.Execution; import org.junit.jupiter.api.parallel.ExecutionMode; import org.opensearch.migrations.replay.util.DiagnosticTrackableCompletableFuture; -import org.opensearch.migrations.replay.util.NettyUtils; -import org.opensearch.migrations.replay.util.RefSafeHolder; import org.opensearch.migrations.testutils.SimpleHttpServer; import org.opensearch.migrations.testutils.WrapWithNettyLeakDetection; import org.opensearch.migrations.tracing.InstrumentationTest; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + @Slf4j @WrapWithNettyLeakDetection(repetitions = 1) class RequestSenderOrchestratorTest extends InstrumentationTest { @@ -67,19 +67,18 @@ public void testThatSchedulingWorks() throws Exception { var arr = cf.get(); Assertions.assertNull(arr.error); Assertions.assertTrue(arr.responseSizeInBytes > 0); - var packetBytesArr = arr.responsePackets.stream().map(SimpleEntry::getValue).collect(Collectors.toList()); - try (var bufStream = NettyUtils.createRefCntNeutralCloseableByteBufStream(packetBytesArr); - var messageHolder = RefSafeHolder.create( - HttpByteBufFormatter.parseHttpMessageFromBufs(HttpByteBufFormatter.HttpMessageType.RESPONSE, - bufStream))) { - var message = messageHolder.get(); - Assertions.assertNotNull(message); - var response = (FullHttpResponse) message; + var httpMessage = HttpByteBufFormatter.parseHttpMessageFromBufs(HttpByteBufFormatter.HttpMessageType.RESPONSE, + arr.responsePackets.stream().map(kvp -> Unpooled.wrappedBuffer(kvp.getValue())), false); + try { + var response = (FullHttpResponse) httpMessage; Assertions.assertEquals(200, response.status().code()); var body = response.content(); Assertions.assertEquals(TestHttpServerContext.SERVER_RESPONSE_BODY_PREFIX + TestHttpServerContext.getUriForIthRequest(i / NUM_REPEATS), - body.duplicate().toString(StandardCharsets.UTF_8)); + new String(body.duplicate().toString(StandardCharsets.UTF_8))); + } finally { + Optional.ofNullable((httpMessage instanceof ByteBufHolder) ? (ByteBufHolder) httpMessage : null) + .ifPresent(bbh -> bbh.content().release()); } } closeFuture.get(); diff --git a/TrafficCapture/trafficReplayer/src/test/java/org/opensearch/migrations/replay/util/RefSafeStreamUtilsTest.java b/TrafficCapture/trafficReplayer/src/test/java/org/opensearch/migrations/replay/util/RefSafeStreamUtilsTest.java deleted file mode 100644 index 2dd3ca7296..0000000000 --- a/TrafficCapture/trafficReplayer/src/test/java/org/opensearch/migrations/replay/util/RefSafeStreamUtilsTest.java +++ /dev/null @@ -1,119 +0,0 @@ -package org.opensearch.migrations.replay.util; - -import io.netty.util.AbstractReferenceCounted; -import io.netty.util.ReferenceCounted; - -import java.util.function.Predicate; -import org.junit.jupiter.api.Test; - -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import static org.junit.jupiter.api.Assertions.*; - -class RefSafeStreamUtilsTest { - - @Test - void refSafeMap_shouldMapAndReleaseResources() { - Stream inputStream = Stream.of("a", "b", "c"); - - List result; - try (Stream mappedStream = RefSafeStreamUtils.refSafeMap(inputStream, TestReferenceCounted::new)) { - result = mappedStream.collect(Collectors.toList()); - } - - assertEquals(3, result.size()); - assertTrue(result.stream().allMatch(TestReferenceCounted::isReleased)); - } - - @Test - void refSafeTransform_shouldTransformAndReleaseResources() { - Stream inputStream = Stream.of("a", "b", "c"); - - List refCountedObjects = new ArrayList<>(); - List result = RefSafeStreamUtils.refSafeTransform( - inputStream, - value -> { - TestReferenceCounted refCounted = new TestReferenceCounted(value); - refCountedObjects.add(refCounted); - return refCounted; - }, - stream -> stream.map(TestReferenceCounted::getValue).collect(Collectors.toList()) - ); - - assertEquals(List.of("a", "b", "c"), result); - assertTrue(refCountedObjects.stream().allMatch(TestReferenceCounted::isReleased)); - } - - @Test - void refSafeMap_shouldHandleExceptionDuringMapping() { - List inputStreamConsumedObjects = new ArrayList<>(); - Stream inputStream = Stream.of("a", "b", "c", "d", "e") - .peek(inputStreamConsumedObjects::add); - - List refCountedObjects = new ArrayList<>(); - assertThrows(RuntimeException.class, () -> { - try (Stream mappedStream = RefSafeStreamUtils.refSafeMap(inputStream, - value -> { - if (value.equals("d")) { - throw new RuntimeException("Simulated exception"); - } - TestReferenceCounted refCounted = new TestReferenceCounted(value); - refCountedObjects.add(refCounted); - return refCounted; - })) { - try { - mappedStream.collect(Collectors.toList()); - } finally { - // Expect no release until try-with-resources close - assertEquals(3, refCountedObjects.size()); - assertTrue(refCountedObjects.stream().allMatch(Predicate.not(TestReferenceCounted::isReleased))); - } - } - }); - assertEquals(4, inputStreamConsumedObjects.size()); - assertEquals(3, refCountedObjects.size()); - assertTrue(refCountedObjects.stream().allMatch(TestReferenceCounted::isReleased)); - } - - private static class TestReferenceCounted extends AbstractReferenceCounted { - private final String value; - private boolean released; - - TestReferenceCounted(String value) { - this.value = value; - } - - String getValue() { - return value; - } - - boolean isReleased() { - return released; - } - - @Override - public boolean release() { - if (released) { - throw new AssertionError("TestReferenceCounted object released twice"); - } - try { - return super.release(); - } finally { - released = true; - } - } - - @Override - protected void deallocate() { - // No-op - } - - @Override - public ReferenceCounted touch(Object hint) { - return this; - } - } -} \ No newline at end of file