Skip to content

Commit

Permalink
WebSockets Next: introduce OpenConnections
Browse files Browse the repository at this point in the history
  • Loading branch information
mkouba committed Apr 3, 2024
1 parent e1cc9a2 commit 10933c0
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ public void registerRoutes(WebSocketServerRecorder recorder, HttpRootPathBuildIt
.route(httpRootPath.relativePath(endpoint.path))
.displayOnNotFoundPage("WebSocket Endpoint")
.handlerType(HandlerType.NORMAL)
.handler(recorder.createEndpointHandler(endpoint.generatedClassName));
.handler(recorder.createEndpointHandler(endpoint.generatedClassName, endpoint.endpointClassName));
routes.produce(builder.build());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ public class ConnectionArgumentTest {
void testArgument() {
String message = "ok";
String header = "fool";
WSClient client = WSClient.create(vertx).connect(new WebSocketConnectOptions().addHeader("X-Test", header),
testUri);
JsonObject reply = client.sendAndAwaitReply(message).toJsonObject();
assertEquals(header, reply.getString("header"), reply.toString());
assertEquals(message, reply.getString("message"), reply.toString());
try (WSClient client = WSClient.create(vertx).connect(new WebSocketConnectOptions().addHeader("X-Test", header),
testUri)) {
JsonObject reply = client.sendAndAwaitReply(message).toJsonObject();
assertEquals(header, reply.getString("header"), reply.toString());
assertEquals(message, reply.getString("message"), reply.toString());
}
}

@WebSocket(path = "/echo")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package io.quarkus.websockets.next.test.openconnections;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
import java.util.Collection;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import jakarta.inject.Inject;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.OnClose;
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.OpenConnections;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.test.utils.WSClient;
import io.vertx.core.Vertx;
import io.vertx.core.http.WebSocketConnectOptions;

public class OpenConnectionsTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Endpoint.class, WSClient.class);
});

@Inject
Vertx vertx;

@TestHTTPResource("endpoint")
URI endUri;

@Inject
OpenConnections connections;

@Test
void testOpenConnections() throws Exception {
String headerName = "X-Test";
String header2 = "foo";
String header3 = "bar";

try (WSClient client1 = WSClient.create(vertx).connect(endUri);
WSClient client2 = WSClient.create(vertx).connect(new WebSocketConnectOptions().addHeader(headerName, header2),
endUri);
WSClient client3 = WSClient.create(vertx).connect(new WebSocketConnectOptions().addHeader(headerName, header3),
endUri)) {

client1.waitForMessages(1);
String client1Id = client1.getMessages().get(0).toString();

client2.waitForMessages(1);
String client2Id = client2.getMessages().get(0).toString();

client3.waitForMessages(1);
String client3Id = client3.getMessages().get(0).toString();

assertNotNull(connections.get(client1Id));
Collection<WebSocketConnection> found = connections.stream()
.filter(c -> header3.equals(c.handshakeRequest().header(headerName)))
.toList();
assertEquals(1, found.size());
assertEquals(client3Id, found.iterator().next().id());

found = connections.get();
assertEquals(3, found.size());
for (WebSocketConnection c : found) {
assertTrue(c.id().equals(client1Id) || c.id().equals(client2Id) || c.id().equals(client3Id));
}

client2.disconnect();
assertTrue(Endpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS));

assertEquals(2, connections.get().size());
assertNull(connections.stream().filter(c -> c.id().equals(client2Id)).findFirst().orElse(null));

found = connections.stream().filter(
c -> c.endpoint().equals("io.quarkus.websockets.next.test.openconnections.OpenConnectionsTest$Endpoint"))
.toList();
assertEquals(2, found.size());
}
}

@WebSocket(path = "/endpoint")
public static class Endpoint {

static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1);

@OnOpen
String open(WebSocketConnection connection) {
return connection.id();
}

