diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgument.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgument.java index 4981903e4b0979..a69c38e953ad26 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgument.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgument.java @@ -3,12 +3,14 @@ import java.util.Set; import org.jboss.jandex.AnnotationInstance; +import org.jboss.jandex.IndexView; import org.jboss.jandex.MethodParameterInfo; import org.jboss.jandex.Type; import io.quarkus.gizmo.BytecodeCreator; import io.quarkus.gizmo.ResultHandle; import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnError; import io.quarkus.websockets.next.OnOpen; import io.quarkus.websockets.next.WebSocketConnection; import io.quarkus.websockets.next.WebSocketServerException; @@ -53,6 +55,12 @@ interface ParameterContext { */ String endpointPath(); + /** + * + * @return the index that can be used to inspect parameter types + */ + IndexView index(); + /** * * @return the callback marker annotation @@ -88,17 +96,17 @@ interface InvocationBytecodeContext extends ParameterContext { BytecodeCreator bytecode(); /** - * Obtains the message directly in the bytecode. + * Obtains the message or error directly in the bytecode. * - * @return the message object or {@code null} for {@link OnOpen} and {@link OnClose} callbacks + * @return the message/error object or {@code null} for {@link OnOpen} and {@link OnClose} callbacks */ - ResultHandle getMessage(); + ResultHandle getPayload(); /** * Attempts to obtain the decoded message directly in the bytecode. * * @param parameterType - * @return the decoded message object or {@code null} for {@link OnOpen} and {@link OnClose} callbacks + * @return the decoded message object or {@code null} for {@link OnOpen}, {@link OnClose} and {@link OnError} callbacks */ ResultHandle getDecodedMessage(Type parameterType); diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ErrorCallbackArgument.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ErrorCallbackArgument.java new file mode 100644 index 00000000000000..31f26f3f5f0800 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ErrorCallbackArgument.java @@ -0,0 +1,45 @@ +package io.quarkus.websockets.next.deployment; + +import org.jboss.jandex.ClassInfo; +import org.jboss.jandex.DotName; +import org.jboss.jandex.IndexView; + +import io.quarkus.gizmo.ResultHandle; + +class ErrorCallbackArgument implements CallbackArgument { + + @Override + public boolean matches(ParameterContext context) { + return context.callbackAnnotation().name().equals(WebSocketDotNames.ON_ERROR) + && isThrowable(context.index(), context.parameter().type().name()); + } + + @Override + public ResultHandle get(InvocationBytecodeContext context) { + return context.getPayload(); + } + + boolean isThrowable(IndexView index, DotName clazzName) { + if (clazzName.equals(WebSocketDotNames.THROWABLE)) { + return true; + } + ClassInfo clazz = index.getClassByName(clazzName); + if (clazz == null) { + throw new IllegalArgumentException("The class " + clazzName + " not found in the index"); + } + if (clazz.superName().equals(DotName.OBJECT_NAME) + || clazz.superName().equals(DotName.RECORD_NAME) + || clazz.superName().equals(DotName.ENUM_NAME)) { + return false; + } + if (clazz.superName().equals(WebSocketDotNames.THROWABLE)) { + return true; + } + return isThrowable(index, clazz.superName()); + } + + public static boolean isError(CallbackArgument callbackArgument) { + return callbackArgument instanceof ErrorCallbackArgument; + } + +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/MessageCallbackArgument.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/MessageCallbackArgument.java index 45e92e857cf391..ae32e79da87b12 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/MessageCallbackArgument.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/MessageCallbackArgument.java @@ -19,4 +19,8 @@ public int priotity() { return DEFAULT_PRIORITY - 1; } + public static boolean isMessage(CallbackArgument callbackArgument) { + return callbackArgument instanceof MessageCallbackArgument; + } + } diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketDotNames.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketDotNames.java index c6803b6d7baf13..72b18bc9f29cb8 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketDotNames.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketDotNames.java @@ -4,6 +4,7 @@ import io.quarkus.websockets.next.OnBinaryMessage; import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnError; import io.quarkus.websockets.next.OnOpen; import io.quarkus.websockets.next.OnPongMessage; import io.quarkus.websockets.next.OnTextMessage; @@ -27,6 +28,7 @@ final class WebSocketDotNames { static final DotName ON_BINARY_MESSAGE = DotName.createSimple(OnBinaryMessage.class); static final DotName ON_PONG_MESSAGE = DotName.createSimple(OnPongMessage.class); static final DotName ON_CLOSE = DotName.createSimple(OnClose.class); + static final DotName ON_ERROR = DotName.createSimple(OnError.class); static final DotName UNI = DotName.createSimple(Uni.class); static final DotName MULTI = DotName.createSimple(Multi.class); static final DotName RUN_ON_VIRTUAL_THREAD = DotName.createSimple(RunOnVirtualThread.class); @@ -38,4 +40,5 @@ final class WebSocketDotNames { static final DotName VOID = DotName.createSimple(Void.class); static final DotName PATH_PARAM = DotName.createSimple(PathParam.class); static final DotName HANDSHAKE_REQUEST = DotName.createSimple(WebSocketConnection.HandshakeRequest.class); + static final DotName THROWABLE = DotName.createSimple(Throwable.class); } diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java index 90064c89aa95f1..4d703d8ce4f258 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java @@ -3,10 +3,12 @@ import java.util.ArrayList; import java.util.List; import java.util.Set; +import java.util.function.Predicate; import org.jboss.jandex.AnnotationInstance; import org.jboss.jandex.AnnotationValue; import org.jboss.jandex.DotName; +import org.jboss.jandex.IndexView; import org.jboss.jandex.MethodInfo; import org.jboss.jandex.MethodParameterInfo; import org.jboss.jandex.Type; @@ -41,9 +43,11 @@ public final class WebSocketEndpointBuildItem extends MultiBuildItem { public final Callback onBinaryMessage; public final Callback onPongMessage; public final Callback onClose; + public final List onErrors; - public WebSocketEndpointBuildItem(BeanInfo bean, String path, WebSocket.ExecutionMode executionMode, Callback onOpen, - Callback onTextMessage, Callback onBinaryMessage, Callback onPongMessage, Callback onClose) { + WebSocketEndpointBuildItem(BeanInfo bean, String path, WebSocket.ExecutionMode executionMode, Callback onOpen, + Callback onTextMessage, Callback onBinaryMessage, Callback onPongMessage, Callback onClose, + List onErrors) { this.bean = bean; this.path = path; this.executionMode = executionMode; @@ -52,6 +56,7 @@ public WebSocketEndpointBuildItem(BeanInfo bean, String path, WebSocket.Executio this.onBinaryMessage = onBinaryMessage; this.onPongMessage = onPongMessage; this.onClose = onClose; + this.onErrors = onErrors; } public static class Callback { @@ -64,7 +69,7 @@ public static class Callback { public Callback(AnnotationInstance annotation, MethodInfo method, ExecutionModel executionModel, CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, - String endpointPath) { + String endpointPath, IndexView index) { this.method = method; this.annotation = annotation; this.executionModel = executionModel; @@ -77,7 +82,8 @@ public Callback(AnnotationInstance annotation, MethodInfo method, ExecutionModel } else { this.messageType = MessageType.NONE; } - this.arguments = collectArguments(annotation, method, callbackArguments, transformedAnnotations, endpointPath); + this.arguments = collectArguments(annotation, method, callbackArguments, transformedAnnotations, endpointPath, + index); } public boolean isOnOpen() { @@ -88,6 +94,10 @@ public boolean isOnClose() { return annotation.name().equals(WebSocketDotNames.ON_CLOSE); } + public boolean isOnError() { + return annotation.name().equals(WebSocketDotNames.ON_ERROR); + } + public Type returnType() { return method.returnType(); } @@ -153,21 +163,8 @@ public enum MessageType { BINARY } - public List messageArguments() { - if (arguments.isEmpty()) { - return List.of(); - } - List ret = new ArrayList<>(); - for (CallbackArgument arg : arguments) { - if (arg instanceof MessageCallbackArgument) { - ret.add(arg); - } - } - return ret; - } - public ResultHandle[] generateArguments(BytecodeCreator bytecode, - TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath) { + TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath, IndexView index) { if (arguments.isEmpty()) { return new ResultHandle[] {}; } @@ -176,7 +173,7 @@ public ResultHandle[] generateArguments(BytecodeCreator bytecode, for (CallbackArgument argument : arguments) { resultHandles[idx] = argument.get( invocationBytecodeContext(annotation, method.parameters().get(idx), transformedAnnotations, - endpointPath, bytecode)); + endpointPath, index, bytecode)); idx++; } return resultHandles; @@ -184,7 +181,7 @@ public ResultHandle[] generateArguments(BytecodeCreator bytecode, static List collectArguments(AnnotationInstance annotation, MethodInfo method, CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, - String endpointPath) { + String endpointPath, IndexView index) { List parameters = method.parameters(); if (parameters.isEmpty()) { return List.of(); @@ -192,7 +189,7 @@ static List collectArguments(AnnotationInstance annotation, Me List arguments = new ArrayList<>(parameters.size()); for (MethodParameterInfo parameter : parameters) { List found = callbackArguments - .findMatching(parameterContext(annotation, parameter, transformedAnnotations, endpointPath)); + .findMatching(parameterContext(annotation, parameter, transformedAnnotations, endpointPath, index)); if (found.isEmpty()) { String msg = String.format("Unable to inject @%s callback parameter '%s' declared on %s: no injector found", DotNames.simpleName(annotation.name()), @@ -210,11 +207,21 @@ static List collectArguments(AnnotationInstance annotation, Me } arguments.add(found.get(0)); } - return arguments; + return List.copyOf(arguments); + } + + Type argumentType(Predicate filter) { + int idx = 0; + for (int i = 0; i < arguments.size(); i++) { + if (filter.test(arguments.get(idx))) { + return method.parameterType(i); + } + } + return null; } static ParameterContext parameterContext(AnnotationInstance callbackAnnotation, MethodParameterInfo parameter, - TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath) { + TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath, IndexView index) { return new ParameterContext() { @Override @@ -238,11 +245,17 @@ public String endpointPath() { return endpointPath; } + @Override + public IndexView index() { + return index; + } + }; } private InvocationBytecodeContext invocationBytecodeContext(AnnotationInstance callbackAnnotation, MethodParameterInfo parameter, TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath, + IndexView index, BytecodeCreator bytecode) { return new InvocationBytecodeContext() { @@ -267,21 +280,28 @@ public String endpointPath() { return endpointPath; } + @Override + public IndexView index() { + return index; + } + @Override public BytecodeCreator bytecode() { return bytecode; } @Override - public ResultHandle getMessage() { - return acceptsMessage() ? bytecode.getMethodParam(0) : null; + public ResultHandle getPayload() { + return acceptsMessage() || callbackAnnotation.name().equals(WebSocketDotNames.ON_ERROR) + ? bytecode.getMethodParam(0) + : null; } @Override public ResultHandle getDecodedMessage(Type parameterType) { return acceptsMessage() ? WebSocketServerProcessor.decodeMessage(bytecode, acceptsBinaryMessage(), parameterType, - getMessage(), Callback.this) + getPayload(), Callback.this) : null; } diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java index 9654f58d5a1414..c0080fe79aadbe 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java @@ -153,7 +153,8 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex, onTextMessage, onBinaryMessage, onPongMessage, - onClose)); + onClose, + findErrorHandlers(index, beanClass, argumentProviders, transformedAnnotations, path))); } } } @@ -169,7 +170,7 @@ CallbackArgumentsBuildItem collectCallbackArguments(List endpoints, + public void generateEndpoints(BeanArchiveIndexBuildItem index, List endpoints, CallbackArgumentsBuildItem argumentProviders, TransformedAnnotationsBuildItem transformedAnnotations, BuildProducer generatedClasses, @@ -193,7 +194,8 @@ public String apply(String name) { // A new instance of this generated endpoint is created for each client connection // The generated endpoint ensures the correct execution model is used // and delegates callback invocations to the endpoint bean - String generatedName = generateEndpoint(endpoint, argumentProviders, transformedAnnotations, classOutput); + String generatedName = generateEndpoint(endpoint, argumentProviders, transformedAnnotations, + index.getIndex(), classOutput); reflectiveClasses.produce(ReflectiveClassBuildItem.builder(generatedName).constructors().build()); generatedEndpoints.produce(new GeneratedEndpointBuildItem(endpoint.bean.getImplClazz().name().toString(), generatedName, endpoint.path)); @@ -254,6 +256,7 @@ void builtinCallbackArguments(BuildProducer providers providers.produce(new CallbackArgumentBuildItem(new ConnectionCallbackArgument())); providers.produce(new CallbackArgumentBuildItem(new PathParamCallbackArgument())); providers.produce(new CallbackArgumentBuildItem(new HandshakeRequestCallbackArgument())); + providers.produce(new CallbackArgumentBuildItem(new ErrorCallbackArgument())); } static String mergePath(String prefix, String path) { @@ -312,13 +315,12 @@ private void validateOnPongMessage(Callback callback) { throw new WebSocketServerException( "@OnPongMessage callback must return void or Uni: " + callbackToString(callback.method)); } - // TODO validate message arguments - // List> messageArguments = getMessageArguments(providers); - // if (messageArguments.size() != 1 || !messageArguments.get(0).getKey().type().name().equals(WebSocketDotNames.BUFFER)) { - // throw new WebSocketServerException( - // "@OnPongMessage callback must accept exactly one message parameter of type io.vertx.core.buffer.Buffer: " - // + callbackToString(callback.method)); - // } + Type messageType = callback.argumentType(MessageCallbackArgument::isMessage); + if (!messageType.name().equals(WebSocketDotNames.BUFFER)) { + throw new WebSocketServerException( + "@OnPongMessage callback must accept exactly one message parameter of type io.vertx.core.buffer.Buffer: " + + callbackToString(callback.method)); + } } private void validateOnClose(Callback callback) { @@ -366,6 +368,7 @@ private void validateOnClose(Callback callback) { private String generateEndpoint(WebSocketEndpointBuildItem endpoint, CallbackArgumentsBuildItem argumentProviders, TransformedAnnotationsBuildItem transformedAnnotations, + IndexView index, ClassOutput classOutput) { ClassInfo implClazz = endpoint.bean.getImplClazz(); String baseName; @@ -402,8 +405,8 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), doOnOpen.getThis(), doOnOpen.load(endpoint.bean.getIdentifier())); // Call the business method - TryBlock tryBlock = uniFailureTryBlock(doOnOpen); - ResultHandle[] args = callback.generateArguments(tryBlock, transformedAnnotations, endpoint.path); + TryBlock tryBlock = onErrorTryBlock(doOnOpen); + ResultHandle[] args = callback.generateArguments(tryBlock, transformedAnnotations, endpoint.path, index); ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); encodeAndReturnResult(tryBlock, callback, ret); @@ -412,9 +415,10 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, onOpenExecutionModel.returnValue(onOpenExecutionModel.load(callback.executionModel)); } - generateOnMessage(endpointCreator, endpoint, endpoint.onBinaryMessage, argumentProviders, transformedAnnotations); - generateOnMessage(endpointCreator, endpoint, endpoint.onTextMessage, argumentProviders, transformedAnnotations); - generateOnMessage(endpointCreator, endpoint, endpoint.onPongMessage, argumentProviders, transformedAnnotations); + generateOnMessage(endpointCreator, endpoint, endpoint.onBinaryMessage, argumentProviders, transformedAnnotations, + index); + generateOnMessage(endpointCreator, endpoint, endpoint.onTextMessage, argumentProviders, transformedAnnotations, index); + generateOnMessage(endpointCreator, endpoint, endpoint.onPongMessage, argumentProviders, transformedAnnotations, index); if (endpoint.onClose != null) { Callback callback = endpoint.onClose; @@ -424,8 +428,8 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), doOnClose.getThis(), doOnClose.load(endpoint.bean.getIdentifier())); // Call the business method - TryBlock tryBlock = uniFailureTryBlock(doOnClose); - ResultHandle[] args = callback.generateArguments(tryBlock, transformedAnnotations, endpoint.path); + TryBlock tryBlock = onErrorTryBlock(doOnClose); + ResultHandle[] args = callback.generateArguments(tryBlock, transformedAnnotations, endpoint.path, index); ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); encodeAndReturnResult(tryBlock, callback, ret); @@ -434,13 +438,87 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, onCloseExecutionModel.returnValue(onCloseExecutionModel.load(callback.executionModel)); } + generateOnError(endpointCreator, endpoint, argumentProviders, transformedAnnotations, index); + endpointCreator.close(); return generatedName.replace('/', '.'); } + private void generateOnError(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint, + CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, + IndexView index) { + if (endpoint.onErrors.isEmpty()) { + return; + } + MethodCreator doOnError = endpointCreator.getMethodCreator("doOnError", Uni.class, Throwable.class); + + Map errors = new HashMap<>(); + List throwableInfos = new ArrayList<>(); + for (Callback callback : endpoint.onErrors) { + DotName errorTypeName = callback.argumentType(ErrorCallbackArgument::isError).name(); + if (errors.containsKey(errorTypeName)) { + throw new WebSocketServerException(String.format( + "Multiple @OnError callback may not accept the same error parameter: %s\n\t- %s\n\t- %s", errorTypeName, + callbackToString(callback.method), callbackToString(errors.get(errorTypeName).method))); + } + errors.put(errorTypeName, callback); + throwableInfos.add(new ThrowableInfo(callback, throwableHierarchy(errorTypeName, index))); + } + // Most specific errors go first + throwableInfos.sort(Comparator.comparingInt(ThrowableInfo::level).reversed()); + + for (ThrowableInfo throwableInfo : throwableInfos) { + BytecodeCreator isInstanceOfThrowable = doOnError + .ifTrue(doOnError.instanceOf(doOnError.getMethodParam(0), throwableInfo.hierarchy.get(0).toString())) + .trueBranch(); + Callback callback = throwableInfo.callback; + ResultHandle beanInstance = isInstanceOfThrowable.invokeSpecialMethod( + MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), + isInstanceOfThrowable.getThis(), isInstanceOfThrowable.load(endpoint.bean.getIdentifier())); + // Call the business method + TryBlock tryBlock = uniFailureTryBlock(isInstanceOfThrowable); + ResultHandle[] args = callback.generateArguments(tryBlock, transformedAnnotations, endpoint.path, index); + ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); + encodeAndReturnResult(tryBlock, callback, ret); + } + + ResultHandle uniCreate = doOnError + .invokeStaticInterfaceMethod(MethodDescriptor.ofMethod(Uni.class, "createFrom", UniCreate.class)); + doOnError.returnValue(doOnError.invokeVirtualMethod( + MethodDescriptor.ofMethod(UniCreate.class, "failure", Uni.class, Throwable.class), uniCreate, + doOnError.getMethodParam(0))); + } + + private List throwableHierarchy(DotName throwableName, IndexView index) { + // TextDecodeException -> [TextDecodeException, WebSocketServerException, RuntimeException, Exception, Throwable] + List ret = new ArrayList<>(); + addToThrowableHierarchy(throwableName, index, ret); + return ret; + } + + private void addToThrowableHierarchy(DotName throwableName, IndexView index, List hierarchy) { + hierarchy.add(throwableName); + ClassInfo errorClass = index.getClassByName(throwableName); + if (errorClass == null) { + throw new IllegalArgumentException("The class " + throwableName + " not found in the index"); + } + if (errorClass.superName().equals(DotName.OBJECT_NAME)) { + return; + } + addToThrowableHierarchy(errorClass.superName(), index, hierarchy); + } + + record ThrowableInfo(Callback callback, List hierarchy) { + + public int level() { + return hierarchy.size(); + } + + } + private void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint, Callback callback, - CallbackArgumentsBuildItem paramInjectors, - TransformedAnnotationsBuildItem transformedAnnotations) { + CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, + IndexView index) { if (callback == null) { return; } @@ -469,8 +547,8 @@ private void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBu MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), doOnMessage.getThis(), doOnMessage.load(endpoint.bean.getIdentifier())); // Call the business method - TryBlock tryBlock = uniFailureTryBlock(doOnMessage); - ResultHandle[] args = callback.generateArguments(tryBlock, transformedAnnotations, endpoint.path); + TryBlock tryBlock = onErrorTryBlock(doOnMessage); + ResultHandle[] args = callback.generateArguments(tryBlock, transformedAnnotations, endpoint.path, index); ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); encodeAndReturnResult(tryBlock, callback, ret); @@ -504,6 +582,16 @@ private TryBlock uniFailureTryBlock(BytecodeCreator method) { return tryBlock; } + private TryBlock onErrorTryBlock(BytecodeCreator method) { + TryBlock tryBlock = method.tryBlock(); + CatchBlockCreator catchBlock = tryBlock.addCatch(Throwable.class); + // return doOnError(t); + catchBlock.returnValue(catchBlock.invokeVirtualMethod( + MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "doOnError", Uni.class, Throwable.class), + catchBlock.getThis(), catchBlock.getCaughtException())); + return tryBlock; + } + static ResultHandle decodeMessage(BytecodeCreator method, boolean binaryMessage, Type valueType, ResultHandle value, Callback callback) { if (WebSocketDotNames.MULTI.equals(valueType.name())) { @@ -762,16 +850,32 @@ private void encodeAndReturnResult(BytecodeCreator method, Callback callback, Re } } - private Callback findCallback(IndexView index, ClassInfo beanClass, DotName annotationName, - CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, + private List findErrorHandlers(IndexView index, ClassInfo beanClass, CallbackArgumentsBuildItem callbackArguments, + TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath) { - return findCallback(index, beanClass, annotationName, callbackArguments, transformedAnnotations, endpointPath, null); + List annotations = findCallbackAnnotations(index, beanClass, WebSocketDotNames.ON_ERROR); + if (annotations.isEmpty()) { + return List.of(); + } + List errorHandlers = new ArrayList<>(); + for (AnnotationInstance annotation : annotations) { + MethodInfo method = annotation.target().asMethod(); + Callback callback = new Callback(annotation, method, executionModel(method), callbackArguments, + transformedAnnotations, endpointPath, index); + long errorArguments = callback.arguments.stream().filter(ca -> ca instanceof ErrorCallbackArgument).count(); + if (errorArguments != 1) { + throw new WebSocketServerException( + String.format("@OnError callback must accept exactly one error parameter; found %s: %s", + DotNames.simpleName(callback.annotation.name()), + errorArguments, + callbackToString(callback.method))); + } + errorHandlers.add(callback); + } + return errorHandlers; } - private Callback findCallback(IndexView index, ClassInfo beanClass, DotName annotationName, - CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, - String endpointPath, - Consumer validator) { + private List findCallbackAnnotations(IndexView index, ClassInfo beanClass, DotName annotationName) { ClassInfo aClass = beanClass; List annotations = new ArrayList<>(); while (aClass != null) { @@ -784,15 +888,28 @@ private Callback findCallback(IndexView index, ClassInfo beanClass, DotName anno ? index.getClassByName(superName) : null; } + return annotations; + } + private Callback findCallback(IndexView index, ClassInfo beanClass, DotName annotationName, + CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, + String endpointPath) { + return findCallback(index, beanClass, annotationName, callbackArguments, transformedAnnotations, endpointPath, null); + } + + private Callback findCallback(IndexView index, ClassInfo beanClass, DotName annotationName, + CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, + String endpointPath, + Consumer validator) { + List annotations = findCallbackAnnotations(index, beanClass, annotationName); if (annotations.isEmpty()) { return null; } else if (annotations.size() == 1) { AnnotationInstance annotation = annotations.get(0); MethodInfo method = annotation.target().asMethod(); Callback callback = new Callback(annotation, method, executionModel(method), callbackArguments, - transformedAnnotations, endpointPath); - int messageArguments = callback.messageArguments().size(); + transformedAnnotations, endpointPath, index); + long messageArguments = callback.arguments.stream().filter(ca -> ca instanceof MessageCallbackArgument).count(); if (callback.acceptsMessage()) { if (messageArguments > 1) { throw new WebSocketServerException( diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/BinaryDecodeErrorTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/BinaryDecodeErrorTest.java new file mode 100644 index 00000000000000..43d8c7035eda27 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/BinaryDecodeErrorTest.java @@ -0,0 +1,60 @@ +package io.quarkus.websockets.next.test.errors; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.BinaryDecodeException; +import io.quarkus.websockets.next.OnBinaryMessage; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; + +public class BinaryDecodeErrorTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Echo.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testError() { + WSClient client = WSClient.create(vertx).connect(testUri); + client.send(Buffer.buffer("1")); + client.waitForMessages(1); + assertEquals("Problem decoding: 1", client.getLastMessage().toString()); + } + + @WebSocket(path = "/echo") + public static class Echo { + + @OnBinaryMessage + void process(WebSocketConnection connection, Integer message) { + throw new IllegalStateException(); + } + + @OnError + String decodingError(BinaryDecodeException e) { + return "Problem decoding: " + e.getBytes().toString(); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/BinaryEncodeErrorTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/BinaryEncodeErrorTest.java new file mode 100644 index 00000000000000..f0d6e0f480a017 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/BinaryEncodeErrorTest.java @@ -0,0 +1,60 @@ +package io.quarkus.websockets.next.test.errors; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.BinaryEncodeException; +import io.quarkus.websockets.next.OnBinaryMessage; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; + +public class BinaryEncodeErrorTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Echo.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testError() { + WSClient client = WSClient.create(vertx).connect(testUri); + client.send(Buffer.buffer("1")); + client.waitForMessages(1); + assertEquals("Problem encoding: 1", client.getLastMessage().toString()); + } + + @WebSocket(path = "/echo") + public static class Echo { + + @OnBinaryMessage + Integer process(WebSocketConnection connection, Buffer message) { + return Integer.parseInt(message.toString()); + } + + @OnError + String encodingError(BinaryEncodeException e) { + return "Problem encoding: " + e.getEncodedObject().toString(); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultipleAmbiguousErrorHandlersTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultipleAmbiguousErrorHandlersTest.java new file mode 100644 index 00000000000000..cbd863f0babcb6 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultipleAmbiguousErrorHandlersTest.java @@ -0,0 +1,45 @@ +package io.quarkus.websockets.next.test.errors; + +import static org.junit.jupiter.api.Assertions.fail; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class MultipleAmbiguousErrorHandlersTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Endpoint.class); + }) + .setExpectedException(WebSocketServerException.class); + + @Test + void testMultipleAmbiguousErrorHandlers() { + fail(); + } + + @WebSocket(path = "/end") + public static class Endpoint { + + @OnOpen + void open() { + } + + @OnError + void onError1(IllegalStateException ise) { + } + + @OnError + void onError2(IllegalStateException ise) { + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/RuntimeErrorCloseConnectionTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/RuntimeErrorCloseConnectionTest.java new file mode 100644 index 00000000000000..b19fc4c9a0597b --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/RuntimeErrorCloseConnectionTest.java @@ -0,0 +1,59 @@ +package io.quarkus.websockets.next.test.errors; + +import java.net.URI; +import java.time.Duration; + +import jakarta.inject.Inject; + +import org.awaitility.Awaitility; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnBinaryMessage; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; + +public class RuntimeErrorCloseConnectionTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Echo.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testError() { + WSClient client = WSClient.create(vertx).connect(testUri); + client.sendAndAwait(Buffer.buffer("1")); + Awaitility.await().atMost(Duration.ofSeconds(5)).until(() -> client.isClosed()); + } + + @WebSocket(path = "/echo") + public static class Echo { + + @OnBinaryMessage + void process(Buffer message) { + throw new IllegalStateException("Something went wrong"); + } + + @OnError + Uni runtimeProblem(RuntimeException e, WebSocketConnection connection) { + return connection.close(); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/RuntimeErrorTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/RuntimeErrorTest.java new file mode 100644 index 00000000000000..e875ffc14b28ee --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/RuntimeErrorTest.java @@ -0,0 +1,77 @@ +package io.quarkus.websockets.next.test.errors; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.BinaryDecodeException; +import io.quarkus.websockets.next.BinaryEncodeException; +import io.quarkus.websockets.next.OnBinaryMessage; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; + +public class RuntimeErrorTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Echo.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testError() { + WSClient client = WSClient.create(vertx).connect(testUri); + client.send(Buffer.buffer("1")); + client.waitForMessages(1); + assertEquals("Something went wrong", client.getLastMessage().toString()); + } + + @WebSocket(path = "/echo") + public static class Echo { + + @OnBinaryMessage + void process(WebSocketConnection connection, Buffer message) { + throw new IllegalStateException("Something went wrong"); + } + + @OnError + String encodingError(BinaryEncodeException e) { + return "Problem encoding: " + e.getEncodedObject().toString(); + } + + @OnError + String decodingError(BinaryDecodeException e) { + return "Problem decoding: " + e.getBytes().toString(); + } + + @OnError + Uni runtimeProblem(RuntimeException e, WebSocketConnection connection) { + return connection.sendText(e.getMessage()); + } + + @OnError + String catchAll(Throwable e) { + return "Ooops!"; + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/TextDecodeErrorTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/TextDecodeErrorTest.java new file mode 100644 index 00000000000000..ad415548ab286d --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/TextDecodeErrorTest.java @@ -0,0 +1,71 @@ +package io.quarkus.websockets.next.test.errors; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.TextDecodeException; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class TextDecodeErrorTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Echo.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testError() { + WSClient client = WSClient.create(vertx).connect(testUri); + client.send("not a json"); + client.waitForMessages(1); + assertEquals("Problem decoding: not a json", client.getLastMessage().toString()); + } + + @WebSocket(path = "/echo") + public static class Echo { + + @OnTextMessage + void process(Pojo pojo) { + } + + @OnError + String decodingError(TextDecodeException e) { + return "Problem decoding: " + e.getText(); + } + + } + + public static class Pojo { + + private String name; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/TextEncodeErrorTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/TextEncodeErrorTest.java new file mode 100644 index 00000000000000..3b76f8a540bfa1 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/TextEncodeErrorTest.java @@ -0,0 +1,103 @@ +package io.quarkus.websockets.next.test.errors; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.lang.reflect.Type; +import java.net.URI; + +import jakarta.annotation.Priority; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.TextEncodeException; +import io.quarkus.websockets.next.TextMessageCodec; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; +import io.vertx.core.json.JsonObject; + +public class TextEncodeErrorTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Echo.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testError() { + WSClient client = WSClient.create(vertx).connect(testUri); + client.send(new JsonObject().put("name", "Fixa").encode()); + client.waitForMessages(1); + assertEquals("java.lang.IllegalArgumentException:Fixa", client.getLastMessage().toString()); + } + + @WebSocket(path = "/echo") + public static class Echo { + + @OnTextMessage(outputCodec = BadCodec.class) + Pojo process(Pojo pojo) { + return pojo; + } + + @OnError + String encodingError(TextEncodeException e) { + return e.getCause().toString() + ":" + e.getEncodedObject().toString(); + } + + } + + @Priority(-1) // Let the JsonTextMessageCodec decode the pojo + @Singleton + public static class BadCodec implements TextMessageCodec { + + @Override + public boolean supports(Type type) { + return type.equals(Pojo.class); + } + + @Override + public String encode(Pojo value) { + throw new IllegalArgumentException(); + } + + @Override + public Pojo decode(Type type, String value) { + throw new UnsupportedOperationException(); + } + + } + + public static class Pojo { + + private String name; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + @Override + public String toString() { + return name; + } + + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryDecodeException.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryDecodeException.java new file mode 100644 index 00000000000000..897092042bfb19 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryDecodeException.java @@ -0,0 +1,30 @@ +package io.quarkus.websockets.next; + +import io.smallrye.common.annotation.Experimental; +import io.vertx.core.buffer.Buffer; + +/** + * + * @see BinaryMessageCodec + */ +@Experimental("This API is experimental and may change in the future") +public class BinaryDecodeException extends WebSocketServerException { + + private static final long serialVersionUID = 6814319993301938091L; + + private final Buffer bytes; + + public BinaryDecodeException(Buffer bytes, String message) { + this(bytes, message, null); + } + + public BinaryDecodeException(Buffer bytes, String message, Throwable cause) { + super(message, cause); + this.bytes = bytes; + } + + public Buffer getBytes() { + return bytes; + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryEncodeException.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryEncodeException.java new file mode 100644 index 00000000000000..74eb0a425a7a96 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryEncodeException.java @@ -0,0 +1,29 @@ +package io.quarkus.websockets.next; + +import io.smallrye.common.annotation.Experimental; + +/** + * + * @see BinaryMessageCodec + */ +@Experimental("This API is experimental and may change in the future") +public class BinaryEncodeException extends WebSocketServerException { + + private static final long serialVersionUID = -8042792962717461873L; + + private final Object encodedObject; + + public BinaryEncodeException(Object encodedObject, String message) { + this(encodedObject, message, null); + } + + public BinaryEncodeException(Object encodedObject, String message, Throwable cause) { + super(message, cause); + this.encodedObject = encodedObject; + } + + public Object getEncodedObject() { + return encodedObject; + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnBinaryMessage.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnBinaryMessage.java index de0102513dc513..ee8957a0632f68 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnBinaryMessage.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnBinaryMessage.java @@ -6,18 +6,17 @@ import java.lang.annotation.Retention; import java.lang.annotation.Target; +import io.quarkus.websockets.next.WebSocketConnection.HandshakeRequest; import io.smallrye.common.annotation.Experimental; /** * A {@link WebSocket} endpoint method annotated with this annotation consumes binary messages. *

- * An endpoint may declare at most one method annotated with this annotation. - *

- * A binary message is always represented as a {@link io.vertx.core.buffer.Buffer}. Therefore, the following conversion rules + * The method must accept exactly one message parameter. A binary message is always represented as a + * {@link io.vertx.core.buffer.Buffer}. Therefore, the following conversion rules * apply. The types listed below are handled specifically. For all other types a {@link BinaryMessageCodec} is used to encode * and decode input and output messages. By default, the first input codec that supports the message type is used; codecs with * higher priority go first. However, a specific codec can be selected with {@link #codec()} and {@link #outputCodec()}. - * *

    *
  • {@code java.lang.Buffer} is used as is,
  • *
  • {@code byte[]} is encoded with {@link io.vertx.core.buffer.Buffer#buffer(byte[])} and decoded with @@ -28,7 +27,16 @@ * {@link io.vertx.core.json.JsonObject#JsonObject(io.vertx.core.buffer.Buffer)}.
  • *
  • {@code io.vertx.core.json.JsonArray} is encoded with {@link io.vertx.core.json.JsonArray#toBuffer()} and decoded with * {@link io.vertx.core.json.JsonArray#JsonArray(io.vertx.core.buffer.Buffer)}.
  • + *
*

+ * The method may also accept the following parameters: + *

    + *
  • {@link WebSocketConnection}
  • + *
  • {@link HandshakeRequest}
  • + *
  • {@link String} parameters annotated with {@link PathParam}
  • + *
+ *

+ * An endpoint may declare at most one method annotated with this annotation. */ @Retention(RUNTIME) @Target(METHOD) diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnClose.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnClose.java index 5755baa631e6b7..49707928265d6f 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnClose.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnClose.java @@ -9,7 +9,7 @@ import io.smallrye.common.annotation.Experimental; /** - * A method of an {@link WebSocket} endpoint annotated with this annotation is invoked when the client disconnects from the + * A {@link WebSocket} endpoint method annotated with this annotation is invoked when the client disconnects from the * socket. *

* An endpoint may declare at most one method annotated with this annotation. diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnError.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnError.java new file mode 100644 index 00000000000000..6e226283498740 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnError.java @@ -0,0 +1,31 @@ +package io.quarkus.websockets.next; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import io.quarkus.websockets.next.WebSocketConnection.HandshakeRequest; +import io.smallrye.common.annotation.Experimental; + +/** + * A {@link WebSocket} endpoint method annotated with this annotation is invoked when an error occurs. + *

+ * The method must accept exactly one error parameter, i.e. a parameter that is assignable from {@link java.lang.Throwable}. + * The method may also accept the following parameters: + *

    + *
  • {@link WebSocketConnection}
  • + *
  • {@link HandshakeRequest}
  • + *
  • {@link String} parameters annotated with {@link PathParam}
  • + *
+ *

+ * An endpoint may declare multiple methods annotated with this annotation. However, each method must declare a unique error + * parameter. + */ +@Retention(RUNTIME) +@Target(METHOD) +@Experimental("This API is experimental and may change in the future") +public @interface OnError { + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnOpen.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnOpen.java index 913c6b3fc36df7..6aff4281fbdab6 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnOpen.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnOpen.java @@ -6,12 +6,20 @@ import java.lang.annotation.Retention; import java.lang.annotation.Target; +import io.quarkus.websockets.next.WebSocketConnection.HandshakeRequest; import io.smallrye.common.annotation.Experimental; /** - * A method of an {@link WebSocket} endpoint annotated with this annotation is invoked when the client connects to a web socket + * A {@link WebSocket} endpoint method annotated with this annotation is invoked when the client connects to a web socket * endpoint. *

+ * The method may accept the following parameters: + *

    + *
  • {@link WebSocketConnection}
  • + *
  • {@link HandshakeRequest}
  • + *
  • {@link String} parameters annotated with {@link PathParam}
  • + *
+ *

* An endpoint may declare at most one method annotated with this annotation. */ @Retention(RUNTIME) diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnPongMessage.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnPongMessage.java index e5253e83357a27..e628954dafec42 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnPongMessage.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnPongMessage.java @@ -6,14 +6,21 @@ import java.lang.annotation.Retention; import java.lang.annotation.Target; +import io.quarkus.websockets.next.WebSocketConnection.HandshakeRequest; import io.smallrye.common.annotation.Experimental; /** * A {@link WebSocket} endpoint method annotated with this annotation consumes pong messages. + * + * The method must accept exactly one pong message parameter represented as a {@link io.vertx.core.buffer.Buffer}. The method + * may also accept the following parameters: + *

    + *
  • {@link WebSocketConnection}
  • + *
  • {@link HandshakeRequest}
  • + *
  • {@link String} parameters annotated with {@link PathParam}
  • + *
*

* An endpoint may declare at most one method annotated with this annotation. - *

- * A pong message is always represented as a {@link io.vertx.core.buffer.Buffer}. */ @Retention(RUNTIME) @Target(METHOD) diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnTextMessage.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnTextMessage.java index d7a0d546c5683f..922101c3dda0f3 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnTextMessage.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnTextMessage.java @@ -6,18 +6,17 @@ import java.lang.annotation.Retention; import java.lang.annotation.Target; +import io.quarkus.websockets.next.WebSocketConnection.HandshakeRequest; import io.smallrye.common.annotation.Experimental; /** * A {@link WebSocket} endpoint method annotated with this annotation consumes text messages. *

- * An endpoint may declare at most one method annotated with this annotation. - *

- * A text message is always represented as a {@link String}. Therefore, the following conversion rules apply. The types listed + * The method must accept exactly one message parameter. A text message is always represented as a {@link String}. Therefore, + * the following conversion rules apply. The types listed * below are handled specifically. For all other types a {@link TextMessageCodec} is used to encode and decode input and * output messages. By default, the first input codec that supports the message type is used; codecs with higher priority go * first. However, a specific codec can be selected with {@link #codec()} and {@link #outputCodec()}. - * *

    *
  • {@code java.lang.String} is used as is,
  • *
  • {@code io.vertx.core.json.JsonObject} is encoded with {@link io.vertx.core.json.JsonObject#encode()} and decoded with @@ -27,7 +26,16 @@ *
  • {@code java.lang.Buffer} is encoded with {@link io.vertx.core.buffer.Buffer#toString()} and decoded with * {@link io.vertx.core.buffer.Buffer#buffer(String)},
  • *
  • {@code byte[]} is first converted to {@link io.vertx.core.buffer.Buffer} and then converted as defined above.
  • + *
*

+ * The method may also accept the following parameters: + *

    + *
  • {@link WebSocketConnection}
  • + *
  • {@link HandshakeRequest}
  • + *
  • {@link String} parameters annotated with {@link PathParam}
  • + *
+ *

+ * An endpoint may declare at most one method annotated with this annotation. */ @Retention(RUNTIME) @Target(METHOD) diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextDecodeException.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextDecodeException.java new file mode 100644 index 00000000000000..62e49f246946f8 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextDecodeException.java @@ -0,0 +1,29 @@ +package io.quarkus.websockets.next; + +import io.smallrye.common.annotation.Experimental; + +/** + * + * @see TextMessageCodec + */ +@Experimental("This API is experimental and may change in the future") +public class TextDecodeException extends WebSocketServerException { + + private static final long serialVersionUID = 6814319993301938091L; + + private final String text; + + public TextDecodeException(String text, String message) { + this(text, message, null); + } + + public TextDecodeException(String text, String message, Throwable cause) { + super(message, cause); + this.text = text; + } + + public String getText() { + return text; + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextEncodeException.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextEncodeException.java new file mode 100644 index 00000000000000..47a74133f1779f --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextEncodeException.java @@ -0,0 +1,29 @@ +package io.quarkus.websockets.next; + +import io.smallrye.common.annotation.Experimental; + +/** + * + * @see TextMessageCodec + */ +@Experimental("This API is experimental and may change in the future") +public class TextEncodeException extends WebSocketServerException { + + private static final long serialVersionUID = 837621296462089705L; + + private final Object encodedObject; + + public TextEncodeException(Object encodedObject, String message) { + this(encodedObject, message, null); + } + + public TextEncodeException(Object encodedObject, String message, Throwable cause) { + super(message, cause); + this.encodedObject = encodedObject; + } + + public Object getEncodedObject() { + return encodedObject; + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/Codecs.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/Codecs.java index 4bb5d61c9a8fd8..0b3b7f3abaec11 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/Codecs.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/Codecs.java @@ -5,19 +5,20 @@ import jakarta.inject.Singleton; -import org.jboss.logging.Logger; - import io.quarkus.arc.All; +import io.quarkus.websockets.next.BinaryDecodeException; +import io.quarkus.websockets.next.BinaryEncodeException; import io.quarkus.websockets.next.BinaryMessageCodec; import io.quarkus.websockets.next.MessageCodec; +import io.quarkus.websockets.next.TextDecodeException; +import io.quarkus.websockets.next.TextEncodeException; import io.quarkus.websockets.next.TextMessageCodec; +import io.quarkus.websockets.next.WebSocketServerException; import io.vertx.core.buffer.Buffer; @Singleton public class Codecs { - private static final Logger LOG = Logger.getLogger(Codecs.class); - @All List> textCodecs; @@ -29,12 +30,12 @@ public Object textDecode(Type type, String value, Class codecBeanClass) { for (TextMessageCodec codec : textCodecs) { if (codec.getClass().equals(codecBeanClass)) { if (!codec.supports(type)) { - throw forcedCannotHandle(false, codec, type); + throw forcedCannotDecode(value, null, codec, type); } try { return codec.decode(type, value); } catch (Exception e) { - throw unableToDecode(false, codec, e); + throw unableToDecode(value, null, codec, e); } } } @@ -44,13 +45,12 @@ public Object textDecode(Type type, String value, Class codecBeanClass) { try { return codec.decode(type, value); } catch (Exception e) { - throw unableToDecode(false, codec, e); + throw unableToDecode(value, null, codec, e); } } } } - - throw noCodec(false, type); + throw noCodecToDecode(value, null, type); } public String textEncode(T message, Class codecBeanClass) { @@ -59,12 +59,12 @@ public String textEncode(T message, Class codecBeanClass) { for (TextMessageCodec codec : textCodecs) { if (codec.getClass().equals(codecBeanClass)) { if (!codec.supports(type)) { - throw forcedCannotHandle(false, codec, type); + throw forcedCannotEncode(false, codec, message); } try { return codec.encode(cast(message)); } catch (Exception e) { - throw unableToEncode(false, codec, e); + throw unableToEncode(false, codec, message, e); } } } @@ -74,12 +74,12 @@ public String textEncode(T message, Class codecBeanClass) { try { return codec.encode(cast(message)); } catch (Exception e) { - throw unableToEncode(false, codec, e); + throw unableToEncode(false, codec, message, e); } } } } - throw noCodec(false, type); + throw noCodecToEncode(false, message, type); } public Object binaryDecode(Type type, Buffer value, Class codecBeanClass) { @@ -87,12 +87,12 @@ public Object binaryDecode(Type type, Buffer value, Class codecBeanClass) { for (BinaryMessageCodec codec : binaryCodecs) { if (codec.getClass().equals(codecBeanClass)) { if (!codec.supports(type)) { - throw forcedCannotHandle(false, codec, type); + throw forcedCannotDecode(null, value, codec, type); } try { return codec.decode(type, value); } catch (Exception e) { - throw unableToDecode(false, codec, e); + throw unableToDecode(null, value, codec, e); } } } @@ -102,12 +102,12 @@ public Object binaryDecode(Type type, Buffer value, Class codecBeanClass) { try { return codec.decode(type, value); } catch (Exception e) { - LOG.errorf(e, "Unable to decode binary message with %s", codec.getClass().getName()); + throw unableToDecode(null, value, codec, e); } } } } - throw noCodec(true, type); + throw noCodecToDecode(null, value, type); } public Buffer binaryEncode(T message, Class codecBeanClass) { @@ -116,12 +116,12 @@ public Buffer binaryEncode(T message, Class codecBeanClass) { for (BinaryMessageCodec codec : binaryCodecs) { if (codec.getClass().equals(codecBeanClass)) { if (!codec.supports(type)) { - throw forcedCannotHandle(false, codec, type); + throw forcedCannotEncode(true, codec, message); } try { return codec.encode(cast(message)); } catch (Exception e) { - throw unableToEncode(false, codec, e); + throw unableToEncode(true, codec, message, e); } } } @@ -131,35 +131,70 @@ public Buffer binaryEncode(T message, Class codecBeanClass) { try { return codec.encode(cast(message)); } catch (Exception e) { - throw unableToEncode(true, codec, e); + throw unableToEncode(true, codec, message, e); } } } } - throw noCodec(true, type); + throw noCodecToEncode(true, message, type); + } + + WebSocketServerException noCodecToDecode(String text, Buffer bytes, Type type) { + String message = String.format("No %s codec handles the type %s", bytes != null ? "binary" : "text", type); + if (bytes != null) { + return new BinaryDecodeException(bytes, message); + } else { + return new TextDecodeException(text, message); + } } - IllegalStateException noCodec(boolean binary, Type type) { + WebSocketServerException noCodecToEncode(boolean binary, Object encodedObject, Type type) { String message = String.format("No %s codec handles the type %s", binary ? "binary" : "text", type); - throw new IllegalStateException(message); + if (binary) { + return new BinaryEncodeException(encodedObject, message); + } else { + return new TextEncodeException(encodedObject, message); + } } - IllegalStateException unableToEncode(boolean binary, MessageCodec codec, Exception e) { + WebSocketServerException unableToEncode(boolean binary, MessageCodec codec, Object encodedObject, Exception e) { String message = String.format("Unable to encode %s message with %s", binary ? "binary" : "text", codec.getClass().getName()); - throw new IllegalStateException(message, e); + if (binary) { + return new BinaryEncodeException(encodedObject, message, e); + } else { + return new TextEncodeException(encodedObject, message, e); + } } - IllegalStateException unableToDecode(boolean binary, MessageCodec codec, Exception e) { - String message = String.format("Unable to decode %s message with %s", binary ? "binary" : "text", + WebSocketServerException unableToDecode(String text, Buffer bytes, MessageCodec codec, Exception e) { + String message = String.format("Unable to decode %s message with %s", bytes != null ? "binary" : "text", codec.getClass().getName()); - throw new IllegalStateException(message, e); + if (bytes != null) { + return new BinaryDecodeException(bytes, message, e); + } else { + return new TextDecodeException(text, message, e); + } } - IllegalStateException forcedCannotHandle(boolean binary, MessageCodec codec, Type type) { - throw new IllegalStateException( - String.format("Forced %s codec [%s] cannot handle the type %s", binary ? "binary" : "text", - codec.getClass().getName(), type)); + WebSocketServerException forcedCannotEncode(boolean binary, MessageCodec codec, Object encodedObject) { + String message = String.format("Forced %s codec [%s] cannot handle the type %s", binary ? "binary" : "text", + codec.getClass().getName(), encodedObject.getClass()); + if (binary) { + return new BinaryEncodeException(encodedObject, message); + } else { + return new TextEncodeException(encodedObject, message); + } + } + + WebSocketServerException forcedCannotDecode(String text, Buffer bytes, MessageCodec codec, Type type) { + String message = String.format("Forced %s codec [%s] cannot decode the type %s", bytes != null ? "binary" : "text", + codec.getClass().getName(), type); + if (bytes != null) { + return new BinaryDecodeException(bytes, message); + } else { + return new TextDecodeException(text, message); + } } @SuppressWarnings("unchecked") diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java index 8d9620f09c10a2..8fdd551ded9503 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java @@ -166,16 +166,6 @@ public Void call() { return null; } - // TODO This implementation of timeout does not help a lot - // Should we emit on the current context? - // io.smallrye.mutiny.vertx.core.ContextAwareScheduler - // private Uni withTimeout(Uni action) { - // if (config.timeout().isEmpty()) { - // return action; - // } - // return action.ifNoItem().after(config.timeout().get()).fail(); - // } - protected Object beanInstance(String identifier) { return container.instance(container.bean(identifier)).get(); } @@ -200,6 +190,12 @@ protected Uni doOnClose(Object message) { return Uni.createFrom().voidItem(); } + // Keep this method public - there is a problem with invoking the protected methods in the test mode + public Uni doOnError(Throwable t) { + // This method is overriden if there is at least one error handler defined + return Uni.createFrom().failure(t); + } + protected Object decodeText(Type type, String value, Class codecBeanClass) { return codecs.textDecode(type, value, codecBeanClass); }