From 3ca72f8ee5aaaa3aebd06791a7a76c60c99a5223 Mon Sep 17 00:00:00 2001 From: franz1981 Date: Wed, 9 Oct 2024 00:00:32 +0200 Subject: [PATCH] Save recursive locks on websocket writeFrame --- .../core/http/impl/ServerWebSocketImpl.java | 35 ++++++++++++------- .../core/http/impl/WebSocketImplBase.java | 26 +++++++++----- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/src/main/java/io/vertx/core/http/impl/ServerWebSocketImpl.java b/src/main/java/io/vertx/core/http/impl/ServerWebSocketImpl.java index da50efbd838..cded5f32b2d 100644 --- a/src/main/java/io/vertx/core/http/impl/ServerWebSocketImpl.java +++ b/src/main/java/io/vertx/core/http/impl/ServerWebSocketImpl.java @@ -21,14 +21,15 @@ import io.vertx.core.Future; import io.vertx.core.Handler; import io.vertx.core.Promise; -import io.vertx.codegen.annotations.Nullable; -import io.vertx.core.*; import io.vertx.core.http.ServerWebSocket; import io.vertx.core.http.WebSocketFrame; +import io.vertx.core.http.impl.ws.WebSocketFrameImpl; import io.vertx.core.impl.ContextInternal; import io.vertx.core.net.HostAndPort; import io.vertx.core.spi.metrics.HttpServerMetrics; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + import static io.netty.handler.codec.http.HttpResponseStatus.*; import static io.vertx.core.http.impl.HttpUtils.*; import static io.vertx.core.spi.metrics.Metrics.*; @@ -44,6 +45,8 @@ */ public class ServerWebSocketImpl extends WebSocketImplBase implements ServerWebSocket { + private static final AtomicReferenceFieldUpdater STATUS_UPDATER = + AtomicReferenceFieldUpdater.newUpdater(ServerWebSocketImpl.class, Integer.class, "status"); private final Http1xServerConnection conn; private final long closingTimeoutMS; private final String scheme; @@ -54,7 +57,7 @@ public class ServerWebSocketImpl extends WebSocketImplBase private final String query; private final WebSocketServerHandshaker handshaker; private Http1xServerRequest request; - private Integer status; + private volatile Integer status; private Promise handshakePromise; ServerWebSocketImpl(ContextInternal context, @@ -158,28 +161,26 @@ public Future close(short statusCode, String reason) { @Override public Future writeFrame(WebSocketFrame frame) { synchronized (conn) { - Boolean check = checkAccept(); + // if lucky, tryHandshake will return true without any need of synchronization lock on connection + final Boolean check = tryHandshake(SC_SWITCHING_PROTOCOLS); if (check == null) { throw new IllegalStateException("Cannot write to WebSocket, it is pending accept or reject"); } if (!check) { throw new IllegalStateException("Cannot write to WebSocket, it has been rejected"); } - return super.writeFrame(frame); + // this is not going through super.writeFrame as we want to avoid synchronizing against the connection + return super.unsafeWriteFrame((WebSocketFrameImpl) frame); } } - private Boolean checkAccept() { - return tryHandshake(SC_SWITCHING_PROTOCOLS); - } - private void handleHandshake(int sc) { synchronized (conn) { if (status == null) { if (sc == SC_SWITCHING_PROTOCOLS) { doHandshake(); } else { - status = sc; + STATUS_UPDATER.lazySet(this, sc); HttpUtils.sendError(conn.channel(), HttpResponseStatus.valueOf(sc)); } } @@ -198,7 +199,7 @@ private void doHandshake() { request = null; } response.completeHandshake(); - status = SWITCHING_PROTOCOLS.code(); + STATUS_UPDATER.lazySet(this, SWITCHING_PROTOCOLS.code()); subProtocol(handshaker.selectedSubprotocol()); // remove compressor as its not needed anymore once connection was upgraded to websockets ChannelPipeline pipeline = channel.pipeline(); @@ -210,9 +211,19 @@ private void doHandshake() { } Boolean tryHandshake(int sc) { + Integer status = this.status; + if (status != null) { + return status == sc; + } synchronized (conn) { - if (status == null && handshakePromise == null) { + assert status == null; + status = this.status; + if (status != null) { + return status == sc; + } + if (handshakePromise == null) { setHandshake(Future.succeededFuture(sc)); + status = this.status; } return status == null ? null : status == sc; } diff --git a/src/main/java/io/vertx/core/http/impl/WebSocketImplBase.java b/src/main/java/io/vertx/core/http/impl/WebSocketImplBase.java index 326cf12a9e0..241991f8373 100644 --- a/src/main/java/io/vertx/core/http/impl/WebSocketImplBase.java +++ b/src/main/java/io/vertx/core/http/impl/WebSocketImplBase.java @@ -399,13 +399,18 @@ private Future writePartialMessage(WebSocketFrameType frameType, Buffer da @Override public Future writeFrame(WebSocketFrame frame) { synchronized (conn) { - if (isClosed()) { - return context.failedFuture("WebSocket is closed"); - } - PromiseInternal promise = context.promise(); - conn.writeToChannel(encodeFrame((WebSocketFrameImpl) frame), promise); - return promise.future(); + return unsafeWriteFrame((WebSocketFrameImpl) frame); + } + } + + protected final Future unsafeWriteFrame(WebSocketFrameImpl frame) { + assert Thread.holdsLock(conn); + if (unsafeIsClosed()) { + return context.failedFuture("WebSocket is closed"); } + PromiseInternal promise = context.promise(); + conn.writeToChannel(encodeFrame(frame), promise); + return promise.future(); } public final S writeFrame(WebSocketFrame frame, Handler> handler) { @@ -424,7 +429,7 @@ private void writeTextFrameInternal(String str) { writeFrame(new WebSocketFrameImpl(str)); } - private io.netty.handler.codec.http.websocketx.WebSocketFrame encodeFrame(WebSocketFrameImpl frame) { + private static io.netty.handler.codec.http.websocketx.WebSocketFrame encodeFrame(WebSocketFrameImpl frame) { ByteBuf buf = safeBuffer(frame.getBinaryData()); switch (frame.type()) { case BINARY: @@ -452,10 +457,15 @@ void checkClosed() { public boolean isClosed() { synchronized (conn) { - return closed || closeStatusCode != null; + return unsafeIsClosed(); } } + private boolean unsafeIsClosed() { + assert Thread.holdsLock(conn); + return closed || closeStatusCode != null; + } + void handleFrame(WebSocketFrameInternal frame) { switch (frame.type()) { case PING: