Skip to content

Commit

Permalink
WebSocketFrameFactory.Encoder: add text frames support
Browse files Browse the repository at this point in the history
  • Loading branch information
mostroverkhov committed Jun 28, 2024
1 parent eca5134 commit 98e2cc8
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,25 @@ void binaryFramesEncoder(boolean mask) throws Exception {
client.close();
}

@Timeout(300)
@ValueSource(booleans = {true, false})
@ParameterizedTest
void textFramesEncoder(boolean mask) throws Exception {
int maxFrameSize = DEFAULT_CODEC_MAX_FRAME_SIZE;
Channel s = server = nettyServer(new WebSocketFramesTestServerHandler(), mask, false);
TextFramesEncoderClientHandler clientHandler =
new TextFramesEncoderClientHandler(maxFrameSize, 'a');
Channel client =
webSocketCallbacksClient(s.localAddress(), mask, true, maxFrameSize, clientHandler);

WebSocketFrameFactory.Encoder encoder = clientHandler.onHandshakeCompleted().join();
Assertions.assertThat(encoder).isNotNull();

CompletableFuture<Void> onComplete = clientHandler.startFramesExchange();
onComplete.join();
client.close();
}

@Timeout(300)
@ValueSource(booleans = {true, false})
@ParameterizedTest
Expand Down Expand Up @@ -778,6 +797,161 @@ private void sendFrames(ChannelHandlerContext c, int toSend) {
}
}

static class TextFramesEncoderClientHandler
implements WebSocketCallbacksHandler, WebSocketFrameListener {
private final CompletableFuture<WebSocketFrameFactory.Encoder> onHandshakeComplete =
new CompletableFuture<>();
private final CompletableFuture<Void> onFrameExchangeComplete = new CompletableFuture<>();
private WebSocketFrameFactory.Encoder textFrameEncoder;
private final int framesCount;
private final char expectedAsciiChar;
private int receivedFrames;
private int sentFrames;
private volatile ChannelHandlerContext ctx;

TextFramesEncoderClientHandler(int maxFrameSize, char expectedAsciiChar) {
this.framesCount = maxFrameSize;
this.expectedAsciiChar = expectedAsciiChar;
}

@Override
public WebSocketFrameListener exchange(
ChannelHandlerContext ctx, WebSocketFrameFactory webSocketFrameFactory) {
this.textFrameEncoder = webSocketFrameFactory.encoder();
return this;
}

@Override
public void onChannelRead(
ChannelHandlerContext ctx, boolean finalFragment, int rsv, int opcode, ByteBuf payload) {
if (!finalFragment) {
onFrameExchangeComplete.completeExceptionally(
new AssertionError("received non-final frame: " + finalFragment));
payload.release();
return;
}
if (rsv != 0) {
onFrameExchangeComplete.completeExceptionally(
new AssertionError("received frame with non-zero rsv: " + rsv));
payload.release();
return;
}
if (opcode != WebSocketProtocol.OPCODE_TEXT) {
onFrameExchangeComplete.completeExceptionally(
new AssertionError("received non-text frame: " + Long.toHexString(opcode)));
payload.release();
return;
}

int readableBytes = payload.readableBytes();

int expectedSize = receivedFrames;
if (expectedSize != readableBytes) {
onFrameExchangeComplete.completeExceptionally(
new AssertionError(
"received frame of unexpected size: "
+ expectedSize
+ ", actual: "
+ readableBytes));
payload.release();
return;
}

for (int i = 0; i < readableBytes; i++) {
char ch = (char) payload.readByte();
if (ch != expectedAsciiChar) {
onFrameExchangeComplete.completeExceptionally(
new AssertionError(
"received frame with unexpected content: "
+ ch
+ ", expected: "
+ expectedAsciiChar));
payload.release();
return;
}
}
payload.release();
if (++receivedFrames == framesCount) {
onFrameExchangeComplete.complete(null);
}
}

@Override
public void onChannelWritabilityChanged(ChannelHandlerContext ctx) {
boolean writable = ctx.channel().isWritable();
if (sentFrames > 0 && writable) {
int toSend = framesCount - sentFrames;
if (toSend > 0) {
sendFrames(ctx, toSend);
}
}
}

@Override
public void onOpen(ChannelHandlerContext ctx) {
this.ctx = ctx;
onHandshakeComplete.complete(textFrameEncoder);
}

@Override
public void onClose(ChannelHandlerContext ctx) {
if (!onFrameExchangeComplete.isDone()) {
onFrameExchangeComplete.completeExceptionally(new ClosedChannelException());
}
}

@Override
public void onExceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (!onFrameExchangeComplete.isDone()) {
onFrameExchangeComplete.completeExceptionally(cause);
}
}

CompletableFuture<WebSocketFrameFactory.Encoder> onHandshakeCompleted() {
return onHandshakeComplete;
}

CompletableFuture<Void> startFramesExchange() {
ChannelHandlerContext c = ctx;
c.executor().execute(() -> sendFrames(c, framesCount - sentFrames));
return onFrameExchangeComplete;
}

private void sendFrames(ChannelHandlerContext c, int toSend) {
Channel ch = c.channel();
WebSocketFrameFactory.Encoder frameEncoder = textFrameEncoder;
boolean pendingFlush = false;
ByteBufAllocator allocator = c.alloc();
for (int frameIdx = 0; frameIdx < toSend; frameIdx++) {
if (!c.channel().isOpen()) {
return;
}
int payloadSize = sentFrames;
int frameSize = frameEncoder.sizeofTextFrame(payloadSize);
ByteBuf textFrame = allocator.buffer(frameSize);
textFrame.writerIndex(frameSize - payloadSize);
for (int payloadIdx = 0; payloadIdx < payloadSize; payloadIdx++) {
textFrame.writeByte(expectedAsciiChar);
}
ByteBuf maskedTextFrame = frameEncoder.encodeTextFrame(textFrame);
sentFrames++;
if (ch.bytesBeforeUnwritable() < textFrame.capacity()) {
c.writeAndFlush(maskedTextFrame, c.voidPromise());
pendingFlush = false;
if (!ch.isWritable()) {
return;
}
} else {
c.write(maskedTextFrame, c.voidPromise());
pendingFlush = true;
}
}
if (pendingFlush) {
c.flush();
}
}
}

