From d3a489bb7f597b943e8d01e1c9ffaf23ec78a856 Mon Sep 17 00:00:00 2001 From: Maksym Ostroverkhov Date: Fri, 28 Jun 2024 06:55:57 +0300 Subject: [PATCH] WebSocketFrameFactory.BulkEncoder: add text frames support --- .../http/websocketx/WebSocketCodecTest.java | 172 ++++++++++++++++++ .../websocketx/MaskingWebSocketEncoder.java | 23 ++- .../NonMaskingWebSocketEncoder.java | 19 +- .../websocketx/WebSocketFrameFactory.java | 16 ++ 4 files changed, 226 insertions(+), 4 deletions(-) diff --git a/netty-websocket-http1-test/src/test/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketCodecTest.java b/netty-websocket-http1-test/src/test/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketCodecTest.java index 1ab1f2d..d725aac 100644 --- a/netty-websocket-http1-test/src/test/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketCodecTest.java +++ b/netty-websocket-http1-test/src/test/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketCodecTest.java @@ -136,6 +136,25 @@ void binaryFramesBulkEncoder(boolean mask) throws Exception { client.close(); } + @Timeout(300) + @ValueSource(booleans = {true, false}) + @ParameterizedTest + void textFramesBulkEncoder(boolean mask) throws Exception { + int maxFrameSize = 1000; + Channel s = server = nettyServer(new WebSocketFramesTestServerHandler(), mask, false); + TextFramesEncoderClientBulkHandler clientHandler = + new TextFramesEncoderClientBulkHandler(maxFrameSize, 'a'); + Channel client = + webSocketCallbacksClient(s.localAddress(), mask, true, maxFrameSize, clientHandler); + + WebSocketFrameFactory.BulkEncoder encoder = clientHandler.onHandshakeCompleted().join(); + Assertions.assertThat(encoder).isNotNull(); + + CompletableFuture onComplete = clientHandler.startFramesExchange(); + onComplete.join(); + client.close(); + } + @Timeout(300) @ValueSource(booleans = {true, false}) @ParameterizedTest @@ -648,6 +667,159 @@ private void sendFrames(ChannelHandlerContext c, int toSend) { } } + static class TextFramesEncoderClientBulkHandler + implements WebSocketCallbacksHandler, WebSocketFrameListener { + private final CompletableFuture onHandshakeComplete = + new CompletableFuture<>(); + private final CompletableFuture onFrameExchangeComplete = new CompletableFuture<>(); + private final int framesCount; + private final char expectedAsciiChar; + private WebSocketFrameFactory.BulkEncoder textFrameEncoder; + private int receivedFrames; + private int sentFrames; + private ByteBuf outBuffer; + private volatile ChannelHandlerContext ctx; + + TextFramesEncoderClientBulkHandler(int maxFrameSize, char expectedAsciiChar) { + this.framesCount = maxFrameSize; + this.expectedAsciiChar = expectedAsciiChar; + } + + @Override + public WebSocketFrameListener exchange( + ChannelHandlerContext ctx, WebSocketFrameFactory webSocketFrameFactory) { + this.textFrameEncoder = webSocketFrameFactory.bulkEncoder(); + return this; + } + + @Override + public void onChannelRead( + ChannelHandlerContext ctx, boolean finalFragment, int rsv, int opcode, ByteBuf payload) { + if (!finalFragment) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError("received non-final frame: " + finalFragment)); + payload.release(); + return; + } + if (rsv != 0) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError("received frame with non-zero rsv: " + rsv)); + payload.release(); + return; + } + if (opcode != WebSocketProtocol.OPCODE_TEXT) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError("received non-text frame: " + Long.toHexString(opcode))); + payload.release(); + return; + } + + int readableBytes = payload.readableBytes(); + + int expectedSize = receivedFrames; + if (expectedSize != readableBytes) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError( + "received frame of unexpected size: " + + expectedSize + + ", actual: " + + readableBytes)); + payload.release(); + return; + } + + for (int i = 0; i < readableBytes; i++) { + char ch = (char) payload.readByte(); + if (ch != expectedAsciiChar) { + onFrameExchangeComplete.completeExceptionally( + new AssertionError( + "received frame with unexpected content: " + + ch + + ", expected: " + + expectedAsciiChar)); + payload.release(); + return; + } + } + payload.release(); + if (++receivedFrames == framesCount) { + onFrameExchangeComplete.complete(null); + } + } + + @Override + public void onOpen(ChannelHandlerContext ctx) { + this.ctx = ctx; + int bufferSize = 4 * framesCount; + this.outBuffer = ctx.alloc().buffer(bufferSize, bufferSize); + onHandshakeComplete.complete(textFrameEncoder); + } + + @Override + public void onClose(ChannelHandlerContext ctx) { + ByteBuf out = outBuffer; + if (out != null) { + outBuffer = null; + out.release(); + } + if (!onFrameExchangeComplete.isDone()) { + onFrameExchangeComplete.completeExceptionally(new ClosedChannelException()); + } + } + + @Override + public void onExceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (!onFrameExchangeComplete.isDone()) { + onFrameExchangeComplete.completeExceptionally(cause); + } + } + + CompletableFuture onHandshakeCompleted() { + return onHandshakeComplete; + } + + CompletableFuture startFramesExchange() { + ChannelHandlerContext c = ctx; + c.executor().execute(() -> sendFrames(c, framesCount - sentFrames)); + return onFrameExchangeComplete; + } + + private void sendFrames(ChannelHandlerContext c, int toSend) { + WebSocketFrameFactory.BulkEncoder frameEncoder = textFrameEncoder; + for (int frameIdx = 0; frameIdx < toSend; frameIdx++) { + if (!c.channel().isOpen()) { + return; + } + int payloadSize = sentFrames; + int frameSize = frameEncoder.sizeofTextFrame(payloadSize); + ByteBuf out = outBuffer; + if (frameSize > out.capacity() - out.writerIndex()) { + int readableBytes = out.readableBytes(); + int bufferSize = 4 * framesCount; + outBuffer = c.alloc().buffer(bufferSize, bufferSize); + if (c.channel().bytesBeforeUnwritable() < readableBytes) { + c.writeAndFlush(out, c.voidPromise()); + } else { + c.write(out, c.voidPromise()); + } + out = outBuffer; + } + int mask = frameEncoder.encodeTextFramePrefix(out, payloadSize); + for (int payloadIdx = 0; payloadIdx < payloadSize; payloadIdx++) { + out.writeByte(expectedAsciiChar); + } + frameEncoder.maskTextFrame(out, mask, payloadSize); + sentFrames++; + } + ByteBuf out = outBuffer; + if (out.readableBytes() > 0) { + c.writeAndFlush(out, c.voidPromise()); + } else { + c.flush(); + } + } + } + static class BinaryFramesEncoderClientHandler implements WebSocketCallbacksHandler, WebSocketFrameListener { private final CompletableFuture onHandshakeComplete = diff --git a/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/MaskingWebSocketEncoder.java b/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/MaskingWebSocketEncoder.java index 86195aa..dde1868 100644 --- a/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/MaskingWebSocketEncoder.java +++ b/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/MaskingWebSocketEncoder.java @@ -201,8 +201,18 @@ static ByteBuf encodeDataFrame(ByteBuf binaryFrame, int prefixSmall, int prefixM @Override public int encodeBinaryFramePrefix(ByteBuf byteBuf, int payloadSize) { + return encodeDataFramePrefix(byteBuf, payloadSize, BINARY_FRAME_SMALL, BINARY_FRAME_MEDIUM); + } + + @Override + public int encodeTextFramePrefix(ByteBuf byteBuf, int textPayloadSize) { + return encodeDataFramePrefix(byteBuf, textPayloadSize, TEXT_FRAME_SMALL, TEXT_FRAME_MEDIUM); + } + + static int encodeDataFramePrefix( + ByteBuf byteBuf, int payloadSize, int prefixSmall, int prefixMedium) { if (payloadSize <= 125) { - byteBuf.writeShort(BINARY_FRAME_SMALL | payloadSize); + byteBuf.writeShort(prefixSmall | payloadSize); int mask = mask(); byteBuf.writeInt(mask); return mask; @@ -210,7 +220,7 @@ public int encodeBinaryFramePrefix(ByteBuf byteBuf, int payloadSize) { if (payloadSize <= 65_535) { int mask = mask(); - byteBuf.writeLong(((BINARY_FRAME_MEDIUM | (long) payloadSize) << 32) | mask); + byteBuf.writeLong(((prefixMedium | (long) payloadSize) << 32) | mask); return mask; } throw new IllegalArgumentException(payloadSizeLimit(payloadSize, 65_535)); @@ -218,6 +228,15 @@ public int encodeBinaryFramePrefix(ByteBuf byteBuf, int payloadSize) { @Override public ByteBuf maskBinaryFrame(ByteBuf byteBuf, int mask, int payloadSize) { + return maskDataFrame(byteBuf, mask, payloadSize); + } + + @Override + public ByteBuf maskTextFrame(ByteBuf byteBuf, int mask, int textPayloadSize) { + return maskDataFrame(byteBuf, mask, textPayloadSize); + } + + static ByteBuf maskDataFrame(ByteBuf byteBuf, int mask, int payloadSize) { int end = byteBuf.writerIndex(); int start = end - payloadSize; return mask(mask, byteBuf, start, end); diff --git a/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/NonMaskingWebSocketEncoder.java b/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/NonMaskingWebSocketEncoder.java index f1b6881..1734904 100644 --- a/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/NonMaskingWebSocketEncoder.java +++ b/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/NonMaskingWebSocketEncoder.java @@ -176,10 +176,20 @@ static ByteBuf encodeDataFrame(ByteBuf binaryFrame, int prefixSmall, int prefixM @Override public int encodeBinaryFramePrefix(ByteBuf byteBuf, int payloadSize) { + return encodeDataFramePrefix(byteBuf, payloadSize, BINARY_FRAME_SMALL, BINARY_FRAME_MEDIUM); + } + + @Override + public int encodeTextFramePrefix(ByteBuf byteBuf, int textPayloadSize) { + return encodeDataFramePrefix(byteBuf, textPayloadSize, TEXT_FRAME_SMALL, TEXT_FRAME_MEDIUM); + } + + static int encodeDataFramePrefix( + ByteBuf byteBuf, int payloadSize, int prefixSmall, int prefixMedium) { if (payloadSize <= 125) { - byteBuf.writeShort(BINARY_FRAME_SMALL | payloadSize); + byteBuf.writeShort(prefixSmall | payloadSize); } else if (payloadSize <= 65_535) { - byteBuf.writeInt(BINARY_FRAME_MEDIUM | payloadSize); + byteBuf.writeInt(prefixMedium | payloadSize); } else { throw new IllegalArgumentException(payloadSizeLimit(payloadSize, 65_535)); } @@ -191,6 +201,11 @@ public ByteBuf maskBinaryFrame(ByteBuf byteBuf, int mask, int payloadSize) { return byteBuf; } + @Override + public ByteBuf maskTextFrame(ByteBuf byteBuf, int mask, int textPayloadSize) { + return byteBuf; + } + @Override public int sizeofBinaryFrame(int payloadSize) { return sizeOfDataFrame(payloadSize); diff --git a/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketFrameFactory.java b/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketFrameFactory.java index 3fd5eb6..63d19a7 100644 --- a/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketFrameFactory.java +++ b/netty-websocket-http1/src/main/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketFrameFactory.java @@ -73,5 +73,21 @@ interface BulkEncoder { ByteBuf maskBinaryFrame(ByteBuf byteBuf, int mask, int payloadSize); int sizeofBinaryFrame(int payloadSize); + + /** @return frame mask, or -1 if masking not applicable */ + default int encodeTextFramePrefix(ByteBuf byteBuf, int textPayloadSize) { + throw new UnsupportedOperationException( + "WebSocketFrameFactory.BulkEncoder.encodeTextFramePrefix() not implemented"); + } + + default ByteBuf maskTextFrame(ByteBuf byteBuf, int mask, int textPayloadSize) { + throw new UnsupportedOperationException( + "WebSocketFrameFactory.BulkEncoder.maskTextFrame() not implemented"); + } + + default int sizeofTextFrame(int textPayloadSize) { + throw new UnsupportedOperationException( + "WebSocketFrameFactory.BulkEncoder.sizeofTextFrame() not implemented"); + } } }