Skip to content

Commit

Permalink
WebSocketServerProtocolHandler, WebSocketClientProtocolHandler: make …
Browse files Browse the repository at this point in the history
…WebSocketCallbacksHandler optional, instead provided via channel handler after websocket handshake completion
  • Loading branch information
mostroverkhov committed Jul 10, 2024
1 parent 5b71566 commit badceaf
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
Expand Down Expand Up @@ -172,26 +173,6 @@ void smallDecoderConfig() throws Exception {
client.close();
}

@Test
void clientBuilderMissingHandler() {
org.junit.jupiter.api.Assertions.assertThrows(
IllegalStateException.class,
() -> {
WebSocketClientProtocolHandler clientProtocolHandler =
WebSocketClientProtocolHandler.create().build();
});
}

@Test
void serverBuilderMissingHandler() {
org.junit.jupiter.api.Assertions.assertThrows(
IllegalStateException.class,
() -> {
WebSocketServerProtocolHandler serverProtocolHandler =
WebSocketServerProtocolHandler.create().build();
});
}

@Timeout(15)
@Test
void clientTimeout() throws InterruptedException {
Expand Down Expand Up @@ -309,6 +290,74 @@ protected void initChannel(SocketChannel ch) {
Assertions.assertThat(client.isOpen()).isFalse();
}

@Test
void noCallbackHandlerHandshake() throws Exception {
String path = "/test";
NoCallbackServerHandler noCallbackServerHandler = new NoCallbackServerHandler();
NoCallbackClientHandler noCallbackClientHandler = new NoCallbackClientHandler();

Channel s =
server =
new ServerBootstrap()
.group(new NioEventLoopGroup(1))
.channel(NioServerSocketChannel.class)
.childHandler(
new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
HttpServerCodec http1Codec = new HttpServerCodec();
HttpObjectAggregator http1Aggregator = new HttpObjectAggregator(65536);
WebSocketServerProtocolHandler webSocketProtocolHandler =
WebSocketServerProtocolHandler.create()
.path(path)
.decoderConfig(webSocketDecoderConfig(true, true, 125))
.build();

ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(
http1Codec,
http1Aggregator,
webSocketProtocolHandler,
noCallbackServerHandler);
}
})
.bind("localhost", 0)
.sync()
.channel();

Channel client =
new Bootstrap()
.group(new NioEventLoopGroup(1))
.channel(NioSocketChannel.class)
.handler(
new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
HttpClientCodec http1Codec = new HttpClientCodec();
HttpObjectAggregator http1Aggregator = new HttpObjectAggregator(65536);
WebSocketClientProtocolHandler webSocketProtocolHandler =
WebSocketClientProtocolHandler.create()
.path(path)
.allowMaskMismatch(true)
.maxFramePayloadLength(125)
.mask(true)
.build();

ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(
http1Codec,
http1Aggregator,
webSocketProtocolHandler,
noCallbackClientHandler);
}
})
.connect(s.localAddress())
.sync()
.channel();

noCallbackClientHandler.exchangeCompleted.get(5, TimeUnit.SECONDS);
}

@SuppressWarnings("deprecation")
@Timeout(15)
@Test
Expand Down Expand Up @@ -565,6 +614,118 @@ public void onClose(ChannelHandlerContext ctx) {
}
}

private static class NoCallbackClientHandler extends ChannelInboundHandlerAdapter {
Promise<Void> exchangeCompleted;

@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
exchangeCompleted = ctx.newPromise();
super.handlerAdded(ctx);
}

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
exchangeCompleted.tryFailure(new ClosedChannelException());
}

@SuppressWarnings("Convert2Lambda")
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt
== io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler
.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
WebSocketCallbacksHandler.exchange(
ctx,
new WebSocketCallbacksHandler() {
@Override
public WebSocketFrameListener exchange(
ChannelHandlerContext ctx, WebSocketFrameFactory webSocketFrameFactory) {
ctx.writeAndFlush(
webSocketFrameFactory.mask(
webSocketFrameFactory.createBinaryFrame(ctx.alloc(), 1).writeByte(0xFE)));

return new WebSocketFrameListener() {
@Override
public void onChannelRead(
ChannelHandlerContext context,
boolean finalFragment,
int rsv,
int opcode,
ByteBuf payload) {
int readableBytes = payload.readableBytes();
if (readableBytes != 1) {
payload.release();
exchangeCompleted.setFailure(
new IllegalStateException("unexpected payload size: " + readableBytes));
return;
}
byte content = payload.readByte();
if (content != (byte) 0xFE) {
payload.release();
exchangeCompleted.setFailure(
new IllegalStateException(
"unexpected payload content: " + Integer.toHexString(content)));
return;
}
payload.release();
exchangeCompleted.setSuccess(null);
}
};
}
});
}
super.userEventTriggered(ctx, evt);
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
ReferenceCountUtil.safeRelease(msg);
}
}

