Skip to content

Commit

Permalink
WebSocketFrameFactory.BulkEncoder: add text frames support
Browse files Browse the repository at this point in the history
  • Loading branch information
mostroverkhov committed Jun 28, 2024
1 parent 98e2cc8 commit d3a489b
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,25 @@ void binaryFramesBulkEncoder(boolean mask) throws Exception {
client.close();
}

@Timeout(300)
@ValueSource(booleans = {true, false})
@ParameterizedTest
void textFramesBulkEncoder(boolean mask) throws Exception {
int maxFrameSize = 1000;
Channel s = server = nettyServer(new WebSocketFramesTestServerHandler(), mask, false);
TextFramesEncoderClientBulkHandler clientHandler =
new TextFramesEncoderClientBulkHandler(maxFrameSize, 'a');
Channel client =
webSocketCallbacksClient(s.localAddress(), mask, true, maxFrameSize, clientHandler);

WebSocketFrameFactory.BulkEncoder encoder = clientHandler.onHandshakeCompleted().join();
Assertions.assertThat(encoder).isNotNull();

CompletableFuture<Void> onComplete = clientHandler.startFramesExchange();
onComplete.join();
client.close();
}

@Timeout(300)
@ValueSource(booleans = {true, false})
@ParameterizedTest
Expand Down Expand Up @@ -648,6 +667,159 @@ private void sendFrames(ChannelHandlerContext c, int toSend) {
}
}

static class TextFramesEncoderClientBulkHandler
implements WebSocketCallbacksHandler, WebSocketFrameListener {
private final CompletableFuture<WebSocketFrameFactory.BulkEncoder> onHandshakeComplete =
new CompletableFuture<>();
private final CompletableFuture<Void> onFrameExchangeComplete = new CompletableFuture<>();
private final int framesCount;
private final char expectedAsciiChar;
private WebSocketFrameFactory.BulkEncoder textFrameEncoder;
private int receivedFrames;
private int sentFrames;
private ByteBuf outBuffer;
private volatile ChannelHandlerContext ctx;

TextFramesEncoderClientBulkHandler(int maxFrameSize, char expectedAsciiChar) {
this.framesCount = maxFrameSize;
this.expectedAsciiChar = expectedAsciiChar;
}

@Override
public WebSocketFrameListener exchange(
ChannelHandlerContext ctx, WebSocketFrameFactory webSocketFrameFactory) {
this.textFrameEncoder = webSocketFrameFactory.bulkEncoder();
return this;
}

@Override
public void onChannelRead(
ChannelHandlerContext ctx, boolean finalFragment, int rsv, int opcode, ByteBuf payload) {
if (!finalFragment) {
onFrameExchangeComplete.completeExceptionally(
new AssertionError("received non-final frame: " + finalFragment));
payload.release();
return;
}
if (rsv != 0) {
onFrameExchangeComplete.completeExceptionally(
new AssertionError("received frame with non-zero rsv: " + rsv));
payload.release();
return;
}
if (opcode != WebSocketProtocol.OPCODE_TEXT) {
onFrameExchangeComplete.completeExceptionally(
new AssertionError("received non-text frame: " + Long.toHexString(opcode)));
payload.release();
return;
}

int readableBytes = payload.readableBytes();

int expectedSize = receivedFrames;
if (expectedSize != readableBytes) {
onFrameExchangeComplete.completeExceptionally(
new AssertionError(
"received frame of unexpected size: "
+ expectedSize
+ ", actual: "
+ readableBytes));
payload.release();
return;
}

for (int i = 0; i < readableBytes; i++) {
char ch = (char) payload.readByte();
if (ch != expectedAsciiChar) {
onFrameExchangeComplete.completeExceptionally(
new AssertionError(
"received frame with unexpected content: "
+ ch
+ ", expected: "
+ expectedAsciiChar));
payload.release();
return;
}
}
payload.release();
if (++receivedFrames == framesCount) {
onFrameExchangeComplete.complete(null);
}
}

@Override
public void onOpen(ChannelHandlerContext ctx) {
this.ctx = ctx;
int bufferSize = 4 * framesCount;
this.outBuffer = ctx.alloc().buffer(bufferSize, bufferSize);
onHandshakeComplete.complete(textFrameEncoder);
}

@Override
public void onClose(ChannelHandlerContext ctx) {
ByteBuf out = outBuffer;
if (out != null) {
outBuffer = null;
out.release();
}
if (!onFrameExchangeComplete.isDone()) {
onFrameExchangeComplete.completeExceptionally(new ClosedChannelException());
}
}

@Override
public void onExceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (!onFrameExchangeComplete.isDone()) {
onFrameExchangeComplete.completeExceptionally(cause);
}
}

CompletableFuture<WebSocketFrameFactory.BulkEncoder> onHandshakeCompleted() {
return onHandshakeComplete;
}

CompletableFuture<Void> startFramesExchange() {
ChannelHandlerContext c = ctx;
c.executor().execute(() -> sendFrames(c, framesCount - sentFrames));
return onFrameExchangeComplete;
}

