Skip to content

Commit

Permalink
WebSocketFrameFactory: add text frames support
Browse files Browse the repository at this point in the history
  • Loading branch information
mostroverkhov committed Jun 28, 2024
1 parent 712b0b3 commit eca5134
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void tearDown() {
@ParameterizedTest
void binaryFramesEncoder(boolean mask) throws Exception {
int maxFrameSize = DEFAULT_CODEC_MAX_FRAME_SIZE;
Channel s = server = nettyServer(new BinaryFramesTestServerHandler(), mask, false);
Channel s = server = nettyServer(new WebSocketFramesTestServerHandler(), mask, false);
BinaryFramesEncoderClientHandler clientHandler =
new BinaryFramesEncoderClientHandler(maxFrameSize);
Channel client =
Expand All @@ -103,7 +103,7 @@ void binaryFramesEncoder(boolean mask) throws Exception {
@ParameterizedTest
void binaryFramesBulkEncoder(boolean mask) throws Exception {
int maxFrameSize = 1000;
Channel s = server = nettyServer(new BinaryFramesTestServerHandler(), mask, false);
Channel s = server = nettyServer(new WebSocketFramesTestServerHandler(), mask, false);
BinaryFramesEncoderClientBulkHandler clientHandler =
new BinaryFramesEncoderClientBulkHandler(maxFrameSize);
Channel client =
Expand All @@ -117,14 +117,33 @@ void binaryFramesBulkEncoder(boolean mask) throws Exception {
client.close();
}

@Timeout(300)
@ValueSource(booleans = {true, false})
@ParameterizedTest
void textFramesFactory(boolean mask) throws Exception {
int maxFrameSize = DEFAULT_CODEC_MAX_FRAME_SIZE;
Channel s = server = nettyServer(new WebSocketFramesTestServerHandler(), mask, false);
TextFramesFactoryClientHandler clientHandler =
new TextFramesFactoryClientHandler(maxFrameSize, 'a');
Channel client =
webSocketCallbacksClient(s.localAddress(), mask, true, maxFrameSize, clientHandler);

WebSocketFrameFactory frameFactory = clientHandler.onHandshakeCompleted().join();
Assertions.assertThat(frameFactory).isNotNull();

CompletableFuture<Void> onComplete = clientHandler.startFramesExchange();
onComplete.join();
client.close();
}

@Timeout(300)
@MethodSource("maskingArgs")
@ParameterizedTest
void allSizeBinaryFramesDefaultDecoder(
boolean mask, Class<?> webSocketFrameFactoryType, Class<ChannelHandler> webSocketDecoderType)
throws Exception {
int maxFrameSize = DEFAULT_CODEC_MAX_FRAME_SIZE;
Channel s = server = nettyServer(new BinaryFramesTestServerHandler(), mask, false);
Channel s = server = nettyServer(new WebSocketFramesTestServerHandler(), mask, false);
BinaryFramesTestClientHandler clientHandler = new BinaryFramesTestClientHandler(maxFrameSize);
Channel client =
webSocketCallbacksClient(s.localAddress(), mask, true, maxFrameSize, clientHandler);
Expand All @@ -142,7 +161,7 @@ void allSizeBinaryFramesDefaultDecoder(
@Test
void binaryFramesSmallDecoder() throws Exception {
int maxFrameSize = SMALL_CODEC_MAX_FRAME_SIZE;
Channel s = server = nettyServer(new BinaryFramesTestServerHandler(), false, false);
Channel s = server = nettyServer(new WebSocketFramesTestServerHandler(), false, false);
BinaryFramesTestClientHandler clientHandler = new BinaryFramesTestClientHandler(maxFrameSize);
Channel client =
webSocketCallbacksClient(s.localAddress(), false, false, maxFrameSize, clientHandler);
Expand Down Expand Up @@ -450,7 +469,7 @@ protected void initChannel(SocketChannel ch) {
WebSocketDecoderConfig.newBuilder()
.expectMaskedFrames(expectMaskedFrames)
.allowMaskMismatch(allowMaskMismatch)
.withUTF8Validator(false)
.withUTF8Validator(true)
.allowExtensions(false)
.maxFramePayloadLength(65535)
.build();
Expand Down Expand Up @@ -759,6 +778,159 @@ private void sendFrames(ChannelHandlerContext c, int toSend) {
}
}

static class TextFramesFactoryClientHandler
implements WebSocketCallbacksHandler, WebSocketFrameListener {
private final CompletableFuture<WebSocketFrameFactory> onHandshakeComplete =
new CompletableFuture<>();
private final CompletableFuture<Void> onFrameExchangeComplete = new CompletableFuture<>();
private WebSocketFrameFactory frameFactory;
private final int framesCount;
private final char expectedAsciiChar;
private int receivedFrames;
private int sentFrames;
private volatile ChannelHandlerContext ctx;

TextFramesFactoryClientHandler(int maxFrameSize, char expectedAsciiChar) {
this.framesCount = maxFrameSize;
this.expectedAsciiChar = expectedAsciiChar;
}

@Override
public WebSocketFrameListener exchange(
ChannelHandlerContext ctx, WebSocketFrameFactory webSocketFrameFactory) {
this.frameFactory = webSocketFrameFactory;
return this;
}

@Override
public void onChannelRead(
ChannelHandlerContext ctx, boolean finalFragment, int rsv, int opcode, ByteBuf payload) {
if (!finalFragment) {
onFrameExchangeComplete.completeExceptionally(
new AssertionError("received non-final frame: " + finalFragment));
payload.release();
return;
}
if (rsv != 0) {
onFrameExchangeComplete.completeExceptionally(
new AssertionError("received frame with non-zero rsv: " + rsv));
payload.release();
return;
}
if (opcode != WebSocketProtocol.OPCODE_TEXT) {
onFrameExchangeComplete.completeExceptionally(
new AssertionError("received non-text frame: " + Long.toHexString(opcode)));
payload.release();
return;
}

int readableBytes = payload.readableBytes();

int expectedSize = receivedFrames;
if (expectedSize != readableBytes) {
onFrameExchangeComplete.completeExceptionally(
new AssertionError(
"received frame of unexpected size: "
+ expectedSize
+ ", actual: "
+ readableBytes));
payload.release();
return;
}

for (int i = 0; i < readableBytes; i++) {
char ch = (char) payload.readByte();
if (ch != expectedAsciiChar) {
onFrameExchangeComplete.completeExceptionally(
new AssertionError(
"received frame with unexpected content: "
+ ch
+ ", expected: "
+ expectedAsciiChar));
payload.release();
return;
}
}
payload.release();
if (++receivedFrames == framesCount) {
onFrameExchangeComplete.complete(null);
}
}

@Override
public void onChannelWritabilityChanged(ChannelHandlerContext ctx) {
boolean writable = ctx.channel().isWritable();
if (sentFrames > 0 && writable) {
int toSend = framesCount - sentFrames;
if (toSend > 0) {
sendFrames(ctx, toSend);
}
}
}

@Override
public void onOpen(ChannelHandlerContext ctx) {
this.ctx = ctx;
onHandshakeComplete.complete(frameFactory);
}

@Override
public void onClose(ChannelHandlerContext ctx) {
if (!onFrameExchangeComplete.isDone()) {
onFrameExchangeComplete.completeExceptionally(new ClosedChannelException());
}
}

@Override
public void onExceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (!onFrameExchangeComplete.isDone()) {
onFrameExchangeComplete.completeExceptionally(cause);
}
}

CompletableFuture<WebSocketFrameFactory> onHandshakeCompleted() {
return onHandshakeComplete;
}

CompletableFuture<Void> startFramesExchange() {
ChannelHandlerContext c = ctx;
c.executor().execute(() -> sendFrames(c, framesCount - sentFrames));
return onFrameExchangeComplete;
}

private void sendFrames(ChannelHandlerContext c, int toSend) {
Channel ch = c.channel();
WebSocketFrameFactory factory = frameFactory;
boolean pendingFlush = false;
ByteBufAllocator allocator = c.alloc();
for (int frameIdx = 0; frameIdx < toSend; frameIdx++) {
if (!c.channel().isOpen()) {
return;
}
int payloadSize = sentFrames;
ByteBuf textFrame = factory.createTextFrame(allocator, payloadSize);
for (int payloadIdx = 0; payloadIdx < payloadSize; payloadIdx++) {
textFrame.writeByte(expectedAsciiChar);
}
ByteBuf maskedTextFrame = factory.mask(textFrame);
sentFrames++;
if (ch.bytesBeforeUnwritable() < textFrame.capacity()) {
c.writeAndFlush(maskedTextFrame, c.voidPromise());
pendingFlush = false;
if (!ch.isWritable()) {
return;
}
} else {
c.write(maskedTextFrame, c.voidPromise());
pendingFlush = true;
}
}
if (pendingFlush) {
c.flush();
}
}
}

static class BinaryFramesTestClientHandler
implements WebSocketCallbacksHandler, WebSocketFrameListener {
private final CompletableFuture<WebSocketFrameFactory> onHandshakeComplete =
Expand Down Expand Up @@ -1186,7 +1358,7 @@ private void sendFrames(ChannelHandlerContext c, int toSend) {
}
}

static class BinaryFramesTestServerHandler extends ChannelInboundHandlerAdapter {
static class WebSocketFramesTestServerHandler extends ChannelInboundHandlerAdapter {
boolean ready = true;
boolean pendingFlush;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static com.jauntsdn.netty.handler.codec.http.websocketx.WebSocketProtocol.OPCODE_CLOSE;
import static com.jauntsdn.netty.handler.codec.http.websocketx.WebSocketProtocol.OPCODE_PING;
import static com.jauntsdn.netty.handler.codec.http.websocketx.WebSocketProtocol.OPCODE_PONG;
import static com.jauntsdn.netty.handler.codec.http.websocketx.WebSocketProtocol.OPCODE_TEXT;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
Expand Down Expand Up @@ -55,6 +56,8 @@ static class FrameFactory
static final int PREFIX_SIZE_SMALL = 6;
static final int BINARY_FRAME_SMALL =
OPCODE_BINARY << 8 | /*FIN*/ (byte) 1 << 15 | /*MASK*/ (byte) 1 << 7;
static final int TEXT_FRAME_SMALL =
OPCODE_TEXT << 8 | /*FIN*/ (byte) 1 << 15 | /*MASK*/ (byte) 1 << 7;

static final int CLOSE_FRAME =
OPCODE_CLOSE << 8 | /*FIN*/ (byte) 1 << 15 | /*MASK*/ (byte) 1 << 7;
Expand All @@ -65,27 +68,38 @@ static class FrameFactory

static final int PREFIX_SIZE_MEDIUM = 8;
static final int BINARY_FRAME_MEDIUM = (BINARY_FRAME_SMALL | /*LEN*/ (byte) 126) << 16;
static final int TEXT_FRAME_MEDIUM = (TEXT_FRAME_SMALL | /*LEN*/ (byte) 126) << 16;

static final WebSocketFrameFactory INSTANCE = new FrameFactory();

@Override
public ByteBuf createBinaryFrame(ByteBufAllocator allocator, int payloadSize) {
static ByteBuf createDataFrame(
ByteBufAllocator allocator, int payloadSize, int prefixSmall, int prefixMedium) {
if (payloadSize <= 125) {
return allocator
.buffer(PREFIX_SIZE_SMALL + payloadSize)
.writeShort(BINARY_FRAME_SMALL | payloadSize)
.writeShort(prefixSmall | payloadSize)
.readerIndex(2)
.writeInt(mask());
} else if (payloadSize <= 65_535) {
return allocator
.buffer(PREFIX_SIZE_MEDIUM + payloadSize)
.writeLong((long) (BINARY_FRAME_MEDIUM | payloadSize) << 32 | mask())
.writeLong((long) (prefixMedium | payloadSize) << 32 | mask())
.readerIndex(4);
} else {
throw new IllegalArgumentException(payloadSizeLimit(payloadSize, 65_535));
}
}

@Override
public ByteBuf createBinaryFrame(ByteBufAllocator allocator, int payloadSize) {
return createDataFrame(allocator, payloadSize, BINARY_FRAME_SMALL, BINARY_FRAME_MEDIUM);
}

@Override
public ByteBuf createTextFrame(ByteBufAllocator allocator, int payloadSize) {
return createDataFrame(allocator, payloadSize, TEXT_FRAME_SMALL, TEXT_FRAME_MEDIUM);
}

@Override
public ByteBuf createCloseFrame(ByteBufAllocator allocator, int statusCode, String reason) {
if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static com.jauntsdn.netty.handler.codec.http.websocketx.WebSocketProtocol.OPCODE_CLOSE;
import static com.jauntsdn.netty.handler.codec.http.websocketx.WebSocketProtocol.OPCODE_PING;
import static com.jauntsdn.netty.handler.codec.http.websocketx.WebSocketProtocol.OPCODE_PONG;
import static com.jauntsdn.netty.handler.codec.http.websocketx.WebSocketProtocol.OPCODE_TEXT;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
Expand Down Expand Up @@ -53,32 +54,44 @@ static class FrameFactory
WebSocketFrameFactory.BulkEncoder {
static final int PREFIX_SIZE_SMALL = 2;
static final int BINARY_FRAME_SMALL = OPCODE_BINARY << 8 | /*FIN*/ (byte) 1 << 15;
static final int TEXT_FRAME_SMALL = OPCODE_TEXT << 8 | /*FIN*/ (byte) 1 << 15;

static final int CLOSE_FRAME = OPCODE_CLOSE << 8 | /*FIN*/ (byte) 1 << 15;
static final int PING_FRAME = OPCODE_PING << 8 | /*FIN*/ (byte) 1 << 15;
static final int PONG_FRAME = OPCODE_PONG << 8 | /*FIN*/ (byte) 1 << 15;

static final int PREFIX_SIZE_MEDIUM = 4;
static final int BINARY_FRAME_MEDIUM = (BINARY_FRAME_SMALL | /*LEN*/ (byte) 126) << 16;
static final int TEXT_FRAME_MEDIUM = (TEXT_FRAME_SMALL | /*LEN*/ (byte) 126) << 16;

static final WebSocketFrameFactory INSTANCE = new FrameFactory();

@Override
public ByteBuf createBinaryFrame(ByteBufAllocator allocator, int payloadSize) {
static ByteBuf createDataFrame(
ByteBufAllocator allocator, int payloadSize, int prefixSmall, int prefixMedium) {
if (payloadSize <= 125) {
return allocator
.buffer(PREFIX_SIZE_SMALL + payloadSize)
.writeShort(BINARY_FRAME_SMALL | payloadSize);
.writeShort(prefixSmall | payloadSize);
}

if (payloadSize <= 65_535) {
return allocator
.buffer(PREFIX_SIZE_MEDIUM + payloadSize)
.writeInt(BINARY_FRAME_MEDIUM | payloadSize);
.writeInt(prefixMedium | payloadSize);
}
throw new IllegalArgumentException(payloadSizeLimit(payloadSize, 65_535));
}

@Override
public ByteBuf createBinaryFrame(ByteBufAllocator allocator, int payloadSize) {
return createDataFrame(allocator, payloadSize, BINARY_FRAME_SMALL, BINARY_FRAME_MEDIUM);
}

@Override
public ByteBuf createTextFrame(ByteBufAllocator allocator, int textDataSize) {
return createDataFrame(allocator, textDataSize, TEXT_FRAME_SMALL, TEXT_FRAME_MEDIUM);
}

@Override
public ByteBuf createCloseFrame(ByteBufAllocator allocator, int statusCode, String reason) {
if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ public interface WebSocketFrameFactory {

ByteBuf createBinaryFrame(ByteBufAllocator allocator, int binaryDataSize);

default ByteBuf createTextFrame(ByteBufAllocator allocator, int textDataSize) {
throw new UnsupportedOperationException(
"WebSocketFrameFactory.createTextFrame() not implemented");
}

ByteBuf createCloseFrame(ByteBufAllocator allocator, int statusCode, String reason);

ByteBuf createPingFrame(ByteBufAllocator allocator, int binaryDataSize);
Expand Down

0 comments on commit eca5134

Please sign in to comment.