private static class NoCallbackServerHandler extends ChannelInboundHandlerAdapter {

@SuppressWarnings("Convert2Lambda")
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt
instanceof
io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.HandshakeComplete) {
WebSocketCallbacksHandler.exchange(
ctx,
new WebSocketCallbacksHandler() {
@Override
public WebSocketFrameListener exchange(
ChannelHandlerContext ctx, WebSocketFrameFactory webSocketFrameFactory) {
return new WebSocketFrameListener() {
@Override
public void onChannelRead(
ChannelHandlerContext context,
boolean finalFragment,
int rsv,
int opcode,
ByteBuf payload) {
ByteBuf binaryFrame =
webSocketFrameFactory.mask(
webSocketFrameFactory.createBinaryFrame(
ctx.alloc(), payload.readableBytes()));
binaryFrame.writeBytes(payload);
payload.release();
ctx.writeAndFlush(binaryFrame);
}
};
}
});
}
super.userEventTriggered(ctx, evt);
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
ReferenceCountUtil.safeRelease(msg);
}
}

static WebSocketDecoderConfig webSocketDecoderConfig(
boolean expectMasked, boolean allowMaskMismatch, int maxFramePayloadLength) {
return WebSocketDecoderConfig.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ private WebSocketClientProtocolHandler(
boolean allowMaskMismatch,
int maxFramePayloadLength,
long handshakeTimeoutMillis,
WebSocketCallbacksHandler webSocketHandler) {
@Nullable WebSocketCallbacksHandler webSocketHandler) {
this.address = address;
this.path = path;
this.subprotocol = subprotocol;
Expand Down Expand Up @@ -183,7 +183,10 @@ private void completeHandshake(ChannelHandlerContext ctx, FullHttpResponse respo
cancelHandshakeTimeout();
}
ctx.pipeline().remove(this);
WebSocketCallbacksHandler.exchange(ctx, webSocketHandler);
WebSocketCallbacksHandler handler = webSocketHandler;
if (handler != null) {
WebSocketCallbacksHandler.exchange(ctx, handler);
}
handshakeCompleted.trySuccess();
ctx.fireUserEventTriggered(
io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler
Expand Down Expand Up @@ -343,16 +346,13 @@ public Builder handshakeTimeoutMillis(long handshakeTimeoutMillis) {
* @param webSocketHandler handler to process successfully handshaked webSocket
* @return this Builder instance
*/
public Builder webSocketHandler(WebSocketCallbacksHandler webSocketHandler) {
this.webSocketHandler = Objects.requireNonNull(webSocketHandler, "webSocketHandler");
public Builder webSocketHandler(@Nullable WebSocketCallbacksHandler webSocketHandler) {
this.webSocketHandler = webSocketHandler;
return this;
}

/** @return new WebSocketClientProtocolHandler instance */
public WebSocketClientProtocolHandler build() {
if (webSocketHandler == null) {
throw new IllegalStateException("webSocketHandler was not provided");
}
int maxPayloadLength = maxFramePayloadLength;
boolean maskMismatch = allowMaskMismatch;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ private WebSocketServerProtocolHandler(
String subprotocols,
WebSocketDecoderConfig webSocketDecoderConfig,
long handshakeTimeoutMillis,
WebSocketCallbacksHandler webSocketHandler) {
@Nullable WebSocketCallbacksHandler webSocketHandler) {
this.path = path;
this.subprotocols = subprotocols;
this.decoderConfig = webSocketDecoderConfig;
Expand Down Expand Up @@ -197,7 +197,10 @@ private void handleHandshakeResult(
ctx.close();
}
} else {
WebSocketCallbacksHandler.exchange(ctx, webSocketHandler);
WebSocketCallbacksHandler handler = webSocketHandler;
if (handler != null) {
WebSocketCallbacksHandler.exchange(ctx, handler);
}
handshake.trySuccess();
ChannelPipeline p = ctx.channel().pipeline();
p.fireUserEventTriggered(
Expand Down Expand Up @@ -303,19 +306,15 @@ public Builder handshakeTimeoutMillis(long handshakeTimeoutMillis) {
* @param webSocketHandler handler to process successfully handshaked webSocket
* @return this Builder instance
*/
public Builder webSocketCallbacksHandler(WebSocketCallbacksHandler webSocketHandler) {
this.webSocketCallbacksHandler = Objects.requireNonNull(webSocketHandler, "webSocketHandler");
public Builder webSocketCallbacksHandler(@Nullable WebSocketCallbacksHandler webSocketHandler) {
this.webSocketCallbacksHandler = webSocketHandler;
return this;
}

/** @return new WebSocketServerProtocolHandler instance */
public WebSocketServerProtocolHandler build() {
WebSocketCallbacksHandler handler = webSocketCallbacksHandler;
if (handler == null) {
throw new IllegalStateException("webSocketCallbacksHandler was not provided");
}
return new WebSocketServerProtocolHandler(
path, subprotocols, decoderConfig, handshakeTimeoutMillis, handler);
path, subprotocols, decoderConfig, handshakeTimeoutMillis, webSocketCallbacksHandler);
}

private static long requirePositive(long val, String desc) {
Expand Down

0 comments on commit badceaf

Please sign in to comment.