static class TextFramesFactoryClientHandler
implements WebSocketCallbacksHandler, WebSocketFrameListener {
private final CompletableFuture<WebSocketFrameFactory> onHandshakeComplete =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,20 @@ public BulkEncoder bulkEncoder() {

@Override
public ByteBuf encodeBinaryFrame(ByteBuf binaryFrame) {
return encodeDataFrame(binaryFrame, BINARY_FRAME_SMALL, BINARY_FRAME_MEDIUM);
}

@Override
public ByteBuf encodeTextFrame(ByteBuf textFrame) {
return encodeDataFrame(textFrame, TEXT_FRAME_SMALL, TEXT_FRAME_MEDIUM);
}

static ByteBuf encodeDataFrame(ByteBuf binaryFrame, int prefixSmall, int prefixMedium) {
int frameSize = binaryFrame.readableBytes();
int smallPrefixSize = 6;
if (frameSize <= 125 + smallPrefixSize) {
int payloadSize = frameSize - smallPrefixSize;
binaryFrame.setShort(0, BINARY_FRAME_SMALL | payloadSize);
binaryFrame.setShort(0, prefixSmall | payloadSize);
int mask = mask();
binaryFrame.setInt(2, mask);
return mask(mask, binaryFrame, smallPrefixSize, binaryFrame.writerIndex());
Expand All @@ -183,7 +192,7 @@ public ByteBuf encodeBinaryFrame(ByteBuf binaryFrame) {
if (frameSize <= 65_535 + mediumPrefixSize) {
int payloadSize = frameSize - mediumPrefixSize;
int mask = mask();
binaryFrame.setLong(0, ((BINARY_FRAME_MEDIUM | (long) payloadSize) << 32) | mask);
binaryFrame.setLong(0, ((prefixMedium | (long) payloadSize) << 32) | mask);
return mask(mask, binaryFrame, mediumPrefixSize, binaryFrame.writerIndex());
}
int payloadSize = frameSize - 12;
Expand Down Expand Up @@ -216,6 +225,15 @@ public ByteBuf maskBinaryFrame(ByteBuf byteBuf, int mask, int payloadSize) {

@Override
public int sizeofBinaryFrame(int payloadSize) {
return sizeOfDataFrame(payloadSize);
}

@Override
public int sizeofTextFrame(int textPayloadSize) {
return sizeOfDataFrame(textPayloadSize);
}

static int sizeOfDataFrame(int payloadSize) {
if (payloadSize <= 125) {
return payloadSize + 6;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,26 @@ public BulkEncoder bulkEncoder() {

@Override
public ByteBuf encodeBinaryFrame(ByteBuf binaryFrame) {
return encodeDataFrame(binaryFrame, BINARY_FRAME_SMALL, BINARY_FRAME_MEDIUM);
}

@Override
public ByteBuf encodeTextFrame(ByteBuf textFrame) {
return encodeDataFrame(textFrame, TEXT_FRAME_SMALL, TEXT_FRAME_MEDIUM);
}

static ByteBuf encodeDataFrame(ByteBuf binaryFrame, int prefixSmall, int prefixMedium) {
int frameSize = binaryFrame.readableBytes();
int smallPrefixSize = 2;
if (frameSize <= 125 + smallPrefixSize) {
int payloadSize = frameSize - smallPrefixSize;
return binaryFrame.setShort(0, BINARY_FRAME_SMALL | payloadSize);
return binaryFrame.setShort(0, prefixSmall | payloadSize);
}

int mediumPrefixSize = 4;
if (frameSize <= 65_535 + mediumPrefixSize) {
int payloadSize = frameSize - mediumPrefixSize;
return binaryFrame.setInt(0, BINARY_FRAME_MEDIUM | payloadSize);
return binaryFrame.setInt(0, prefixMedium | payloadSize);
}
int payloadSize = frameSize - 8;
throw new IllegalArgumentException(payloadSizeLimit(payloadSize, 65_535));
Expand All @@ -184,6 +193,15 @@ public ByteBuf maskBinaryFrame(ByteBuf byteBuf, int mask, int payloadSize) {

@Override
public int sizeofBinaryFrame(int payloadSize) {
return sizeOfDataFrame(payloadSize);
}

@Override
public int sizeofTextFrame(int textPayloadSize) {
return sizeOfDataFrame(textPayloadSize);
}

static int sizeOfDataFrame(int payloadSize) {
if (payloadSize <= 125) {
return payloadSize + 2;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ interface Encoder {
ByteBuf encodeBinaryFrame(ByteBuf binaryFrame);

int sizeofBinaryFrame(int payloadSize);

default ByteBuf encodeTextFrame(ByteBuf textFrame) {
throw new UnsupportedOperationException(
"WebSocketFrameFactory.Encoder.encodeTextFrame() not implemented");
}

default int sizeofTextFrame(int textPayloadSize) {
throw new UnsupportedOperationException(
"WebSocketFrameFactory.Encoder.sizeofTextFrame() not implemented");
}
}

/** Encodes prefixes of multiple binary websocket frames into provided bytebuffer. */
Expand Down

0 comments on commit 98e2cc8

Please sign in to comment.