diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java index 6e58aabf0436be..64c430a41c071b 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java @@ -463,7 +463,7 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), doOnOpen.getThis(), doOnOpen.load(endpoint.bean.getIdentifier())); // Call the business method - TryBlock tryBlock = onErrorTryBlock(doOnOpen); + TryBlock tryBlock = onErrorTryBlock(doOnOpen, doOnOpen.getThis()); ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index); ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, globalErrorHandlers, endpoint, ret); @@ -488,7 +488,7 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), doOnClose.getThis(), doOnClose.load(endpoint.bean.getIdentifier())); // Call the business method - TryBlock tryBlock = onErrorTryBlock(doOnClose); + TryBlock tryBlock = onErrorTryBlock(doOnClose, doOnClose.getThis()); ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index); ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, globalErrorHandlers, endpoint, ret); @@ -632,7 +632,7 @@ private void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBu MethodCreator doOnMessage = endpointCreator.getMethodCreator("doOn" + messageType + "Message", Uni.class, methodParameterType); - TryBlock tryBlock = onErrorTryBlock(doOnMessage); + TryBlock tryBlock = onErrorTryBlock(doOnMessage, doOnMessage.getThis()); // Foo foo = beanInstance("foo"); ResultHandle beanInstance = tryBlock.invokeVirtualMethod( MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), @@ -673,13 +673,13 @@ private TryBlock uniFailureTryBlock(BytecodeCreator method) { return tryBlock; } - private TryBlock onErrorTryBlock(BytecodeCreator method) { + private TryBlock onErrorTryBlock(BytecodeCreator method, ResultHandle endpointThis) { TryBlock tryBlock = method.tryBlock(); CatchBlockCreator catchBlock = tryBlock.addCatch(Throwable.class); // return doOnError(t); catchBlock.returnValue(catchBlock.invokeVirtualMethod( MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "doOnError", Uni.class, Throwable.class), - catchBlock.getThis(), catchBlock.getCaughtException())); + endpointThis, catchBlock.getCaughtException())); return tryBlock; } @@ -810,23 +810,28 @@ private ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCreator me return uniOnFailureDoOnError(endpointThis, method, callback, uniChain, endpoint, globalErrorHandlers); } } else if (callback.isReturnTypeMulti()) { - // return multiBinary(multi, broadcast, m -> { - // Buffer buffer = encodeBuffer(m); - // return sendBinary(buffer,broadcast); - //}); + // try { + // Buffer buffer = encodeBuffer(m); + // return sendBinary(buffer,broadcast); + // } catch(Throwable t) { + // return doOnError(t); + // } FunctionCreator fun = method.createFunction(Function.class); BytecodeCreator funBytecode = fun.getBytecode(); - ResultHandle buffer = encodeBuffer(funBytecode, callback.returnType().asParameterizedType().arguments().get(0), - funBytecode.getMethodParam(0), endpointThis, callback); - funBytecode.returnValue(funBytecode.invokeVirtualMethod( + // This checkcast should not be necessary but we need to use the endpoint in the function bytecode + // otherwise gizmo does not access the endpoint reference correcly + ResultHandle endpointBase = funBytecode.checkCast(endpointThis, WebSocketEndpointBase.class); + TryBlock tryBlock = onErrorTryBlock(fun.getBytecode(), endpointBase); + ResultHandle buffer = encodeBuffer(tryBlock, callback.returnType().asParameterizedType().arguments().get(0), + tryBlock.getMethodParam(0), endpointThis, callback); + tryBlock.returnValue(tryBlock.invokeVirtualMethod( MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "sendBinary", Uni.class, Buffer.class, boolean.class), endpointThis, buffer, - funBytecode.load(callback.broadcast()))); + tryBlock.load(callback.broadcast()))); return method.invokeVirtualMethod(MethodDescriptor.ofMethod(WebSocketEndpointBase.class, - "multiBinary", Uni.class, Multi.class, boolean.class, Function.class), endpointThis, + "multiBinary", Uni.class, Multi.class, Function.class), endpointThis, value, - method.load(callback.broadcast()), fun.getInstance()); } else { // return sendBinary(buffer,broadcast); @@ -865,22 +870,29 @@ private ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCreator me } } else if (callback.isReturnTypeMulti()) { // return multiText(multi, broadcast, m -> { - // String text = encodeText(m); - // return sendText(buffer,broadcast); + // try { + // String text = encodeText(m); + // return sendText(buffer,broadcast); + // } catch(Throwable t) { + // return doOnError(t); + // } //}); FunctionCreator fun = method.createFunction(Function.class); BytecodeCreator funBytecode = fun.getBytecode(); - ResultHandle text = encodeText(funBytecode, callback.returnType().asParameterizedType().arguments().get(0), - funBytecode.getMethodParam(0), endpointThis, callback); - funBytecode.returnValue(funBytecode.invokeVirtualMethod( + // This checkcast should not be necessary but we need to use the endpoint in the function bytecode + // otherwise gizmo does not access the endpoint reference correcly + ResultHandle endpointBase = funBytecode.checkCast(endpointThis, WebSocketEndpointBase.class); + TryBlock tryBlock = onErrorTryBlock(fun.getBytecode(), endpointBase); + ResultHandle text = encodeText(tryBlock, callback.returnType().asParameterizedType().arguments().get(0), + tryBlock.getMethodParam(0), endpointThis, callback); + tryBlock.returnValue(tryBlock.invokeVirtualMethod( MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "sendText", Uni.class, String.class, boolean.class), endpointThis, text, - funBytecode.load(callback.broadcast()))); + tryBlock.load(callback.broadcast()))); return method.invokeVirtualMethod(MethodDescriptor.ofMethod(WebSocketEndpointBase.class, - "multiText", Uni.class, Multi.class, boolean.class, Function.class), endpointThis, + "multiText", Uni.class, Multi.class, Function.class), endpointThis, value, - method.load(callback.broadcast()), fun.getInstance()); } else { // return sendText(text,broadcast); diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiBinaryDecodeErrorTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiBinaryDecodeErrorTest.java new file mode 100644 index 00000000000000..e14a37c3c990c6 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiBinaryDecodeErrorTest.java @@ -0,0 +1,63 @@ +package io.quarkus.websockets.next.test.errors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.BinaryDecodeException; +import io.quarkus.websockets.next.OnBinaryMessage; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.mutiny.Multi; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; +import io.vertx.mutiny.core.Context; + +public class MultiBinaryDecodeErrorTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Echo.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testError() { + WSClient client = WSClient.create(vertx).connect(testUri); + client.send(Buffer.buffer("1")); + client.waitForMessages(1); + assertEquals("Problem decoding: 1", client.getLastMessage().toString()); + } + + @WebSocket(path = "/echo") + public static class Echo { + + @OnBinaryMessage + Multi process(Multi messages) { + return messages; + } + + @OnError + String decodingError(BinaryDecodeException e) { + assertTrue(Context.isOnWorkerThread()); + return "Problem decoding: " + e.getBytes().toString(); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiBinaryEncodeErrorTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiBinaryEncodeErrorTest.java new file mode 100644 index 00000000000000..979feeb4dc7371 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiBinaryEncodeErrorTest.java @@ -0,0 +1,63 @@ +package io.quarkus.websockets.next.test.errors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.BinaryEncodeException; +import io.quarkus.websockets.next.OnBinaryMessage; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.mutiny.Multi; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; +import io.vertx.mutiny.core.Context; + +public class MultiBinaryEncodeErrorTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Echo.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testError() { + WSClient client = WSClient.create(vertx).connect(testUri); + client.send(Buffer.buffer("1")); + client.waitForMessages(1); + assertEquals("Problem encoding: 1", client.getLastMessage().toString()); + } + + @WebSocket(path = "/echo") + public static class Echo { + + @OnBinaryMessage + Multi process(Buffer message) { + return Multi.createFrom().item(Integer.parseInt(message.toString())); + } + + @OnError + String encodingError(BinaryEncodeException e) { + assertTrue(Context.isOnWorkerThread()); + return "Problem encoding: " + e.getEncodedObject().toString(); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiFailureCloseConnectionTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiFailureCloseConnectionTest.java new file mode 100644 index 00000000000000..660e9192f1e02e --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiFailureCloseConnectionTest.java @@ -0,0 +1,70 @@ +package io.quarkus.websockets.next.test.errors; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; +import java.time.Duration; + +import jakarta.inject.Inject; + +import org.awaitility.Awaitility; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Vertx; + +public class MultiFailureCloseConnectionTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Echo.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testError() { + WSClient client = WSClient.create(vertx).connect(testUri); + client.sendAndAwait("bar,foo,baz"); + // "bar" should be sent back + client.waitForMessages(1); + // "foo" results in a failure -> connection closed + Awaitility.await().atMost(Duration.ofSeconds(5)).until(() -> client.isClosed()); + // "foo" and "baz" should never be sent back + assertEquals(1, client.getMessages().size()); + } + + @WebSocket(path = "/echo") + public static class Echo { + + @OnTextMessage + Multi process(String message) { + return Multi.createFrom().items(message.split(",")).invoke(s -> { + if (s.equals("foo")) { + throw new IllegalArgumentException(); + } + }); + } + + @OnError + Uni runtimeProblem(IllegalArgumentException e, WebSocketConnection connection) { + return connection.close(); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiFailureTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiFailureTest.java new file mode 100644 index 00000000000000..6abf64c6194dc8 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiFailureTest.java @@ -0,0 +1,63 @@ +package io.quarkus.websockets.next.test.errors; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.mutiny.Multi; +import io.vertx.core.Vertx; + +public class MultiFailureTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Echo.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testError() { + WSClient client = WSClient.create(vertx).connect(testUri); + client.sendAndAwait("bar,foo,baz"); + client.waitForMessages(2); + assertEquals("bar", client.getMessages().get(0).toString()); + assertEquals("foo detected", client.getMessages().get(1).toString()); + } + + @WebSocket(path = "/echo") + public static class Echo { + + @OnTextMessage + Multi process(String message) { + return Multi.createFrom().items(message.split(",")).invoke(s -> { + if (s.equals("foo")) { + throw new IllegalArgumentException(); + } + }); + } + + @OnError + String runtimeProblem(IllegalArgumentException e) { + return "foo detected"; + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiTextDecodeErrorTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiTextDecodeErrorTest.java new file mode 100644 index 00000000000000..049e6ffc06a5af --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiTextDecodeErrorTest.java @@ -0,0 +1,76 @@ +package io.quarkus.websockets.next.test.errors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.TextDecodeException; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.mutiny.Multi; +import io.vertx.core.Vertx; +import io.vertx.mutiny.core.Context; + +public class MultiTextDecodeErrorTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Echo.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testError() { + WSClient client = WSClient.create(vertx).connect(testUri); + client.send("not a json"); + client.waitForMessages(1); + assertEquals("Problem decoding: not a json", client.getLastMessage().toString()); + } + + @WebSocket(path = "/echo") + public static class Echo { + + @OnTextMessage + Multi process(Multi pojos) { + return pojos; + } + + @OnError + String decodingError(TextDecodeException e) { + assertTrue(Context.isOnWorkerThread()); + return "Problem decoding: " + e.getText(); + } + + } + + public static class Pojo { + + private String name; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiTextEncodeErrorTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiTextEncodeErrorTest.java new file mode 100644 index 00000000000000..7df746edd1662a --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultiTextEncodeErrorTest.java @@ -0,0 +1,107 @@ +package io.quarkus.websockets.next.test.errors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.lang.reflect.Type; +import java.net.URI; + +import jakarta.annotation.Priority; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.TextEncodeException; +import io.quarkus.websockets.next.TextMessageCodec; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.mutiny.Multi; +import io.vertx.core.Vertx; +import io.vertx.core.json.JsonObject; +import io.vertx.mutiny.core.Context; + +public class MultiTextEncodeErrorTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Echo.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testError() { + WSClient client = WSClient.create(vertx).connect(testUri); + client.send(new JsonObject().put("name", "Fixa").encode()); + client.waitForMessages(1); + assertEquals("java.lang.IllegalArgumentException:Fixa", client.getLastMessage().toString()); + } + + @WebSocket(path = "/echo") + public static class Echo { + + @OnTextMessage(outputCodec = BadCodec.class) + Multi process(Pojo pojo) { + return Multi.createFrom().item(pojo); + } + + @OnError + String encodingError(TextEncodeException e) { + assertTrue(Context.isOnWorkerThread()); + return e.getCause().toString() + ":" + e.getEncodedObject().toString(); + } + + } + + @Priority(-1) // Let the JsonTextMessageCodec decode the pojo + @Singleton + public static class BadCodec implements TextMessageCodec { + + @Override + public boolean supports(Type type) { + return type.equals(Pojo.class); + } + + @Override + public String encode(Pojo value) { + throw new IllegalArgumentException(); + } + + @Override + public Pojo decode(Type type, String value) { + throw new UnsupportedOperationException(); + } + + } + + public static class Pojo { + + private String name; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + @Override + public String toString() { + return name; + } + + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java index 3a0d543ae426a2..5ad60e04a69ddc 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java @@ -3,6 +3,7 @@ import java.lang.reflect.Type; import io.quarkus.websockets.next.WebSocket; +import io.smallrye.mutiny.Uni; import io.vertx.core.Future; import io.vertx.core.buffer.Buffer; @@ -76,6 +77,8 @@ default ExecutionModel onCloseExecutionModel() { return ExecutionModel.NONE; } + Uni doOnError(Throwable t); + enum ExecutionModel { WORKER_THREAD, VIRTUAL_THREAD, diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java index 4a0df4119c9cc1..051362461babe2 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java @@ -262,6 +262,7 @@ protected Uni doOnClose(Object message) { return Uni.createFrom().voidItem(); } + @Override public Uni doOnError(Throwable t) { // This method is overriden if there is at least one error handler defined return Uni.createFrom().failure(t); @@ -293,18 +294,19 @@ public Uni sendText(String message, boolean broadcast) { return broadcast ? connection.broadcast().sendText(message) : connection.sendText(message); } - public Uni multiText(Multi multi, boolean broadcast, Function> itemFun) { - multi.onFailure() - .call(connection::close) + public Uni multiText(Multi multi, Function> action) { + multi.onFailure().recoverWithMulti(t -> doOnError(t).toMulti()) .subscribe().with( m -> { - itemFun.apply(m) + // Encode and send message + action.apply(m) + .onFailure().recoverWithUni(this::doOnError) .subscribe() .with(v -> LOG.debugf("Multi >> text message: %s", connection), t -> LOG.errorf(t, "Unable to send text message from Multi: %s", connection)); }, t -> { - LOG.errorf(t, "Unable to send text message from Multi - connection was closed: %s ", connection); + LOG.errorf(t, "Unable to send text message from Multi: %s ", connection); }); return Uni.createFrom().voidItem(); } @@ -313,18 +315,19 @@ public Uni sendBinary(Buffer message, boolean broadcast) { return broadcast ? connection.broadcast().sendBinary(message) : connection.sendBinary(message); } - public Uni multiBinary(Multi multi, boolean broadcast, Function> itemFun) { - multi.onFailure() - .call(connection::close) + public Uni multiBinary(Multi multi, Function> action) { + multi.onFailure().recoverWithMulti(t -> doOnError(t).toMulti()) .subscribe().with( m -> { - itemFun.apply(m) + // Encode and send message + action.apply(m) + .onFailure().recoverWithUni(this::doOnError) .subscribe() .with(v -> LOG.debugf("Multi >> binary message: %s", connection), t -> LOG.errorf(t, "Unable to send binary message from Multi: %s", connection)); }, t -> { - LOG.errorf(t, "Unable to send text message from Multi - connection was closed: %s ", connection); + LOG.errorf(t, "Unable to send text message from Multi: %s ", connection); }); return Uni.createFrom().voidItem(); } diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java index 83b1a934f79831..c97dfc8107630c 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java @@ -164,9 +164,16 @@ public void handle(Void event) { } else { textMessageHandler(connection, endpoint, ws, onOpenContext, m -> { contextSupport.start(); - textBroadcastProcessor.onNext(endpoint.decodeTextMultiItem(m)); - LOG.debugf("Text message >> Multi: %s", connection); - contextSupport.end(false); + try { + textBroadcastProcessor.onNext(endpoint.decodeTextMultiItem(m)); + LOG.debugf("Text message >> Multi: %s", connection); + } catch (Throwable throwable) { + endpoint.doOnError(throwable).subscribe().with( + v -> LOG.debugf("Text message >> Multi: %s", connection), + t -> LOG.errorf(t, "Unable to send text message to Multi: %s", connection)); + } finally { + contextSupport.end(false); + } }, false); } @@ -185,9 +192,16 @@ public void handle(Void event) { } else { binaryMessageHandler(connection, endpoint, ws, onOpenContext, m -> { contextSupport.start(); - binaryBroadcastProcessor.onNext(endpoint.decodeBinaryMultiItem(m)); - LOG.debugf("Binary message >> Multi: %s", connection); - contextSupport.end(false); + try { + binaryBroadcastProcessor.onNext(endpoint.decodeBinaryMultiItem(m)); + LOG.debugf("Binary message >> Multi: %s", connection); + } catch (Throwable throwable) { + endpoint.doOnError(throwable).subscribe().with( + v -> LOG.debugf("Binary message >> Multi: %s", connection), + t -> LOG.errorf(t, "Unable to send binary message to Multi: %s", connection)); + } finally { + contextSupport.end(false); + } }, false); }