diff --git a/TrafficCapture/captureOffloader/src/main/java/org/opensearch/migrations/trafficcapture/CodedOutputStreamSizeUtil.java b/TrafficCapture/captureOffloader/src/main/java/org/opensearch/migrations/trafficcapture/CodedOutputStreamSizeUtil.java index e6335a2387..95706ba405 100644 --- a/TrafficCapture/captureOffloader/src/main/java/org/opensearch/migrations/trafficcapture/CodedOutputStreamSizeUtil.java +++ b/TrafficCapture/captureOffloader/src/main/java/org/opensearch/migrations/trafficcapture/CodedOutputStreamSizeUtil.java @@ -39,7 +39,7 @@ public static int maxBytesNeededForASegmentedObservation(Instant timestamp, int int tsTagAndContentSize = CodedOutputStream.computeInt32Size(TrafficObservation.TS_FIELD_NUMBER, tsContentSize) + tsContentSize; // Capture required bytes - int dataSize = CodedOutputStream.computeByteBufferSize(dataFieldNumber, buffer); + int dataSize = computeByteBufferRemainingSize(dataFieldNumber, buffer); int captureTagAndContentSize = CodedOutputStream.computeInt32Size(observationFieldNumber, dataSize) + dataSize; // Observation and closing index required bytes @@ -47,6 +47,26 @@ public static int maxBytesNeededForASegmentedObservation(Instant timestamp, int Integer.MAX_VALUE); } + /** + * This function determines the number of bytes needed to write the remaining bytes in a byteBuffer and its tag. + * Use this over CodeOutputStream.computeByteBufferSize(int fieldNumber, ByteBuffer buffer) due to the latter + * relying on the ByteBuffer capacity instead of limit in size calculation. + */ + public static int computeByteBufferRemainingSize(int fieldNumber, ByteBuffer buffer) { + return CodedOutputStream.computeTagSize(fieldNumber) + computeByteBufferRemainingSizeNoTag(buffer); + } + + /** + * This function determines the number of bytes needed to write the remaining bytes in a byteBuffer. Use this over + * CodeOutputStream.computeByteBufferSizeNoTag(ByteBuffer buffer) due to the latter relying on the + * ByteBuffer capacity instead of limit in size calculation. + */ + public static int computeByteBufferRemainingSizeNoTag(ByteBuffer buffer) { + int bufferSize = buffer.remaining(); + return CodedOutputStream.computeUInt32SizeNoTag(bufferSize) + bufferSize; + } + + /** * This function determines the number of bytes needed to store a TrafficObservation and a closing index for a * TrafficStream, from the provided input. diff --git a/TrafficCapture/captureOffloader/src/main/java/org/opensearch/migrations/trafficcapture/StreamChannelConnectionCaptureSerializer.java b/TrafficCapture/captureOffloader/src/main/java/org/opensearch/migrations/trafficcapture/StreamChannelConnectionCaptureSerializer.java index fbe72df54f..288c7d54c0 100644 --- a/TrafficCapture/captureOffloader/src/main/java/org/opensearch/migrations/trafficcapture/StreamChannelConnectionCaptureSerializer.java +++ b/TrafficCapture/captureOffloader/src/main/java/org/opensearch/migrations/trafficcapture/StreamChannelConnectionCaptureSerializer.java @@ -3,6 +3,7 @@ import com.google.protobuf.CodedOutputStream; import com.google.protobuf.Descriptors; import com.google.protobuf.Timestamp; +import com.google.protobuf.WireFormat; import io.netty.buffer.ByteBuf; import java.util.function.IntSupplier; @@ -182,8 +183,12 @@ private void writeTimestampForNowToCurrentStream(Instant timestamp) throws IOExc } private void writeByteBufferToCurrentStream(int fieldNum, ByteBuffer byteBuffer) throws IOException { - if (byteBuffer.remaining() > 0) { - getOrCreateCodedOutputStream().writeByteBuffer(fieldNum, byteBuffer); + if (byteBuffer.hasRemaining()) { + // CodedOutputStream.writeByteBuffer writes based on capacity and ignores limits so prefer write + getOrCreateCodedOutputStream().writeTag(fieldNum, WireFormat.WIRETYPE_LENGTH_DELIMITED); + getOrCreateCodedOutputStream().writeUInt32NoTag(byteBuffer.remaining()); + getOrCreateCodedOutputStream().write(byteBuffer.duplicate()); + assert byteBuffer.hasRemaining() : "byteBuffer position should not be modified when writing."; } else { getOrCreateCodedOutputStream().writeUInt32NoTag(0); } @@ -279,8 +284,7 @@ void addDataMessage(int captureFieldNumber, int dataFieldNumber, Instant timesta addDataMessage(captureFieldNumber, dataFieldNumber, timestamp, buf.nioBuffer()); } - void addDataMessage(int captureFieldNumber, int dataFieldNumber, Instant timestamp, ByteBuffer nioBuffer) throws IOException { - var readOnlyDataBuffer = nioBuffer.asReadOnlyBuffer(); + void addDataMessage(int captureFieldNumber, int dataFieldNumber, Instant timestamp, ByteBuffer buffer) throws IOException { int segmentFieldNumber; int segmentDataFieldNumber; if (captureFieldNumber == TrafficObservation.READ_FIELD_NUMBER) { @@ -296,40 +300,44 @@ void addDataMessage(int captureFieldNumber, int dataFieldNumber, Instant timesta // the potentially required bytes for simplicity. This could leave ~5 bytes of unused space in the CodedOutputStream // when considering the case of a message that does not need segments or for the case of a smaller segment created // from a much larger message - int messageAndOverheadBytesLeft = CodedOutputStreamSizeUtil.maxBytesNeededForASegmentedObservation(timestamp, - segmentFieldNumber, segmentDataFieldNumber, readOnlyDataBuffer); - int trafficStreamOverhead = messageAndOverheadBytesLeft - readOnlyDataBuffer.capacity(); + final int messageAndOverheadBytesLeft = CodedOutputStreamSizeUtil.maxBytesNeededForASegmentedObservation(timestamp, + segmentFieldNumber, segmentDataFieldNumber, buffer); + final int dataSize = CodedOutputStreamSizeUtil.computeByteBufferRemainingSizeNoTag(buffer); + final int trafficStreamOverhead = messageAndOverheadBytesLeft - dataSize; - // Ensure that space for at least one data byte and overhead exists, otherwise a flush is necessary. - flushIfNeeded(() -> (trafficStreamOverhead + 1)); + // Ensure that space for at least one data byte, one length byte, and overhead exists, otherwise a flush is necessary. + flushIfNeeded(() -> (trafficStreamOverhead + 2)).join(); + assert getOrCreateCodedOutputStreamHolder().getOutputStreamSpaceLeft() == -1 || + getOrCreateCodedOutputStreamHolder().getOutputStreamSpaceLeft() > trafficStreamOverhead + : "COS does not have space for data"; // If our message is empty or can fit in the current CodedOutputStream no chunking is needed, and we can continue var spaceLeft = getOrCreateCodedOutputStreamHolder().getOutputStreamSpaceLeft(); - if (readOnlyDataBuffer.limit() == 0 || spaceLeft == -1 || messageAndOverheadBytesLeft <= spaceLeft) { + if (!buffer.hasRemaining() || spaceLeft == -1 || messageAndOverheadBytesLeft <= spaceLeft) { int minExpectedSpaceAfterObservation = spaceLeft - messageAndOverheadBytesLeft; - addSubstreamMessage(captureFieldNumber, dataFieldNumber, timestamp, readOnlyDataBuffer); + addSubstreamMessage(captureFieldNumber, dataFieldNumber, timestamp, buffer); observationSizeSanityCheck(minExpectedSpaceAfterObservation, captureFieldNumber); return; } - while(readOnlyDataBuffer.position() < readOnlyDataBuffer.limit()) { + var readBuffer = buffer.duplicate(); + while(readBuffer.hasRemaining()) { + flushIfNeeded(() -> (trafficStreamOverhead + 2)).join(); // COS checked for unbounded limit above - int availableCOSSpace = getOrCreateCodedOutputStreamHolder().getOutputStreamSpaceLeft(); - int chunkBytes = messageAndOverheadBytesLeft > availableCOSSpace ? availableCOSSpace - trafficStreamOverhead : readOnlyDataBuffer.limit() - readOnlyDataBuffer.position(); - ByteBuffer bb = readOnlyDataBuffer.slice(); - bb.limit(chunkBytes); - bb = bb.slice(); - readOnlyDataBuffer.position(readOnlyDataBuffer.position() + chunkBytes); - addSubstreamMessage(segmentFieldNumber, segmentDataFieldNumber, timestamp, bb); - int minExpectedSpaceAfterObservation = availableCOSSpace - chunkBytes - trafficStreamOverhead; + final int availableCOSSpace = getOrCreateCodedOutputStreamHolder().getOutputStreamSpaceLeft(); + final int maxLengthSpace = CodedOutputStream.computeUInt32SizeNoTag(readBuffer.remaining()); + final int maxBytesSpace = availableCOSSpace - trafficStreamOverhead - maxLengthSpace; + final int nextChunkBytes = Math.min(maxBytesSpace, readBuffer.remaining()); + + var dataBytes = new byte[nextChunkBytes]; + readBuffer.get(dataBytes, 0, nextChunkBytes); + addSubstreamMessage(segmentFieldNumber, segmentDataFieldNumber, timestamp, ByteBuffer.wrap(dataBytes)); + + final int minExpectedSpaceAfterObservation = maxBytesSpace - nextChunkBytes; observationSizeSanityCheck(minExpectedSpaceAfterObservation, segmentFieldNumber); - // 1 to N-1 chunked messages - if (readOnlyDataBuffer.position() < readOnlyDataBuffer.limit()) { - flushCommitAndResetStream(false); - messageAndOverheadBytesLeft = messageAndOverheadBytesLeft - chunkBytes; - } } writeEndOfSegmentMessage(timestamp); + } void addSubstreamMessage(int captureFieldNumber, int dataFieldNumber, int dataCountFieldNumber, int dataCount, @@ -342,7 +350,7 @@ void addSubstreamMessage(int captureFieldNumber, int dataFieldNumber, int dataCo segmentCountSize = CodedOutputStream.computeInt32Size(dataCountFieldNumber, dataCount); } if (byteBuffer.remaining() > 0) { - dataSize = CodedOutputStream.computeByteBufferSize(dataFieldNumber, byteBuffer); + dataSize = CodedOutputStreamSizeUtil.computeByteBufferRemainingSize(dataFieldNumber, byteBuffer); captureClosureLength = CodedOutputStream.computeInt32SizeNoTag(dataSize + segmentCountSize); } beginSubstreamObservation(timestamp, captureFieldNumber, captureClosureLength + dataSize + segmentCountSize); diff --git a/TrafficCapture/captureOffloader/src/test/java/org/opensearch/migrations/trafficcapture/CodedOutputStreamSizeUtilTest.java b/TrafficCapture/captureOffloader/src/test/java/org/opensearch/migrations/trafficcapture/CodedOutputStreamSizeUtilTest.java new file mode 100644 index 0000000000..ac40c7612a --- /dev/null +++ b/TrafficCapture/captureOffloader/src/test/java/org/opensearch/migrations/trafficcapture/CodedOutputStreamSizeUtilTest.java @@ -0,0 +1,85 @@ +package org.opensearch.migrations.trafficcapture; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; +import java.nio.ByteBuffer; +import java.time.Instant; + +class CodedOutputStreamSizeUtilTest { + + @Test + void testGetSizeOfTimestamp() { + // Timestamp with only seconds (no explicit nanoseconds) + Instant timestampSecondsOnly = Instant.parse("2024-01-01T00:00:00Z"); + int sizeSecondsOnly = CodedOutputStreamSizeUtil.getSizeOfTimestamp(timestampSecondsOnly); + assertEquals( 6, sizeSecondsOnly); + + // Timestamp with both seconds and nanoseconds + Instant timestampWithNanos = Instant.parse("2024-12-31T23:59:59.123456789Z"); + int sizeWithNanos = CodedOutputStreamSizeUtil.getSizeOfTimestamp(timestampWithNanos); + assertEquals( 11, sizeWithNanos); + } + + @Test + void testMaxBytesNeededForASegmentedObservation() { + Instant timestamp = Instant.parse("2024-01-01T00:00:00Z"); + ByteBuffer buffer = ByteBuffer.allocate(100).limit(50); + buffer.position(25); + int result = CodedOutputStreamSizeUtil.maxBytesNeededForASegmentedObservation(timestamp, 1, 2, buffer); + assertEquals(45, result); + } + + @Test + void test_computeByteBufferRemainingSize() { + ByteBuffer buffer = ByteBuffer.allocate(100).limit(50); + int result = CodedOutputStreamSizeUtil.computeByteBufferRemainingSize(2, buffer); + assertEquals(52, result); + } + + @Test + void test_computeByteBufferRemainingSize_ByteBufferAtCapacity() { + ByteBuffer buffer = ByteBuffer.allocate(200); + int result = CodedOutputStreamSizeUtil.computeByteBufferRemainingSize(2, buffer); + assertEquals(203, result); + } + + @Test + void test_computeByteBufferRemainingSize_EmptyByteBuffer() { + ByteBuffer buffer = ByteBuffer.allocate(0); + int result = CodedOutputStreamSizeUtil.computeByteBufferRemainingSize(2, buffer); + assertEquals(2, result); + } + + @Test + void testBytesNeededForObservationAndClosingIndex() { + int observationContentSize = 50; + int numberOfTrafficStreamsSoFar = 10; + + int result = CodedOutputStreamSizeUtil.bytesNeededForObservationAndClosingIndex(observationContentSize, numberOfTrafficStreamsSoFar); + assertEquals(54, result); + } + + @Test + void testBytesNeededForObservationAndClosingIndex_WithZeroContent() { + int observationContentSize = 0; + int numberOfTrafficStreamsSoFar = 0; + + int result = CodedOutputStreamSizeUtil.bytesNeededForObservationAndClosingIndex(observationContentSize, numberOfTrafficStreamsSoFar); + assertEquals(4, result); + } + + @Test + void testBytesNeededForObservationAndClosingIndex_VariousIndices() { + int observationContentSize = 20; + + // Test with increasing indices to verify scaling of index size + int[] indices = new int[]{1, 1000, 100000}; + int[] expectedResults = new int[]{24, 25, 26}; + + for (int i = 0; i < indices.length; i++) { + int result = CodedOutputStreamSizeUtil.bytesNeededForObservationAndClosingIndex(observationContentSize, indices[i]); + assertEquals(expectedResults[i], result); + } + } + +} diff --git a/TrafficCapture/captureOffloader/src/test/java/org/opensearch/migrations/trafficcapture/StreamChannelConnectionCaptureSerializerTest.java b/TrafficCapture/captureOffloader/src/test/java/org/opensearch/migrations/trafficcapture/StreamChannelConnectionCaptureSerializerTest.java index c33fd32194..04d2d02d07 100644 --- a/TrafficCapture/captureOffloader/src/test/java/org/opensearch/migrations/trafficcapture/StreamChannelConnectionCaptureSerializerTest.java +++ b/TrafficCapture/captureOffloader/src/test/java/org/opensearch/migrations/trafficcapture/StreamChannelConnectionCaptureSerializerTest.java @@ -204,18 +204,55 @@ public void testWriteIsHandledForBufferAllocatedLargerThanWritten() var serializer = createSerializerWithTestHandler(outputBuffersCreated, getEstimatedTrafficStreamByteSize(1, 200)); ByteBuffer byteBuffer = ByteBuffer.allocateDirect(100); - byteBuffer.limit(50); - byteBuffer.putInt(1); + byteBuffer.put(FAKE_READ_PACKET_DATA.getBytes(StandardCharsets.UTF_8)); + byteBuffer.flip(); serializer.addDataMessage(TrafficObservation.WRITE_FIELD_NUMBER, WriteObservation.DATA_FIELD_NUMBER, REFERENCE_TIMESTAMP, byteBuffer); var future = serializer.flushCommitAndResetStream(true); future.get(); + Assertions.assertEquals(0, byteBuffer.position()); + var outputBuffersList = new ArrayList<>(outputBuffersCreated); TrafficStream reconstitutedTrafficStream = TrafficStream.parseFrom(outputBuffersList.get(0)); - Assertions.assertEquals(1, reconstitutedTrafficStream.getSubStream(0).getWrite().getData().size()); + Assertions.assertEquals(FAKE_READ_PACKET_DATA, reconstitutedTrafficStream.getSubStream(0).getWrite().getData().toStringUtf8()); + } + + @Test + public void testWriteIsHandledForBufferAllocatedLargerThanWrittenWithChunking() + throws IOException, ExecutionException, InterruptedException { + var outputBuffersCreated = new ConcurrentLinkedQueue(); + var serializer = createSerializerWithTestHandler(outputBuffersCreated, getEstimatedTrafficStreamByteSize(1, 4)); + + ByteBuffer byteBuffer = ByteBuffer.allocate(100); + byteBuffer.put(FAKE_READ_PACKET_DATA.getBytes(StandardCharsets.UTF_8)); + byteBuffer.flip(); + + Assertions.assertEquals(0, byteBuffer.position()); + Assertions.assertEquals(100, byteBuffer.capacity()); + Assertions.assertEquals(16, byteBuffer.limit()); + + serializer.addDataMessage(TrafficObservation.WRITE_FIELD_NUMBER, WriteObservation.DATA_FIELD_NUMBER, REFERENCE_TIMESTAMP, byteBuffer); + var future = serializer.flushCommitAndResetStream(true); + future.get(); + + Assertions.assertEquals(0, byteBuffer.position()); + + List observations = new ArrayList<>(); + for (ByteBuffer buffer : outputBuffersCreated) { + var trafficStream = TrafficStream.parseFrom(buffer); + observations.add(trafficStream.getSubStream(0)); + } + + StringBuilder reconstructedData = new StringBuilder(); + for (TrafficObservation observation : observations) { + var stringChunk = observation.getWriteSegment().getData().toStringUtf8(); + reconstructedData.append(stringChunk); + } + Assertions.assertEquals(FAKE_READ_PACKET_DATA, reconstructedData.toString()); } + @Test public void testWithLimitlessCodedOutputStreamHolder() throws IOException, ExecutionException, InterruptedException {