private void sendFrames(ChannelHandlerContext c, int toSend) {
WebSocketFrameFactory.BulkEncoder frameEncoder = textFrameEncoder;
for (int frameIdx = 0; frameIdx < toSend; frameIdx++) {
if (!c.channel().isOpen()) {
return;
}
int payloadSize = sentFrames;
int frameSize = frameEncoder.sizeofTextFrame(payloadSize);
ByteBuf out = outBuffer;
if (frameSize > out.capacity() - out.writerIndex()) {
int readableBytes = out.readableBytes();
int bufferSize = 4 * framesCount;
outBuffer = c.alloc().buffer(bufferSize, bufferSize);
if (c.channel().bytesBeforeUnwritable() < readableBytes) {
c.writeAndFlush(out, c.voidPromise());
} else {
c.write(out, c.voidPromise());
}
out = outBuffer;
}
int mask = frameEncoder.encodeTextFramePrefix(out, payloadSize);
for (int payloadIdx = 0; payloadIdx < payloadSize; payloadIdx++) {
out.writeByte(expectedAsciiChar);
}
frameEncoder.maskTextFrame(out, mask, payloadSize);
sentFrames++;
}
ByteBuf out = outBuffer;
if (out.readableBytes() > 0) {
c.writeAndFlush(out, c.voidPromise());
} else {
c.flush();
}
}
}

static class BinaryFramesEncoderClientHandler
implements WebSocketCallbacksHandler, WebSocketFrameListener {
private final CompletableFuture<WebSocketFrameFactory.Encoder> onHandshakeComplete =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,23 +201,42 @@ static ByteBuf encodeDataFrame(ByteBuf binaryFrame, int prefixSmall, int prefixM

@Override
public int encodeBinaryFramePrefix(ByteBuf byteBuf, int payloadSize) {
return encodeDataFramePrefix(byteBuf, payloadSize, BINARY_FRAME_SMALL, BINARY_FRAME_MEDIUM);
}

@Override
public int encodeTextFramePrefix(ByteBuf byteBuf, int textPayloadSize) {
return encodeDataFramePrefix(byteBuf, textPayloadSize, TEXT_FRAME_SMALL, TEXT_FRAME_MEDIUM);
}

static int encodeDataFramePrefix(
ByteBuf byteBuf, int payloadSize, int prefixSmall, int prefixMedium) {
if (payloadSize <= 125) {
byteBuf.writeShort(BINARY_FRAME_SMALL | payloadSize);
byteBuf.writeShort(prefixSmall | payloadSize);
int mask = mask();
byteBuf.writeInt(mask);
return mask;
}

if (payloadSize <= 65_535) {
int mask = mask();
byteBuf.writeLong(((BINARY_FRAME_MEDIUM | (long) payloadSize) << 32) | mask);
byteBuf.writeLong(((prefixMedium | (long) payloadSize) << 32) | mask);
return mask;
}
throw new IllegalArgumentException(payloadSizeLimit(payloadSize, 65_535));
}

@Override
public ByteBuf maskBinaryFrame(ByteBuf byteBuf, int mask, int payloadSize) {
return maskDataFrame(byteBuf, mask, payloadSize);
}

@Override
public ByteBuf maskTextFrame(ByteBuf byteBuf, int mask, int textPayloadSize) {
return maskDataFrame(byteBuf, mask, textPayloadSize);
}

static ByteBuf maskDataFrame(ByteBuf byteBuf, int mask, int payloadSize) {
int end = byteBuf.writerIndex();
int start = end - payloadSize;
return mask(mask, byteBuf, start, end);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,20 @@ static ByteBuf encodeDataFrame(ByteBuf binaryFrame, int prefixSmall, int prefixM

@Override
public int encodeBinaryFramePrefix(ByteBuf byteBuf, int payloadSize) {
return encodeDataFramePrefix(byteBuf, payloadSize, BINARY_FRAME_SMALL, BINARY_FRAME_MEDIUM);
}

@Override
public int encodeTextFramePrefix(ByteBuf byteBuf, int textPayloadSize) {
return encodeDataFramePrefix(byteBuf, textPayloadSize, TEXT_FRAME_SMALL, TEXT_FRAME_MEDIUM);
}

static int encodeDataFramePrefix(
ByteBuf byteBuf, int payloadSize, int prefixSmall, int prefixMedium) {
if (payloadSize <= 125) {
byteBuf.writeShort(BINARY_FRAME_SMALL | payloadSize);
byteBuf.writeShort(prefixSmall | payloadSize);
} else if (payloadSize <= 65_535) {
byteBuf.writeInt(BINARY_FRAME_MEDIUM | payloadSize);
byteBuf.writeInt(prefixMedium | payloadSize);
} else {
throw new IllegalArgumentException(payloadSizeLimit(payloadSize, 65_535));
}
Expand All @@ -191,6 +201,11 @@ public ByteBuf maskBinaryFrame(ByteBuf byteBuf, int mask, int payloadSize) {
return byteBuf;
}

@Override
public ByteBuf maskTextFrame(ByteBuf byteBuf, int mask, int textPayloadSize) {
return byteBuf;
}

@Override
public int sizeofBinaryFrame(int payloadSize) {
return sizeOfDataFrame(payloadSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,21 @@ interface BulkEncoder {
ByteBuf maskBinaryFrame(ByteBuf byteBuf, int mask, int payloadSize);

int sizeofBinaryFrame(int payloadSize);

/** @return frame mask, or -1 if masking not applicable */
default int encodeTextFramePrefix(ByteBuf byteBuf, int textPayloadSize) {
throw new UnsupportedOperationException(
"WebSocketFrameFactory.BulkEncoder.encodeTextFramePrefix() not implemented");
}

default ByteBuf maskTextFrame(ByteBuf byteBuf, int mask, int textPayloadSize) {
throw new UnsupportedOperationException(
"WebSocketFrameFactory.BulkEncoder.maskTextFrame() not implemented");
}

default int sizeofTextFrame(int textPayloadSize) {
throw new UnsupportedOperationException(
"WebSocketFrameFactory.BulkEncoder.sizeofTextFrame() not implemented");
}
}
}

0 comments on commit d3a489b

Please sign in to comment.