diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ArgumentProvider.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ArgumentProvider.java new file mode 100644 index 00000000000000..e9e00b470d6523 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ArgumentProvider.java @@ -0,0 +1,100 @@ +package io.quarkus.websockets.next.deployment; + +import java.util.Set; + +import org.jboss.jandex.AnnotationInstance; +import org.jboss.jandex.MethodParameterInfo; +import org.jboss.jandex.Type; + +import io.quarkus.gizmo.BytecodeCreator; +import io.quarkus.gizmo.ResultHandle; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.deployment.WebSocketEndpointBuildItem.Callback; + +/** + * Provides arguments for method parameters of a callback method declared on a WebSocket endpoint. + */ +interface ArgumentProvider { + + /** + * + * @param parameterContext + * @return {@code true} if this provider matches the given parameter context, {@code false} otherwise + */ + boolean matches(ParameterContext parameterContext); + + /** + * + * @param parameterContext + * @param callbackContext + * @return the result handle to be passed as an argument to a callback method + */ + ResultHandle get(ParameterContext parameterContext, CallbackContext callbackContext); + + static final int DEFAULT_PRIORITY = 0; + + /** + * + * @return the priority + */ + default int priotity() { + return DEFAULT_PRIORITY; + } + + interface ParameterContext { + + /** + * + * @return the callback + */ + Callback callback(); + + /** + * + * @return the Java method parameter + */ + MethodParameterInfo parameter(); + + /** + * + * @return the set of annotations, potentially transformed + */ + Set paramAnnotations(); + + } + + interface CallbackContext { + + /** + * + * @return the bytecode + */ + BytecodeCreator bytecode(); + + /** + * Obtains the message directly in the bytecode. + * + * @return the message object or {@code null} for {@link OnOpen} and {@link OnClose} callbacks + */ + ResultHandle message(); + + /** + * Attempts to decode the message directly in the bytecode. + * + * @param parameterType + * @return the decoded message object or {@code null} for {@link OnOpen} and {@link OnClose} callbacks + */ + ResultHandle decodeMessage(Type parameterType); + + /** + * Obtains the current connection directly in the bytecode. + * + * @return the current {@link WebSocketConnection}, never {@code null} + */ + ResultHandle connection(); + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ArgumentProviderBuildItem.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ArgumentProviderBuildItem.java new file mode 100644 index 00000000000000..9993e4cbeffcc9 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ArgumentProviderBuildItem.java @@ -0,0 +1,17 @@ +package io.quarkus.websockets.next.deployment; + +import io.quarkus.builder.item.MultiBuildItem; + +final class ArgumentProviderBuildItem extends MultiBuildItem { + + private final ArgumentProvider provider; + + ArgumentProviderBuildItem(ArgumentProvider provider) { + this.provider = provider; + } + + ArgumentProvider getProvider() { + return provider; + } + +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ArgumentProvidersBuildItem.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ArgumentProvidersBuildItem.java new file mode 100644 index 00000000000000..f995a1b474609c --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ArgumentProvidersBuildItem.java @@ -0,0 +1,46 @@ +package io.quarkus.websockets.next.deployment; + +import java.util.ArrayList; +import java.util.List; + +import io.quarkus.builder.item.SimpleBuildItem; +import io.quarkus.websockets.next.deployment.ArgumentProvider.ParameterContext; + +final class ArgumentProvidersBuildItem extends SimpleBuildItem { + + final List sortedProviders; + + ArgumentProvidersBuildItem(List injectors) { + this.sortedProviders = injectors; + } + + /** + * + * @param context + * @return the first matching provider or {@code null} + */ + ArgumentProvider findMatching(ParameterContext context) { + for (ArgumentProvider provider : sortedProviders) { + if (provider.matches(context)) { + return provider; + } + } + return null; + } + + /** + * + * @param context + * @return all matching providers, never {@code null} + */ + List findAllMatching(ParameterContext context) { + List matching = new ArrayList<>(); + for (ArgumentProvider provider : sortedProviders) { + if (provider.matches(context)) { + matching.add(provider); + } + } + return matching; + } + +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ConnectionArgumentProvider.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ConnectionArgumentProvider.java new file mode 100644 index 00000000000000..23e2af22b34c77 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ConnectionArgumentProvider.java @@ -0,0 +1,22 @@ +package io.quarkus.websockets.next.deployment; + +import io.quarkus.gizmo.ResultHandle; + +class ConnectionArgumentProvider implements ArgumentProvider { + + @Override + public boolean matches(ParameterContext parameterContext) { + return parameterContext.parameter().type().name().equals(WebSocketDotNames.WEB_SOCKET_CONNECTION); + } + + @Override + public ResultHandle get(ParameterContext parameterContext, CallbackContext callbackContext) { + return callbackContext.connection(); + } + + @Override + public int priotity() { + return DEFAULT_PRIORITY + 1; + } + +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/MessageArgumentProvider.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/MessageArgumentProvider.java new file mode 100644 index 00000000000000..a2b3c3418349c0 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/MessageArgumentProvider.java @@ -0,0 +1,17 @@ +package io.quarkus.websockets.next.deployment; + +import io.quarkus.gizmo.ResultHandle; + +class MessageArgumentProvider implements ArgumentProvider { + + @Override + public boolean matches(ParameterContext parameterContext) { + return parameterContext.callback().acceptsMessage() && parameterContext.paramAnnotations().isEmpty(); + } + + @Override + public ResultHandle get(ParameterContext parameterContext, CallbackContext callbackContext) { + return callbackContext.decodeMessage(parameterContext.parameter().type()); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java index 58f7920d933b8d..2512578bbce9ef 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java @@ -60,6 +60,14 @@ public Callback(AnnotationInstance annotation, MethodInfo method, ExecutionModel } } + public boolean isOnOpen() { + return annotation.name().equals(WebSocketDotNames.ON_OPEN); + } + + public boolean isOnClose() { + return annotation.name().equals(WebSocketDotNames.ON_CLOSE); + } + public Type returnType() { return method.returnType(); } diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java index a73cb0531725af..6d2259de2b07f0 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java @@ -3,9 +3,11 @@ import static io.quarkus.deployment.annotations.ExecutionTime.RUNTIME_INIT; import java.util.ArrayList; +import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.Consumer; import java.util.function.Function; import java.util.regex.Matcher; @@ -20,6 +22,7 @@ import org.jboss.jandex.DotName; import org.jboss.jandex.IndexView; import org.jboss.jandex.MethodInfo; +import org.jboss.jandex.MethodParameterInfo; import org.jboss.jandex.PrimitiveType; import org.jboss.jandex.Type; import org.jboss.jandex.Type.Kind; @@ -32,7 +35,9 @@ import io.quarkus.arc.deployment.ContextRegistrationPhaseBuildItem.ContextConfiguratorBuildItem; import io.quarkus.arc.deployment.CustomScopeBuildItem; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; +import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem; import io.quarkus.arc.deployment.UnremovableBeanBuildItem; +import io.quarkus.arc.processor.Annotations; import io.quarkus.arc.processor.BeanInfo; import io.quarkus.arc.processor.DotNames; import io.quarkus.arc.processor.Types; @@ -47,6 +52,7 @@ import io.quarkus.gizmo.CatchBlockCreator; import io.quarkus.gizmo.ClassCreator; import io.quarkus.gizmo.ClassOutput; +import io.quarkus.gizmo.FieldDescriptor; import io.quarkus.gizmo.FunctionCreator; import io.quarkus.gizmo.MethodCreator; import io.quarkus.gizmo.MethodDescriptor; @@ -60,6 +66,8 @@ import io.quarkus.websockets.next.WebSocketConnection; import io.quarkus.websockets.next.WebSocketServerException; import io.quarkus.websockets.next.WebSocketsRuntimeConfig; +import io.quarkus.websockets.next.deployment.ArgumentProvider.CallbackContext; +import io.quarkus.websockets.next.deployment.ArgumentProvider.ParameterContext; import io.quarkus.websockets.next.deployment.WebSocketEndpointBuildItem.Callback; import io.quarkus.websockets.next.deployment.WebSocketEndpointBuildItem.Callback.MessageType; import io.quarkus.websockets.next.runtime.Codecs; @@ -104,6 +112,8 @@ void unremovableBeans(BuildProducer unremovableBeans) @BuildStep public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex, BeanDiscoveryFinishedBuildItem beanDiscoveryFinished, + ArgumentProvidersBuildItem argumentProviders, + TransformedAnnotationsBuildItem transformedAnnotations, BuildProducer endpoints) { IndexView index = beanArchiveIndex.getIndex(); @@ -124,15 +134,16 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex, String.format("Multiple endpoints [%s, %s] define the same path: %s", previous, beanClass, path)); } Callback onOpen = findCallback(beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_OPEN, - this::validateOnOpen); + argumentProviders, transformedAnnotations); Callback onTextMessage = findCallback(beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_TEXT_MESSAGE, - this::validateOnTextMessage); + argumentProviders, transformedAnnotations); Callback onBinaryMessage = findCallback(beanArchiveIndex.getIndex(), beanClass, - WebSocketDotNames.ON_BINARY_MESSAGE, - this::validateOnBinaryMessage); + WebSocketDotNames.ON_BINARY_MESSAGE, argumentProviders, transformedAnnotations); Callback onPongMessage = findCallback(beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_PONG_MESSAGE, + argumentProviders, transformedAnnotations, this::validateOnPongMessage); Callback onClose = findCallback(beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_CLOSE, + argumentProviders, transformedAnnotations, this::validateOnClose); if (onOpen == null && onTextMessage == null && onBinaryMessage == null && onPongMessage == null) { throw new WebSocketServerException( @@ -152,8 +163,20 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex, } } + @BuildStep + ArgumentProvidersBuildItem collectParamInjectors(List providers) { + List sorted = new ArrayList<>(); + for (ArgumentProviderBuildItem provider : providers) { + sorted.add(provider.getProvider()); + } + sorted.sort(Comparator.comparingInt(ArgumentProvider::priotity).reversed()); + return new ArgumentProvidersBuildItem(sorted); + } + @BuildStep public void generateEndpoints(List endpoints, + ArgumentProvidersBuildItem argumentProviders, + TransformedAnnotationsBuildItem transformedAnnotations, BuildProducer generatedClasses, BuildProducer generatedEndpoints, BuildProducer reflectiveClasses) { @@ -175,7 +198,7 @@ public String apply(String name) { // A new instance of this generated endpoint is created for each client connection // The generated endpoint ensures the correct execution model is used // and delegates callback invocations to the endpoint bean - String generatedName = generateEndpoint(endpoint, classOutput); + String generatedName = generateEndpoint(endpoint, argumentProviders, transformedAnnotations, classOutput); reflectiveClasses.produce(ReflectiveClassBuildItem.builder(generatedName).constructors().build()); generatedEndpoints.produce(new GeneratedEndpointBuildItem(endpoint.bean.getImplClazz().name().toString(), generatedName, endpoint.path)); @@ -230,6 +253,12 @@ CustomScopeBuildItem registerSessionScope() { return new CustomScopeBuildItem(DotName.createSimple(SessionScoped.class.getName())); } + @BuildStep + void builtinArgumentProviders(BuildProducer providers) { + providers.produce(new ArgumentProviderBuildItem(new MessageArgumentProvider())); + providers.produce(new ArgumentProviderBuildItem(new ConnectionArgumentProvider())); + } + static String mergePath(String prefix, String path) { if (prefix.endsWith("/")) { prefix = prefix.substring(0, prefix.length() - 1); @@ -281,27 +310,6 @@ private String getPathPrefix(IndexView index, DotName enclosingClassName) { return ""; } - private void validateOnOpen(MethodInfo callback) { - if (!callback.parameters().isEmpty()) { - throw new WebSocketServerException( - "@OnOpen callback must not accept any parameters: " + callbackToString(callback)); - } - } - - private void validateOnTextMessage(MethodInfo callback) { - if (callback.parameters().size() != 1) { - throw new WebSocketServerException( - "@OnTextMessage callback must accept exactly one parameter: " + callbackToString(callback)); - } - } - - private void validateOnBinaryMessage(MethodInfo callback) { - if (callback.parameters().size() != 1) { - throw new WebSocketServerException( - "@OnTextMessage callback must accept exactly one parameter: " + callbackToString(callback)); - } - } - private void validateOnPongMessage(MethodInfo callback) { if (callback.returnType().kind() != Kind.VOID && !WebSocketServerProcessor.isUniVoid(callback.returnType())) { throw new WebSocketServerException( @@ -319,10 +327,6 @@ private void validateOnClose(MethodInfo callback) { throw new WebSocketServerException( "@OnClose callback must return void or Uni: " + callbackToString(callback)); } - if (!callback.parameters().isEmpty()) { - throw new WebSocketServerException( - "@OnClose callback must not accept any parameters: " + callbackToString(callback)); - } } /** @@ -360,7 +364,10 @@ private void validateOnClose(MethodInfo callback) { * @param classOutput * @return the name of the generated class */ - private String generateEndpoint(WebSocketEndpointBuildItem endpoint, ClassOutput classOutput) { + private String generateEndpoint(WebSocketEndpointBuildItem endpoint, + ArgumentProvidersBuildItem argumentProviders, + TransformedAnnotationsBuildItem transformedAnnotations, + ClassOutput classOutput) { ClassInfo implClazz = endpoint.bean.getImplClazz(); String baseName; if (implClazz.enclosingClass() != null) { @@ -389,6 +396,7 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, ClassOutput executionMode.returnValue(executionMode.load(endpoint.executionMode)); if (endpoint.onOpen != null) { + Callback callback = endpoint.onOpen; MethodCreator doOnOpen = endpointCreator.getMethodCreator("doOnOpen", Uni.class, Object.class); // Foo foo = beanInstance("foo"); ResultHandle beanInstance = doOnOpen.invokeSpecialMethod( @@ -396,19 +404,22 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, ClassOutput doOnOpen.getThis(), doOnOpen.load(endpoint.bean.getIdentifier())); // Call the business method TryBlock tryBlock = uniFailureTryBlock(doOnOpen); - ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(endpoint.onOpen.method), beanInstance); - encodeAndReturnResult(tryBlock, endpoint.onOpen, ret); + ResultHandle[] args = collectArguments(tryBlock, callback, argumentProviders, transformedAnnotations, + callbackContext(tryBlock, callback)); + ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); + encodeAndReturnResult(tryBlock, callback, ret); MethodCreator onOpenExecutionModel = endpointCreator.getMethodCreator("onOpenExecutionModel", ExecutionModel.class); - onOpenExecutionModel.returnValue(onOpenExecutionModel.load(endpoint.onOpen.executionModel)); + onOpenExecutionModel.returnValue(onOpenExecutionModel.load(callback.executionModel)); } - generateOnMessage(endpointCreator, endpoint, endpoint.onBinaryMessage); - generateOnMessage(endpointCreator, endpoint, endpoint.onTextMessage); - generateOnMessage(endpointCreator, endpoint, endpoint.onPongMessage); + generateOnMessage(endpointCreator, endpoint, endpoint.onBinaryMessage, argumentProviders, transformedAnnotations); + generateOnMessage(endpointCreator, endpoint, endpoint.onTextMessage, argumentProviders, transformedAnnotations); + generateOnMessage(endpointCreator, endpoint, endpoint.onPongMessage, argumentProviders, transformedAnnotations); if (endpoint.onClose != null) { + Callback callback = endpoint.onClose; MethodCreator doOnClose = endpointCreator.getMethodCreator("doOnClose", Uni.class, Object.class); // Foo foo = beanInstance("foo"); ResultHandle beanInstance = doOnClose.invokeSpecialMethod( @@ -416,19 +427,69 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, ClassOutput doOnClose.getThis(), doOnClose.load(endpoint.bean.getIdentifier())); // Call the business method TryBlock tryBlock = uniFailureTryBlock(doOnClose); - ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(endpoint.onClose.method), beanInstance); - encodeAndReturnResult(tryBlock, endpoint.onClose, ret); + ResultHandle[] args = collectArguments(tryBlock, callback, argumentProviders, transformedAnnotations, + callbackContext(tryBlock, callback)); + ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); + encodeAndReturnResult(tryBlock, callback, ret); MethodCreator onCloseExecutionModel = endpointCreator.getMethodCreator("onCloseExecutionModel", ExecutionModel.class); - onCloseExecutionModel.returnValue(onCloseExecutionModel.load(endpoint.onClose.executionModel)); + onCloseExecutionModel.returnValue(onCloseExecutionModel.load(callback.executionModel)); } endpointCreator.close(); return generatedName.replace('/', '.'); } - private void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint, Callback callback) { + private void validateParameters(Callback callback, + ArgumentProvidersBuildItem providers, TransformedAnnotationsBuildItem transformedAnnotations) { + // TODO how do we validate that at most one message argument is defined? + for (MethodParameterInfo parameter : callback.method.parameters()) { + List found = providers + .findAllMatching(parameterContext(callback, parameter, transformedAnnotations)); + if (found.isEmpty()) { + String msg = String.format("Unable to inject @%s callback parameter '%s' declared on %s: no injector found", + DotNames.simpleName(callback.annotation.name()), + parameter.name() != null ? parameter.name() : "#" + parameter.position(), + callbackToString(callback.method)); + throw new WebSocketServerException(msg); + } else if (found.size() > 1) { + if (found.get(0).priotity() == found.get(1).priotity()) { + String msg = String.format( + "Unable to inject @%s callback parameter '%s' declared on %s: ambiguous injectors found: %s", + DotNames.simpleName(callback.annotation.name()), + parameter.name() != null ? parameter.name() : "#" + parameter.position(), + callbackToString(callback.method), + found.stream().map(p -> p.getClass().getSimpleName() + ":" + p.priotity())); + throw new WebSocketServerException(msg); + } + } + } + } + + private ResultHandle[] collectArguments(BytecodeCreator method, Callback callback, + ArgumentProvidersBuildItem providers, TransformedAnnotationsBuildItem transformedAnnotations, + CallbackContext callbackContext) { + List parameters = callback.method.parameters(); + if (parameters.isEmpty()) { + return new ResultHandle[] {}; + } + ResultHandle[] resultHandles = new ResultHandle[parameters.size()]; + int idx = 0; + for (MethodParameterInfo parameter : parameters) { + // At this point we can be sure there's exactly one injector matching + ParameterContext parameterContext = parameterContext(callback, parameter, transformedAnnotations); + resultHandles[idx] = providers + .findMatching(parameterContext) + .get(parameterContext, callbackContext); + idx++; + } + return resultHandles; + } + + private void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint, Callback callback, + ArgumentProvidersBuildItem paramInjectors, + TransformedAnnotationsBuildItem transformedAnnotations) { if (callback == null) { return; } @@ -456,15 +517,10 @@ private void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBu ResultHandle beanInstance = doOnMessage.invokeSpecialMethod( MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), doOnMessage.getThis(), doOnMessage.load(endpoint.bean.getIdentifier())); - ResultHandle[] args; - if (callback.acceptsMessage()) { - args = new ResultHandle[] { decodeMessage(doOnMessage, callback.acceptsBinaryMessage(), - callback.method.parameterType(0), doOnMessage.getMethodParam(0), callback) }; - } else { - args = new ResultHandle[] {}; - } // Call the business method TryBlock tryBlock = uniFailureTryBlock(doOnMessage); + ResultHandle[] args = collectArguments(tryBlock, callback, paramInjectors, transformedAnnotations, + callbackContext(tryBlock, callback)); ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); encodeAndReturnResult(tryBlock, callback, ret); @@ -486,6 +542,57 @@ private void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBu } } + private ParameterContext parameterContext(Callback callback, MethodParameterInfo parameter, + TransformedAnnotationsBuildItem transformedAnnotations) { + return new ParameterContext() { + + @Override + public MethodParameterInfo parameter() { + return parameter; + } + + @Override + public Set paramAnnotations() { + return Annotations.getParameterAnnotations( + transformedAnnotations::getAnnotations, callback.method, parameter.position()); + } + + @Override + public Callback callback() { + return callback; + } + }; + } + + private CallbackContext callbackContext(BytecodeCreator bytecode, Callback callback) { + return new CallbackContext() { + @Override + public BytecodeCreator bytecode() { + return bytecode; + } + + @Override + public ResultHandle message() { + return callback.acceptsMessage() ? bytecode.getMethodParam(0) : null; + } + + @Override + public ResultHandle decodeMessage(Type parameterType) { + return callback.acceptsMessage() + ? WebSocketServerProcessor.this.decodeMessage(bytecode, callback.acceptsBinaryMessage(), parameterType, + message(), callback) + : null; + } + + @Override + public ResultHandle connection() { + return bytecode.readInstanceField( + FieldDescriptor.of(WebSocketEndpointBase.class, "connection", WebSocketConnection.class), + bytecode.getThis()); + } + }; + } + private TryBlock uniFailureTryBlock(BytecodeCreator method) { TryBlock tryBlock = method.tryBlock(); CatchBlockCreator catchBlock = tryBlock.addCatch(Throwable.class); @@ -498,7 +605,7 @@ private TryBlock uniFailureTryBlock(BytecodeCreator method) { return tryBlock; } - private ResultHandle decodeMessage(MethodCreator method, boolean binaryMessage, Type valueType, ResultHandle value, + private ResultHandle decodeMessage(BytecodeCreator method, boolean binaryMessage, Type valueType, ResultHandle value, Callback callback) { if (WebSocketDotNames.MULTI.equals(valueType.name())) { // Multi is decoded at runtime in the recorder @@ -757,6 +864,12 @@ private void encodeAndReturnResult(BytecodeCreator method, Callback callback, Re } private Callback findCallback(IndexView index, ClassInfo beanClass, DotName annotationName, + ArgumentProvidersBuildItem paramInjectors, TransformedAnnotationsBuildItem transformedAnnotations) { + return findCallback(index, beanClass, annotationName, paramInjectors, transformedAnnotations, null); + } + + private Callback findCallback(IndexView index, ClassInfo beanClass, DotName annotationName, + ArgumentProvidersBuildItem paramInjectors, TransformedAnnotationsBuildItem transformedAnnotations, Consumer validator) { ClassInfo aClass = beanClass; List annotations = new ArrayList<>(); @@ -776,8 +889,12 @@ private Callback findCallback(IndexView index, ClassInfo beanClass, DotName anno } else if (annotations.size() == 1) { AnnotationInstance annotation = annotations.get(0); MethodInfo method = annotation.target().asMethod(); - validator.accept(method); - return new Callback(annotation, method, executionModel(method)); + if (validator != null) { + validator.accept(method); + } + Callback callback = new Callback(annotation, method, executionModel(method)); + validateParameters(callback, paramInjectors, transformedAnnotations); + return callback; } throw new WebSocketServerException( String.format("There can be only one callback annotated with %s declared on %s", annotationName, beanClass)); diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/ConnectionArgumentTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/ConnectionArgumentTest.java new file mode 100644 index 00000000000000..357ddfcb78d15d --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/ConnectionArgumentTest.java @@ -0,0 +1,62 @@ +package io.quarkus.websockets.next.test.args; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Vertx; +import io.vertx.core.http.WebSocketConnectOptions; +import io.vertx.core.json.JsonObject; + +public class ConnectionArgumentTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Echo.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testArgument() { + String message = "ok"; + String header = "fool"; + WSClient client = WSClient.create(vertx).connect(new WebSocketConnectOptions().addHeader("X-Test", header), + testUri); + JsonObject reply = client.sendAndAwaitReply(message).toJsonObject(); + assertEquals(header, reply.getString("header"), reply.toString()); + assertEquals(message, reply.getString("message"), reply.toString()); + } + + @WebSocket(path = "/echo") + public static class Echo { + + @OnTextMessage + Uni process(WebSocketConnection connection, String message) throws InterruptedException { + return connection.sendText( + new JsonObject() + .put("id", connection.id()) + .put("message", message) + .put("header", connection.handshakeRequest().header("X-Test"))); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnCloseInvalidArgumentTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnCloseInvalidArgumentTest.java new file mode 100644 index 00000000000000..c5a934eeb6cdfe --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnCloseInvalidArgumentTest.java @@ -0,0 +1,38 @@ +package io.quarkus.websockets.next.test.args; + +import static org.junit.jupiter.api.Assertions.fail; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class OnCloseInvalidArgumentTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Endpoint.class); + }) + .setExpectedException(WebSocketServerException.class); + + @Test + void testInvalidArgument() { + fail(); + } + + @WebSocket(path = "/end") + public static class Endpoint { + + @OnClose + void close(List unsupported) { + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnOpenInvalidArgumentTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnOpenInvalidArgumentTest.java new file mode 100644 index 00000000000000..5f3b9071cf546d --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnOpenInvalidArgumentTest.java @@ -0,0 +1,38 @@ +package io.quarkus.websockets.next.test.args; + +import static org.junit.jupiter.api.Assertions.fail; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class OnOpenInvalidArgumentTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Endpoint.class); + }) + .setExpectedException(WebSocketServerException.class); + + @Test + void testInvalidArgument() { + fail(); + } + + @WebSocket(path = "/end") + public static class Endpoint { + + @OnOpen + void open(List unsupported) { + } + + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java index 29b7d925fc25d2..aa8ac393960319 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java @@ -13,11 +13,15 @@ import java.util.function.Predicate; import java.util.stream.Collectors; +import io.quarkus.vertx.core.runtime.VertxBufferImpl; import io.quarkus.websockets.next.WebSocketConnection; import io.smallrye.mutiny.Uni; import io.smallrye.mutiny.vertx.UniHelper; import io.vertx.core.buffer.Buffer; +import io.vertx.core.buffer.impl.BufferImpl; import io.vertx.core.http.ServerWebSocket; +import io.vertx.core.json.JsonArray; +import io.vertx.core.json.JsonObject; import io.vertx.ext.web.RoutingContext; class WebSocketConnectionImpl implements WebSocketConnection { @@ -75,7 +79,17 @@ public Uni sendBinary(Buffer message) { @Override public Uni sendText(M message) { - return UniHelper.toUni(webSocket.writeTextMessage(codecs.textEncode(message, null).toString())); + String text; + // Use the same conversion rules as defined for the OnTextMessage + if (message instanceof JsonObject || message instanceof JsonArray || message instanceof BufferImpl + || message instanceof VertxBufferImpl) { + text = message.toString(); + } else if (message.getClass().isArray() && message.getClass().arrayType().equals(byte.class)) { + text = Buffer.buffer((byte[]) message).toString(); + } else { + text = codecs.textEncode(message, null); + } + return sendText(text); } @Override diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java index fdb032ae57ca48..8d9620f09c10a2 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java @@ -27,7 +27,8 @@ public abstract class WebSocketEndpointBase implements WebSocketEndpoint { private static final Logger LOG = Logger.getLogger(WebSocketEndpointBase.class); - protected final WebSocketConnection connection; + // Keep this field public - there's a problem with ConnectionArgumentProvider reading the protected field in the test mode + public final WebSocketConnection connection; protected final Codecs codecs;