diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/BroadcastOnMessageTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/BroadcastOnMessageTest.java index e985cee6dd142..89e918bb0ebb7 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/BroadcastOnMessageTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/BroadcastOnMessageTest.java @@ -57,7 +57,7 @@ public void assertBroadcast(URI testUri) throws Exception { WebSocketClient client2 = vertx.createWebSocketClient(); try { CountDownLatch connectedLatch = new CountDownLatch(2); - CountDownLatch onMessageLatch = new CountDownLatch(2); + CountDownLatch messagesLatch = new CountDownLatch(2); AtomicReference ws1 = new AtomicReference<>(); List messages = new CopyOnWriteArrayList<>(); @@ -68,7 +68,7 @@ public void assertBroadcast(URI testUri) throws Exception { WebSocket ws = r.result(); ws.textMessageHandler(msg -> { messages.add(msg); - onMessageLatch.countDown(); + messagesLatch.countDown(); }); // We will use this socket to write a message later on ws1.set(ws); @@ -84,7 +84,7 @@ public void assertBroadcast(URI testUri) throws Exception { WebSocket ws = r.result(); ws.textMessageHandler(msg -> { messages.add(msg); - onMessageLatch.countDown(); + messagesLatch.countDown(); }); connectedLatch.countDown(); } else { @@ -93,8 +93,8 @@ public void assertBroadcast(URI testUri) throws Exception { }); assertTrue(connectedLatch.await(5, TimeUnit.SECONDS)); ws1.get().writeTextMessage("hello"); - assertTrue(onMessageLatch.await(5, TimeUnit.SECONDS)); - assertEquals(2, messages.size()); + assertTrue(messagesLatch.await(5, TimeUnit.SECONDS), "Messages: " + messages); + assertEquals(2, messages.size(), "Messages: " + messages); // Both messages come from the first client assertEquals("1:HELLO", messages.get(0)); assertEquals("1:HELLO", messages.get(1)); diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/BroadcastOnOpenTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/BroadcastOnOpenTest.java index f69285d26fe0f..cebc41a1e7ced 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/BroadcastOnOpenTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/BroadcastOnOpenTest.java @@ -55,7 +55,6 @@ public void assertBroadcast(URI testUri) throws Exception { WebSocketClient client1 = vertx.createWebSocketClient(); WebSocketClient client2 = vertx.createWebSocketClient(); try { - CountDownLatch c1ConnectedLatch = new CountDownLatch(1); CountDownLatch c1MessageLatch = new CountDownLatch(1); CountDownLatch c2MessageLatch = new CountDownLatch(2); List messages = new CopyOnWriteArrayList<>(); @@ -74,18 +73,15 @@ public void assertBroadcast(URI testUri) throws Exception { } }); - c1ConnectedLatch.countDown(); } else { throw new IllegalStateException(r.cause()); } }); - assertTrue(c1ConnectedLatch.await(5, TimeUnit.SECONDS)); assertTrue(c1MessageLatch.await(5, TimeUnit.SECONDS)); assertEquals(1, messages.size()); assertEquals("c1", messages.get(0)); messages.clear(); // Now connect the second client - CountDownLatch c2ConnectedLatch = new CountDownLatch(1); client2 .connect(testUri.getPort(), testUri.getHost(), testUri.getPath() + "/C2") .onComplete(r -> { @@ -95,13 +91,11 @@ public void assertBroadcast(URI testUri) throws Exception { messages.add(msg); c2MessageLatch.countDown(); }); - c2ConnectedLatch.countDown(); } else { throw new IllegalStateException(r.cause()); } }); - assertTrue(c2ConnectedLatch.await(5, TimeUnit.SECONDS)); - assertTrue(c2MessageLatch.await(10, TimeUnit.SECONDS)); + assertTrue(c2MessageLatch.await(5, TimeUnit.SECONDS), "Messages: " + messages); // onOpen should be broadcasted to both clients assertEquals(2, messages.size()); assertEquals("c2", messages.get(0)); diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/UpMultiBidi.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/UpMultiBidi.java index 607edd6b0a6dd..bc46c76be1312 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/UpMultiBidi.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/UpMultiBidi.java @@ -1,6 +1,5 @@ package io.quarkus.websockets.next.test.broadcast; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import jakarta.inject.Inject; @@ -17,10 +16,11 @@ public class UpMultiBidi { @Inject WebSocketConnection connection; + // Keep in mind that this callback is invoked eagerly immediately after @OnOpen - due to consumed Multi + // That's why we cannot assert the number of open connections inside the callback @OnMessage(broadcast = true) Multi echo(Multi multi) { assertTrue(Context.isOnEventLoopThread()); - assertEquals(2, connection.getOpenConnections().size()); return multi.map(m -> connection.pathParam("client") + ":" + m.toUpperCase()); } diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConnectionManager.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConnectionManager.java index 4afad77fdc5be..cbcd3824fc650 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConnectionManager.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConnectionManager.java @@ -7,18 +7,24 @@ import jakarta.annotation.PreDestroy; import jakarta.inject.Singleton; +import org.jboss.logging.Logger; + import io.quarkus.websockets.next.WebSocketConnection; @Singleton public class ConnectionManager { + private static final Logger LOG = Logger.getLogger(ConnectionManager.class); + private final ConcurrentMap> endpointToConnections = new ConcurrentHashMap<>(); void add(String endpoint, WebSocketConnection connection) { + LOG.debugf("Add connection: %s", connection); endpointToConnections.computeIfAbsent(endpoint, e -> ConcurrentHashMap.newKeySet()).add(connection); } void remove(String endpoint, WebSocketConnection connection) { + LOG.debugf("Remove connection: %s", connection); Set connections = endpointToConnections.get(endpoint); if (connections != null) { connections.remove(connection); diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java index 218c848fc3879..417b801b19899 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java @@ -208,11 +208,18 @@ protected Uni sendText(String message, boolean broadcast) { } protected Uni multiText(Multi multi, boolean broadcast, Function> itemFun) { - multi.onFailure().call(connection::close).subscribe().with( - m -> { - itemFun.apply(m).subscribe().with(v -> LOG.debugf("Multi >> text message: %s", connection), - t -> LOG.errorf(t, "Unable to send text message from Multi: %s", connection)); - }); + multi.onFailure() + .call(connection::close) + .subscribe().with( + m -> { + itemFun.apply(m) + .subscribe() + .with(v -> LOG.debugf("Multi >> text message: %s", connection), + t -> LOG.errorf(t, "Unable to send text message from Multi: %s", connection)); + }, + t -> { + LOG.errorf(t, "Unable to send text message from Multi - connection was closed: %s ", connection); + }); return Uni.createFrom().voidItem(); } @@ -221,11 +228,18 @@ protected Uni sendBinary(Buffer message, boolean broadcast) { } protected Uni multiBinary(Multi multi, boolean broadcast, Function> itemFun) { - multi.onFailure().call(connection::close).subscribe().with( - m -> { - itemFun.apply(m).subscribe().with(v -> LOG.debugf("Multi >> binary message: %s", connection), - t -> LOG.errorf(t, "Unable to send binary message from Multi: %s", connection)); - }); + multi.onFailure() + .call(connection::close) + .subscribe().with( + m -> { + itemFun.apply(m) + .subscribe() + .with(v -> LOG.debugf("Multi >> binary message: %s", connection), + t -> LOG.errorf(t, "Unable to send binary message from Multi: %s", connection)); + }, + t -> { + LOG.errorf(t, "Unable to send text message from Multi - connection was closed: %s ", connection); + }); return Uni.createFrom().voidItem(); } }