@OnClose
void close() {
CLOSED_LATCH.countDown();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import io.vertx.core.http.WebSocketClient;
import io.vertx.core.http.WebSocketConnectOptions;

public class WSClient {
public class WSClient implements AutoCloseable {

private final WebSocketClient client;
private AtomicReference<WebSocket> socket = new AtomicReference<>();
Expand Down Expand Up @@ -124,4 +124,10 @@ public Buffer sendAndAwaitReply(String message) {
public boolean isClosed() {
return socket.get().isClosed();
}

@Override
public void close() {
disconnect();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package io.quarkus.websockets.next;

import java.util.Collection;
import java.util.stream.Stream;

import io.smallrye.common.annotation.Experimental;

/**
* Provides convenient access to all open connections from clients to {@link WebSocket} endpoints on the server.
*
* @see WebSocketConnection
*/
@Experimental("This API is experimental and may change in the future")
public interface OpenConnections {

/**
*
* @return an immutable collection of all open connections
*/
Collection<WebSocketConnection> get();

/**
*
* @param id
* @return the open connection or {@code null} if no open connection with the given id exists
* @see WebSocketConnection#id()
*/
WebSocketConnection get(String id);

/**
*
* @return the stream of open connections
*/
Stream<WebSocketConnection> stream();

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ public interface WebSocketConnection extends Sender, BlockingSender {
*/
String id();

/**
*
* @return the fully qualified class name of the {@link WebSocket} endpoint
*/
String endpoint();

/**
*
* @param name
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,46 @@
package io.quarkus.websockets.next.runtime;

import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import jakarta.annotation.PreDestroy;
import jakarta.inject.Singleton;

import org.jboss.logging.Logger;

import io.quarkus.websockets.next.OpenConnections;
import io.quarkus.websockets.next.WebSocketConnection;

@Singleton
public class ConnectionManager {
public class ConnectionManager implements OpenConnections {

private static final Logger LOG = Logger.getLogger(ConnectionManager.class);

private final ConcurrentMap<String, Set<WebSocketConnection>> endpointToConnections = new ConcurrentHashMap<>();

private final List<ConnectionListener> listeners = new CopyOnWriteArrayList<>();

@Override
public WebSocketConnection get(String id) {
return stream().filter(c -> c.id().equals(id)).findFirst().orElse(null);
}

@Override
public Collection<WebSocketConnection> get() {
return stream().collect(Collectors.toUnmodifiableList());
}

@Override
public Stream<WebSocketConnection> stream() {
return endpointToConnections.values().stream().flatMap(Set::stream).filter(WebSocketConnection::isOpen);
}

void add(String endpoint, WebSocketConnection connection) {
LOG.debugf("Add connection: %s", connection);
if (endpointToConnections.computeIfAbsent(endpoint, e -> ConcurrentHashMap.newKeySet()).add(connection)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@

class WebSocketConnectionImpl implements WebSocketConnection {

private final String endpoint;
private final String generatedEndpointClass;

private final String endpointClass;

private final String identifier;

Expand All @@ -44,9 +46,11 @@ class WebSocketConnectionImpl implements WebSocketConnection {

private final Instant creationTime;

WebSocketConnectionImpl(String endpoint, ServerWebSocket webSocket, ConnectionManager connectionManager,
WebSocketConnectionImpl(String generatedEndpointClass, String endpointClass, ServerWebSocket webSocket,
ConnectionManager connectionManager,
Codecs codecs, RoutingContext ctx) {
this.endpoint = endpoint;
this.generatedEndpointClass = generatedEndpointClass;
this.endpointClass = endpointClass;
this.identifier = UUID.randomUUID().toString();
this.webSocket = Objects.requireNonNull(webSocket);
this.connectionManager = Objects.requireNonNull(connectionManager);
Expand All @@ -62,6 +66,11 @@ public String id() {
return identifier;
}

@Override
public String endpoint() {
return endpointClass;
}

@Override
public String pathParam(String name) {
return pathParams.get(name);
Expand Down Expand Up @@ -124,7 +133,7 @@ public boolean isClosed() {

@Override
public Set<WebSocketConnection> getOpenConnections() {
return connectionManager.getConnections(endpoint).stream().filter(WebSocketConnection::isOpen)
return connectionManager.getConnections(generatedEndpointClass).stream().filter(WebSocketConnection::isOpen)
.collect(Collectors.toUnmodifiableSet());
}

Expand Down Expand Up @@ -292,7 +301,7 @@ public Uni<Void> sendPong(Buffer data) {
}

private <M> Uni<Void> doSend(BiFunction<WebSocketConnection, M, Uni<Void>> function, M message) {
Set<WebSocketConnection> connections = connectionManager.getConnections(endpoint);
Set<WebSocketConnection> connections = connectionManager.getConnections(generatedEndpointClass);
if (connections.isEmpty()) {
return Uni.createFrom().voidItem();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public Object get() {
};
}

public Handler<RoutingContext> createEndpointHandler(String endpointClass) {
public Handler<RoutingContext> createEndpointHandler(String generatedEndpointClass, String endpointClass) {
ArcContainer container = Arc.container();
ConnectionManager connectionManager = container.instance(ConnectionManager.class).get();
Codecs codecs = container.instance(Codecs.class).get();
Expand All @@ -69,9 +69,9 @@ public void handle(RoutingContext ctx) {
future.onSuccess(ws -> {
Context context = VertxCoreRecorder.getVertx().get().getOrCreateContext();

WebSocketConnection connection = new WebSocketConnectionImpl(endpointClass, ws,
WebSocketConnection connection = new WebSocketConnectionImpl(generatedEndpointClass, endpointClass, ws,
connectionManager, codecs, ctx);
connectionManager.add(endpointClass, connection);
connectionManager.add(generatedEndpointClass, connection);
LOG.debugf("Connnected: %s", connection);

// Initialize and capture the session context state that will be activated
Expand All @@ -83,7 +83,7 @@ public void handle(RoutingContext ctx) {
container.requestContext());

// Create an endpoint that delegates callbacks to the @WebSocket bean
WebSocketEndpoint endpoint = createEndpoint(endpointClass, context, connection, codecs, config,
WebSocketEndpoint endpoint = createEndpoint(generatedEndpointClass, context, connection, codecs, config,
contextSupport);

// A broadcast processor is only needed if Multi is consumed by the callback
Expand Down Expand Up @@ -214,7 +214,7 @@ public void handle(Void event) {
} else {
LOG.errorf(r.cause(), "Unable to complete @OnClose callback: %s", connection);
}
connectionManager.remove(endpointClass, connection);
connectionManager.remove(generatedEndpointClass, connection);
});
}
});
Expand Down

0 comments on commit 10933c0

Please sign in to comment.