Skip to content

Commit

Permalink
WebSockets Next: error handlers part 1
Browse files Browse the repository at this point in the history
  • Loading branch information
mkouba committed Mar 26, 2024
1 parent 316f8b7 commit 986667d
Show file tree
Hide file tree
Showing 25 changed files with 996 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import java.util.Set;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.IndexView;
import org.jboss.jandex.MethodParameterInfo;
import org.jboss.jandex.Type;

import io.quarkus.gizmo.BytecodeCreator;
import io.quarkus.gizmo.ResultHandle;
import io.quarkus.websockets.next.OnClose;
import io.quarkus.websockets.next.OnError;
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.WebSocketServerException;
Expand Down Expand Up @@ -53,6 +55,12 @@ interface ParameterContext {
*/
String endpointPath();

/**
*
* @return the index that can be used to inspect parameter types
*/
IndexView index();

/**
*
* @return the callback marker annotation
Expand Down Expand Up @@ -88,17 +96,17 @@ interface InvocationBytecodeContext extends ParameterContext {
BytecodeCreator bytecode();

/**
* Obtains the message directly in the bytecode.
* Obtains the message or error directly in the bytecode.
*
* @return the message object or {@code null} for {@link OnOpen} and {@link OnClose} callbacks
* @return the message/error object or {@code null} for {@link OnOpen} and {@link OnClose} callbacks
*/
ResultHandle getMessage();
ResultHandle getPayload();

/**
* Attempts to obtain the decoded message directly in the bytecode.
*
* @param parameterType
* @return the decoded message object or {@code null} for {@link OnOpen} and {@link OnClose} callbacks
* @return the decoded message object or {@code null} for {@link OnOpen}, {@link OnClose} and {@link OnError} callbacks
*/
ResultHandle getDecodedMessage(Type parameterType);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package io.quarkus.websockets.next.deployment;

import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.DotName;
import org.jboss.jandex.IndexView;

import io.quarkus.gizmo.ResultHandle;

class ErrorCallbackArgument implements CallbackArgument {

@Override
public boolean matches(ParameterContext context) {
return context.callbackAnnotation().name().equals(WebSocketDotNames.ON_ERROR)
&& isThrowable(context.index(), context.parameter().type().name());
}

@Override
public ResultHandle get(InvocationBytecodeContext context) {
return context.getPayload();
}

boolean isThrowable(IndexView index, DotName clazzName) {
if (clazzName.equals(WebSocketDotNames.THROWABLE)) {
return true;
}
ClassInfo clazz = index.getClassByName(clazzName);
if (clazz == null) {
throw new IllegalArgumentException("The class " + clazzName + " not found in the index");
}
if (clazz.superName().equals(DotName.OBJECT_NAME)
|| clazz.superName().equals(DotName.RECORD_NAME)
|| clazz.superName().equals(DotName.ENUM_NAME)) {
return false;
}
if (clazz.superName().equals(WebSocketDotNames.THROWABLE)) {
return true;
}
return isThrowable(index, clazz.superName());
}

public static boolean isError(CallbackArgument callbackArgument) {
return callbackArgument instanceof ErrorCallbackArgument;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,8 @@ public int priotity() {
return DEFAULT_PRIORITY - 1;
}

public static boolean isMessage(CallbackArgument callbackArgument) {
return callbackArgument instanceof MessageCallbackArgument;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import io.quarkus.websockets.next.OnBinaryMessage;
import io.quarkus.websockets.next.OnClose;
import io.quarkus.websockets.next.OnError;
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.OnPongMessage;
import io.quarkus.websockets.next.OnTextMessage;
Expand All @@ -27,6 +28,7 @@ final class WebSocketDotNames {
static final DotName ON_BINARY_MESSAGE = DotName.createSimple(OnBinaryMessage.class);
static final DotName ON_PONG_MESSAGE = DotName.createSimple(OnPongMessage.class);
static final DotName ON_CLOSE = DotName.createSimple(OnClose.class);
static final DotName ON_ERROR = DotName.createSimple(OnError.class);
static final DotName UNI = DotName.createSimple(Uni.class);
static final DotName MULTI = DotName.createSimple(Multi.class);
static final DotName RUN_ON_VIRTUAL_THREAD = DotName.createSimple(RunOnVirtualThread.class);
Expand All @@ -38,4 +40,5 @@ final class WebSocketDotNames {
static final DotName VOID = DotName.createSimple(Void.class);
static final DotName PATH_PARAM = DotName.createSimple(PathParam.class);
static final DotName HANDSHAKE_REQUEST = DotName.createSimple(WebSocketConnection.HandshakeRequest.class);
static final DotName THROWABLE = DotName.createSimple(Throwable.class);
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationValue;
import org.jboss.jandex.DotName;
import org.jboss.jandex.IndexView;
import org.jboss.jandex.MethodInfo;
import org.jboss.jandex.MethodParameterInfo;
import org.jboss.jandex.Type;
Expand Down Expand Up @@ -41,9 +43,11 @@ public final class WebSocketEndpointBuildItem extends MultiBuildItem {
public final Callback onBinaryMessage;
public final Callback onPongMessage;
public final Callback onClose;
public final List<Callback> onErrors;

public WebSocketEndpointBuildItem(BeanInfo bean, String path, WebSocket.ExecutionMode executionMode, Callback onOpen,
Callback onTextMessage, Callback onBinaryMessage, Callback onPongMessage, Callback onClose) {
WebSocketEndpointBuildItem(BeanInfo bean, String path, WebSocket.ExecutionMode executionMode, Callback onOpen,
Callback onTextMessage, Callback onBinaryMessage, Callback onPongMessage, Callback onClose,
List<Callback> onErrors) {
this.bean = bean;
this.path = path;
this.executionMode = executionMode;
Expand All @@ -52,6 +56,7 @@ public WebSocketEndpointBuildItem(BeanInfo bean, String path, WebSocket.Executio
this.onBinaryMessage = onBinaryMessage;
this.onPongMessage = onPongMessage;
this.onClose = onClose;
this.onErrors = onErrors;
}

public static class Callback {
Expand All @@ -64,7 +69,7 @@ public static class Callback {

public Callback(AnnotationInstance annotation, MethodInfo method, ExecutionModel executionModel,
CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations,
String endpointPath) {
String endpointPath, IndexView index) {
this.method = method;
this.annotation = annotation;
this.executionModel = executionModel;
Expand All @@ -77,7 +82,8 @@ public Callback(AnnotationInstance annotation, MethodInfo method, ExecutionModel
} else {
this.messageType = MessageType.NONE;
}
this.arguments = collectArguments(annotation, method, callbackArguments, transformedAnnotations, endpointPath);
this.arguments = collectArguments(annotation, method, callbackArguments, transformedAnnotations, endpointPath,
index);
}

public boolean isOnOpen() {
Expand All @@ -88,6 +94,10 @@ public boolean isOnClose() {
return annotation.name().equals(WebSocketDotNames.ON_CLOSE);
}

public boolean isOnError() {
return annotation.name().equals(WebSocketDotNames.ON_ERROR);
}

public Type returnType() {
return method.returnType();
}
Expand Down Expand Up @@ -153,21 +163,8 @@ public enum MessageType {
BINARY
}

public List<CallbackArgument> messageArguments() {
if (arguments.isEmpty()) {
return List.of();
}
List<CallbackArgument> ret = new ArrayList<>();
for (CallbackArgument arg : arguments) {
if (arg instanceof MessageCallbackArgument) {
ret.add(arg);
}
}
return ret;
}

public ResultHandle[] generateArguments(BytecodeCreator bytecode,
TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath) {
TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath, IndexView index) {
if (arguments.isEmpty()) {
return new ResultHandle[] {};
}
Expand All @@ -176,23 +173,23 @@ public ResultHandle[] generateArguments(BytecodeCreator bytecode,
for (CallbackArgument argument : arguments) {
resultHandles[idx] = argument.get(
invocationBytecodeContext(annotation, method.parameters().get(idx), transformedAnnotations,
endpointPath, bytecode));
endpointPath, index, bytecode));
idx++;
}
return resultHandles;
}

static List<CallbackArgument> collectArguments(AnnotationInstance annotation, MethodInfo method,
CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations,
String endpointPath) {
String endpointPath, IndexView index) {
List<MethodParameterInfo> parameters = method.parameters();
if (parameters.isEmpty()) {
return List.of();
}
List<CallbackArgument> arguments = new ArrayList<>(parameters.size());
for (MethodParameterInfo parameter : parameters) {
List<CallbackArgument> found = callbackArguments
.findMatching(parameterContext(annotation, parameter, transformedAnnotations, endpointPath));
.findMatching(parameterContext(annotation, parameter, transformedAnnotations, endpointPath, index));
if (found.isEmpty()) {
String msg = String.format("Unable to inject @%s callback parameter '%s' declared on %s: no injector found",
DotNames.simpleName(annotation.name()),
Expand All @@ -210,11 +207,21 @@ static List<CallbackArgument> collectArguments(AnnotationInstance annotation, Me
}
arguments.add(found.get(0));
}
return arguments;
return List.copyOf(arguments);
}

Type argumentType(Predicate<CallbackArgument> filter) {
int idx = 0;
for (int i = 0; i < arguments.size(); i++) {
if (filter.test(arguments.get(idx))) {
return method.parameterType(i);
}
}
return null;
}

static ParameterContext parameterContext(AnnotationInstance callbackAnnotation, MethodParameterInfo parameter,
TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath) {
TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath, IndexView index) {
return new ParameterContext() {

@Override
Expand All @@ -238,11 +245,17 @@ public String endpointPath() {
return endpointPath;
}

@Override
public IndexView index() {
return index;
}

};
}

private InvocationBytecodeContext invocationBytecodeContext(AnnotationInstance callbackAnnotation,
MethodParameterInfo parameter, TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath,
IndexView index,
BytecodeCreator bytecode) {
return new InvocationBytecodeContext() {

Expand All @@ -267,21 +280,28 @@ public String endpointPath() {
return endpointPath;
}

@Override
public IndexView index() {
return index;
}

@Override
public BytecodeCreator bytecode() {
return bytecode;
}

@Override
public ResultHandle getMessage() {
return acceptsMessage() ? bytecode.getMethodParam(0) : null;
public ResultHandle getPayload() {
return acceptsMessage() || callbackAnnotation.name().equals(WebSocketDotNames.ON_ERROR)
? bytecode.getMethodParam(0)
: null;
}

@Override
public ResultHandle getDecodedMessage(Type parameterType) {
return acceptsMessage()
? WebSocketServerProcessor.decodeMessage(bytecode, acceptsBinaryMessage(), parameterType,
getMessage(), Callback.this)
getPayload(), Callback.this)
: null;
}

Expand Down
Loading

0 comments on commit 986667d

Please sign in to comment.