diff --git a/bom/application/pom.xml b/bom/application/pom.xml index ba6824f7c91f6..a39fde04b1b1f 100644 --- a/bom/application/pom.xml +++ b/bom/application/pom.xml @@ -2075,6 +2075,16 @@ quarkus-websockets-client-deployment ${project.version} + + io.quarkus + quarkus-websockets-next + ${project.version} + + + io.quarkus + quarkus-websockets-next-deployment + ${project.version} + io.quarkus quarkus-undertow-spi diff --git a/devtools/bom-descriptor-json/pom.xml b/devtools/bom-descriptor-json/pom.xml index bdc9ab6b66b46..dcc5e857a1f9c 100644 --- a/devtools/bom-descriptor-json/pom.xml +++ b/devtools/bom-descriptor-json/pom.xml @@ -2891,6 +2891,19 @@ + + io.quarkus + quarkus-websockets-next + ${project.version} + pom + test + + + * + * + + + diff --git a/docs/pom.xml b/docs/pom.xml index b4ee3031cba71..a3b12150163da 100644 --- a/docs/pom.xml +++ b/docs/pom.xml @@ -2907,6 +2907,19 @@ + + io.quarkus + quarkus-websockets-next-deployment + ${project.version} + pom + test + + + * + * + + + diff --git a/extensions/pom.xml b/extensions/pom.xml index a9ae0043e8504..5dd8c802a2f68 100644 --- a/extensions/pom.xml +++ b/extensions/pom.xml @@ -38,6 +38,7 @@ vertx-http undertow websockets + websockets-next webjars-locator resteasy-reactive reactive-routes diff --git a/extensions/websockets-next/pom.xml b/extensions/websockets-next/pom.xml new file mode 100644 index 0000000000000..6eb2495866185 --- /dev/null +++ b/extensions/websockets-next/pom.xml @@ -0,0 +1,20 @@ + + + + quarkus-extensions-parent + io.quarkus + 999-SNAPSHOT + ../pom.xml + + 4.0.0 + + quarkus-websockets-next-aggregator + Quarkus - WebSockets Next Aggregator + pom + + server + + + diff --git a/extensions/websockets-next/server/deployment/pom.xml b/extensions/websockets-next/server/deployment/pom.xml new file mode 100644 index 0000000000000..f2405b4fc35c9 --- /dev/null +++ b/extensions/websockets-next/server/deployment/pom.xml @@ -0,0 +1,72 @@ + + + + quarkus-websockets-next-parent + io.quarkus + 999-SNAPSHOT + + 4.0.0 + + quarkus-websockets-next-deployment + Quarkus - WebSockets Next - Deployment + + + + io.quarkus + quarkus-core-deployment + + + io.quarkus + quarkus-vertx-http-deployment + + + io.quarkus + quarkus-jackson-deployment + + + io.quarkus + quarkus-websockets-next + + + io.quarkus + quarkus-junit5-internal + test + + + org.assertj + assertj-core + test + + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${project.version} + + + + + + maven-surefire-plugin + + + + org.jboss.logmanager.LogManager + INFO + ${maven.home} + + + + + + + diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/GeneratedEndpointBuildItem.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/GeneratedEndpointBuildItem.java new file mode 100644 index 0000000000000..7e4d72cc1b6d1 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/GeneratedEndpointBuildItem.java @@ -0,0 +1,18 @@ +package io.quarkus.websockets.next.deployment; + +import io.quarkus.builder.item.MultiBuildItem; + +/** + * A generated {@link io.quarkus.websockets.next.runtime.WebSocketEndpoint}. + */ +final class GeneratedEndpointBuildItem extends MultiBuildItem { + + final String className; + final String path; + + GeneratedEndpointBuildItem(String className, String path) { + this.className = className; + this.path = path; + } + +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketDotNames.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketDotNames.java new file mode 100644 index 0000000000000..d3158709bec12 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketDotNames.java @@ -0,0 +1,38 @@ +package io.quarkus.websockets.next.deployment; + +import org.jboss.jandex.DotName; + +import io.quarkus.websockets.next.BinaryMessage; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.TextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.smallrye.common.annotation.Blocking; +import io.smallrye.common.annotation.RunOnVirtualThread; +import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.Uni; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.json.JsonArray; +import io.vertx.core.json.JsonObject; + +final class WebSocketDotNames { + + static final DotName WEB_SOCKET = DotName.createSimple(WebSocket.class); + static final DotName WEB_SOCKET_CONNECTION = DotName.createSimple(WebSocketServerConnection.class); + static final DotName ON_OPEN = DotName.createSimple(OnOpen.class); + static final DotName ON_MESSAGE = DotName.createSimple(OnMessage.class); + static final DotName ON_CLOSE = DotName.createSimple(OnClose.class); + static final DotName UNI = DotName.createSimple(Uni.class); + static final DotName MULTI = DotName.createSimple(Multi.class); + static final DotName RUN_ON_VIRTUAL_THREAD = DotName.createSimple(RunOnVirtualThread.class); + static final DotName BLOCKING = DotName.createSimple(Blocking.class); + static final DotName STRING = DotName.createSimple(String.class); + static final DotName BUFFER = DotName.createSimple(Buffer.class); + static final DotName JSON_OBJECT = DotName.createSimple(JsonObject.class); + static final DotName JSON_ARRAY = DotName.createSimple(JsonArray.class); + static final DotName VOID = DotName.createSimple(Void.class); + static final DotName BINARY_MESSAGE = DotName.createSimple(BinaryMessage.class); + static final DotName TEXT_MESSAGE = DotName.createSimple(TextMessage.class); +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java new file mode 100644 index 0000000000000..1c1fcb112f22c --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java @@ -0,0 +1,153 @@ +package io.quarkus.websockets.next.deployment; + +import org.jboss.jandex.AnnotationInstance; +import org.jboss.jandex.AnnotationValue; +import org.jboss.jandex.DotName; +import org.jboss.jandex.MethodInfo; +import org.jboss.jandex.PrimitiveType; +import org.jboss.jandex.Type; +import org.jboss.jandex.Type.Kind; + +import io.quarkus.arc.processor.BeanInfo; +import io.quarkus.builder.item.MultiBuildItem; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.runtime.WebSocketEndpoint; +import io.quarkus.websockets.next.runtime.WebSocketEndpoint.ExecutionModel; +import io.quarkus.websockets.next.runtime.WebSocketEndpoint.MessageType; + +/** + * This build item represents a WebSocket endpoint class. + */ +public final class WebSocketEndpointBuildItem extends MultiBuildItem { + + public final BeanInfo bean; + public final String path; + public final WebSocket.ExecutionMode executionMode; + public final Callback onOpen; + public final Callback onMessage; + public final Callback onClose; + + public WebSocketEndpointBuildItem(BeanInfo bean, String path, WebSocket.ExecutionMode executionMode, Callback onOpen, + Callback onMessage, Callback onClose) { + this.bean = bean; + this.path = path; + this.executionMode = executionMode; + this.onOpen = onOpen; + this.onMessage = onMessage; + this.onClose = onClose; + } + + public static class Callback { + + public final AnnotationInstance annotation; + public final MethodInfo method; + public final ExecutionModel executionModel; + public final MessageType consumedMessageType; + public final MessageType producedMessageType; + + public Callback(AnnotationInstance annotation, MethodInfo method, ExecutionModel executionModel) { + this.method = method; + this.annotation = annotation; + this.executionModel = executionModel; + this.consumedMessageType = initMessageType(method.parameters().isEmpty() ? null : method.parameterType(0)); + this.producedMessageType = initMessageType(method.returnType()); + } + + public Type returnType() { + return method.returnType(); + } + + public Type messageParamType() { + return acceptsMessage() ? method.parameterType(0) : null; + } + + public boolean isReturnTypeVoid() { + return returnType().kind() == Kind.VOID; + } + + public boolean isReturnTypeUni() { + return WebSocketDotNames.UNI.equals(returnType().name()); + } + + public boolean isReturnTypeMulti() { + return WebSocketDotNames.MULTI.equals(returnType().name()); + } + + public boolean acceptsMessage() { + return consumedMessageType != MessageType.NONE; + } + + public boolean acceptsBinaryMessage() { + return consumedMessageType == MessageType.BINARY; + } + + public boolean acceptsMulti() { + return acceptsMessage() && method.parameterType(0).name().equals(WebSocketDotNames.MULTI); + } + + public WebSocketEndpoint.MessageType consumedMessageType() { + return consumedMessageType; + } + + public WebSocketEndpoint.MessageType producedMessageType() { + return producedMessageType; + } + + public boolean broadcast() { + AnnotationValue broadcastValue = annotation.value("broadcast"); + return broadcastValue != null && broadcastValue.asBoolean(); + } + + public DotName getInputCodec() { + return getCodec("inputCodec"); + } + + public DotName getOutputCodec() { + DotName output = getCodec("outputCodec"); + return output != null ? output : getInputCodec(); + } + + private DotName getCodec(String valueName) { + AnnotationInstance messageAnnotation = method.declaredAnnotation(WebSocketDotNames.BINARY_MESSAGE); + if (messageAnnotation == null) { + messageAnnotation = method.declaredAnnotation(WebSocketDotNames.TEXT_MESSAGE); + } + if (messageAnnotation != null) { + AnnotationValue codecValue = messageAnnotation.value(valueName); + if (codecValue != null) { + return codecValue.asClass().name(); + } + } + return null; + } + + MessageType initMessageType(Type messageType) { + MessageType ret = MessageType.NONE; + if (messageType != null && !messageType.name().equals(WebSocketDotNames.VOID)) { + if (method.hasDeclaredAnnotation(WebSocketDotNames.BINARY_MESSAGE)) { + ret = MessageType.BINARY; + } else if (method.hasDeclaredAnnotation(WebSocketDotNames.TEXT_MESSAGE)) { + ret = MessageType.TEXT; + } else { + if (isByteArray(messageType) || WebSocketDotNames.BUFFER.equals(messageType.name())) { + ret = MessageType.BINARY; + } else { + ret = MessageType.TEXT; + } + } + } + return ret; + } + + static boolean isByteArray(Type type) { + return type.kind() == Kind.ARRAY && PrimitiveType.BYTE.equals(type.asArrayType().constituent()); + } + + static boolean isUniVoid(Type type) { + return WebSocketDotNames.UNI.equals(type.name()) + && type.asParameterizedType().arguments().get(0).name().equals(WebSocketDotNames.VOID); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java new file mode 100644 index 0000000000000..f2403c2f19569 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java @@ -0,0 +1,743 @@ +package io.quarkus.websockets.next.deployment; + +import static io.quarkus.deployment.annotations.ExecutionTime.RUNTIME_INIT; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import jakarta.enterprise.context.SessionScoped; + +import org.jboss.jandex.AnnotationInstance; +import org.jboss.jandex.AnnotationValue; +import org.jboss.jandex.ClassInfo; +import org.jboss.jandex.ClassInfo.NestingType; +import org.jboss.jandex.DotName; +import org.jboss.jandex.IndexView; +import org.jboss.jandex.MethodInfo; +import org.jboss.jandex.Type; +import org.jboss.jandex.Type.Kind; + +import io.quarkus.arc.deployment.AdditionalBeanBuildItem; +import io.quarkus.arc.deployment.BeanArchiveIndexBuildItem; +import io.quarkus.arc.deployment.BeanDefiningAnnotationBuildItem; +import io.quarkus.arc.deployment.BeanDiscoveryFinishedBuildItem; +import io.quarkus.arc.deployment.ContextRegistrationPhaseBuildItem; +import io.quarkus.arc.deployment.ContextRegistrationPhaseBuildItem.ContextConfiguratorBuildItem; +import io.quarkus.arc.deployment.CustomScopeBuildItem; +import io.quarkus.arc.deployment.SyntheticBeanBuildItem; +import io.quarkus.arc.deployment.UnremovableBeanBuildItem; +import io.quarkus.arc.processor.BeanInfo; +import io.quarkus.arc.processor.DotNames; +import io.quarkus.arc.processor.Types; +import io.quarkus.deployment.GeneratedClassGizmoAdaptor; +import io.quarkus.deployment.annotations.BuildProducer; +import io.quarkus.deployment.annotations.BuildStep; +import io.quarkus.deployment.annotations.Record; +import io.quarkus.deployment.builditem.FeatureBuildItem; +import io.quarkus.deployment.builditem.GeneratedClassBuildItem; +import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem; +import io.quarkus.gizmo.BytecodeCreator; +import io.quarkus.gizmo.ClassCreator; +import io.quarkus.gizmo.ClassOutput; +import io.quarkus.gizmo.FunctionCreator; +import io.quarkus.gizmo.MethodCreator; +import io.quarkus.gizmo.MethodDescriptor; +import io.quarkus.gizmo.ResultHandle; +import io.quarkus.vertx.http.deployment.HttpRootPathBuildItem; +import io.quarkus.vertx.http.deployment.RouteBuildItem; +import io.quarkus.vertx.http.runtime.HandlerType; +import io.quarkus.websockets.next.TextMessageCodec; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketsRuntimeConfig; +import io.quarkus.websockets.next.deployment.WebSocketEndpointBuildItem.Callback; +import io.quarkus.websockets.next.runtime.Codecs; +import io.quarkus.websockets.next.runtime.ConnectionManager; +import io.quarkus.websockets.next.runtime.ContextSupport; +import io.quarkus.websockets.next.runtime.JsonTextMessageCodec; +import io.quarkus.websockets.next.runtime.WebSocketEndpoint; +import io.quarkus.websockets.next.runtime.WebSocketEndpoint.ExecutionModel; +import io.quarkus.websockets.next.runtime.WebSocketEndpoint.MessageType; +import io.quarkus.websockets.next.runtime.WebSocketEndpointBase; +import io.quarkus.websockets.next.runtime.WebSocketServerRecorder; +import io.quarkus.websockets.next.runtime.WebSocketSessionContext; +import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.Uni; +import io.smallrye.mutiny.groups.UniCreate; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.json.JsonArray; +import io.vertx.core.json.JsonObject; + +public class WebSocketServerProcessor { + + static final String ENDPOINT_SUFFIX = "_WebSocketEndpoint"; + static final String NESTED_SEPARATOR = "$_"; + + private static final Pattern PATH_PARAM_PATTERN = Pattern.compile("\\{[a-zA-Z0-9_]+\\}"); + + @BuildStep + FeatureBuildItem feature() { + return new FeatureBuildItem("websockets-next"); + } + + @BuildStep + BeanDefiningAnnotationBuildItem beanDefiningAnnotation() { + return new BeanDefiningAnnotationBuildItem(WebSocketDotNames.WEB_SOCKET, DotNames.SINGLETON); + } + + @BuildStep + void unremovableBeans(BuildProducer unremovableBeans) { + unremovableBeans.produce(UnremovableBeanBuildItem.beanTypes(TextMessageCodec.class)); + } + + @BuildStep + public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex, + BeanDiscoveryFinishedBuildItem beanDiscoveryFinished, + BuildProducer endpoints) { + + IndexView index = beanArchiveIndex.getIndex(); + Map pathToEndpoint = new HashMap<>(); + + for (BeanInfo bean : beanDiscoveryFinished.beanStream().classBeans()) { + ClassInfo beanClass = bean.getTarget().get().asClass(); + AnnotationInstance webSocketAnnotation = beanClass.annotation(WebSocketDotNames.WEB_SOCKET); + if (webSocketAnnotation != null) { + String path = getPath(webSocketAnnotation.value("path").asString()); + if (beanClass.nestingType() == NestingType.INNER) { + // Sub-websocket - merge the path from the enclosing classes + path = mergePath(getPathPrefix(index, beanClass.enclosingClass()), path); + } + DotName previous = pathToEndpoint.put(path, beanClass.name()); + if (previous != null) { + throw new WebSocketServerException( + String.format("Multiple endpoints [%s, %s] define the same path: %s", previous, beanClass, path)); + } + Callback onOpen = findCallback(beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_OPEN, + this::validateOnOpen); + Callback onMessage = findCallback(beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_MESSAGE, + this::validateOnMessage); + Callback onClose = findCallback(beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_CLOSE, + this::validateOnClose); + if (onOpen == null && onMessage == null) { + throw new WebSocketServerException( + "The endpoint must declare at least one method annotated with @OnMessage or @OnOpen: " + beanClass); + } + AnnotationValue executionMode = webSocketAnnotation.value("executionMode"); + endpoints.produce(new WebSocketEndpointBuildItem(bean, path, + executionMode != null ? WebSocket.ExecutionMode.valueOf(executionMode.asEnum()) + : WebSocket.ExecutionMode.SERIAL, + onOpen, + onMessage, onClose)); + } + } + } + + @BuildStep + public void generateEndpoints(List endpoints, + BuildProducer generatedClasses, + BuildProducer generatedEndpoints, + BuildProducer reflectiveClasses) { + ClassOutput classOutput = new GeneratedClassGizmoAdaptor(generatedClasses, new Function() { + @Override + public String apply(String name) { + int idx = name.indexOf(ENDPOINT_SUFFIX); + if (idx != -1) { + name = name.substring(0, idx); + } + if (name.contains(NESTED_SEPARATOR)) { + name = name.replace(NESTED_SEPARATOR, "$"); + } + return name; + } + }); + for (WebSocketEndpointBuildItem endpoint : endpoints) { + // For each WebSocket endpoint bean generate an implementation of WebSocketEndpoint + // A new instance of this generated endpoint is created for each client connection + // The generated endpoint ensures the correct execution model is used + // and delegates callback invocations to the endpoint bean + String generatedName = generateEndpoint(endpoint, classOutput); + reflectiveClasses.produce(ReflectiveClassBuildItem.builder(generatedName).constructors().build()); + generatedEndpoints.produce(new GeneratedEndpointBuildItem(generatedName, endpoint.path)); + } + } + + @Record(RUNTIME_INIT) + @BuildStep + public void registerRoutes(WebSocketServerRecorder recorder, HttpRootPathBuildItem httpRootPath, + List generatedEndpoints, + BuildProducer routes) { + + for (GeneratedEndpointBuildItem endpoint : generatedEndpoints) { + RouteBuildItem.Builder builder = RouteBuildItem.builder() + .route(httpRootPath.relativePath(endpoint.path)) + .handlerType(HandlerType.NORMAL) + .handler(recorder.createEndpointHandler(endpoint.className)); + routes.produce(builder.build()); + } + } + + @BuildStep + AdditionalBeanBuildItem additionalBeans() { + return AdditionalBeanBuildItem.builder().setUnremovable() + .addBeanClasses(Codecs.class, JsonTextMessageCodec.class, ConnectionManager.class).build(); + } + + @BuildStep + @Record(RUNTIME_INIT) + void syntheticBeans(WebSocketServerRecorder recorder, BuildProducer syntheticBeans) { + syntheticBeans.produce(SyntheticBeanBuildItem.configure(WebSocketServerConnection.class) + .scope(SessionScoped.class) + .setRuntimeInit() + .supplier(recorder.connectionSupplier()) + .unremovable() + .done()); + } + + @BuildStep + ContextConfiguratorBuildItem registerSessionContext(ContextRegistrationPhaseBuildItem phase) { + return new ContextConfiguratorBuildItem(phase.getContext() + .configure(SessionScoped.class) + .normal() + .contextClass(WebSocketSessionContext.class)); + } + + @BuildStep + CustomScopeBuildItem registerSessionScope() { + return new CustomScopeBuildItem(DotName.createSimple(SessionScoped.class.getName())); + } + + static String mergePath(String prefix, String path) { + if (prefix.endsWith("/")) { + prefix = prefix.substring(0, prefix.length() - 1); + } + if (!path.startsWith("/")) { + path = "/" + path; + } + return prefix + path; + } + + static String getPath(String path) { + StringBuilder sb = new StringBuilder(); + Matcher m = PATH_PARAM_PATTERN.matcher(path); + while (m.find()) { + // Replace {foo} with :foo + String match = m.group(); + m.appendReplacement(sb, ":" + match.subSequence(1, match.length() - 1)); + } + m.appendTail(sb); + return sb.toString(); + } + + private void validateCallback(MethodInfo callback) { + if (callback.hasDeclaredAnnotation(WebSocketDotNames.BINARY_MESSAGE) + && callback.hasDeclaredAnnotation(WebSocketDotNames.TEXT_MESSAGE)) { + throw new WebSocketServerException( + "Either @BinaryMessage or @TextMessage can be declared on a callback: " + callbackToString(callback)); + } + } + + private void validateOnMessage(MethodInfo callback) { + if (callback.parameters().size() != 1) { + throw new WebSocketServerException( + "@OnMessage callback must accept exactly one parameter: " + callbackToString(callback)); + } + } + + private String callbackToString(MethodInfo callback) { + return callback.declaringClass().name() + "#" + callback.name() + "()"; + } + + private String getPathPrefix(IndexView index, DotName enclosingClassName) { + ClassInfo enclosingClass = index.getClassByName(enclosingClassName); + if (enclosingClass == null) { + throw new WebSocketServerException("Enclosing class not found in index: " + enclosingClass); + } + AnnotationInstance webSocketAnnotation = enclosingClass.annotation(WebSocketDotNames.WEB_SOCKET); + if (webSocketAnnotation != null) { + String path = getPath(webSocketAnnotation.value("path").asString()); + if (enclosingClass.nestingType() == NestingType.INNER) { + return mergePath(getPathPrefix(index, enclosingClass.enclosingClass()), path); + } else { + return path.endsWith("/") ? path.substring(path.length() - 1) : path; + } + } + return ""; + } + + private void validateOnOpen(MethodInfo callback) { + if (!callback.parameters().isEmpty()) { + throw new WebSocketServerException( + "@OnOpen callback must not accept any parameters: " + callbackToString(callback)); + } + } + + private void validateOnClose(MethodInfo callback) { + if (callback.returnType().kind() != Kind.VOID && !Callback.isUniVoid(callback.returnType())) { + throw new WebSocketServerException( + "@OnClose callback must return void or Uni: " + callbackToString(callback)); + } + if (!callback.parameters().isEmpty()) { + throw new WebSocketServerException( + "@OnClose callback must not accept any parameters: " + callbackToString(callback)); + } + } + + /** + * The generated endpoint class looks like: + * + *
+     * public class Echo_WebSocketEndpoint extends WebSocketEndpointBase {
+     *
+     *     public WebSocket.ExecutionMode executionMode() {
+     *         return WebSocket.ExecutionMode.SERIAL;
+     *     }
+     *
+     *     public Echo_WebSocketEndpoint(WebSocketServerConnection connection, Codecs codecs,
+     *             WebSocketRuntimeConfig config, ContextSupport contextSupport) {
+     *         super(context, connection, codecs, config, contextActivator);
+     *     }
+     *
+     *     public WebSocketEndpoint.MessageType consumedMessageType() {
+     *         return MessageType.TEXT;
+     *     }
+     *
+     *     public Uni doOnMessage(Object message) {
+     *         Uni uni = ((Echo) super.beanInstance("MTd91f3oxHtG8gnznR7XcZBCLdE")).echo((String) message);
+     *         if (uni != null) {
+     *             // The lambda is implemented as a generated function: Echo_WebSocketEndpoint$$function$$1
+     *             return uni.chain(m -> sendText(m, false));
+     *         } else {
+     *             return Uni.createFrom().voidItem();
+     *         }
+     *     }
+     *
+     *     public WebSocketEndpoint.ExecutionModel onMessageExecutionModel() {
+     *         return ExecutionModel.EVENT_LOOP;
+     *     }
+     * }
+     * 
+ * + * @param endpoint + * @param classOutput + * @return the name of the generated class + */ + private String generateEndpoint(WebSocketEndpointBuildItem endpoint, ClassOutput classOutput) { + ClassInfo implClazz = endpoint.bean.getImplClazz(); + String baseName; + if (implClazz.enclosingClass() != null) { + baseName = DotNames.simpleName(implClazz.enclosingClass()) + NESTED_SEPARATOR + + DotNames.simpleName(implClazz); + } else { + baseName = DotNames.simpleName(implClazz.name()); + } + String generatedName = DotNames.internalPackageNameWithTrailingSlash(implClazz.name()) + baseName + + ENDPOINT_SUFFIX; + + ClassCreator endpointCreator = ClassCreator.builder().classOutput(classOutput).className(generatedName) + .superClass(WebSocketEndpointBase.class) + .build(); + + MethodCreator constructor = endpointCreator.getConstructorCreator(WebSocketServerConnection.class, + Codecs.class, WebSocketsRuntimeConfig.class, ContextSupport.class); + constructor.invokeSpecialMethod( + MethodDescriptor.ofConstructor(WebSocketEndpointBase.class, WebSocketServerConnection.class, + Codecs.class, WebSocketsRuntimeConfig.class, ContextSupport.class), + constructor.getThis(), constructor.getMethodParam(0), constructor.getMethodParam(1), + constructor.getMethodParam(2), constructor.getMethodParam(3)); + constructor.returnNull(); + + MethodCreator executionMode = endpointCreator.getMethodCreator("executionMode", WebSocket.ExecutionMode.class); + executionMode.returnValue(executionMode.load(endpoint.executionMode)); + + if (endpoint.onMessage != null && endpoint.onMessage.acceptsMessage()) { + MethodCreator messageType = endpointCreator.getMethodCreator("consumedMessageType", + WebSocketEndpoint.MessageType.class); + messageType.returnValue(messageType.load(endpoint.onMessage.consumedMessageType())); + } + + if (endpoint.onMessage != null && endpoint.onMessage.acceptsMulti()) { + Type multiItemType = endpoint.onMessage.messageParamType().asParameterizedType().arguments().get(0); + MethodCreator consumedMultiType = endpointCreator.getMethodCreator("consumedMultiType", + java.lang.reflect.Type.class); + consumedMultiType.returnValue(Types.getTypeHandle(consumedMultiType, multiItemType)); + + MethodCreator decodeMultiItem = endpointCreator.getMethodCreator("decodeMultiItem", + Object.class, Object.class); + decodeMultiItem.returnValue(decodeMessage(decodeMultiItem, endpoint.onMessage.acceptsBinaryMessage(), + multiItemType, decodeMultiItem.getMethodParam(0), endpoint.onMessage)); + } + + if (endpoint.onOpen != null) { + MethodCreator doOnOpen = endpointCreator.getMethodCreator("doOnOpen", Uni.class, Object.class); + // Foo foo = beanInstance("foo"); + ResultHandle beanInstance = doOnOpen.invokeSpecialMethod( + MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), + doOnOpen.getThis(), doOnOpen.load(endpoint.bean.getIdentifier())); + // Call the business method + ResultHandle ret = doOnOpen.invokeVirtualMethod(MethodDescriptor.of(endpoint.onOpen.method), beanInstance); + encodeAndReturnResult(doOnOpen, endpoint.onOpen, ret); + + MethodCreator onOpenExecutionModel = endpointCreator.getMethodCreator("onOpenExecutionModel", + ExecutionModel.class); + onOpenExecutionModel.returnValue(onOpenExecutionModel.load(endpoint.onOpen.executionModel)); + } + + if (endpoint.onMessage != null) { + MethodCreator doOnMessage = endpointCreator.getMethodCreator("doOnMessage", Uni.class, Object.class); + // Foo foo = beanInstance("foo"); + ResultHandle beanInstance = doOnMessage.invokeSpecialMethod( + MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), + doOnMessage.getThis(), doOnMessage.load(endpoint.bean.getIdentifier())); + ResultHandle[] args; + if (endpoint.onMessage.acceptsMessage()) { + args = new ResultHandle[] { decodeMessage(doOnMessage, endpoint.onMessage.acceptsBinaryMessage(), + endpoint.onMessage.method.parameterType(0), doOnMessage.getMethodParam(0), endpoint.onMessage) }; + } else { + args = new ResultHandle[] {}; + } + // Call the business method + ResultHandle ret = doOnMessage.invokeVirtualMethod(MethodDescriptor.of(endpoint.onMessage.method), beanInstance, + args); + encodeAndReturnResult(doOnMessage, endpoint.onMessage, ret); + + MethodCreator onMessageExecutionModel = endpointCreator.getMethodCreator("onMessageExecutionModel", + ExecutionModel.class); + onMessageExecutionModel.returnValue(onMessageExecutionModel.load(endpoint.onMessage.executionModel)); + } + + if (endpoint.onClose != null) { + MethodCreator doOnClose = endpointCreator.getMethodCreator("doOnClose", Uni.class, Object.class); + // Foo foo = beanInstance("foo"); + ResultHandle beanInstance = doOnClose.invokeSpecialMethod( + MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), + doOnClose.getThis(), doOnClose.load(endpoint.bean.getIdentifier())); + // Call the business method + ResultHandle ret = doOnClose.invokeVirtualMethod(MethodDescriptor.of(endpoint.onClose.method), beanInstance); + encodeAndReturnResult(doOnClose, endpoint.onClose, ret); + + MethodCreator onCloseExecutionModel = endpointCreator.getMethodCreator("onCloseExecutionModel", + ExecutionModel.class); + onCloseExecutionModel.returnValue(onCloseExecutionModel.load(endpoint.onClose.executionModel)); + } + + endpointCreator.close(); + return generatedName.replace('/', '.'); + } + + private ResultHandle decodeMessage(MethodCreator method, boolean binaryMessage, Type valueType, ResultHandle value, + Callback callback) { + if (WebSocketDotNames.MULTI.equals(valueType.name())) { + // Multi is decoded at runtime in the recorder + return value; + } else if (binaryMessage) { + // Binary message + if (WebSocketDotNames.BUFFER.equals(valueType.name())) { + return value; + } else if (Callback.isByteArray(valueType)) { + // byte[] message = buffer.getBytes(); + return method.invokeInterfaceMethod( + MethodDescriptor.ofMethod(Buffer.class, "getBytes", byte[].class), value); + } else if (WebSocketDotNames.STRING.equals(valueType.name())) { + // String message = buffer.toString(); + return method.invokeInterfaceMethod( + MethodDescriptor.ofMethod(Buffer.class, "toString", String.class), value); + } else if (WebSocketDotNames.JSON_OBJECT.equals(valueType.name())) { + // JsonObject message = new JsonObject(buffer); + return method.newInstance( + MethodDescriptor.ofConstructor(JsonObject.class, Buffer.class), value); + } else if (WebSocketDotNames.JSON_ARRAY.equals(valueType.name())) { + // JsonArray message = new JsonArray(buffer); + return method.newInstance( + MethodDescriptor.ofConstructor(JsonArray.class, Buffer.class), value); + } else { + // Try to use codecs + DotName inputCodec = callback.getInputCodec(); + ResultHandle type = Types.getTypeHandle(method, valueType); + ResultHandle decoded = method.invokeSpecialMethod(MethodDescriptor.ofMethod(WebSocketEndpointBase.class, + "decodeBinary", Object.class, java.lang.reflect.Type.class, Buffer.class, Class.class), + method.getThis(), type, + value, inputCodec != null ? method.loadClass(inputCodec.toString()) : method.loadNull()); + return decoded; + } + } else { + // Text message + if (WebSocketDotNames.STRING.equals(valueType.name())) { + // String message = string; + return value; + } else if (WebSocketDotNames.JSON_OBJECT.equals(valueType.name())) { + // JsonObject message = new JsonObject(string); + return method.newInstance( + MethodDescriptor.ofConstructor(JsonObject.class, String.class), value); + } else if (WebSocketDotNames.JSON_ARRAY.equals(valueType.name())) { + // JsonArray message = new JsonArray(string); + return method.newInstance( + MethodDescriptor.ofConstructor(JsonArray.class, String.class), value); + } else if (WebSocketDotNames.BUFFER.equals(valueType.name())) { + // Buffer message = Buffer.buffer(string); + return method.invokeStaticInterfaceMethod( + MethodDescriptor.ofMethod(Buffer.class, "buffer", Buffer.class, String.class), value); + } else if (Callback.isByteArray(valueType)) { + // byte[] message = Buffer.buffer(string).getBytes(); + ResultHandle buffer = method.invokeStaticInterfaceMethod( + MethodDescriptor.ofMethod(Buffer.class, "buffer", Buffer.class, byte[].class), value); + return method.invokeInterfaceMethod( + MethodDescriptor.ofMethod(Buffer.class, "getBytes", byte[].class), buffer); + } else { + // Try to use codecs + DotName inputCodec = callback.getInputCodec(); + ResultHandle type = Types.getTypeHandle(method, valueType); + ResultHandle decoded = method.invokeSpecialMethod(MethodDescriptor.ofMethod(WebSocketEndpointBase.class, + "decodeText", Object.class, java.lang.reflect.Type.class, String.class, Class.class), method.getThis(), + type, value, inputCodec != null ? method.loadClass(inputCodec.toString()) : method.loadNull()); + return decoded; + } + } + } + + private ResultHandle encodeMessage(MethodCreator method, Callback callback, ResultHandle value) { + if (callback.producedMessageType == MessageType.BINARY) { + // ---------------------- + // === Binary message === + // ---------------------- + if (callback.isReturnTypeUni()) { + Type messageType = callback.returnType().asParameterizedType().arguments().get(0); + if (messageType.name().equals(WebSocketDotNames.VOID)) { + // Uni + return value; + } else { + // return uniMessage.chain(m -> { + // Buffer buffer = encodeBuffer(m); + // return sendBinary(buffer,broadcast); + // }); + FunctionCreator fun = method.createFunction(Function.class); + BytecodeCreator funBytecode = fun.getBytecode(); + ResultHandle buffer = encodeBuffer(funBytecode, + callback.returnType().asParameterizedType().arguments().get(0), + funBytecode.getMethodParam(0), method.getThis(), callback); + funBytecode.returnValue(funBytecode.invokeSpecialMethod( + MethodDescriptor.ofMethod(WebSocketEndpointBase.class, + "sendBinary", Uni.class, Buffer.class, boolean.class), + method.getThis(), buffer, + funBytecode.load(callback.broadcast()))); + ResultHandle uniChain = method.invokeInterfaceMethod( + MethodDescriptor.ofMethod(Uni.class, "chain", Uni.class, Function.class), value, + fun.getInstance()); + return uniChain; + } + } else if (callback.isReturnTypeMulti()) { + // return multiBinary(multi, broadcast, m -> { + // Buffer buffer = encodeBuffer(m); + // return sendBinary(buffer,broadcast); + //}); + FunctionCreator fun = method.createFunction(Function.class); + BytecodeCreator funBytecode = fun.getBytecode(); + ResultHandle buffer = encodeBuffer(funBytecode, callback.returnType().asParameterizedType().arguments().get(0), + funBytecode.getMethodParam(0), method.getThis(), callback); + funBytecode.returnValue(funBytecode.invokeSpecialMethod( + MethodDescriptor.ofMethod(WebSocketEndpointBase.class, + "sendBinary", Uni.class, Buffer.class, boolean.class), + method.getThis(), buffer, + funBytecode.load(callback.broadcast()))); + return method.invokeSpecialMethod(MethodDescriptor.ofMethod(WebSocketEndpointBase.class, + "multiBinary", Uni.class, Multi.class, boolean.class, Function.class), method.getThis(), + value, + method.load(callback.broadcast()), + fun.getInstance()); + } else { + // return sendBinary(buffer,broadcast); + ResultHandle buffer = encodeBuffer(method, callback.returnType(), value, method.getThis(), callback); + return method.invokeSpecialMethod(MethodDescriptor.ofMethod(WebSocketEndpointBase.class, + "sendBinary", Uni.class, Buffer.class, boolean.class), method.getThis(), buffer, + method.load(callback.broadcast())); + } + } else { + // ---------------------- + // === Text message === + // ---------------------- + if (callback.isReturnTypeUni()) { + Type messageType = callback.returnType().asParameterizedType().arguments().get(0); + if (messageType.name().equals(WebSocketDotNames.VOID)) { + // Uni + return value; + } else { + // return uniMessage.chain(m -> { + // String text = encodeText(m); + // return sendText(string,broadcast); + // }); + FunctionCreator fun = method.createFunction(Function.class); + BytecodeCreator funBytecode = fun.getBytecode(); + ResultHandle text = encodeText(funBytecode, callback.returnType().asParameterizedType().arguments().get(0), + funBytecode.getMethodParam(0), method.getThis(), callback); + funBytecode.returnValue(funBytecode.invokeSpecialMethod( + MethodDescriptor.ofMethod(WebSocketEndpointBase.class, + "sendText", Uni.class, String.class, boolean.class), + method.getThis(), text, + funBytecode.load(callback.broadcast()))); + ResultHandle uniChain = method.invokeInterfaceMethod( + MethodDescriptor.ofMethod(Uni.class, "chain", Uni.class, Function.class), value, + fun.getInstance()); + return uniChain; + } + } else if (callback.isReturnTypeMulti()) { + // return multiText(multi, broadcast, m -> { + // String text = encodeText(m); + // return sendText(buffer,broadcast); + //}); + FunctionCreator fun = method.createFunction(Function.class); + BytecodeCreator funBytecode = fun.getBytecode(); + ResultHandle text = encodeText(funBytecode, callback.returnType().asParameterizedType().arguments().get(0), + funBytecode.getMethodParam(0), method.getThis(), callback); + funBytecode.returnValue(funBytecode.invokeSpecialMethod( + MethodDescriptor.ofMethod(WebSocketEndpointBase.class, + "sendText", Uni.class, String.class, boolean.class), + method.getThis(), text, + funBytecode.load(callback.broadcast()))); + return method.invokeSpecialMethod(MethodDescriptor.ofMethod(WebSocketEndpointBase.class, + "multiText", Uni.class, Multi.class, boolean.class, Function.class), method.getThis(), + value, + method.load(callback.broadcast()), + fun.getInstance()); + } else { + // return sendText(text,broadcast); + ResultHandle text = encodeText(method, callback.returnType(), value, method.getThis(), callback); + return method.invokeSpecialMethod(MethodDescriptor.ofMethod(WebSocketEndpointBase.class, + "sendText", Uni.class, String.class, boolean.class), method.getThis(), text, + method.load(callback.broadcast())); + } + } + } + + private ResultHandle encodeBuffer(BytecodeCreator method, Type messageType, ResultHandle value, + ResultHandle defaultWebSocketEndpoint, Callback callback) { + ResultHandle buffer; + if (messageType.name().equals(WebSocketDotNames.BUFFER)) { + buffer = value; + } else if (Callback.isByteArray(messageType)) { + buffer = method.invokeStaticInterfaceMethod( + MethodDescriptor.ofMethod(Buffer.class, "buffer", Buffer.class, byte[].class), value); + } else if (messageType.name().equals(WebSocketDotNames.STRING)) { + buffer = method.invokeStaticInterfaceMethod( + MethodDescriptor.ofMethod(Buffer.class, "buffer", Buffer.class, String.class), value); + } else if (messageType.name().equals(WebSocketDotNames.JSON_OBJECT)) { + buffer = method.invokeVirtualMethod(MethodDescriptor.ofMethod(JsonObject.class, "toBuffer", Buffer.class), + value); + } else if (messageType.name().equals(WebSocketDotNames.JSON_ARRAY)) { + buffer = method.invokeVirtualMethod(MethodDescriptor.ofMethod(JsonArray.class, "toBuffer", Buffer.class), + value); + } else { + // Try to use codecs + DotName outputCodec = callback.getOutputCodec(); + buffer = method.invokeSpecialMethod(MethodDescriptor.ofMethod(WebSocketEndpointBase.class, + "encodeBinary", Buffer.class, Object.class, Class.class), defaultWebSocketEndpoint, value, + outputCodec != null ? method.loadClass(outputCodec.toString()) : method.loadNull()); + } + return buffer; + } + + private ResultHandle encodeText(BytecodeCreator method, Type messageType, ResultHandle value, + ResultHandle defaultWebSocketEndpoint, Callback callback) { + ResultHandle text; + if (messageType.name().equals(WebSocketDotNames.BUFFER)) { + text = method.invokeInterfaceMethod( + MethodDescriptor.ofMethod(Buffer.class, "toString", String.class), value); + } else if (Callback.isByteArray(messageType)) { + ResultHandle buffer = method.invokeStaticInterfaceMethod( + MethodDescriptor.ofMethod(Buffer.class, "buffer", Buffer.class, byte[].class), value); + text = method.invokeInterfaceMethod( + MethodDescriptor.ofMethod(Buffer.class, "toString", String.class), buffer); + } else if (messageType.name().equals(WebSocketDotNames.STRING)) { + text = value; + } else if (messageType.name().equals(WebSocketDotNames.JSON_OBJECT)) { + text = method.invokeVirtualMethod(MethodDescriptor.ofMethod(JsonObject.class, "encode", String.class), + value); + } else if (messageType.name().equals(WebSocketDotNames.JSON_ARRAY)) { + text = method.invokeVirtualMethod(MethodDescriptor.ofMethod(JsonArray.class, "encode", String.class), + value); + } else { + // Try to use codecs + DotName outputCodec = callback.getOutputCodec(); + text = method.invokeSpecialMethod(MethodDescriptor.ofMethod(WebSocketEndpointBase.class, + "encodeText", String.class, Object.class, Class.class), defaultWebSocketEndpoint, value, + outputCodec != null ? method.loadClass(outputCodec.toString()) : method.loadNull()); + } + return text; + } + + private ResultHandle uniVoid(BytecodeCreator method) { + ResultHandle uniCreate = method + .invokeStaticInterfaceMethod(MethodDescriptor.ofMethod(Uni.class, "createFrom", UniCreate.class)); + return method.invokeVirtualMethod(MethodDescriptor.ofMethod(UniCreate.class, "voidItem", Uni.class), uniCreate); + } + + private void encodeAndReturnResult(MethodCreator method, Callback callback, ResultHandle result) { + // The result must be always Uni + if (callback.isReturnTypeVoid()) { + // return Uni.createFrom().void() + method.returnValue(uniVoid(method)); + } else { + // Skip response + BytecodeCreator isNull = method.ifNull(result).trueBranch(); + isNull.returnValue(uniVoid(isNull)); + method.returnValue(encodeMessage(method, callback, result)); + } + } + + private Callback findCallback(IndexView index, ClassInfo beanClass, DotName annotationName, + Consumer validator) { + ClassInfo aClass = beanClass; + List annotations = new ArrayList<>(); + while (aClass != null) { + List declared = aClass.annotationsMap().get(annotationName); + if (declared != null) { + annotations.addAll(declared); + } + DotName superName = aClass.superName(); + aClass = superName != null && !superName.equals(DotNames.OBJECT) + ? index.getClassByName(superName) + : null; + } + + if (annotations.isEmpty()) { + return null; + } else if (annotations.size() == 1) { + AnnotationInstance annotation = annotations.get(0); + MethodInfo method = annotation.target().asMethod(); + validateCallback(method); + validator.accept(method); + return new Callback(annotation, method, executionModel(method)); + } + throw new WebSocketServerException( + String.format("There can be only one callback annotated with %s declared on %s", annotationName, beanClass)); + } + + ExecutionModel executionModel(MethodInfo method) { + if (hasBlockingSignature(method)) { + return method.hasDeclaredAnnotation(WebSocketDotNames.RUN_ON_VIRTUAL_THREAD) ? ExecutionModel.VIRTUAL_THREAD + : ExecutionModel.WORKER_THREAD; + } + return method.hasDeclaredAnnotation(WebSocketDotNames.BLOCKING) ? ExecutionModel.WORKER_THREAD + : ExecutionModel.EVENT_LOOP; + } + + boolean hasBlockingSignature(MethodInfo method) { + switch (method.returnType().kind()) { + case VOID: + case CLASS: + return true; + case PARAMETERIZED_TYPE: + // Uni, Multi -> non-blocking + DotName name = method.returnType().asParameterizedType().name(); + return !name.equals(WebSocketDotNames.UNI) && !name.equals(WebSocketDotNames.MULTI); + default: + throw new WebSocketServerException("Unsupported return type:" + callbackToString(method)); + } + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessorTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessorTest.java new file mode 100644 index 0000000000000..6c774f82c21ce --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessorTest.java @@ -0,0 +1,25 @@ +package io.quarkus.websockets.next.deployment; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +public class WebSocketServerProcessorTest { + + @Test + public void testGetPath() { + assertEquals("/foo/:id", WebSocketServerProcessor.getPath("/foo/{id}")); + assertEquals("/foo/:bar-:baz", WebSocketServerProcessor.getPath("/foo/{bar}-{baz}")); + assertEquals("/ws/v:version", WebSocketServerProcessor.getPath("/ws/v{version}")); + assertEquals("/foo/v:bar/:bazand:alpha_1-:name", + WebSocketServerProcessor.getPath("/foo/v{bar}/{baz}and{alpha_1}-{name}")); + } + + @Test + public void testMergePath() { + assertEquals("foo/bar", WebSocketServerProcessor.mergePath("foo/", "/bar")); + assertEquals("foo/bar", WebSocketServerProcessor.mergePath("foo", "/bar")); + assertEquals("foo/bar", WebSocketServerProcessor.mergePath("foo/", "bar")); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/Echo.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/Echo.java new file mode 100644 index 0000000000000..d7d39893b6883 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/Echo.java @@ -0,0 +1,24 @@ +package io.quarkus.websockets.next.test; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Context; + +@WebSocket(path = "/echo") +public class Echo { + + @Inject + EchoService echoService; + + @OnMessage + Uni echo(String msg) { + assertTrue(Context.isOnEventLoopThread()); + return Uni.createFrom().item(echoService.echo(msg)); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoBlocking.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoBlocking.java new file mode 100644 index 0000000000000..9b7dafccc4459 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoBlocking.java @@ -0,0 +1,23 @@ +package io.quarkus.websockets.next.test; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.vertx.core.Context; + +@WebSocket(path = "/echo-blocking") +public class EchoBlocking { + + @Inject + EchoService echoService; + + @OnMessage + String echo(String msg) { + assertTrue(Context.isOnWorkerThread()); + return echoService.echo(msg); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoBlockingAndAwait.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoBlockingAndAwait.java new file mode 100644 index 0000000000000..0023287fbc143 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoBlockingAndAwait.java @@ -0,0 +1,27 @@ +package io.quarkus.websockets.next.test; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.vertx.core.Context; + +@WebSocket(path = "/echo-blocking-await") +public class EchoBlockingAndAwait { + + @Inject + EchoService echoService; + + @Inject + WebSocketServerConnection connection; + + @OnMessage + void echo(String msg) { + assertTrue(Context.isOnWorkerThread()); + connection.sendTextAndAwait(echoService.echo(msg)); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoBlockingPojo.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoBlockingPojo.java new file mode 100644 index 0000000000000..f3f5edac8bb4a --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoBlockingPojo.java @@ -0,0 +1,39 @@ +package io.quarkus.websockets.next.test; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.vertx.core.Context; + +@WebSocket(path = "/echo-blocking-pojo") +public class EchoBlockingPojo { + + @Inject + EchoService echoService; + + @OnMessage + Message echo(Message msg) { + assertTrue(Context.isOnWorkerThread()); + Message ret = new Message(); + ret.setMsg(echoService.echo(msg.getMsg())); + return ret; + } + + public static class Message { + + private String msg; + + public String getMsg() { + return msg; + } + + public void setMsg(String value) { + this.msg = value; + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoJson.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoJson.java new file mode 100644 index 0000000000000..36f2bcbb76972 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoJson.java @@ -0,0 +1,25 @@ +package io.quarkus.websockets.next.test; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Context; +import io.vertx.core.json.JsonObject; + +@WebSocket(path = "/echo-json") +public class EchoJson { + + @Inject + EchoService echoService; + + @OnMessage + Uni echo(JsonObject msg) { + assertTrue(Context.isOnEventLoopThread()); + return Uni.createFrom().item(new JsonObject().put("msg", echoService.echo(msg.getString("msg")))); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoJsonArray.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoJsonArray.java new file mode 100644 index 0000000000000..ef696f439a628 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoJsonArray.java @@ -0,0 +1,28 @@ +package io.quarkus.websockets.next.test; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Context; +import io.vertx.core.json.JsonArray; +import io.vertx.core.json.JsonObject; + +@WebSocket(path = "/echo-json-array") +public class EchoJsonArray { + + @Inject + EchoService echoService; + + @OnMessage + Uni echo(JsonArray msg) { + assertTrue(Context.isOnEventLoopThread()); + return Uni.createFrom().item( + new JsonArray().add( + new JsonObject().put("msg", echoService.echo(msg.getJsonObject(0).getString("msg"))))); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoMultiBidi.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoMultiBidi.java new file mode 100644 index 0000000000000..28deaf1a5d236 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoMultiBidi.java @@ -0,0 +1,15 @@ +package io.quarkus.websockets.next.test; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.smallrye.mutiny.Multi; + +@WebSocket(path = "/echo-multi-bidi") +public class EchoMultiBidi { + + @OnMessage + Multi echo(Multi multi) { + return multi; + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoMultiConsume.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoMultiConsume.java new file mode 100644 index 0000000000000..eb89fef2b5e66 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoMultiConsume.java @@ -0,0 +1,26 @@ +package io.quarkus.websockets.next.test; + +import jakarta.inject.Inject; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.Uni; + +@WebSocket(path = "/echo-multi-consume") +public class EchoMultiConsume { + + @Inject + WebSocketServerConnection connection; + + @OnMessage + Uni echo(Multi multi) { + multi.subscribe().with(msg -> { + connection.sendText(msg).subscribe().with(v -> { + }); + }); + return connection.sendText("subscribed"); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoMultiProduce.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoMultiProduce.java new file mode 100644 index 0000000000000..3192accc4106a --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoMultiProduce.java @@ -0,0 +1,15 @@ +package io.quarkus.websockets.next.test; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.smallrye.mutiny.Multi; + +@WebSocket(path = "/echo-multi-produce") +public class EchoMultiProduce { + + @OnMessage + Multi echo(String msg) { + return Multi.createFrom().item(msg); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoPojo.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoPojo.java new file mode 100644 index 0000000000000..6402192c03db6 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoPojo.java @@ -0,0 +1,40 @@ +package io.quarkus.websockets.next.test; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Context; + +@WebSocket(path = "/echo-pojo") +public class EchoPojo { + + @Inject + EchoService echoService; + + @OnMessage + Uni echo(Message msg) { + assertTrue(Context.isOnEventLoopThread()); + Message ret = new Message(); + ret.setMsg(echoService.echo(msg.getMsg())); + return Uni.createFrom().item(ret); + } + + public static class Message { + + private String msg; + + public String getMsg() { + return msg; + } + + public void setMsg(String value) { + this.msg = value; + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoService.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoService.java new file mode 100644 index 0000000000000..a88e6fac3f071 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoService.java @@ -0,0 +1,11 @@ +package io.quarkus.websockets.next.test; + +import jakarta.enterprise.context.ApplicationScoped; + +@ApplicationScoped +public class EchoService { + + public String echo(String msg) { + return msg; + } +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoWebSocketTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoWebSocketTest.java new file mode 100644 index 0000000000000..dddd23741b5a4 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/EchoWebSocketTest.java @@ -0,0 +1,149 @@ +package io.quarkus.websockets.next.test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; +import java.util.Queue; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.vertx.core.Vertx; +import io.vertx.core.http.WebSocket; +import io.vertx.core.http.WebSocketClient; +import io.vertx.core.json.JsonArray; +import io.vertx.core.json.JsonObject; + +public class EchoWebSocketTest { + + @TestHTTPResource("echo") + URI echoUri; + + @TestHTTPResource("echo-blocking") + URI echoBlockingUri; + + @TestHTTPResource("echo-blocking-await") + URI echoBlockingAwaitUri; + + @TestHTTPResource("echo-json") + URI echoJson; + + @TestHTTPResource("echo-json-array") + URI echoJsonArray; + + @TestHTTPResource("echo-pojo") + URI echoPojo; + + @TestHTTPResource("echo-blocking-pojo") + URI echoBlockingPojo; + + @TestHTTPResource("echo-multi-consume") + URI echoMultiConsume; + + @TestHTTPResource("echo-multi-produce") + URI echoMultiProduce; + + @TestHTTPResource("echo-multi-bidi") + URI echoMultiBidi; + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Echo.class, EchoBlocking.class, EchoBlockingAndAwait.class, EchoService.class, EchoJson.class, + EchoJsonArray.class, EchoPojo.class, EchoBlockingPojo.class, EchoMultiConsume.class, + EchoMultiProduce.class, EchoMultiBidi.class); + }); + + @Test + public void testEcho() throws Exception { + assertEcho(echoUri, "hello"); + } + + @Test + public void testEchoBlocking() throws Exception { + assertEcho(echoBlockingUri, "hello"); + } + + @Test + public void testEchoBlockingAndAwait() throws Exception { + assertEcho(echoBlockingAwaitUri, "hello"); + } + + @Test + public void testEchoJson() throws Exception { + assertEcho(echoJson, new JsonObject().put("msg", "hello").encode()); + } + + @Test + public void testEchoJsonArray() throws Exception { + assertEcho(echoJsonArray, new JsonArray().add(new JsonObject().put("msg", "hello")).encode()); + } + + @Test + public void testEchoPojo() throws Exception { + assertEcho(echoPojo, new JsonObject().put("msg", "hello").toString()); + } + + @Test + public void testEchoBlockingPojo() throws Exception { + assertEcho(echoBlockingPojo, new JsonObject().put("msg", "hello").toString()); + } + + @Test + public void testEchoMultiConsume() throws Exception { + assertEcho(echoMultiConsume, "hello", (ws, queue) -> { + ws.textMessageHandler(msg -> { + if ("subscribed".equals(msg)) { + ws.writeTextMessage("hello"); + } else { + queue.add(msg); + } + }); + }); + } + + @Test + public void testEchoMultiProduce() throws Exception { + assertEcho(echoMultiProduce, "hello"); + } + + @Test + public void testEchoMultiBidi() throws Exception { + assertEcho(echoMultiBidi, "hello"); + } + + public void assertEcho(URI testUri, String payload) throws Exception { + assertEcho(testUri, payload, (ws, queue) -> { + ws.textMessageHandler(msg -> { + queue.add(msg); + }); + ws.writeTextMessage(payload); + }); + } + + public void assertEcho(URI testUri, String payload, BiConsumer> action) throws Exception { + Vertx vertx = Vertx.vertx(); + WebSocketClient client = vertx.createWebSocketClient(); + try { + LinkedBlockingDeque message = new LinkedBlockingDeque<>(); + client + .connect(testUri.getPort(), testUri.getHost(), testUri.getPath()) + .onComplete(r -> { + if (r.succeeded()) { + WebSocket ws = r.result(); + action.accept(ws, message); + } else { + throw new IllegalStateException(r.cause()); + } + }); + assertEquals(payload, message.poll(10, TimeUnit.SECONDS)); + } finally { + client.close().toCompletionStage().toCompletableFuture().get(); + } + } +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/BroadcastOnMessageTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/BroadcastOnMessageTest.java new file mode 100644 index 0000000000000..e985cee6dd142 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/BroadcastOnMessageTest.java @@ -0,0 +1,107 @@ +package io.quarkus.websockets.next.test.broadcast; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.vertx.core.Vertx; +import io.vertx.core.http.WebSocket; +import io.vertx.core.http.WebSocketClient; + +public class BroadcastOnMessageTest { + + @TestHTTPResource("up") + URI upUri; + + @TestHTTPResource("up-blocking") + URI upBlockingUri; + + @TestHTTPResource("up-multi-bidi") + URI upMultiBidiUri; + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Up.class, UpBlocking.class, UpMultiBidi.class); + }); + + @Test + public void testUp() throws Exception { + assertBroadcast(upUri); + } + + @Test + public void testUpBlocking() throws Exception { + assertBroadcast(upBlockingUri); + } + + @Test + public void testUpMultiBidi() throws Exception { + assertBroadcast(upMultiBidiUri); + } + + public void assertBroadcast(URI testUri) throws Exception { + Vertx vertx = Vertx.vertx(); + WebSocketClient client1 = vertx.createWebSocketClient(); + WebSocketClient client2 = vertx.createWebSocketClient(); + try { + CountDownLatch connectedLatch = new CountDownLatch(2); + CountDownLatch onMessageLatch = new CountDownLatch(2); + AtomicReference ws1 = new AtomicReference<>(); + + List messages = new CopyOnWriteArrayList<>(); + client1 + .connect(testUri.getPort(), testUri.getHost(), testUri.getPath() + "/1") + .onComplete(r -> { + if (r.succeeded()) { + WebSocket ws = r.result(); + ws.textMessageHandler(msg -> { + messages.add(msg); + onMessageLatch.countDown(); + }); + // We will use this socket to write a message later on + ws1.set(ws); + connectedLatch.countDown(); + } else { + throw new IllegalStateException(r.cause()); + } + }); + client2 + .connect(testUri.getPort(), testUri.getHost(), testUri.getPath() + "/2") + .onComplete(r -> { + if (r.succeeded()) { + WebSocket ws = r.result(); + ws.textMessageHandler(msg -> { + messages.add(msg); + onMessageLatch.countDown(); + }); + connectedLatch.countDown(); + } else { + throw new IllegalStateException(r.cause()); + } + }); + assertTrue(connectedLatch.await(5, TimeUnit.SECONDS)); + ws1.get().writeTextMessage("hello"); + assertTrue(onMessageLatch.await(5, TimeUnit.SECONDS)); + assertEquals(2, messages.size()); + // Both messages come from the first client + assertEquals("1:HELLO", messages.get(0)); + assertEquals("1:HELLO", messages.get(1)); + } finally { + client1.close().toCompletionStage().toCompletableFuture().get(); + client2.close().toCompletionStage().toCompletableFuture().get(); + } + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/BroadcastOnOpenTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/BroadcastOnOpenTest.java new file mode 100644 index 0000000000000..95814bfb36d91 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/BroadcastOnOpenTest.java @@ -0,0 +1,115 @@ +package io.quarkus.websockets.next.test.broadcast; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.vertx.core.Vertx; +import io.vertx.core.http.WebSocket; +import io.vertx.core.http.WebSocketClient; + +public class BroadcastOnOpenTest { + + @TestHTTPResource("lo") + URI loUri; + + @TestHTTPResource("lo-blocking") + URI loBlockingUri; + + @TestHTTPResource("lo-multi-produce") + URI loMultiProduceUri; + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Lo.class, LoBlocking.class, LoMultiProduce.class); + }); + + @Test + public void testLo() throws Exception { + assertBroadcast(loUri); + } + + @Test + public void testLoBlocking() throws Exception { + assertBroadcast(loBlockingUri); + } + + @Test + public void testLoMultiBidi() throws Exception { + assertBroadcast(loMultiProduceUri); + } + + public void assertBroadcast(URI testUri) throws Exception { + Vertx vertx = Vertx.vertx(); + WebSocketClient client1 = vertx.createWebSocketClient(); + WebSocketClient client2 = vertx.createWebSocketClient(); + try { + CountDownLatch c1ConnectedLatch = new CountDownLatch(1); + CountDownLatch c1MessageLatch = new CountDownLatch(1); + CountDownLatch c2MessageLatch = new CountDownLatch(2); + List messages = new CopyOnWriteArrayList<>(); + client1 + .connect(testUri.getPort(), testUri.getHost(), testUri.getPath() + "/C1") + .onComplete(r -> { + if (r.succeeded()) { + WebSocket ws = r.result(); + ws.textMessageHandler(msg -> { + messages.add(msg); + if (msg.equals("c1")) { + c1MessageLatch.countDown(); + } else if (msg.equals("c2")) { + // onOpen callback from the second client + c2MessageLatch.countDown(); + } + + }); + c1ConnectedLatch.countDown(); + } else { + throw new IllegalStateException(r.cause()); + } + }); + assertTrue(c1ConnectedLatch.await(5, TimeUnit.SECONDS)); + assertTrue(c1MessageLatch.await(5, TimeUnit.SECONDS)); + assertEquals(1, messages.size()); + assertEquals("c1", messages.get(0)); + messages.clear(); + // Now connect the second client + CountDownLatch c2ConnectedLatch = new CountDownLatch(1); + client2 + .connect(testUri.getPort(), testUri.getHost(), testUri.getPath() + "/C2") + .onComplete(r -> { + if (r.succeeded()) { + WebSocket ws = r.result(); + ws.textMessageHandler(msg -> { + messages.add(msg); + c2MessageLatch.countDown(); + }); + c2ConnectedLatch.countDown(); + } else { + throw new IllegalStateException(r.cause()); + } + }); + assertTrue(c2ConnectedLatch.await(5, TimeUnit.SECONDS)); + assertTrue(c2MessageLatch.await(5, TimeUnit.SECONDS)); + // onOpen should be broadcasted to both clients + assertEquals(2, messages.size()); + assertEquals("c2", messages.get(0)); + assertEquals("c2", messages.get(1)); + } finally { + client1.close().toCompletionStage().toCompletableFuture().get(); + client2.close().toCompletionStage().toCompletableFuture().get(); + } + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/Lo.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/Lo.java new file mode 100644 index 0000000000000..9aad5c7e475f7 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/Lo.java @@ -0,0 +1,25 @@ +package io.quarkus.websockets.next.test.broadcast; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; + +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Context; + +@WebSocket(path = "/lo/{client}") +public class Lo { + + @Inject + WebSocketServerConnection connection; + + @OnOpen(broadcast = true) + Uni open() { + assertTrue(Context.isOnEventLoopThread()); + return Uni.createFrom().item(connection.pathParam("client").toLowerCase()); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/LoBlocking.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/LoBlocking.java new file mode 100644 index 0000000000000..457fb01f8b8f0 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/LoBlocking.java @@ -0,0 +1,24 @@ +package io.quarkus.websockets.next.test.broadcast; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; + +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.vertx.core.Context; + +@WebSocket(path = "/lo-blocking/{client}") +public class LoBlocking { + + @Inject + WebSocketServerConnection connection; + + @OnOpen(broadcast = true) + String open() { + assertTrue(Context.isOnWorkerThread()); + return connection.pathParam("client").toLowerCase(); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/LoMultiProduce.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/LoMultiProduce.java new file mode 100644 index 0000000000000..94440a76402ff --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/LoMultiProduce.java @@ -0,0 +1,25 @@ +package io.quarkus.websockets.next.test.broadcast; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; + +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.smallrye.mutiny.Multi; +import io.vertx.core.Context; + +@WebSocket(path = "/lo-multi-produce/{client}") +public class LoMultiProduce { + + @Inject + WebSocketServerConnection connection; + + @OnOpen(broadcast = true) + Multi open() { + assertTrue(Context.isOnEventLoopThread()); + return Multi.createFrom().item(connection.pathParam("client").toLowerCase()); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/Up.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/Up.java new file mode 100644 index 0000000000000..6c7c029a39d19 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/Up.java @@ -0,0 +1,27 @@ +package io.quarkus.websockets.next.test.broadcast; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Context; + +@WebSocket(path = "/up/{client}") +public class Up { + + @Inject + WebSocketServerConnection connection; + + @OnMessage(broadcast = true) + Uni echo(String msg) { + assertTrue(Context.isOnEventLoopThread()); + assertEquals(2, connection.getOpenConnections().size()); + return Uni.createFrom().item(connection.pathParam("client") + ":" + msg.toUpperCase()); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/UpBlocking.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/UpBlocking.java new file mode 100644 index 0000000000000..c75e0c1855510 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/UpBlocking.java @@ -0,0 +1,26 @@ +package io.quarkus.websockets.next.test.broadcast; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.vertx.core.Context; + +@WebSocket(path = "/up-blocking/{client}") +public class UpBlocking { + + @Inject + WebSocketServerConnection connection; + + @OnMessage(broadcast = true) + String echo(String msg) { + assertTrue(Context.isOnWorkerThread()); + assertEquals(2, connection.getOpenConnections().size()); + return connection.pathParam("client") + ":" + msg.toUpperCase(); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/UpMultiBidi.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/UpMultiBidi.java new file mode 100644 index 0000000000000..c2622c0ad24ca --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/broadcast/UpMultiBidi.java @@ -0,0 +1,27 @@ +package io.quarkus.websockets.next.test.broadcast; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.smallrye.mutiny.Multi; +import io.vertx.core.Context; + +@WebSocket(path = "/up-multi-bidi/{client}") +public class UpMultiBidi { + + @Inject + WebSocketServerConnection connection; + + @OnMessage(broadcast = true) + Multi echo(Multi multi) { + assertTrue(Context.isOnEventLoopThread()); + assertEquals(2, connection.getOpenConnections().size()); + return multi.map(m -> connection.pathParam("client") + ":" + m.toUpperCase()); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/AbstractFind.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/AbstractFind.java new file mode 100644 index 0000000000000..9efa8dc2542af --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/AbstractFind.java @@ -0,0 +1,12 @@ +package io.quarkus.websockets.next.test.codec; + +import java.util.Comparator; +import java.util.List; + +public abstract class AbstractFind { + + Item find(List items) { + return items.stream().sorted(Comparator.comparingInt(Item::getCount)).findFirst().orElse(null); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/BinaryCodecTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/BinaryCodecTest.java new file mode 100644 index 0000000000000..67f834b7f2ea2 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/BinaryCodecTest.java @@ -0,0 +1,66 @@ +package io.quarkus.websockets.next.test.codec; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.WebSocket; +import io.vertx.core.http.WebSocketClient; +import io.vertx.core.json.JsonArray; +import io.vertx.core.json.JsonObject; + +public class BinaryCodecTest { + + @TestHTTPResource("find-binary") + URI findBinaryUri; + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(FindBinary.class, AbstractFind.class, Item.class, FindBinary.ItemBinaryMessageCodec.class, + FindBinary.ListItemBinaryMessageCodec.class); + }); + + @Test + public void testCodec() throws Exception { + JsonArray items = new JsonArray(); + items.add(new JsonObject().put("name", "foo").put("count", 10)); + items.add(new JsonObject().put("name", "bar").put("count", 1)); + items.add(new JsonObject().put("name", "baz").put("count", 100)); + assertCodec(findBinaryUri, items.toBuffer(), Buffer.buffer("Item [count=2]")); + } + + public void assertCodec(URI testUri, Buffer payload, Buffer expected) + throws Exception { + Vertx vertx = Vertx.vertx(); + WebSocketClient client = vertx.createWebSocketClient(); + try { + LinkedBlockingDeque message = new LinkedBlockingDeque<>(); + client + .connect(testUri.getPort(), testUri.getHost(), testUri.getPath()) + .onComplete(r -> { + if (r.succeeded()) { + WebSocket ws = r.result(); + ws.binaryMessageHandler(msg -> { + message.add(msg); + }); + ws.writeBinaryMessage(payload); + } else { + throw new IllegalStateException(r.cause()); + } + }); + assertEquals(expected, message.poll(10, TimeUnit.SECONDS)); + } finally { + client.close().toCompletionStage().toCompletableFuture().get(); + } + } +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/CustomCodecTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/CustomCodecTest.java new file mode 100644 index 0000000000000..8fccf5a957f6e --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/CustomCodecTest.java @@ -0,0 +1,64 @@ +package io.quarkus.websockets.next.test.codec; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.vertx.core.Vertx; +import io.vertx.core.http.WebSocket; +import io.vertx.core.http.WebSocketClient; +import io.vertx.core.json.JsonArray; +import io.vertx.core.json.JsonObject; + +public class CustomCodecTest { + + @TestHTTPResource("find") + URI findUri; + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Find.class, Item.class, AbstractFind.class, MyItemCodec.class); + }); + + @Test + public void testCodec() throws Exception { + JsonArray items = new JsonArray(); + items.add(new JsonObject().put("name", "foo").put("count", 10)); + items.add(new JsonObject().put("name", "bar").put("count", 1)); + items.add(new JsonObject().put("name", "baz").put("count", 100)); + assertCodec(findUri, items.encode(), new JsonObject().put("count", 1).encode()); + } + + public void assertCodec(URI testUri, String payload, String expected) + throws Exception { + Vertx vertx = Vertx.vertx(); + WebSocketClient client = vertx.createWebSocketClient(); + try { + LinkedBlockingDeque message = new LinkedBlockingDeque<>(); + client + .connect(testUri.getPort(), testUri.getHost(), testUri.getPath()) + .onComplete(r -> { + if (r.succeeded()) { + WebSocket ws = r.result(); + ws.textMessageHandler(msg -> { + message.add(msg); + }); + ws.writeTextMessage(payload); + } else { + throw new IllegalStateException(r.cause()); + } + }); + assertEquals(expected, message.poll(10, TimeUnit.SECONDS)); + } finally { + client.close().toCompletionStage().toCompletableFuture().get(); + } + } +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/DefaultTextCodecTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/DefaultTextCodecTest.java new file mode 100644 index 0000000000000..866172e1bf296 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/DefaultTextCodecTest.java @@ -0,0 +1,64 @@ +package io.quarkus.websockets.next.test.codec; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.vertx.core.Vertx; +import io.vertx.core.http.WebSocket; +import io.vertx.core.http.WebSocketClient; +import io.vertx.core.json.JsonArray; +import io.vertx.core.json.JsonObject; + +public class DefaultTextCodecTest { + + @TestHTTPResource("find") + URI findUri; + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Find.class, AbstractFind.class, Item.class); + }); + + @Test + public void testCodec() throws Exception { + JsonArray items = new JsonArray(); + items.add(new JsonObject().put("name", "foo").put("count", 10)); + items.add(new JsonObject().put("name", "bar").put("count", 1)); + items.add(new JsonObject().put("name", "baz").put("count", 100)); + assertCodec(findUri, items.encode(), new JsonObject().put("name", "bar").put("count", 1).encode()); + } + + public void assertCodec(URI testUri, String payload, String expected) + throws Exception { + Vertx vertx = Vertx.vertx(); + WebSocketClient client = vertx.createWebSocketClient(); + try { + LinkedBlockingDeque message = new LinkedBlockingDeque<>(); + client + .connect(testUri.getPort(), testUri.getHost(), testUri.getPath()) + .onComplete(r -> { + if (r.succeeded()) { + WebSocket ws = r.result(); + ws.textMessageHandler(msg -> { + message.add(msg); + }); + ws.writeTextMessage(payload); + } else { + throw new IllegalStateException(r.cause()); + } + }); + assertEquals(expected, message.poll(10, TimeUnit.SECONDS)); + } finally { + client.close().toCompletionStage().toCompletableFuture().get(); + } + } +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/Find.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/Find.java new file mode 100644 index 0000000000000..6aba96f3f9ab9 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/Find.java @@ -0,0 +1,16 @@ +package io.quarkus.websockets.next.test.codec; + +import java.util.List; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; + +@WebSocket(path = "/find") +public class Find extends AbstractFind { + + @OnMessage + Item find(List items) { + return super.find(items); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/FindBinary.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/FindBinary.java new file mode 100644 index 0000000000000..c44e3d89573a1 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/FindBinary.java @@ -0,0 +1,78 @@ +package io.quarkus.websockets.next.test.codec; + +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.List; + +import jakarta.inject.Singleton; + +import io.quarkus.websockets.next.BinaryMessage; +import io.quarkus.websockets.next.BinaryMessageCodec; +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.json.JsonArray; +import io.vertx.core.json.JsonObject; + +@WebSocket(path = "/find-binary") +public class FindBinary extends AbstractFind { + + // Force binary message format + // There's no binary codec available out of the box so the codecs below are needed + @BinaryMessage + @OnMessage + Item find(List items) { + return super.find(items); + } + + @Singleton + public static class ItemBinaryMessageCodec implements BinaryMessageCodec { + + @Override + public boolean supports(Type type) { + return type.equals(Item.class); + } + + @Override + public Buffer encode(Item value) { + return Buffer.buffer(value.toString()); + } + + @Override + public Item decode(Type type, Buffer value) { + throw new UnsupportedOperationException(); + } + + } + + @Singleton + public static class ListItemBinaryMessageCodec implements BinaryMessageCodec> { + + @Override + public boolean supports(Type type) { + return type instanceof ParameterizedType && ((ParameterizedType) type).getRawType().equals(List.class); + } + + @Override + public Buffer encode(List value) { + throw new UnsupportedOperationException(); + } + + @Override + public List decode(Type type, Buffer value) { + JsonArray json = value.toJsonArray(); + List items = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + Item item = new Item(); + JsonObject jsonObject = json.getJsonObject(i); + // Intentionally skip the name + item.setCount(2 * jsonObject.getInteger("count")); + items.add(item); + } + return items; + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/FindInputCodec.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/FindInputCodec.java new file mode 100644 index 0000000000000..24cf369f52e3f --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/FindInputCodec.java @@ -0,0 +1,56 @@ +package io.quarkus.websockets.next.test.codec; + +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.List; + +import jakarta.annotation.Priority; +import jakarta.inject.Singleton; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.TextMessage; +import io.quarkus.websockets.next.TextMessageCodec; +import io.quarkus.websockets.next.WebSocket; +import io.vertx.core.json.JsonArray; +import io.vertx.core.json.JsonObject; + +@WebSocket(path = "/find-input-codec") +public class FindInputCodec extends AbstractFind { + + // The codec is used for both input/output + @TextMessage(inputCodec = MyInputCodec.class) + @OnMessage + Item find(List items) { + return super.find(items); + } + + @Singleton + @Priority(-10) + public static class MyInputCodec implements TextMessageCodec { + + @Override + public boolean supports(Type type) { + return true; + } + + @Override + public String encode(Object value) { + return value.toString(); + } + + @Override + public Object decode(Type type, String value) { + JsonArray json = new JsonArray(value); + List items = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + Item item = new Item(); + JsonObject jsonObject = json.getJsonObject(i); + item.setCount(2 * jsonObject.getInteger("count")); + items.add(item); + } + return items; + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/FindOutputCodec.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/FindOutputCodec.java new file mode 100644 index 0000000000000..10b5cbc659f6c --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/FindOutputCodec.java @@ -0,0 +1,48 @@ +package io.quarkus.websockets.next.test.codec; + +import java.lang.reflect.Type; +import java.util.List; + +import jakarta.annotation.Priority; +import jakarta.inject.Singleton; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.TextMessage; +import io.quarkus.websockets.next.TextMessageCodec; +import io.quarkus.websockets.next.WebSocket; +import io.vertx.core.json.JsonObject; + +@WebSocket(path = "/find-output-codec") +public class FindOutputCodec extends AbstractFind { + + // The codec is only used for output + @TextMessage(outputCodec = MyOutputCodec.class) + @OnMessage + Item find(List items) { + return super.find(items); + } + + @Singleton + @Priority(-10) + public static class MyOutputCodec implements TextMessageCodec { + + @Override + public boolean supports(Type type) { + return type.equals(Item.class); + } + + @Override + public String encode(Item value) { + JsonObject json = JsonObject.mapFrom(value); + json.remove("count"); // intentionally remove the "count" + return json.encode(); + } + + @Override + public Item decode(Type type, String value) { + throw new UnsupportedOperationException(); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/Item.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/Item.java new file mode 100644 index 0000000000000..ad557730cea52 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/Item.java @@ -0,0 +1,29 @@ +package io.quarkus.websockets.next.test.codec; + +public class Item { + + private String name; + private int count; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public int getCount() { + return count; + } + + public void setCount(int count) { + this.count = count; + } + + @Override + public String toString() { + return "Item [" + (name != null ? "name=" + name + ", " : "") + "count=" + count + "]"; + } + +} \ No newline at end of file diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/MyItemCodec.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/MyItemCodec.java new file mode 100644 index 0000000000000..e2ce78019099c --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/MyItemCodec.java @@ -0,0 +1,31 @@ +package io.quarkus.websockets.next.test.codec; + +import java.lang.reflect.Type; + +import jakarta.annotation.Priority; +import jakarta.inject.Singleton; + +import io.quarkus.websockets.next.TextMessageCodec; +import io.vertx.core.json.JsonObject; + +@Singleton +@Priority(10) +public class MyItemCodec implements TextMessageCodec { + + @Override + public boolean supports(Type type) { + return type.equals(Item.class); + } + + @Override + public String encode(Item value) { + // Intentionally skip the name + return new JsonObject().put("count", value.getCount()).encode(); + } + + @Override + public Item decode(Type type, String value) { + throw new UnsupportedOperationException(); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/TextInputCodecTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/TextInputCodecTest.java new file mode 100644 index 0000000000000..572c777810a52 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/TextInputCodecTest.java @@ -0,0 +1,64 @@ +package io.quarkus.websockets.next.test.codec; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.vertx.core.Vertx; +import io.vertx.core.http.WebSocket; +import io.vertx.core.http.WebSocketClient; +import io.vertx.core.json.JsonArray; +import io.vertx.core.json.JsonObject; + +public class TextInputCodecTest { + + @TestHTTPResource("find-input-codec") + URI itemCodecUri; + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(FindInputCodec.class, FindInputCodec.MyInputCodec.class, AbstractFind.class, Item.class); + }); + + @Test + public void testCodec() throws Exception { + JsonArray items = new JsonArray(); + items.add(new JsonObject().put("name", "foo").put("count", 10)); + items.add(new JsonObject().put("name", "bar").put("count", 1)); + items.add(new JsonObject().put("name", "baz").put("count", 100)); + assertCodec(itemCodecUri, items.encode(), "Item [count=2]"); + } + + public void assertCodec(URI testUri, String payload, String expected) + throws Exception { + Vertx vertx = Vertx.vertx(); + WebSocketClient client = vertx.createWebSocketClient(); + try { + LinkedBlockingDeque message = new LinkedBlockingDeque<>(); + client + .connect(testUri.getPort(), testUri.getHost(), testUri.getPath()) + .onComplete(r -> { + if (r.succeeded()) { + WebSocket ws = r.result(); + ws.textMessageHandler(msg -> { + message.add(msg); + }); + ws.writeTextMessage(payload); + } else { + throw new IllegalStateException(r.cause()); + } + }); + assertEquals(expected, message.poll(10, TimeUnit.SECONDS)); + } finally { + client.close().toCompletionStage().toCompletableFuture().get(); + } + } +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/TextOutputCodecTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/TextOutputCodecTest.java new file mode 100644 index 0000000000000..66b04e7c2b277 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/codec/TextOutputCodecTest.java @@ -0,0 +1,64 @@ +package io.quarkus.websockets.next.test.codec; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.vertx.core.Vertx; +import io.vertx.core.http.WebSocket; +import io.vertx.core.http.WebSocketClient; +import io.vertx.core.json.JsonArray; +import io.vertx.core.json.JsonObject; + +public class TextOutputCodecTest { + + @TestHTTPResource("find-output-codec") + URI itemCodecUri; + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(FindOutputCodec.class, FindOutputCodec.MyOutputCodec.class, AbstractFind.class, Item.class); + }); + + @Test + public void testCodec() throws Exception { + JsonArray items = new JsonArray(); + items.add(new JsonObject().put("name", "foo").put("count", 10)); + items.add(new JsonObject().put("name", "bar").put("count", 1)); + items.add(new JsonObject().put("name", "baz").put("count", 100)); + assertCodec(itemCodecUri, items.encode(), "{\"name\":\"bar\"}"); + } + + public void assertCodec(URI testUri, String payload, String expected) + throws Exception { + Vertx vertx = Vertx.vertx(); + WebSocketClient client = vertx.createWebSocketClient(); + try { + LinkedBlockingDeque message = new LinkedBlockingDeque<>(); + client + .connect(testUri.getPort(), testUri.getHost(), testUri.getPath()) + .onComplete(r -> { + if (r.succeeded()) { + WebSocket ws = r.result(); + ws.textMessageHandler(msg -> { + message.add(msg); + }); + ws.writeTextMessage(payload); + } else { + throw new IllegalStateException(r.cause()); + } + }); + assertEquals(expected, message.poll(10, TimeUnit.SECONDS)); + } finally { + client.close().toCompletionStage().toCompletableFuture().get(); + } + } +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/connection/ConnectionCloseTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/connection/ConnectionCloseTest.java new file mode 100644 index 0000000000000..d97124e4d42ca --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/connection/ConnectionCloseTest.java @@ -0,0 +1,100 @@ +package io.quarkus.websockets.next.test.connection; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import jakarta.inject.Inject; + +import org.awaitility.Awaitility; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Vertx; + +public class ConnectionCloseTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Closing.class, ClosingBlocking.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("closing") + URI closingUri; + + @TestHTTPResource("closing-blocking") + URI closingBlockingUri; + + @Test + public void testClosed() throws InterruptedException { + assertClosed(closingUri); + assertTrue(Closing.CLOSED.await(5, TimeUnit.SECONDS)); + } + + @Test + public void testClosedBlocking() throws InterruptedException { + assertClosed(closingBlockingUri); + assertTrue(ClosingBlocking.CLOSED.await(5, TimeUnit.SECONDS)); + } + + private void assertClosed(URI testUri) throws InterruptedException { + WSClient client = WSClient.create(vertx).connect(testUri); + client.sendAndAwait("foo"); + Awaitility.await().atMost(5, TimeUnit.SECONDS).until(() -> client.isClosed()); + } + + @WebSocket(path = "/closing") + public static class Closing { + + static final CountDownLatch CLOSED = new CountDownLatch(1); + + @Inject + WebSocketServerConnection connection; + + @OnMessage + public Uni onMessage(String message) { + return connection.close(); + } + + @OnClose + void onClose() { + CLOSED.countDown(); + } + + } + + @WebSocket(path = "/closing-blocking") + public static class ClosingBlocking { + + static final CountDownLatch CLOSED = new CountDownLatch(1); + + @Inject + WebSocketServerConnection connection; + + @OnMessage + public void onMessage(String message) { + connection.closeAndAwait(); + } + + @OnClose + void onClose() { + CLOSED.countDown(); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/connection/ServiceConnectionScopeTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/connection/ServiceConnectionScopeTest.java new file mode 100644 index 0000000000000..fd9c2d9608588 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/connection/ServiceConnectionScopeTest.java @@ -0,0 +1,82 @@ +package io.quarkus.websockets.next.test.connection; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.net.URI; + +import jakarta.enterprise.context.ContextNotActiveException; +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.SessionScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.arc.Arc; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class ServiceConnectionScopeTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(MyEndpoint.class, WSClient.class); + }); + + @Inject + MyEndpoint endpoint; + + @Inject + Vertx vertx; + + @TestHTTPResource("/") + URI baseUri; + + @Test + void verifyThatConnectionIsNotAccessibleOutsideOfTheSessionScope() { + endpoint.testConnectionNotAccessibleOutsideOfWsMethods(); + } + + @Test + void verifyThatConnectionIsAccessibleInSessionScope() { + WSClient client = WSClient.create(vertx); + var resp = client.connect(WSClient.toWS(baseUri, "/ws")) + .sendAndAwaitReply("hello"); + assertThat(resp.toString()).isEqualTo("HELLO"); + } + + @WebSocket(path = "/ws") + public static class MyEndpoint { + + @Inject + WebSocketServerConnection connection; + + @OnMessage + public String onMessage(String message) { + assertNotNull(Arc.container().getActiveContext(SessionScoped.class)); + assertNotNull(Arc.container().getActiveContext(RequestScoped.class)); + assertNotNull(connection.id()); + return message.toUpperCase(); + } + + @ActivateRequestContext + void testConnectionNotAccessibleOutsideOfWsMethods() { + assertNull(Arc.container().getActiveContext(SessionScoped.class)); + assertNotNull(Arc.container().getActiveContext(RequestScoped.class)); + // WebSocketServerConnection is @SessionScoped + assertThrows(ContextNotActiveException.class, () -> connection.id()); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/EmptyEndpointTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/EmptyEndpointTest.java new file mode 100644 index 0000000000000..eb13b8cc5ff2e --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/EmptyEndpointTest.java @@ -0,0 +1,31 @@ +package io.quarkus.websockets.next.test.endpoints; + +import static org.junit.jupiter.api.Assertions.fail; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class EmptyEndpointTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(EmptyEndpoint.class); + }) + .setExpectedException(WebSocketServerException.class); + + @Test + void verifyThatEndpointWithoutAnyMethodFailsToDeploy() { + fail(); + } + + @WebSocket(path = "/ws") + public static class EmptyEndpoint { + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/EmptySubEndpointTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/EmptySubEndpointTest.java new file mode 100644 index 0000000000000..be1e60bdc3d0b --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/EmptySubEndpointTest.java @@ -0,0 +1,40 @@ +package io.quarkus.websockets.next.test.endpoints; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class EmptySubEndpointTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(ParentEndpoint.class, ParentEndpoint.EmptySubEndpoint.class); + }) + .setExpectedException(WebSocketServerException.class); + + @Test + void verifyThatSubEndpointWithoutAnyMethodFailsToDeploy() { + + } + + @WebSocket(path = "/ws") + public static class ParentEndpoint { + + @OnMessage + public void onMessage(String message) { + // Ignored. + } + + @WebSocket(path = "/sub") + public static class EmptySubEndpoint { + + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/NoOnOpenOrOnMessageInSubEndpointTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/NoOnOpenOrOnMessageInSubEndpointTest.java new file mode 100644 index 0000000000000..7e694f667f845 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/NoOnOpenOrOnMessageInSubEndpointTest.java @@ -0,0 +1,46 @@ +package io.quarkus.websockets.next.test.endpoints; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class NoOnOpenOrOnMessageInSubEndpointTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(ParentEndpoint.class, ParentEndpoint.SubEndpointWithoutOnOpenAndOnMessage.class); + }) + .setExpectedException(WebSocketServerException.class); + + @Test + void verifyThatSubEndpointWithoutOnOpenOrOnMessageFailsToDeploy() { + + } + + @WebSocket(path = "/ws") + public static class ParentEndpoint { + + @OnMessage + public void onMessage(String message) { + // Ignored. + } + + @WebSocket(path = "/sub") + public static class SubEndpointWithoutOnOpenAndOnMessage { + + @OnClose + public void onClose() { + // Ignored. + } + + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/NoOnOpenOrOnMessageTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/NoOnOpenOrOnMessageTest.java new file mode 100644 index 0000000000000..7133c3d1cae60 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/NoOnOpenOrOnMessageTest.java @@ -0,0 +1,36 @@ +package io.quarkus.websockets.next.test.endpoints; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class NoOnOpenOrOnMessageTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(NoOnOpenOrOnMessage.class); + }) + .setExpectedException(WebSocketServerException.class); + + @Test + void verifyThatEndpointWithoutOnMessageOrOnOpenFailsToDeploy() { + + } + + @WebSocket(path = "/ws") + public static class NoOnOpenOrOnMessage { + + // Invalid endpoint, must have at least one @OnOpen or @OnMessage method. + + @OnClose + public void onClose() { + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnCloseInSubEndpointTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnCloseInSubEndpointTest.java new file mode 100644 index 0000000000000..b42bc821913d0 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnCloseInSubEndpointTest.java @@ -0,0 +1,56 @@ +package io.quarkus.websockets.next.test.endpoints; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class TooManyOnCloseInSubEndpointTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(ParentEndpoint.class, ParentEndpoint.SubEndpointWithTooManyOnClose.class); + }) + .setExpectedException(WebSocketServerException.class); + + @Test + void verifyThatSubEndpointWithoutTooManyOnCloseFailsToDeploy() { + + } + + @WebSocket(path = "/ws") + public static class ParentEndpoint { + + @OnMessage + public void onMessage(String message) { + // Ignored. + } + + @WebSocket(path = "/sub") + public static class SubEndpointWithTooManyOnClose { + + @OnMessage + public void onMessage(String message) { + // Ignored. + } + + @OnClose + public void onClose() { + // Ignored. + } + + @OnClose + public void onClose2() { + // Ignored. + } + + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnCloseTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnCloseTest.java new file mode 100644 index 0000000000000..275c70c6230b7 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnCloseTest.java @@ -0,0 +1,41 @@ +package io.quarkus.websockets.next.test.endpoints; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class TooManyOnCloseTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(TooManyOnClose.class); + }) + .setExpectedException(WebSocketServerException.class); + + @Test + void verifyThatEndpointWithMultipleOnCloseMethodsFailsToDeploy() { + + } + + @WebSocket(path = "/ws") + public static class TooManyOnClose { + @OnOpen + public void onOpen() { + } + + @OnClose + public void onClose() { + } + + @OnClose + public void onClose2() { + } + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnMessageInSubEndpointTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnMessageInSubEndpointTest.java new file mode 100644 index 0000000000000..055112d15b002 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnMessageInSubEndpointTest.java @@ -0,0 +1,49 @@ +package io.quarkus.websockets.next.test.endpoints; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class TooManyOnMessageInSubEndpointTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(ParentEndpoint.class, ParentEndpoint.SubEndpointWithTooManyOnMessage.class); + }) + .setExpectedException(WebSocketServerException.class); + + @Test + void verifyThatSubEndpointWithoutTooManyOnMessageFailsToDeploy() { + + } + + @WebSocket(path = "/ws") + public static class ParentEndpoint { + + @OnMessage + public void onMessage(String message) { + // Ignored. + } + + @WebSocket(path = "/sub") + public static class SubEndpointWithTooManyOnMessage { + + @OnMessage + public void onMessage(String message) { + // Ignored. + } + + @OnMessage + public void onMessage2(String message) { + // Ignored. + } + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnMessageTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnMessageTest.java new file mode 100644 index 0000000000000..cfeba5c20456f --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnMessageTest.java @@ -0,0 +1,36 @@ +package io.quarkus.websockets.next.test.endpoints; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class TooManyOnMessageTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(TooManyOnMessage.class); + }) + .setExpectedException(WebSocketServerException.class); + + @Test + void verifyThatEndpointWithMultipleOnMessageMethodsFailsToDeploy() { + + } + + @WebSocket(path = "/ws") + public static class TooManyOnMessage { + @OnMessage + public void onMessage(String message) { + } + + @OnMessage + public void onMessage2(String message) { + } + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnOpenInSubEndpointTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnOpenInSubEndpointTest.java new file mode 100644 index 0000000000000..db96dccea7d46 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnOpenInSubEndpointTest.java @@ -0,0 +1,56 @@ +package io.quarkus.websockets.next.test.endpoints; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class TooManyOnOpenInSubEndpointTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(ParentEndpoint.class, ParentEndpoint.SubEndpointWithTooManyOnOpen.class); + }) + .setExpectedException(WebSocketServerException.class); + + @Test + void verifyThatSubEndpointWithoutTooManyOnOpenFailsToDeploy() { + + } + + @WebSocket(path = "/ws") + public static class ParentEndpoint { + + @OnMessage + public void onMessage(String message) { + // Ignored. + } + + @WebSocket(path = "/sub") + public static class SubEndpointWithTooManyOnOpen { + + @OnMessage + public void onMessage(String message) { + // Ignored. + } + + @OnOpen + public void onOpen() { + // Ignored. + } + + @OnOpen + public void onOpen2() { + // Ignored. + } + + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnOpenTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnOpenTest.java new file mode 100644 index 0000000000000..5fd41b498f817 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnOpenTest.java @@ -0,0 +1,36 @@ +package io.quarkus.websockets.next.test.endpoints; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class TooManyOnOpenTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(TooManyOnOpen.class); + }) + .setExpectedException(WebSocketServerException.class); + + @Test + void verifyThatEndpointWithMultipleOnOpenMethodsFailsToDeploy() { + + } + + @WebSocket(path = "/ws") + public static class TooManyOnOpen { + @OnOpen + public void onOpen() { + } + + @OnOpen + public void onOpen2() { + } + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/executionmode/ConcurrentExecutionModeTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/executionmode/ConcurrentExecutionModeTest.java new file mode 100644 index 0000000000000..7591417786dbe --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/executionmode/ConcurrentExecutionModeTest.java @@ -0,0 +1,67 @@ +package io.quarkus.websockets.next.test.executionmode; + +import static io.quarkus.websockets.next.WebSocket.ExecutionMode.CONCURRENT; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class ConcurrentExecutionModeTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Sim.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("sim") + URI simUri; + + @Test + void testSimultaneousExecution() { + WSClient client = WSClient.create(vertx).connect(simUri); + client.send("1"); + client.send("2"); + client.send("3"); + client.send("4"); + client.waitForMessages(4); + for (int i = 0; i < 4; i++) { + assertEquals("ok", client.getMessages().get(i).toString()); + } + } + + @WebSocket(path = "/sim", executionMode = CONCURRENT) + public static class Sim { + + private final CountDownLatch latch = new CountDownLatch(4); + + @OnMessage + String process(String message) throws InterruptedException { + latch.countDown(); + // Now wait for other messages to arrive + if (latch.await(10, TimeUnit.SECONDS)) { + return "ok"; + } else { + return "" + latch.getCount(); + } + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/handshake/HandshakeRequestTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/handshake/HandshakeRequestTest.java new file mode 100644 index 0000000000000..794348ab39bf4 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/handshake/HandshakeRequestTest.java @@ -0,0 +1,75 @@ +package io.quarkus.websockets.next.test.handshake; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; +import io.vertx.core.http.WebSocketConnectOptions; +import io.vertx.core.json.JsonObject; + +public class HandshakeRequestTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Head.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("/") + URI baseUri; + + @Test + void testHandshake() { + String header = "fool"; + String query = "name=Lu"; + WSClient client = WSClient.create(vertx).connect(new WebSocketConnectOptions().addHeader("X-Test", header), + WSClient.toWS(baseUri, "/head?" + query)); + JsonObject reply = client.sendAndAwaitReply("1").toJsonObject(); + assertEquals(header, reply.getString("header")); + assertEquals(header, reply.getJsonObject("headers").getString("X-Test".toLowerCase()), + reply.getJsonObject("headers").toString()); + assertEquals(baseUri.getScheme(), reply.getString("scheme")); + assertEquals(baseUri.getHost(), reply.getString("host")); + assertEquals(baseUri.getPort(), reply.getInteger("port")); + assertEquals("/head", reply.getString("path")); + assertEquals(query, reply.getString("query")); + } + + @WebSocket(path = "/head") + public static class Head { + + @Inject + WebSocketServerConnection connection; + + @OnMessage + JsonObject process(String message) throws InterruptedException { + JsonObject headers = new JsonObject(); + connection.handshakeRequest().headers().forEach((k, v) -> headers.put(k, v.get(0))); + return new JsonObject() + .put("header", connection.handshakeRequest().header("X-Test")) + .put("headers", headers) + .put("scheme", connection.handshakeRequest().scheme()) + .put("host", connection.handshakeRequest().host()) + .put("port", connection.handshakeRequest().port()) + .put("path", connection.handshakeRequest().path()) + .put("query", connection.handshakeRequest().query()); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextListener.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextListener.java new file mode 100644 index 0000000000000..4d7e7b10eb1d1 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextListener.java @@ -0,0 +1,40 @@ +package io.quarkus.websockets.next.test.requestcontext; + +import java.lang.annotation.Annotation; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CopyOnWriteArrayList; + +import jakarta.enterprise.context.BeforeDestroyed; +import jakarta.enterprise.context.Destroyed; +import jakarta.enterprise.context.Initialized; +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.event.Observes; +import jakarta.enterprise.inject.spi.EventMetadata; +import jakarta.inject.Singleton; + +@Singleton +public class RequestContextListener { + + final List events = new CopyOnWriteArrayList<>(); + + void clear() { + events.clear(); + } + + void init(@Observes @Initialized(RequestScoped.class) Object event, EventMetadata metadata) { + events.add(new ContextEvent(metadata.getQualifiers(), event.toString())); + } + + void beforeDestroy(@Observes @BeforeDestroyed(RequestScoped.class) Object event, EventMetadata metadata) { + events.add(new ContextEvent(metadata.getQualifiers(), event.toString())); + } + + void destroy(@Observes @Destroyed(RequestScoped.class) Object event, EventMetadata metadata) { + events.add(new ContextEvent(metadata.getQualifiers(), event.toString())); + } + + record ContextEvent(Set qualifiers, String payload) { + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextTest.java new file mode 100644 index 0000000000000..dd44d4c5c8dc9 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestContextTest.java @@ -0,0 +1,109 @@ +package io.quarkus.websockets.next.test.requestcontext; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.concurrent.TimeUnit; + +import jakarta.enterprise.context.BeforeDestroyed; +import jakarta.enterprise.context.Destroyed; +import jakarta.enterprise.context.Initialized; +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Vertx; + +public class RequestContextTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(AppendBlocking.class, WSClient.class, RequestScopedBean.class, RequestContextListener.class); + }) + // Disable SR Context Propagation for ArC, otherwise there could be a mess in context lifecycle events + .overrideConfigKey("quarkus.arc.context-propagation.enabled", "false"); + + @Inject + Vertx vertx; + + @TestHTTPResource("append") + URI appendUri; + + @TestHTTPResource("append-blocking") + URI appendBlockingUri; + + @Inject + RequestContextListener listener; + + @Test + void testRequestContext() throws InterruptedException { + assertRequestContext(appendUri); + } + + @Test + void testRequestContextBlocking() throws InterruptedException { + assertRequestContext(appendBlockingUri); + } + + private void assertRequestContext(URI testUri) throws InterruptedException { + // Remove all events that could be fired due to startup observers + listener.clear(); + RequestScopedBean.COUNTER.set(0); + + WSClient client = WSClient.create(vertx).connect(testUri); + client.send("foo"); + client.send("bar"); + client.send("baz"); + client.waitForMessages(3); + assertEquals("foo:1", client.getMessages().get(0).toString()); + assertEquals("bar:2", client.getMessages().get(1).toString()); + assertEquals("baz:3", client.getMessages().get(2).toString()); + client.disconnect(); + assertTrue(RequestScopedBean.DESTROYED_LATCH.await(5, TimeUnit.SECONDS), + "Latch count: " + RequestScopedBean.DESTROYED_LATCH.getCount()); + assertEquals(9, listener.events.size()); + assertTrue(listener.events.get(0).qualifiers().contains(Initialized.Literal.REQUEST)); + assertTrue(listener.events.get(1).qualifiers().contains(BeforeDestroyed.Literal.REQUEST)); + assertTrue(listener.events.get(2).qualifiers().contains(Destroyed.Literal.REQUEST)); + assertTrue(listener.events.get(3).qualifiers().contains(Initialized.Literal.REQUEST)); + assertTrue(listener.events.get(4).qualifiers().contains(BeforeDestroyed.Literal.REQUEST)); + assertTrue(listener.events.get(5).qualifiers().contains(Destroyed.Literal.REQUEST)); + assertTrue(listener.events.get(6).qualifiers().contains(Initialized.Literal.REQUEST)); + assertTrue(listener.events.get(7).qualifiers().contains(BeforeDestroyed.Literal.REQUEST)); + assertTrue(listener.events.get(8).qualifiers().contains(Destroyed.Literal.REQUEST)); + } + + @WebSocket(path = "/append-blocking") + public static class AppendBlocking { + + @Inject + RequestScopedBean bean; + + @OnMessage + String process(String message) throws InterruptedException { + return bean.appendId(message); + } + } + + @WebSocket(path = "/append") + public static class Append { + + @Inject + RequestScopedBean bean; + + @OnMessage + Uni process(String message) throws InterruptedException { + return Uni.createFrom().item(() -> bean.appendId(message)); + } + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestScopedBean.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestScopedBean.java new file mode 100644 index 0000000000000..8d816d6c14d7c --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestScopedBean.java @@ -0,0 +1,32 @@ +package io.quarkus.websockets.next.test.requestcontext; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicInteger; + +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.enterprise.context.RequestScoped; + +@RequestScoped +public class RequestScopedBean { + + static final CountDownLatch DESTROYED_LATCH = new CountDownLatch(3); + static final AtomicInteger COUNTER = new AtomicInteger(); + + private int id; + + @PostConstruct + void init() { + id = COUNTER.incrementAndGet(); + } + + public String appendId(String message) { + return message + ":" + id; + } + + @PreDestroy + void destroy() { + DESTROYED_LATCH.countDown(); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/sessioncontext/SessionContextListener.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/sessioncontext/SessionContextListener.java new file mode 100644 index 0000000000000..325dbdfd15824 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/sessioncontext/SessionContextListener.java @@ -0,0 +1,39 @@ +package io.quarkus.websockets.next.test.sessioncontext; + +import java.lang.annotation.Annotation; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; + +import jakarta.enterprise.context.BeforeDestroyed; +import jakarta.enterprise.context.Destroyed; +import jakarta.enterprise.context.Initialized; +import jakarta.enterprise.context.SessionScoped; +import jakarta.enterprise.event.Observes; +import jakarta.enterprise.inject.spi.EventMetadata; +import jakarta.inject.Singleton; + +@Singleton +public class SessionContextListener { + + final CountDownLatch destroyLatch = new CountDownLatch(1); + final List events = new CopyOnWriteArrayList<>(); + + void init(@Observes @Initialized(SessionScoped.class) Object event, EventMetadata metadata) { + events.add(new ContextEvent(metadata.getQualifiers(), event.toString())); + } + + void beforeDestroy(@Observes @BeforeDestroyed(SessionScoped.class) Object event, EventMetadata metadata) { + events.add(new ContextEvent(metadata.getQualifiers(), event.toString())); + } + + void destroy(@Observes @Destroyed(SessionScoped.class) Object event, EventMetadata metadata) { + events.add(new ContextEvent(metadata.getQualifiers(), event.toString())); + destroyLatch.countDown(); + } + + record ContextEvent(Set qualifiers, String payload) { + }; + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/sessioncontext/SessionContextTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/sessioncontext/SessionContextTest.java new file mode 100644 index 0000000000000..c855c26452a1e --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/sessioncontext/SessionContextTest.java @@ -0,0 +1,74 @@ +package io.quarkus.websockets.next.test.sessioncontext; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.concurrent.TimeUnit; + +import jakarta.enterprise.context.BeforeDestroyed; +import jakarta.enterprise.context.Destroyed; +import jakarta.enterprise.context.Initialized; +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class SessionContextTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Append.class, WSClient.class, SessionScopedBean.class, SessionContextListener.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("append") + URI appendUri; + + @Inject + SessionContextListener listener; + + @Test + void testSessionContext() throws InterruptedException { + WSClient client = WSClient.create(vertx).connect(appendUri); + client.send("foo"); + client.send("bar"); + client.send("baz"); + client.waitForMessages(3); + assertEquals("foo", client.getMessages().get(0).toString()); + assertEquals("foobar", client.getMessages().get(1).toString()); + assertEquals("foobarbaz", client.getMessages().get(2).toString()); + client.disconnect(); + assertTrue(listener.destroyLatch.await(5, TimeUnit.SECONDS)); + assertTrue(SessionScopedBean.DESTROYED.get()); + assertEquals(3, listener.events.size()); + assertEquals(listener.events.get(0).payload(), listener.events.get(1).payload()); + assertEquals(listener.events.get(1).payload(), listener.events.get(2).payload()); + assertTrue(listener.events.get(0).qualifiers().contains(Initialized.Literal.SESSION)); + assertTrue(listener.events.get(1).qualifiers().contains(BeforeDestroyed.Literal.SESSION)); + assertTrue(listener.events.get(2).qualifiers().contains(Destroyed.Literal.SESSION)); + } + + @WebSocket(path = "/append") + public static class Append { + + @Inject + SessionScopedBean bean; + + @OnMessage + String process(String message) throws InterruptedException { + return bean.appendAndGet(message); + } + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/sessioncontext/SessionScopedBean.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/sessioncontext/SessionScopedBean.java new file mode 100644 index 0000000000000..0fc8d5a0541a5 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/sessioncontext/SessionScopedBean.java @@ -0,0 +1,25 @@ +package io.quarkus.websockets.next.test.sessioncontext; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import jakarta.annotation.PreDestroy; +import jakarta.enterprise.context.SessionScoped; + +@SessionScoped +public class SessionScopedBean { + + static final AtomicBoolean DESTROYED = new AtomicBoolean(); + + private final AtomicReference lastMessage = new AtomicReference<>(""); + + public String appendAndGet(String message) { + return lastMessage.accumulateAndGet(message, (s1, s2) -> s1 + s2); + } + + @PreDestroy + void destroy() { + DESTROYED.set(true); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/signatures/SignatureConsumingMultiTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/signatures/SignatureConsumingMultiTest.java new file mode 100644 index 0000000000000..1f7f1819e56aa --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/signatures/SignatureConsumingMultiTest.java @@ -0,0 +1,112 @@ +package io.quarkus.websockets.next.test.signatures; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +import java.net.URI; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import jakarta.annotation.PostConstruct; +import jakarta.enterprise.context.RequestScoped; +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.common.vertx.VertxContext; +import io.smallrye.mutiny.Multi; +import io.vertx.core.Context; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; + +public class SignatureConsumingMultiTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(BiDirectional.class, WSClient.class, RequestScopedBean.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("/") + URI uri; + + @Test + void verifyExecutionOfOnMessageWhenConsumingAndReturningMultis() { + WSClient client = WSClient.create(vertx).connect(WSClient.toWS(uri, "/ws/%s/%d".formatted("bi-directional", 3))); + + for (int i = 0; i < 10; i++) { + client.sendAndAwait("hello" + i); + } + + await().until(() -> client.getMessages().size() == 10); + assertThat(client.getMessages().stream().map(Buffer::toString).collect(Collectors.toList())) + .containsExactlyInAnyOrderElementsOf( + IntStream.range(0, 10).mapToObj(id -> "WS " + 3 + " received: hello" + id) + .collect(Collectors.toList())); + } + + @WebSocket(path = "/ws/bi-directional/{id}") + public static class BiDirectional { + + @Inject + WebSocketServerConnection connection; + + @Inject + RequestScopedBean requestScopedBean; + + volatile Context context = null; + + @OnMessage + Multi process(Multi multi) { + assertThat(Context.isOnEventLoopThread()).isTrue(); + Context context = Vertx.currentContext(); + assertThat(context).isNotNull(); + assertThat(VertxContext.isOnDuplicatedContext()).isTrue(); + String requestScopedBeanId = requestScopedBean.getId(); + + assertThat(connection).isNotNull(); + if (this.context == null) { + this.context = context; + } + + return multi.map(s -> { + assertThat(this.context).isSameAs(Vertx.currentContext()); + // For bi-directional streams we subscribe to the returned Multi during onOpen + // As a result the same duplicated context is used but a new request context is activated/terminated per each message processing + assertNotEquals(requestScopedBeanId, requestScopedBean.getId()); + String id = connection.pathParam("id"); + return "WS " + id + " received: " + s; + }); + } + + } + + @RequestScoped + public static class RequestScopedBean { + + private String id; + + @PostConstruct + void init() { + id = UUID.randomUUID().toString(); + } + + String getId() { + return id; + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/signatures/SignatureTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/signatures/SignatureTest.java new file mode 100644 index 0000000000000..a40d8cdf8d1f3 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/signatures/SignatureTest.java @@ -0,0 +1,123 @@ +package io.quarkus.websockets.next.test.signatures; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.net.URI; +import java.util.stream.Stream; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.common.vertx.VertxContext; +import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.Uni; +import io.smallrye.mutiny.infrastructure.Infrastructure; +import io.vertx.core.Context; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; + +public class SignatureTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(MethodReturningString.class, UniWs.class, MultiWs.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("/") + URI uri; + + private static Stream methods() { + return Stream.of( + Arguments.of("string", 1), + Arguments.of("uni", 2), + Arguments.of("multi", 3)); + } + + @ParameterizedTest(name = "{index} Checking the reception of message for method returning {0}") + @MethodSource("methods") + void verifyExecutionOfOnMessage(String path, int id) { + WSClient client = WSClient.create(vertx).connect(WSClient.toWS(uri, "/ws/%s/%d".formatted(path, id))); + Buffer resp = client.sendAndAwaitReply("hello"); + assertThat(resp.toString()).isEqualTo("WS " + id + " received: hello"); + } + + @WebSocket(path = "/ws/string/{id}") + public static class MethodReturningString { + + @Inject + WebSocketServerConnection connection; + + @OnMessage + String process(String message) { + assertThat(Context.isOnEventLoopThread()).isFalse(); + assertThat(Vertx.currentContext()).isNotNull(); + assertThat(VertxContext.isOnDuplicatedContext()).isTrue(); + + assertThat(connection).isNotNull(); + + String id = connection.pathParam("id"); + return "WS " + id + " received: " + message; + } + + } + + @WebSocket(path = "/ws/uni/{id}") + public static class UniWs { + + @Inject + WebSocketServerConnection connection; + + @OnMessage + Uni process(String message) { + assertThat(Context.isOnEventLoopThread()).isTrue(); + Context context = Vertx.currentContext(); + assertThat(context).isNotNull(); + assertThat(VertxContext.isOnDuplicatedContext()).isTrue(); + + assertThat(connection).isNotNull(); + + return Uni.createFrom().item(() -> { + assertThat(context).isSameAs(Vertx.currentContext()); + String id = connection.pathParam("id"); + return "WS " + id + " received: " + message; + }).emitOn(Infrastructure.getDefaultExecutor()); + } + } + + @WebSocket(path = "/ws/multi/{id}") + public static class MultiWs { + + @Inject + WebSocketServerConnection connection; + + @OnMessage + Multi process(String message) { + assertThat(Context.isOnEventLoopThread()).isTrue(); + Context context = Vertx.currentContext(); + assertThat(context).isNotNull(); + assertThat(VertxContext.isOnDuplicatedContext()).isTrue(); + + assertThat(connection).isNotNull(); + + return Multi.createFrom().item(() -> { + assertThat(context).isSameAs(Vertx.currentContext()); + String id = connection.pathParam("id"); + return "WS " + id + " received: " + message; + }).emitOn(Infrastructure.getDefaultExecutor()); + } + } +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/subsocket/Sub.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/subsocket/Sub.java new file mode 100644 index 0000000000000..f4f45f004f043 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/subsocket/Sub.java @@ -0,0 +1,49 @@ +package io.quarkus.websockets.next.test.subsocket; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import jakarta.inject.Inject; + +import io.quarkus.websockets.next.OnMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Context; + +@WebSocket(path = "/sub") +public class Sub { + + @OnMessage + Uni echo(String msg) { + assertTrue(Context.isOnEventLoopThread()); + return Uni.createFrom().item(msg); + } + + @WebSocket(path = "/sub/{id}") + public static class SubSub { + + @Inject + WebSocketServerConnection connection; + + @OnMessage + Uni echo(String msg) { + assertTrue(Context.isOnEventLoopThread()); + return Uni.createFrom().item(connection.pathParam("id") + ":" + msg); + } + + @WebSocket(path = "/sub/{name}") + public static class SubSubSub { + + @Inject + WebSocketServerConnection connection; + + @OnMessage + Uni echo(String msg) { + assertTrue(Context.isOnEventLoopThread()); + return Uni.createFrom().item(connection.pathParam("id") + ":" + connection.pathParam("name") + ":" + msg); + } + + } + + } +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/subsocket/SubWebSocketTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/subsocket/SubWebSocketTest.java new file mode 100644 index 0000000000000..e2663b3813c09 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/subsocket/SubWebSocketTest.java @@ -0,0 +1,69 @@ +package io.quarkus.websockets.next.test.subsocket; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.test.subsocket.Sub.SubSub; +import io.quarkus.websockets.next.test.subsocket.Sub.SubSub.SubSubSub; +import io.vertx.core.Vertx; +import io.vertx.core.http.WebSocket; +import io.vertx.core.http.WebSocketClient; + +public class SubWebSocketTest { + + @TestHTTPResource("sub") + URI echoUri; + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Sub.class, SubSub.class, SubSubSub.class); + }); + + @Test + public void testSub() throws Exception { + assertEcho(echoUri, "", "hello", "hello"); + } + + @Test + public void testSubSub() throws Exception { + assertEcho(echoUri, "/sub/1", "hello", "1:hello"); + } + + @Test + public void testSubSubSub() throws Exception { + assertEcho(echoUri, "/sub/1/sub/foo", "hello", "1:foo:hello"); + } + + public void assertEcho(URI testUri, String path, String payload, String expected) throws Exception { + Vertx vertx = Vertx.vertx(); + WebSocketClient client = vertx.createWebSocketClient(); + try { + LinkedBlockingDeque message = new LinkedBlockingDeque<>(); + client + .connect(testUri.getPort(), testUri.getHost(), testUri.getPath() + path) + .onComplete(r -> { + if (r.succeeded()) { + WebSocket ws = r.result(); + ws.textMessageHandler(msg -> { + message.add(msg); + }); + ws.writeTextMessage(payload); + } else { + throw new IllegalStateException(r.cause()); + } + }); + assertEquals(expected, message.poll(10, TimeUnit.SECONDS)); + } finally { + client.close().toCompletionStage().toCompletableFuture().get(); + } + } +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java new file mode 100644 index 0000000000000..d0fd4fad46c08 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java @@ -0,0 +1,127 @@ +package io.quarkus.websockets.next.test.utils; + +import java.net.URI; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicReference; + +import org.awaitility.Awaitility; + +import io.vertx.core.Future; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.WebSocket; +import io.vertx.core.http.WebSocketClient; +import io.vertx.core.http.WebSocketConnectOptions; + +public class WSClient { + + private final WebSocketClient client; + private AtomicReference socket = new AtomicReference<>(); + private List messages = new CopyOnWriteArrayList<>(); + + public WSClient(Vertx vertx) { + this.client = vertx.createWebSocketClient(); + } + + public static WSClient create(Vertx vertx) { + return new WSClient(vertx); + } + + public static URI toWS(URI uri, String path) { + String result = "ws://"; + if (uri.getScheme().equals("https")) { + result = "wss://"; + } + if (path.startsWith("/")) { + path = path.substring(1); + } + result += uri.getHost() + ":" + uri.getPort() + "/" + path; + try { + return new URI(result); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public WSClient connect(WebSocketConnectOptions options, URI url) { + StringBuilder uri = new StringBuilder(); + uri.append(url.getPath()); + if (url.getQuery() != null) { + uri.append("?").append(url.getQuery()); + } + WebSocket webSocket = await( + client.connect(options.setPort(url.getPort()).setHost(url.getHost()).setURI(uri.toString()))); + var prev = socket.getAndSet(webSocket); + if (prev != null) { + messages.clear(); + await(prev.close()); + } + webSocket.handler(b -> messages.add(b)); + return this; + } + + public WSClient connect(URI url) { + return connect(new WebSocketConnectOptions(), url); + } + + public Future send(String message) { + return socket.get().writeTextMessage(message); + } + + public void sendAndAwait(String message) { + await(send(message)); + } + + public Future send(Buffer message) { + return socket.get().writeBinaryMessage(message); + } + + public void sendAndAwait(Buffer message) { + await(send(message)); + } + + public List getMessages() { + return messages; + } + + public Buffer getLastMessage() { + if (messages.isEmpty()) { + return null; + } + return messages.get(messages.size() - 1); + } + + public Buffer waitForNextMessage() { + var c = messages.size(); + Awaitility.await().until(() -> messages.size() > c); + return messages.get(c); + } + + public void waitForMessages(int count) { + Awaitility.await().until(() -> messages.size() >= count); + } + + public void disconnect() { + WebSocket current = socket.getAndSet(null); + if (current != null) { + await(current.close()); + } + messages.clear(); + } + + private T await(Future future) { + return future.toCompletionStage().toCompletableFuture().join(); + } + + public Buffer sendAndAwaitReply(String message) { + var c = messages.size(); + sendAndAwait(message); + Awaitility.await().until(() -> messages.size() > c); + return messages.get(c); + } + + public boolean isClosed() { + return socket.get().isClosed(); + } +} diff --git a/extensions/websockets-next/server/pom.xml b/extensions/websockets-next/server/pom.xml new file mode 100644 index 0000000000000..6b8892d9e8458 --- /dev/null +++ b/extensions/websockets-next/server/pom.xml @@ -0,0 +1,21 @@ + + + + quarkus-extensions-parent + io.quarkus + 999-SNAPSHOT + ../../pom.xml + + 4.0.0 + + quarkus-websockets-next-parent + Quarkus - WebSockets Next + pom + + deployment + runtime + + + diff --git a/extensions/websockets-next/server/runtime/pom.xml b/extensions/websockets-next/server/runtime/pom.xml new file mode 100644 index 0000000000000..db665dca03aea --- /dev/null +++ b/extensions/websockets-next/server/runtime/pom.xml @@ -0,0 +1,55 @@ + + + + quarkus-websockets-next-parent + io.quarkus + 999-SNAPSHOT + + 4.0.0 + + quarkus-websockets-next + Quarkus - WebSockets Next - Runtime + + + + io.quarkus + quarkus-core + + + io.quarkus + quarkus-vertx-http + + + io.quarkus + quarkus-jackson + + + + + + + io.quarkus + quarkus-extension-maven-plugin + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${project.version} + + + + + + + diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryMessage.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryMessage.java new file mode 100644 index 0000000000000..0d1a069ce2eae --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryMessage.java @@ -0,0 +1,58 @@ +package io.quarkus.websockets.next; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import io.smallrye.common.annotation.Experimental; + +/** + * The annotated method consumes/produces binary messages. + *

+ * A binary message is always represented as a {@link io.vertx.core.buffer.Buffer}. Therefore, the following conversion rules + * apply. The types listed below are handled specifically. For all other types a {@link BinaryMessageCodec} is used to encode + * and decode input and + * output messages. By default, the first input codec that supports the message type is used; codecs with higher priority go + * first. However, a specific codec can be selected with {@link #inputCodec()} and {@link #outputCodec()}. + * + *

    + *
  • {@code java.lang.Buffer} is used as is,
  • + *
  • {@code byte[]} is encoded with {@link io.vertx.core.buffer.Buffer#buffer(byte[])} and decoded with + * {@link io.vertx.core.buffer.Buffer#getBytes()},
  • + *
  • {@code java.lang.String} is encoded with {@link io.vertx.core.buffer.Buffer#buffer(String)} and decoded with + * {@link io.vertx.core.buffer.Buffer#toString()},
  • + *
  • {@code io.vertx.core.json.JsonObject} is encoded with {@link io.vertx.core.json.JsonObject#toBuffer()} and decoded with + * {@link io.vertx.core.json.JsonObject#JsonObject(io.vertx.core.buffer.Buffer)}.
  • + *
  • {@code io.vertx.core.json.JsonArray} is encoded with {@link io.vertx.core.json.JsonArray#toBuffer()} and decoded with + * {@link io.vertx.core.json.JsonArray#JsonArray(io.vertx.core.buffer.Buffer)}.
  • + *

    + * + * @see BinaryMessageCodec + */ +@Retention(RUNTIME) +@Target(METHOD) +@Experimental("This API is experimental and may change in the future") +public @interface BinaryMessage { + + /** + * The codec used for input messages. + *

    + * By default, the first codec that supports the message type is used; codecs with higher priority go first. + *

    + * Note that, if specified, the codec is also used for output messages unless {@link #outputCodec()} returns a non-default + * value. + */ + @SuppressWarnings("rawtypes") + Class inputCodec() default BinaryMessageCodec.class; + + /** + * The codec used for output messages. + *

    + * By default, the same codec as for the input message is used. + */ + @SuppressWarnings("rawtypes") + Class outputCodec() default BinaryMessageCodec.class; + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryMessageCodec.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryMessageCodec.java new file mode 100644 index 0000000000000..2755e92239ed3 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryMessageCodec.java @@ -0,0 +1,15 @@ +package io.quarkus.websockets.next; + +import io.smallrye.common.annotation.Experimental; +import io.vertx.core.buffer.Buffer; + +/** + * Used to encode and decode binary messages. + * + * @param + * @see BinaryMessage + */ +@Experimental("This API is experimental and may change in the future") +public interface BinaryMessageCodec extends MessageCodec { + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BlockingSender.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BlockingSender.java new file mode 100644 index 0000000000000..67f0d1a471def --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BlockingSender.java @@ -0,0 +1,50 @@ +package io.quarkus.websockets.next; + +import io.smallrye.common.annotation.Experimental; +import io.vertx.core.buffer.Buffer; + +/** + * Sends a message to the connected WebSocket client and waits for the completion. + *

    + * Note that blocking sender methods should never be called on an event loop thread. + */ +@Experimental("This API is experimental and may change in the future") +public interface BlockingSender extends Sender { + + /** + * Sends a text message and waits for the completion. + * + * @param message + */ + default void sendTextAndAwait(String message) { + sendText(message).await().indefinitely(); + } + + /** + * Sends a text message and waits for the completion. + * + * @param + * @param message + */ + default void sendTextAndAwait(M message) { + sendText(message).await().indefinitely(); + } + + /** + * Sends a binary message and waits for the completion. + * + * @param message + */ + default void sendBinaryAndAwait(Buffer message) { + sendBinary(message).await().indefinitely(); + } + + /** + * Sends a binary message and waits for the completion. + * + * @param message + */ + default void sendBinaryAndAwait(byte[] message) { + sendBinary(message).await().indefinitely(); + } +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/MessageCodec.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/MessageCodec.java new file mode 100644 index 0000000000000..eaf84d311e8c4 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/MessageCodec.java @@ -0,0 +1,56 @@ +package io.quarkus.websockets.next; + +import java.lang.reflect.Type; + +import io.smallrye.common.annotation.Experimental; + +/** + * Used to encode and decode messages. + * + *

    Special types of messages

    + * Some types of messages bypass the encoding/decoding process: + *
      + *
    • {@code java.lang.Buffer},
    • + *
    • {@code byte[]},
    • + *
    • {@code java.lang.String},
    • + *
    • {@code io.vertx.core.json.JsonObject}.
    • + *
    • {@code io.vertx.core.json.JsonArray}.
    • + *
    + * The encoding/decoding details are described in {@link BinaryMessage} and {@link TextMessage}. + * + *

    CDI beans

    + * Implementation classes must be CDI beans. Qualifiers are ignored. {@link jakarta.enterprise.context.Dependent} beans are + * reused during encoding/decoding. + * + *

    Lifecycle and concurrency

    + * Codecs are shared accross all WebSocket connections. Therefore, implementations should be either stateless or thread-safe. + * + * @param + * @param + */ +@Experimental("This API is experimental and may change in the future") +public interface MessageCodec { + + /** + * + * @param type the type to handle, must not be {@code null} + * @return {@code true} if this codec can encode/decode the provided type, {@code false} otherwise + */ + boolean supports(Type type); + + /** + * + * @param value the value to encode, must not be {@code null} + * @return the encoded representation of the value + */ + MESSAGE encode(T value); + + /** + * + * @param type the type of the object to decode, must not be {@code null} + * @param value the value to decode, must not be {@code null} + * @return the decoded representation of the value + */ + T decode(Type type, MESSAGE value); + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnClose.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnClose.java new file mode 100644 index 0000000000000..b370a2a747f13 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnClose.java @@ -0,0 +1,22 @@ +package io.quarkus.websockets.next; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import io.smallrye.common.annotation.Experimental; + +/** + * A method of an {@link WebSocket} endpoint annotated with this annotation is invoked when the client disconnects from the + * socket. + *

    + * An endpoint may only declare one method annotated with this annotation. + */ +@Retention(RUNTIME) +@Target(METHOD) +@Experimental("This API is experimental and may change in the future") +public @interface OnClose { + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnMessage.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnMessage.java new file mode 100644 index 0000000000000..7079ab430ea32 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnMessage.java @@ -0,0 +1,28 @@ +package io.quarkus.websockets.next; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import io.smallrye.common.annotation.Experimental; + +/** + * A method of an {@link WebSocket} endpoint annotated with this annotation is invoked when an incoming message is received. + *

    + * An endpoint may only declare one method annotated with this annotation. + */ +@Retention(RUNTIME) +@Target(METHOD) +@Experimental("This API is experimental and may change in the future") +public @interface OnMessage { + + /** + * + * @return {@code true} if all the connected clients should receive the objects emitted by the annotated method + * @see WebSocketServerConnection#broadcast() + */ + public boolean broadcast() default false; + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnOpen.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnOpen.java new file mode 100644 index 0000000000000..a16d1ec35feb8 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnOpen.java @@ -0,0 +1,28 @@ +package io.quarkus.websockets.next; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import io.smallrye.common.annotation.Experimental; + +/** + * A method of an {@link WebSocket} endpoint annotated with this annotation is invoked when the client connects to a web socket + * endpoint. + *

    + * An endpoint may only declare one method annotated with this annotation. + */ +@Retention(RUNTIME) +@Target(METHOD) +@Experimental("This API is experimental and may change in the future") +public @interface OnOpen { + + /** + * @return {@code true} if all the connected clients should receive the objects emitted by the annotated method + * @see WebSocketServerConnection#broadcast() + */ + public boolean broadcast() default false; + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/Sender.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/Sender.java new file mode 100644 index 0000000000000..b09795f239c4e --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/Sender.java @@ -0,0 +1,52 @@ +package io.quarkus.websockets.next; + +import io.smallrye.common.annotation.CheckReturnValue; +import io.smallrye.common.annotation.Experimental; +import io.smallrye.mutiny.Uni; +import io.vertx.core.buffer.Buffer; + +/** + * Sends a message to the connected WebSocket client. + */ +@Experimental("This API is experimental and may change in the future") +public interface Sender { + + /** + * Send a text message. + * + * @param message + * @return a new {@link Uni} with a {@code null} item + */ + @CheckReturnValue + Uni sendText(String message); + + /** + * Send a text message. + * + * @param + * @param message + * @return a new {@link Uni} with a {@code null} item + */ + @CheckReturnValue + Uni sendText(M message); + + /** + * Send a binary message. + * + * @param message + * @return a new {@link Uni} with a {@code null} item + */ + @CheckReturnValue + Uni sendBinary(Buffer message); + + /** + * Send a binary message. + * + * @param message + */ + @CheckReturnValue + default Uni sendBinary(byte[] message) { + return sendBinary(Buffer.buffer(message)); + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextMessage.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextMessage.java new file mode 100644 index 0000000000000..25599a36e587b --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextMessage.java @@ -0,0 +1,56 @@ +package io.quarkus.websockets.next; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import io.smallrye.common.annotation.Experimental; + +/** + * The annotated method consumes/produces text messages. + *

    + * A text message is always represented as a {@link String}. Therefore, the following conversion rules apply. The types listed + * below are handled specifically. For all other types a {@link TextMessageCodec} is used to encode and decode input and + * output messages. By default, the first input codec that supports the message type is used; codecs with higher priority go + * first. However, a specific codec can be selected with {@link #inputCodec()} and {@link #outputCodec()}. + * + *

      + *
    • {@code java.lang.String} is used as is,
    • + *
    • {@code io.vertx.core.json.JsonObject} is encoded with {@link io.vertx.core.json.JsonObject#encode()} and decoded with + * {@link io.vertx.core.json.JsonObject#JsonObject(String))}.
    • + *
    • {@code io.vertx.core.json.JsonArray} is encoded with {@link io.vertx.core.json.JsonArray#encode()} and decoded with + * {@link io.vertx.core.json.JsonArray#JsonArray(String))}.
    • + *
    • {@code java.lang.Buffer} is encoded with {@link io.vertx.core.buffer.Buffer#toString()} and decoded with + * {@link io.vertx.core.buffer.Buffer#buffer(String)},
    • + *
    • {@code byte[]} is first converted to {@link io.vertx.core.buffer.Buffer} and then converted as defined above,
    • + *

      + * + * @see TextMessageCodec + */ +@Retention(RUNTIME) +@Target(METHOD) +@Experimental("This API is experimental and may change in the future") +public @interface TextMessage { + + /** + * The codec used for input messages. + *

      + * By default, the first codec that supports the message type is used; codecs with higher priority go first. + *

      + * Note that, if specified, the codec is also used for output messages unless {@link #outputCodec()} returns a non-default + * value. + */ + @SuppressWarnings("rawtypes") + Class inputCodec() default TextMessageCodec.class; + + /** + * The codec used for output messages. + *

      + * By default, the same codec as for the input message is used. + */ + @SuppressWarnings("rawtypes") + Class outputCodec() default TextMessageCodec.class; + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextMessageCodec.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextMessageCodec.java new file mode 100644 index 0000000000000..1146f686c3553 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextMessageCodec.java @@ -0,0 +1,14 @@ +package io.quarkus.websockets.next; + +import io.smallrye.common.annotation.Experimental; + +/** + * Used to encode and decode text messages. + * + * @param + * @see TextMessage + */ +@Experimental("This API is experimental and may change in the future") +public interface TextMessageCodec extends MessageCodec { + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocket.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocket.java new file mode 100644 index 0000000000000..78f1849b6e2d3 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocket.java @@ -0,0 +1,55 @@ +package io.quarkus.websockets.next; + +import static java.lang.annotation.ElementType.TYPE; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import io.smallrye.common.annotation.Experimental; + +/** + * Denotes a WebSocket endpoint. + */ +@Retention(RUNTIME) +@Target(TYPE) +@Experimental("This API is experimental and may change in the future") +public @interface WebSocket { + + /** + * The path of the endpoint. + *

      + * It is possible to match path parameters. The placeholder of a path parameter consists of the parameter name surrounded by + * curly brackets. The actual value of a path parameter can be obtained using + * {@link WebSocketServerConnection#pathParam(String)}. For example, the path /foo/{bar} defines the path + * parameter {@code bar}. + * + * @see WebSocketServerConnection#pathParam(String) + */ + public String path(); + + /** + * The execution mode used to process incoming messages for a specific connection. + */ + public ExecutionMode executionMode() default ExecutionMode.SERIAL; + + /** + * Defines the execution mode used to process incoming messages for a specific connection. + * + * @see WebSocketServerConnection + */ + enum ExecutionMode { + + /** + * Messages are processed serially, ordering is guaranteed. + */ + SERIAL, + + /** + * Messages are processed concurrently, there are no ordering guarantees. + */ + CONCURRENT, + + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketServerConnection.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketServerConnection.java new file mode 100644 index 0000000000000..78e631307e27a --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketServerConnection.java @@ -0,0 +1,171 @@ +package io.quarkus.websockets.next; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Predicate; + +import io.smallrye.common.annotation.CheckReturnValue; +import io.smallrye.common.annotation.Experimental; +import io.smallrye.mutiny.Uni; + +/** + * This interface represents a connection from a client to a specific {@link WebSocket} endpoint on the server. + *

      + * Quarkus provides a built-in CDI bean of type {@code WebSocketServerConnection} that can be injected in a {@link WebSocket} + * endpoint and used to interact with the connected client, or all clients connected to the endpoint respectively + * (broadcasting). + *

      + * Specifically, it is possible to send messages using blocking and non-blocking methods, declared on + * {@link BlockingSender} and {@link Sender} respectively. + */ +@Experimental("This API is experimental and may change in the future") +public interface WebSocketServerConnection extends Sender, BlockingSender { + + /** + * + * @return the unique identifier assigned to this connection + */ + String id(); + + /** + * + * @param name + * @return the actual value of the path parameter or null + * @see WebSocket#path() + */ + String pathParam(String name); + + /** + * Sends messages to all open clients connected to the same WebSocket endpoint. + * + * @return the broadcast sender + * @see #getOpenConnections() + */ + BroadcastSender broadcast(); + + /** + * Sends messages to all open clients connected to the same WebSocket endpoint and matching the given filter predicate. + * + * @param filter + * @return the broadcast sender + * @see #getOpenConnections() + */ + BroadcastSender broadcast(Predicate filter); + + /** + * The returned set also includes the connection this method is called upon. + * + * @return the set of open connections to the same endpoint + */ + Set getOpenConnections(); + + /** + * @return {@code true} if the HTTP connection is encrypted via SSL/TLS + */ + boolean isSecure(); + + /** + * @return {@code true} if the WebSocket is closed + */ + boolean isClosed(); + + /** + * + * @return {@code true} if the WebSocket is open + */ + default boolean isOpen() { + return !isClosed(); + } + + /** + * Close the connection. + * + * @return a new {@link Uni} with a {@code null} item + */ + @CheckReturnValue + Uni close(); + + /** + * Close the connection. + */ + default void closeAndAwait() { + close().await().indefinitely(); + } + + /** + * + * @return the handshake request + */ + HandshakeRequest handshakeRequest(); + + /** + * Makes it possible to send messages to all clients connected to the same WebSocket endpoint. + * + * @see WebSocketServerConnection#getOpenConnections() + */ + interface BroadcastSender extends Sender, BlockingSender { + + } + + /** + * Provides some useful information about the initial handshake request. + */ + interface HandshakeRequest { + + /** + * The name is case insensitive. + * + * @param name + * @return the first header value for the given header name, or {@code null} + */ + String header(String name); + + /** + * The name is case insensitive. + * + * @param name + * @return an immutable list of header values for the given header name, never {@code null} + */ + List headers(String name); + + /** + * Returned header names are lower case. + * + * @return an immutable map of header names to header values + */ + Map> headers(); + + /** + * + * @return the scheme + */ + String scheme(); + + /** + * + * @return the host + */ + String host(); + + /** + * + * @return the port + */ + int port(); + + /** + * + * @return the path + */ + String path(); + + /** + * + * @return the query string + */ + String query(); + + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketServerException.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketServerException.java new file mode 100644 index 0000000000000..c226d78983d47 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketServerException.java @@ -0,0 +1,22 @@ +package io.quarkus.websockets.next; + +import io.smallrye.common.annotation.Experimental; + +@Experimental("This API is experimental and may change in the future") +public class WebSocketServerException extends RuntimeException { + + private static final long serialVersionUID = 903932032264812404L; + + public WebSocketServerException(String message, Throwable cause) { + super(message, cause); + } + + public WebSocketServerException(String message) { + super(message); + } + + public WebSocketServerException(Throwable cause) { + super(cause); + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsRuntimeConfig.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsRuntimeConfig.java new file mode 100644 index 0000000000000..6ef568f6345f7 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsRuntimeConfig.java @@ -0,0 +1,21 @@ +package io.quarkus.websockets.next; + +import java.time.Duration; +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigPhase; +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; + +@ConfigMapping(prefix = "quarkus.websockets-next") +@ConfigRoot(phase = ConfigPhase.RUN_TIME) +public interface WebSocketsRuntimeConfig { + + /** + * TODO Not implemented yet. + * + * The default timeout to complete processing of a message. + */ + Optional timeout(); + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/Codecs.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/Codecs.java new file mode 100644 index 0000000000000..4bb5d61c9a8fd --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/Codecs.java @@ -0,0 +1,169 @@ +package io.quarkus.websockets.next.runtime; + +import java.lang.reflect.Type; +import java.util.List; + +import jakarta.inject.Singleton; + +import org.jboss.logging.Logger; + +import io.quarkus.arc.All; +import io.quarkus.websockets.next.BinaryMessageCodec; +import io.quarkus.websockets.next.MessageCodec; +import io.quarkus.websockets.next.TextMessageCodec; +import io.vertx.core.buffer.Buffer; + +@Singleton +public class Codecs { + + private static final Logger LOG = Logger.getLogger(Codecs.class); + + @All + List> textCodecs; + + @All + List> binaryCodecs; + + public Object textDecode(Type type, String value, Class codecBeanClass) { + if (codecBeanClass != null) { + for (TextMessageCodec codec : textCodecs) { + if (codec.getClass().equals(codecBeanClass)) { + if (!codec.supports(type)) { + throw forcedCannotHandle(false, codec, type); + } + try { + return codec.decode(type, value); + } catch (Exception e) { + throw unableToDecode(false, codec, e); + } + } + } + } else { + for (TextMessageCodec codec : textCodecs) { + if (codec.supports(type)) { + try { + return codec.decode(type, value); + } catch (Exception e) { + throw unableToDecode(false, codec, e); + } + } + } + } + + throw noCodec(false, type); + } + + public String textEncode(T message, Class codecBeanClass) { + Class type = message.getClass(); + if (codecBeanClass != null) { + for (TextMessageCodec codec : textCodecs) { + if (codec.getClass().equals(codecBeanClass)) { + if (!codec.supports(type)) { + throw forcedCannotHandle(false, codec, type); + } + try { + return codec.encode(cast(message)); + } catch (Exception e) { + throw unableToEncode(false, codec, e); + } + } + } + } else { + for (TextMessageCodec codec : textCodecs) { + if (codec.supports(type)) { + try { + return codec.encode(cast(message)); + } catch (Exception e) { + throw unableToEncode(false, codec, e); + } + } + } + } + throw noCodec(false, type); + } + + public Object binaryDecode(Type type, Buffer value, Class codecBeanClass) { + if (codecBeanClass != null) { + for (BinaryMessageCodec codec : binaryCodecs) { + if (codec.getClass().equals(codecBeanClass)) { + if (!codec.supports(type)) { + throw forcedCannotHandle(false, codec, type); + } + try { + return codec.decode(type, value); + } catch (Exception e) { + throw unableToDecode(false, codec, e); + } + } + } + } else { + for (BinaryMessageCodec codec : binaryCodecs) { + if (codec.supports(type)) { + try { + return codec.decode(type, value); + } catch (Exception e) { + LOG.errorf(e, "Unable to decode binary message with %s", codec.getClass().getName()); + } + } + } + } + throw noCodec(true, type); + } + + public Buffer binaryEncode(T message, Class codecBeanClass) { + Class type = message.getClass(); + if (codecBeanClass != null) { + for (BinaryMessageCodec codec : binaryCodecs) { + if (codec.getClass().equals(codecBeanClass)) { + if (!codec.supports(type)) { + throw forcedCannotHandle(false, codec, type); + } + try { + return codec.encode(cast(message)); + } catch (Exception e) { + throw unableToEncode(false, codec, e); + } + } + } + } else { + for (BinaryMessageCodec codec : binaryCodecs) { + if (codec.supports(type)) { + try { + return codec.encode(cast(message)); + } catch (Exception e) { + throw unableToEncode(true, codec, e); + } + } + } + } + throw noCodec(true, type); + } + + IllegalStateException noCodec(boolean binary, Type type) { + String message = String.format("No %s codec handles the type %s", binary ? "binary" : "text", type); + throw new IllegalStateException(message); + } + + IllegalStateException unableToEncode(boolean binary, MessageCodec codec, Exception e) { + String message = String.format("Unable to encode %s message with %s", binary ? "binary" : "text", + codec.getClass().getName()); + throw new IllegalStateException(message, e); + } + + IllegalStateException unableToDecode(boolean binary, MessageCodec codec, Exception e) { + String message = String.format("Unable to decode %s message with %s", binary ? "binary" : "text", + codec.getClass().getName()); + throw new IllegalStateException(message, e); + } + + IllegalStateException forcedCannotHandle(boolean binary, MessageCodec codec, Type type) { + throw new IllegalStateException( + String.format("Forced %s codec [%s] cannot handle the type %s", binary ? "binary" : "text", + codec.getClass().getName(), type)); + } + + @SuppressWarnings("unchecked") + static T cast(Object obj) { + return (T) obj; + } +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConcurrencyLimiter.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConcurrencyLimiter.java new file mode 100644 index 0000000000000..67ba402aab090 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConcurrencyLimiter.java @@ -0,0 +1,102 @@ +package io.quarkus.websockets.next.runtime; + +import java.util.Queue; +import java.util.concurrent.atomic.AtomicLong; + +import org.jboss.logging.Logger; + +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.smallrye.mutiny.helpers.queues.Queues; +import io.vertx.core.Context; +import io.vertx.core.Handler; +import io.vertx.core.Promise; + +/** + * Used to limit concurrent invocations. + */ +class ConcurrencyLimiter { + + private static final Logger LOG = Logger.getLogger(ConcurrencyLimiter.class); + + private final WebSocketServerConnection connection; + private final Queue queue; + private final AtomicLong uncompleted; + private final AtomicLong queueCounter; + + ConcurrencyLimiter(WebSocketServerConnection connection) { + this.connection = connection; + this.uncompleted = new AtomicLong(); + this.queueCounter = new AtomicLong(); + this.queue = Queues.createMpscQueue(); + } + + /** + * This method must be always used before {@link #run(Runnable)} and the returned callback must be always invoked when an + * async computation completes. + * + * @param promise + * @return a new callback to complete the given promise + */ + PromiseComplete newComplete(Promise promise) { + return new PromiseComplete(promise); + } + + /** + * Run or queue up the given action. + * + * @param action + * @param context + */ + void run(Context context, Runnable action) { + if (uncompleted.compareAndSet(0, 1)) { + LOG.debugf("Run action: %s", connection); + action.run(); + } else { + long queueIndex = queueCounter.incrementAndGet(); + LOG.debugf("Action queued as %s: %s", queueIndex, connection); + queue.offer(new Action(queueIndex, action, context)); + // We need to make sure that at least one completion is in flight + if (uncompleted.getAndIncrement() == 0) { + Action queuedAction = queue.poll(); + assert queuedAction != null; + LOG.debugf("Run action %s from queue: %s", queuedAction.queueIndex, connection); + queuedAction.runnable.run(); + } + } + } + + class PromiseComplete { + + final Promise promise; + + private PromiseComplete(Promise promise) { + this.promise = promise; + } + + void failure(Throwable t) { + complete(); + } + + void complete() { + try { + promise.complete(); + } finally { + if (uncompleted.decrementAndGet() == 0) { + return; + } + Action queuedAction = queue.poll(); + assert queuedAction != null; + LOG.debugf("Run action %s from queue: %s", queuedAction.queueIndex, connection); + queuedAction.context.runOnContext(new Handler() { + @Override + public void handle(Void event) { + queuedAction.runnable.run(); + } + }); + } + } + } + + record Action(long queueIndex, Runnable runnable, Context context) { + } +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConnectionManager.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConnectionManager.java new file mode 100644 index 0000000000000..71f368e6f5188 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConnectionManager.java @@ -0,0 +1,37 @@ +package io.quarkus.websockets.next.runtime; + +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import jakarta.annotation.PreDestroy; +import jakarta.inject.Singleton; + +import io.quarkus.websockets.next.WebSocketServerConnection; + +@Singleton +public class ConnectionManager { + + private final ConcurrentMap> endpointToConnections = new ConcurrentHashMap<>(); + + void add(String endpoint, WebSocketServerConnection connection) { + endpointToConnections.computeIfAbsent(endpoint, e -> ConcurrentHashMap.newKeySet()).add(connection); + } + + void remove(String endpoint, WebSocketServerConnection connection) { + Set connections = endpointToConnections.get(endpoint); + if (connections != null) { + connections.remove(connection); + } + } + + Set getConnections(String endpoint) { + return endpointToConnections.get(endpoint); + } + + @PreDestroy + void destroy() { + endpointToConnections.clear(); + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java new file mode 100644 index 0000000000000..091f939c4de10 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java @@ -0,0 +1,67 @@ +package io.quarkus.websockets.next.runtime; + +import org.jboss.logging.Logger; + +import io.quarkus.arc.ManagedContext; +import io.quarkus.vertx.core.runtime.context.VertxContextSafetyToggle; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.quarkus.websockets.next.runtime.WebSocketSessionContext.SessionContextState; +import io.smallrye.common.vertx.VertxContext; +import io.vertx.core.Context; + +public class ContextSupport { + + private static final Logger LOG = Logger.getLogger(ContextSupport.class); + + private final WebSocketServerConnection connection; + private final SessionContextState sessionContextState; + private final WebSocketSessionContext sessionContext; + private final ManagedContext requestContext; + + ContextSupport(WebSocketServerConnection connection, SessionContextState sessionContextState, + WebSocketSessionContext sessionContext, + ManagedContext requestContext) { + this.connection = connection; + this.sessionContextState = sessionContextState; + this.sessionContext = sessionContext; + this.requestContext = requestContext; + } + + void start() { + LOG.debugf("Start contexts: %s", connection); + startSession(); + // Activate a new request context + requestContext.activate(); + } + + void startSession() { + // Activate the captured session context + sessionContext.activate(sessionContextState); + } + + void end(boolean terminateSession) { + LOG.debugf("End contexts: %s", connection); + requestContext.terminate(); + if (terminateSession) { + // OnClose - terminate the session context + endSession(); + } else { + sessionContext.deactivate(); + } + } + + void endSession() { + sessionContext.terminate(); + } + + static Context createNewDuplicatedContext(Context context, WebSocketServerConnection connection) { + Context duplicated = VertxContext.createNewDuplicatedContext(context); + VertxContextSafetyToggle.setContextSafe(duplicated, true); + // We need to store the connection in the duplicated context + // It's used to initialize the synthetic bean later on + duplicated.putLocal(WebSocketServerRecorder.WEB_SOCKET_CONN_KEY, connection); + LOG.debugf("New vertx duplicated context [%s] created: %s", duplicated, connection); + return duplicated; + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/JsonTextMessageCodec.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/JsonTextMessageCodec.java new file mode 100644 index 0000000000000..52ae78ffca171 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/JsonTextMessageCodec.java @@ -0,0 +1,53 @@ +package io.quarkus.websockets.next.runtime; + +import java.lang.reflect.Type; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import jakarta.annotation.Priority; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.quarkus.websockets.next.TextMessageCodec; + +@Singleton +@Priority(0) +public class JsonTextMessageCodec implements TextMessageCodec { + + private final ConcurrentMap types = new ConcurrentHashMap<>(); + + @Inject + ObjectMapper mapper; + + @Override + public String encode(Object value) { + try { + return mapper.writeValueAsString(value); + } catch (JsonProcessingException e) { + throw new IllegalStateException(e); + } + } + + @Override + public Object decode(Type type, String value) { + try { + return mapper.readValue(value, types.computeIfAbsent(type, this::computeJavaType)); + } catch (JsonProcessingException e) { + throw new IllegalStateException(e); + } + } + + private JavaType computeJavaType(Type type) { + return mapper.getTypeFactory().constructType(type); + } + + @Override + public boolean supports(Type type) { + return true; + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java new file mode 100644 index 0000000000000..6c767595b0bb2 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java @@ -0,0 +1,65 @@ +package io.quarkus.websockets.next.runtime; + +import java.lang.reflect.Type; + +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.WebSocket; +import io.vertx.core.Future; + +/** + * Internal representation of a WebSocket endpoint. + *

      + * A new instance is created for each client connection. {@link #onOpen()}, {@link #onMessage(Object)} and {@link OnClose} are + * always executed on a new vertx duplicated context. + */ +public interface WebSocketEndpoint { + + WebSocket.ExecutionMode executionMode(); + + Future onOpen(); + + default ExecutionModel onOpenExecutionModel() { + return ExecutionModel.NONE; + } + + Future onMessage(Object message); + + default ExecutionModel onMessageExecutionModel() { + return ExecutionModel.NONE; + } + + Future onClose(); + + default ExecutionModel onCloseExecutionModel() { + return ExecutionModel.NONE; + } + + default MessageType consumedMessageType() { + return MessageType.NONE; + } + + default Type consumedMultiType() { + return null; + } + + default Object decodeMultiItem(Object message) { + throw new UnsupportedOperationException(); + } + + enum ExecutionModel { + WORKER_THREAD, + VIRTUAL_THREAD, + EVENT_LOOP, + NONE; + + boolean isBlocking() { + return this == WORKER_THREAD || this == VIRTUAL_THREAD; + } + } + + enum MessageType { + NONE, + TEXT, + BINARY + } +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java new file mode 100644 index 0000000000000..1f65e10b36203 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java @@ -0,0 +1,231 @@ +package io.quarkus.websockets.next.runtime; + +import java.lang.reflect.Type; +import java.util.concurrent.Callable; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.jboss.logging.Logger; + +import io.quarkus.arc.Arc; +import io.quarkus.arc.ArcContainer; +import io.quarkus.virtual.threads.VirtualThreadsRecorder; +import io.quarkus.websockets.next.WebSocket.ExecutionMode; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.quarkus.websockets.next.WebSocketsRuntimeConfig; +import io.quarkus.websockets.next.runtime.ConcurrencyLimiter.PromiseComplete; +import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Context; +import io.vertx.core.Future; +import io.vertx.core.Handler; +import io.vertx.core.Promise; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; + +public abstract class WebSocketEndpointBase implements WebSocketEndpoint { + + private static final Logger LOG = Logger.getLogger(WebSocketEndpointBase.class); + + protected final WebSocketServerConnection connection; + + protected final Codecs codecs; + + private final ConcurrencyLimiter limiter; + + @SuppressWarnings("unused") + private final WebSocketsRuntimeConfig config; + + private final ArcContainer container; + + private final ContextSupport contextSupport; + + public WebSocketEndpointBase(WebSocketServerConnection connection, Codecs codecs, + WebSocketsRuntimeConfig config, ContextSupport contextSupport) { + this.connection = connection; + this.codecs = codecs; + this.limiter = executionMode() == ExecutionMode.SERIAL ? new ConcurrencyLimiter(connection) : null; + this.config = config; + this.container = Arc.container(); + this.contextSupport = contextSupport; + } + + @Override + public Future onOpen() { + return execute(null, onOpenExecutionModel(), this::doOnOpen, false); + } + + @Override + public Future onMessage(Object message) { + return execute(message, onMessageExecutionModel(), this::doOnMessage, false); + } + + @Override + public Future onClose() { + return execute(null, onCloseExecutionModel(), this::doOnClose, true); + } + + private Future execute(Object message, ExecutionModel executionModel, + Function> action, boolean terminateSession) { + if (executionModel == ExecutionModel.NONE) { + if (terminateSession) { + // Just start and terminate the session context + contextSupport.startSession(); + contextSupport.endSession(); + } + return Future.succeededFuture(); + } + Promise promise = Promise.promise(); + Context context = Vertx.currentContext(); + if (limiter != null) { + PromiseComplete complete = limiter.newComplete(promise); + limiter.run(context, new Runnable() { + @Override + public void run() { + doExecute(context, promise, message, executionModel, action, terminateSession, complete::complete, + complete::failure); + } + }); + } else { + // No need to limit the concurrency + doExecute(context, promise, message, executionModel, action, terminateSession, promise::complete, promise::fail); + } + return promise.future(); + } + + private Future doExecute(Context context, Promise promise, Object message, ExecutionModel executionModel, + Function> action, boolean terminateSession, Runnable onComplete, + Consumer onFailure) { + Handler contextSupportEnd = executionModel.isBlocking() ? new Handler() { + + @Override + public void handle(Void event) { + contextSupport.end(terminateSession); + } + } : null; + + if (executionModel == ExecutionModel.VIRTUAL_THREAD) { + VirtualThreadsRecorder.getCurrent().execute(new Runnable() { + @Override + public void run() { + Context context = Vertx.currentContext(); + contextSupport.start(); + action.apply(message).subscribe().with( + v -> { + context.runOnContext(contextSupportEnd); + onComplete.run(); + }, + t -> { + context.runOnContext(contextSupportEnd); + onFailure.accept(t); + }); + } + }); + } else if (executionModel == ExecutionModel.WORKER_THREAD) { + context.executeBlocking(new Callable() { + @Override + public Void call() { + Context context = Vertx.currentContext(); + contextSupport.start(); + action.apply(message).subscribe().with( + v -> { + context.runOnContext(contextSupportEnd); + onComplete.run(); + }, + t -> { + context.runOnContext(contextSupportEnd); + onFailure.accept(t); + }); + return null; + } + }, false); + } else { + // Event loop + contextSupport.start(); + action.apply(message).subscribe().with( + v -> { + contextSupport.end(terminateSession); + onComplete.run(); + }, + t -> { + contextSupport.end(terminateSession); + onFailure.accept(t); + }); + } + return null; + } + + // TODO This implementation of timeout does not help a lot + // Should we emit on the current context? + // io.smallrye.mutiny.vertx.core.ContextAwareScheduler + // private Uni withTimeout(Uni action) { + // if (config.timeout().isEmpty()) { + // return action; + // } + // return action.ifNoItem().after(config.timeout().get()).fail(); + // } + + protected Object beanInstance(String identifier) { + return container.instance(container.bean(identifier)).get(); + } + + protected Uni doOnOpen(Object message) { + return Uni.createFrom().voidItem(); + } + + protected Uni doOnMessage(Object message) { + return Uni.createFrom().voidItem(); + } + + protected Uni doOnClose(Object message) { + return Uni.createFrom().voidItem(); + } + + protected Object decodeText(Type type, String value, Class codecBeanClass) { + return codecs.textDecode(type, value, codecBeanClass); + } + + protected String encodeText(Object value, Class codecBeanClass) { + if (value == null) { + return null; + } + return codecs.textEncode(value, codecBeanClass); + } + + protected Object decodeBinary(Type type, Buffer value, Class codecBeanClass) { + return codecs.binaryDecode(type, value, codecBeanClass); + } + + protected Buffer encodeBinary(Object value, Class codecBeanClass) { + if (value == null) { + return null; + } + return codecs.binaryEncode(value, codecBeanClass); + } + + protected Uni sendText(String message, boolean broadcast) { + return broadcast ? connection.broadcast().sendText(message) : connection.sendText(message); + } + + protected Uni multiText(Multi multi, boolean broadcast, Function> itemFun) { + multi.onFailure().call(connection::close).subscribe().with( + m -> { + itemFun.apply(m).subscribe().with(v -> LOG.debugf("Multi >> text message: %s", connection), + t -> LOG.errorf(t, "Unable to send text message from Multi: %s", connection)); + }); + return Uni.createFrom().voidItem(); + } + + protected Uni sendBinary(Buffer message, boolean broadcast) { + return broadcast ? connection.broadcast().sendBinary(message) : connection.sendBinary(message); + } + + protected Uni multiBinary(Multi multi, boolean broadcast, Function> itemFun) { + multi.onFailure().call(connection::close).subscribe().with( + m -> { + itemFun.apply(m).subscribe().with(v -> LOG.debugf("Multi >> binary message: %s", connection), + t -> LOG.errorf(t, "Unable to send binary message from Multi: %s", connection)); + }); + return Uni.createFrom().voidItem(); + } +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerConnectionImpl.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerConnectionImpl.java new file mode 100644 index 0000000000000..b09fa85fcabee --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerConnectionImpl.java @@ -0,0 +1,257 @@ +package io.quarkus.websockets.next.runtime; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Set; +import java.util.UUID; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.smallrye.mutiny.Uni; +import io.smallrye.mutiny.vertx.UniHelper; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.ServerWebSocket; +import io.vertx.ext.web.RoutingContext; + +class WebSocketServerConnectionImpl implements WebSocketServerConnection { + + private final String endpoint; + + private final String identifier; + + private final ServerWebSocket webSocket; + + private final ConnectionManager connectionManager; + + private final Codecs codecs; + + private final Map pathParams; + + private final HandshakeRequest handshakeRequest; + + private final BroadcastSender defaultBroadcast; + + WebSocketServerConnectionImpl(String endpoint, ServerWebSocket webSocket, ConnectionManager connectionManager, + Codecs codecs, RoutingContext ctx) { + this.endpoint = endpoint; + this.identifier = UUID.randomUUID().toString(); + this.webSocket = Objects.requireNonNull(webSocket); + this.connectionManager = Objects.requireNonNull(connectionManager); + this.pathParams = Map.copyOf(ctx.pathParams()); + this.defaultBroadcast = new BroadcastImpl(null); + this.codecs = codecs; + this.handshakeRequest = new HandshakeRequestImpl(ctx); + } + + @Override + public String id() { + return identifier; + } + + @Override + public String pathParam(String name) { + return pathParams.get(name); + } + + @Override + public Uni sendText(String message) { + return UniHelper.toUni(webSocket.writeTextMessage(message)); + } + + @Override + public Uni sendBinary(Buffer message) { + return UniHelper.toUni(webSocket.writeBinaryMessage(message)); + } + + @Override + public Uni sendText(M message) { + return UniHelper.toUni(webSocket.writeTextMessage(codecs.textEncode(message, null).toString())); + } + + @Override + public BroadcastSender broadcast() { + return defaultBroadcast; + } + + @Override + public BroadcastSender broadcast(Predicate filter) { + return new BroadcastImpl(Objects.requireNonNull(filter)); + } + + @Override + public Uni close() { + return UniHelper.toUni(webSocket.close()); + } + + @Override + public boolean isSecure() { + return webSocket.isSsl(); + } + + @Override + public boolean isClosed() { + return webSocket.isClosed(); + } + + @Override + public Set getOpenConnections() { + return connectionManager.getConnections(endpoint).stream().filter(WebSocketServerConnection::isOpen) + .collect(Collectors.toUnmodifiableSet()); + } + + @Override + public HandshakeRequest handshakeRequest() { + return handshakeRequest; + } + + @Override + public String toString() { + return "WebSocket connection [id=" + identifier + ", path=" + webSocket.path() + "]"; + } + + @Override + public int hashCode() { + return Objects.hash(identifier); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + WebSocketServerConnectionImpl other = (WebSocketServerConnectionImpl) obj; + return Objects.equals(identifier, other.identifier); + } + + private class HandshakeRequestImpl implements HandshakeRequest { + + private final Map> headers; + + HandshakeRequestImpl(RoutingContext ctx) { + this.headers = initHeaders(ctx); + } + + @Override + public String header(String name) { + List values = headers(name); + return values.isEmpty() ? null : values.get(0); + } + + @Override + public List headers(String name) { + return headers.getOrDefault(Objects.requireNonNull(name).toLowerCase(), List.of()); + } + + @Override + public Map> headers() { + return headers; + } + + @Override + public String scheme() { + return webSocket.scheme(); + } + + @Override + public String host() { + return webSocket.authority().host(); + } + + @Override + public int port() { + return webSocket.authority().port(); + } + + @Override + public String path() { + return webSocket.path(); + } + + @Override + public String query() { + return webSocket.query(); + } + + static Map> initHeaders(RoutingContext ctx) { + Map> headers = new HashMap<>(); + for (Entry e : ctx.request().headers()) { + String key = e.getKey().toLowerCase(); + List values = headers.get(key); + if (values == null) { + values = new ArrayList<>(); + headers.put(key, values); + } + values.add(e.getValue()); + } + for (Entry> e : headers.entrySet()) { + // Make the list of values immutable + e.setValue(List.copyOf(e.getValue())); + } + return Map.copyOf(headers); + } + + } + + private class BroadcastImpl implements WebSocketServerConnection.BroadcastSender { + + private final Predicate filter; + + BroadcastImpl(Predicate filter) { + this.filter = filter; + } + + @Override + public Uni sendText(String message) { + return doSend(new Function>() { + + @Override + public Uni apply(WebSocketServerConnection c) { + return c.sendText(message); + } + }); + } + + @Override + public Uni sendText(M message) { + return doSend(new Function>() { + + @Override + public Uni apply(WebSocketServerConnection c) { + return c.sendText(message); + } + }); + } + + @Override + public Uni sendBinary(Buffer message) { + return doSend(new Function>() { + + @Override + public Uni apply(WebSocketServerConnection c) { + return c.sendBinary(message); + } + }); + } + + private Uni doSend(Function> function) { + List> unis = new ArrayList<>(); + for (WebSocketServerConnection connection : connectionManager.getConnections(endpoint)) { + if (connection.isOpen() && (filter == null || filter.test(connection))) { + unis.add(function.apply(connection)); + } + } + return unis.isEmpty() ? Uni.createFrom().voidItem() : Uni.join().all(unis).andFailFast().replaceWithVoid(); + } + + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java new file mode 100644 index 0000000000000..7367c55a97395 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java @@ -0,0 +1,255 @@ +package io.quarkus.websockets.next.runtime; + +import java.util.function.Consumer; +import java.util.function.Supplier; + +import jakarta.enterprise.context.SessionScoped; + +import org.jboss.logging.Logger; + +import io.quarkus.arc.Arc; +import io.quarkus.arc.ArcContainer; +import io.quarkus.arc.InjectableContext; +import io.quarkus.runtime.annotations.Recorder; +import io.quarkus.vertx.core.runtime.VertxCoreRecorder; +import io.quarkus.websockets.next.WebSocketServerConnection; +import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketsRuntimeConfig; +import io.quarkus.websockets.next.runtime.WebSocketEndpoint.MessageType; +import io.quarkus.websockets.next.runtime.WebSocketSessionContext.SessionContextState; +import io.smallrye.common.vertx.VertxContext; +import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.operators.multi.processors.BroadcastProcessor; +import io.vertx.core.Context; +import io.vertx.core.Future; +import io.vertx.core.Handler; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.ServerWebSocket; +import io.vertx.ext.web.RoutingContext; + +@Recorder +public class WebSocketServerRecorder { + + private static final Logger LOG = Logger.getLogger(WebSocketServerRecorder.class); + + static final String WEB_SOCKET_CONN_KEY = WebSocketServerConnection.class.getName(); + + private final WebSocketsRuntimeConfig config; + + public WebSocketServerRecorder(WebSocketsRuntimeConfig config) { + this.config = config; + } + + public Supplier connectionSupplier() { + return new Supplier() { + + @Override + public Object get() { + Context context = Vertx.currentContext(); + if (context != null && VertxContext.isDuplicatedContext(context)) { + Object connection = context.getLocal(WEB_SOCKET_CONN_KEY); + if (connection != null) { + return connection; + } + } + throw new WebSocketServerException("Unable to obtain the connection from the Vert.x duplicated context"); + } + }; + } + + public Handler createEndpointHandler(String endpointClass) { + ArcContainer container = Arc.container(); + ConnectionManager connectionManager = container.instance(ConnectionManager.class).get(); + Codecs codecs = container.instance(Codecs.class).get(); + return new Handler() { + + @Override + public void handle(RoutingContext ctx) { + Future future = ctx.request().toWebSocket(); + future.onSuccess(ws -> { + Context context = VertxCoreRecorder.getVertx().get().getOrCreateContext(); + + WebSocketServerConnection connection = new WebSocketServerConnectionImpl(endpointClass, ws, + connectionManager, codecs, ctx); + connectionManager.add(endpointClass, connection); + LOG.debugf("Connnected: %s", connection); + + // Initialize and capture the session context state that will be activated + // during message processing + WebSocketSessionContext sessionContext = sessionContext(container); + SessionContextState sessionContextState = sessionContext.initializeContextState(); + ContextSupport contextSupport = new ContextSupport(connection, sessionContextState, + sessionContext(container), + container.requestContext()); + + // Create an endpoint that delegates callbacks to the @WebSocket bean + WebSocketEndpoint endpoint = createEndpoint(endpointClass, context, connection, codecs, config, + contextSupport); + + // The processor is only needed if Multi is consumed by the @OnMessage callback + BroadcastProcessor broadcastProcessor = endpoint.consumedMultiType() != null + ? BroadcastProcessor.create() + : null; + + // NOTE: we always invoke callbacks (onOpen, onMessage, onClose) on a new duplicated context + // and the endpoint is responsible to make the switch if blocking/virtualThread + + Context onOpenContext = ContextSupport.createNewDuplicatedContext(context, connection); + onOpenContext.runOnContext(new Handler() { + @Override + public void handle(Void event) { + endpoint.onOpen().onComplete(r -> { + if (r.succeeded()) { + LOG.debugf("@OnOpen callback completed: %s", connection); + if (broadcastProcessor != null) { + // If Multi is consumed we need to invoke @OnMessage callback eagerly + // but after @OnOpen completes + Multi multi = broadcastProcessor.onCancellation().call(connection::close); + onOpenContext.runOnContext(new Handler() { + @Override + public void handle(Void event) { + endpoint.onMessage(multi).onComplete(r -> { + if (r.succeeded()) { + LOG.debugf("@OnMessage callback consuming Multi completed: %s", + connection); + } else { + LOG.errorf(r.cause(), + "Unable to complete @OnMessage callback consuming Multi: %s", + connection); + } + }); + } + }); + } + } else { + LOG.errorf(r.cause(), "Unable to complete @OnOpen callback: %s", connection); + } + }); + } + }); + + if (broadcastProcessor == null) { + // Multi not consumed - invoke @OnMessage callback for each message received + messageHandlers(connection, endpoint, ws, context, m -> { + endpoint.onMessage(m).onComplete(r -> { + if (r.succeeded()) { + LOG.debugf("@OnMessage callback consumed binary message: %s", connection); + } else { + LOG.errorf(r.cause(), "Unable to consume binary message in @OnMessage callback: %s", + connection); + } + }); + }, m -> { + endpoint.onMessage(m).onComplete(r -> { + if (r.succeeded()) { + LOG.debugf("@OnMessage callback consumed text message: %s", connection); + } else { + LOG.errorf(r.cause(), "Unable to consume text message in @OnMessage callback: %s", + connection); + } + }); + }, true); + } else { + // Multi consumed - forward message to subcribers + messageHandlers(connection, endpoint, ws, onOpenContext, m -> { + contextSupport.start(); + broadcastProcessor.onNext(endpoint.decodeMultiItem(m)); + LOG.debugf("Binary message >> Multi: %s", connection); + contextSupport.end(false); + }, m -> { + contextSupport.start(); + broadcastProcessor.onNext(endpoint.decodeMultiItem(m)); + LOG.debugf("Text message >> Multi: %s", connection); + contextSupport.end(false); + }, false); + } + + ws.closeHandler(new Handler() { + @Override + public void handle(Void event) { + ContextSupport.createNewDuplicatedContext(context, connection).runOnContext(new Handler() { + @Override + public void handle(Void event) { + endpoint.onClose().onComplete(r -> { + if (r.succeeded()) { + LOG.debugf("@OnClose callback completed: %s", connection); + } else { + LOG.errorf(r.cause(), "Unable to complete @OnClose callback: %s", connection); + } + connectionManager.remove(endpointClass, connection); + }); + } + }); + } + }); + }); + } + }; + } + + private void messageHandlers(WebSocketServerConnection connection, WebSocketEndpoint endpoint, ServerWebSocket ws, + Context context, Consumer binaryAction, Consumer textAction, boolean newDuplicatedContext) { + if (endpoint.consumedMessageType() == MessageType.BINARY) { + ws.binaryMessageHandler(new Handler() { + @Override + public void handle(Buffer message) { + Context duplicatedContext = newDuplicatedContext + ? ContextSupport.createNewDuplicatedContext(context, connection) + : context; + duplicatedContext.runOnContext(new Handler() { + @Override + public void handle(Void event) { + binaryAction.accept(message); + } + }); + } + }); + } else if (endpoint.consumedMessageType() == MessageType.TEXT) { + ws.textMessageHandler(new Handler() { + @Override + public void handle(String message) { + Context duplicatedContext = newDuplicatedContext + ? ContextSupport.createNewDuplicatedContext(context, connection) + : context; + duplicatedContext.runOnContext(new Handler() { + @Override + public void handle(Void event) { + textAction.accept(message); + } + }); + } + }); + } + } + + private WebSocketEndpoint createEndpoint(String endpointClassName, Context context, WebSocketServerConnection connection, + Codecs codecs, WebSocketsRuntimeConfig config, ContextSupport contextSupport) { + try { + ClassLoader cl = Thread.currentThread().getContextClassLoader(); + if (cl == null) { + cl = WebSocketServerRecorder.class.getClassLoader(); + } + @SuppressWarnings("unchecked") + Class endpointClazz = (Class) cl + .loadClass(endpointClassName); + WebSocketEndpoint endpoint = (WebSocketEndpoint) endpointClazz + .getDeclaredConstructor(WebSocketServerConnection.class, Codecs.class, + WebSocketsRuntimeConfig.class, ContextSupport.class) + .newInstance(connection, codecs, config, contextSupport); + return endpoint; + } catch (Exception e) { + throw new WebSocketServerException("Unable to create endpoint instance: " + endpointClassName, e); + } + } + + private static WebSocketSessionContext sessionContext(ArcContainer container) { + for (InjectableContext injectableContext : container.getContexts(SessionScoped.class)) { + if (WebSocketSessionContext.class.equals(injectableContext.getClass())) { + return (WebSocketSessionContext) injectableContext; + } + } + throw new WebSocketServerException("CDI session context not registered"); + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketSessionContext.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketSessionContext.java new file mode 100644 index 0000000000000..b34796e516653 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketSessionContext.java @@ -0,0 +1,266 @@ +package io.quarkus.websockets.next.runtime; + +import java.lang.annotation.Annotation; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.util.Map; +import java.util.Objects; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import jakarta.enterprise.context.BeforeDestroyed; +import jakarta.enterprise.context.ContextNotActiveException; +import jakarta.enterprise.context.Destroyed; +import jakarta.enterprise.context.Initialized; +import jakarta.enterprise.context.SessionScoped; +import jakarta.enterprise.context.spi.Contextual; +import jakarta.enterprise.context.spi.CreationalContext; +import jakarta.enterprise.event.Event; +import jakarta.enterprise.inject.Any; + +import org.jboss.logging.Logger; + +import io.quarkus.arc.Arc; +import io.quarkus.arc.ArcContainer; +import io.quarkus.arc.ContextInstanceHandle; +import io.quarkus.arc.CurrentContext; +import io.quarkus.arc.InjectableBean; +import io.quarkus.arc.ManagedContext; +import io.quarkus.arc.impl.ComputingCacheContextInstances; +import io.quarkus.arc.impl.ContextInstanceHandleImpl; +import io.quarkus.arc.impl.ContextInstances; +import io.quarkus.arc.impl.LazyValue; + +public class WebSocketSessionContext implements ManagedContext { + + private static final Logger LOG = Logger.getLogger(WebSocketSessionContext.class); + + private final LazyValue> currentContext; + private final LazyValue> initializedEvent; + private final LazyValue> beforeDestroyEvent; + private final LazyValue> destroyEvent; + + public WebSocketSessionContext() { + // Use lazy value because no-args constructor is needed + this.currentContext = new LazyValue<>(new Supplier>() { + @Override + public CurrentContext get() { + return Arc.container().getCurrentContextFactory().create(SessionScoped.class); + } + }); + this.initializedEvent = newEvent(Initialized.Literal.SESSION, Any.Literal.INSTANCE); + this.beforeDestroyEvent = newEvent(BeforeDestroyed.Literal.SESSION, Any.Literal.INSTANCE); + this.destroyEvent = newEvent(Destroyed.Literal.SESSION, Any.Literal.INSTANCE); + } + + @Override + public Class getScope() { + return SessionScoped.class; + } + + @Override + public ContextState getState() { + SessionContextState state = currentState(); + if (state == null) { + // Thread local not set - context is not active! + throw notActive(); + } + return state; + } + + @Override + public ContextState activate(ContextState initialState) { + if (initialState == null) { + SessionContextState state = initializeContextState(); + currentContext().set(state); + return state; + } else { + if (initialState instanceof SessionContextState) { + currentContext().set((SessionContextState) initialState); + return initialState; + } else { + throw new IllegalArgumentException("Invalid initial state: " + initialState.getClass().getName()); + } + } + } + + @Override + public void deactivate() { + currentContext().remove(); + } + + @SuppressWarnings("unchecked") + @Override + public T get(Contextual contextual, CreationalContext creationalContext) { + Objects.requireNonNull(contextual, "Contextual must not be null"); + Objects.requireNonNull(creationalContext, "CreationalContext must not be null"); + InjectableBean bean = (InjectableBean) contextual; + if (!SessionScoped.class.getName().equals(bean.getScope().getName())) { + throw invalidScope(); + } + SessionContextState state = currentState(); + if (state == null || !state.isValid()) { + throw notActive(); + } + return (T) state.contextInstances.computeIfAbsent(bean.getIdentifier(), new Supplier>() { + + @Override + public ContextInstanceHandle get() { + return new ContextInstanceHandleImpl<>(bean, contextual.create(creationalContext), creationalContext); + } + }).get(); + } + + @Override + public T get(Contextual contextual) { + Objects.requireNonNull(contextual, "Contextual must not be null"); + InjectableBean bean = (InjectableBean) contextual; + if (!SessionScoped.class.getName().equals(bean.getScope().getName())) { + throw invalidScope(); + } + SessionContextState state = currentState(); + if (state == null || !state.isValid()) { + throw notActive(); + } + @SuppressWarnings("unchecked") + ContextInstanceHandle instance = (ContextInstanceHandle) state.contextInstances + .getIfPresent(bean.getIdentifier()); + return instance == null ? null : instance.get(); + } + + @Override + public boolean isActive() { + SessionContextState contextState = currentState(); + return contextState == null ? false : contextState.isValid(); + } + + @Override + public void destroy() { + destroy(currentState()); + } + + @Override + public void destroy(Contextual contextual) { + SessionContextState state = currentState(); + if (state == null || !state.isValid()) { + throw notActive(); + } + InjectableBean bean = (InjectableBean) contextual; + ContextInstanceHandle instance = state.contextInstances.remove(bean.getIdentifier()); + if (instance != null) { + instance.destroy(); + } + } + + @Override + public void destroy(ContextState state) { + if (state == null) { + // nothing to destroy + return; + } + if (state instanceof SessionContextState) { + SessionContextState sessionState = ((SessionContextState) state); + if (sessionState.invalidate()) { + fireIfNotNull(beforeDestroyEvent.get()); + sessionState.contextInstances.removeEach(ContextInstanceHandle::destroy); + fireIfNotNull(destroyEvent.get()); + } + } else { + throw new IllegalArgumentException("Invalid state implementation: " + state.getClass().getName()); + } + } + + SessionContextState initializeContextState() { + SessionContextState state = new SessionContextState(new ComputingCacheContextInstances()); + fireIfNotNull(initializedEvent.get()); + return state; + } + + private CurrentContext currentContext() { + return currentContext.get(); + } + + private SessionContextState currentState() { + return currentContext().get(); + } + + private IllegalArgumentException invalidScope() { + throw new IllegalArgumentException("The bean does not declare @SessionScoped"); + } + + private ContextNotActiveException notActive() { + return new ContextNotActiveException("Session context is not active"); + } + + private void fireIfNotNull(Event event) { + if (event != null) { + try { + event.fire(toString()); + } catch (Exception e) { + LOG.warn("An error occurred during delivery of the context lifecycle event for " + toString(), e); + } + } + } + + private static LazyValue> newEvent(Annotation... qualifiers) { + return new LazyValue<>(new Supplier>() { + @Override + public Event get() { + ArcContainer container = Arc.container(); + if (container.resolveObserverMethods(Object.class, qualifiers).isEmpty()) { + return null; + } + return container.beanManager().getEvent().select(qualifiers); + } + }); + } + + class SessionContextState implements ContextState { + + // Using 0 as default value enable removing an initialization + // in the constructor, piggybacking on the default value. + // As per https://docs.oracle.com/javase/specs/jls/se8/html/jls-12.html#jls-12.5 + // the default field values are set before 'this' is accessible, hence + // they should be the very first value observable even in presence of + // unsafe publication of this object. + private static final int VALID = 0; + private static final int INVALID = 1; + private static final VarHandle IS_VALID; + + static { + try { + IS_VALID = MethodHandles.lookup().findVarHandle(SessionContextState.class, "isValid", int.class); + } catch (ReflectiveOperationException e) { + throw new Error(e); + } + } + + private final ContextInstances contextInstances; + private volatile int isValid; + + SessionContextState(ContextInstances contextInstances) { + this.contextInstances = contextInstances; + } + + @Override + public Map, Object> getContextualInstances() { + return contextInstances.getAllPresent().stream() + .collect(Collectors.toUnmodifiableMap(ContextInstanceHandle::getBean, ContextInstanceHandle::get)); + } + + /** + * @return {@code true} if the state was successfully invalidated, {@code false} otherwise + */ + boolean invalidate() { + // Atomically sets the value just like AtomicBoolean.compareAndSet(boolean, boolean) + return IS_VALID.compareAndSet(this, VALID, INVALID); + } + + @Override + public boolean isValid() { + return isValid == VALID; + } + + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/resources/META-INF/quarkus-extension.yaml b/extensions/websockets-next/server/runtime/src/main/resources/META-INF/quarkus-extension.yaml new file mode 100644 index 0000000000000..3efa575b24631 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/resources/META-INF/quarkus-extension.yaml @@ -0,0 +1,15 @@ +name: "WebSockets Next" +artifact: ${project.groupId}:${project.artifactId}:${project.version} +metadata: + short-name: "websockets" + keywords: + - "websocket" + - "websockets" + - "web-socket" + - "web-sockets" + - "http" + categories: + - "web" + status: "experimental" + config: + - "quarkus.websockets.next" \ No newline at end of file diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ArcContainer.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ArcContainer.java index 01af040bd7399..7881a05414726 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ArcContainer.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ArcContainer.java @@ -230,4 +230,12 @@ public interface ArcContainer { * @return true is strict mode is enabled, false otherwise. */ boolean strictCompatibility(); + + /** + * + * @param eventType + * @param eventQualifiers + * @return an ordered list of observer methods + */ + List> resolveObserverMethods(Type eventType, Annotation... eventQualifiers); } diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ArcContainerImpl.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ArcContainerImpl.java index 31972ac95db20..94d6c17e59d12 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ArcContainerImpl.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ArcContainerImpl.java @@ -494,6 +494,11 @@ public List> getObservers() { return observers; } + @Override + public List> resolveObserverMethods(Type eventType, Annotation... eventQualifiers) { + return resolveObserverMethods(eventType, Set.of(eventQualifiers)); + } + InstanceHandle getResource(Type type, Set annotations) { for (ResourceReferenceProvider resourceProvider : resourceProviders) { InstanceHandle ret = resourceProvider.get(type, annotations); @@ -871,7 +876,8 @@ private static int compareDefaultBeans(InjectableBean bean1, InjectableBean List> resolveObservers(Type eventType, Set eventQualifiers) { + List> resolveObserverMethods(Type eventType, + Set eventQualifiers) { registeredQualifiers.verify(eventQualifiers); if (observers.isEmpty()) { return Collections.emptyList(); diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/BeanManagerImpl.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/BeanManagerImpl.java index 4e6aeb0bfa9ed..72d4853101f83 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/BeanManagerImpl.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/BeanManagerImpl.java @@ -135,7 +135,7 @@ public Set> resolveObserverMethods(T event, Annota throw new IllegalArgumentException("The runtime type of the event object contains a type variable: " + eventType); } Set eventQualifiers = new HashSet<>(Arrays.asList(qualifiers)); - return new LinkedHashSet<>(ArcContainerImpl.instance().resolveObservers(eventType, eventQualifiers)); + return new LinkedHashSet<>(ArcContainerImpl.instance().resolveObserverMethods(eventType, eventQualifiers)); } @Override diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ComputingCacheContextInstances.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ComputingCacheContextInstances.java index a875191831f7d..89f02ef769030 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ComputingCacheContextInstances.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ComputingCacheContextInstances.java @@ -6,11 +6,11 @@ import io.quarkus.arc.ContextInstanceHandle; -class ComputingCacheContextInstances implements ContextInstances { +public class ComputingCacheContextInstances implements ContextInstances { protected final ComputingCache> instances; - ComputingCacheContextInstances() { + public ComputingCacheContextInstances() { instances = new ComputingCache<>(); } diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/EventImpl.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/EventImpl.java index cd4a70d3f406b..fb7fa562013d0 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/EventImpl.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/EventImpl.java @@ -179,7 +179,7 @@ static Notifier createNotifier(Class runtimeType, Type eventType, Set< normalizedQualifiers.add(Any.Literal.INSTANCE); EventMetadata metadata = new EventMetadataImpl(normalizedQualifiers, eventType, injectionPoint); List> notifierObserverMethods = new ArrayList<>( - container.resolveObservers(eventType, normalizedQualifiers)); + container.resolveObserverMethods(eventType, normalizedQualifiers)); return new Notifier<>(runtimeType, notifierObserverMethods, metadata, activateRequestContext); }