From 1acdb64020e1908cadce99769c9de859588f3908 Mon Sep 17 00:00:00 2001 From: Martin Kouba <mkouba@redhat.com> Date: Mon, 22 Apr 2024 12:04:05 +0200 Subject: [PATCH] WebSockets Next: avoid unnecessary bean lookups - i.e. obtain the contextual reference when endpoint is created --- .../deployment/WebSocketServerProcessor.java | 24 +++++++++++-------- .../next/runtime/WebSocketEndpoint.java | 7 ++++++ .../next/runtime/WebSocketEndpointBase.java | 21 ++++++++++++++++ 3 files changed, 42 insertions(+), 10 deletions(-) diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java index 18df7a083df1d..1ccd5ae685638 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java @@ -398,7 +398,7 @@ private void validateOnClose(Callback callback) { * } * * public Uni doOnTextMessage(String message) { - * Uni uni = ((Echo) super.beanInstance("MTd91f3oxHtG8gnznR7XcZBCLdE")).echo((String) message); + * Uni uni = ((Echo) super.beanInstance().echo((String) message); * if (uni != null) { * // The lambda is implemented as a generated function: Echo_WebSocketEndpoint$$function$$1 * return uni.chain(m -> sendText(m, false)); @@ -408,7 +408,7 @@ private void validateOnClose(Callback callback) { * } * * public Uni doOnTextMessage(Object message) { - * Object bean = super.beanInstance("egBJQ7_QAFkQlYXSTKE0XlN3wow"); + * Object bean = super.beanInstance(); * try { * String ret = ((EchoEndpoint) bean).echo((String) message); * return ret != null ? super.sendText(ret, false) : Uni.createFrom().voidItem(); @@ -430,6 +430,10 @@ private void validateOnClose(Callback callback) { * public WebSocketEndpoint.ExecutionModel onTextMessageExecutionModel() { * return ExecutionModel.EVENT_LOOP; * } + * + * public String beanIdentifier() { + * return "egBJQ7_QAFkQlYXSTKE0XlN3wow"; + * } * } * </pre> * @@ -470,13 +474,15 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, MethodCreator executionMode = endpointCreator.getMethodCreator("executionMode", WebSocket.ExecutionMode.class); executionMode.returnValue(executionMode.load(endpoint.executionMode)); + MethodCreator beanIdentifier = endpointCreator.getMethodCreator("beanIdentifier", String.class); + beanIdentifier.returnValue(beanIdentifier.load(endpoint.bean.getIdentifier())); + if (endpoint.onOpen != null) { Callback callback = endpoint.onOpen; MethodCreator doOnOpen = endpointCreator.getMethodCreator("doOnOpen", Uni.class, Object.class); - // Foo foo = beanInstance("foo"); + // Foo foo = beanInstance(); ResultHandle beanInstance = doOnOpen.invokeVirtualMethod( - MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), - doOnOpen.getThis(), doOnOpen.load(endpoint.bean.getIdentifier())); + MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class), doOnOpen.getThis()); // Call the business method TryBlock tryBlock = onErrorTryBlock(doOnOpen, doOnOpen.getThis()); ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index); @@ -500,8 +506,7 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, MethodCreator doOnClose = endpointCreator.getMethodCreator("doOnClose", Uni.class, Object.class); // Foo foo = beanInstance("foo"); ResultHandle beanInstance = doOnClose.invokeVirtualMethod( - MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), - doOnClose.getThis(), doOnClose.load(endpoint.bean.getIdentifier())); + MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class), doOnClose.getThis()); // Call the business method TryBlock tryBlock = onErrorTryBlock(doOnClose, doOnClose.getThis()); ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index); @@ -648,10 +653,9 @@ private void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBu methodParameterType); TryBlock tryBlock = onErrorTryBlock(doOnMessage, doOnMessage.getThis()); - // Foo foo = beanInstance("foo"); + // Foo foo = beanInstance(); ResultHandle beanInstance = tryBlock.invokeVirtualMethod( - MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), - tryBlock.getThis(), tryBlock.load(endpoint.bean.getIdentifier())); + MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class), tryBlock.getThis()); ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index); // Call the business method ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java index 5ad60e04a69dd..a5dfb70a86076 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java @@ -79,6 +79,13 @@ default ExecutionModel onCloseExecutionModel() { Uni<Void> doOnError(Throwable t); + /** + * + * @return the identifier of the bean with callbacks + * @see io.quarkus.arc.InjectableBean#getIdentifier() + */ + String beanIdentifier(); + enum ExecutionModel { WORKER_THREAD, VIRTUAL_THREAD, diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java index 051362461babe..261de140f1683 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java @@ -5,10 +5,14 @@ import java.util.function.Consumer; import java.util.function.Function; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Singleton; + import org.jboss.logging.Logger; import io.quarkus.arc.Arc; import io.quarkus.arc.ArcContainer; +import io.quarkus.arc.InjectableBean; import io.quarkus.arc.InjectableContext.ContextState; import io.quarkus.virtual.threads.VirtualThreadsRecorder; import io.quarkus.websockets.next.WebSocket.ExecutionMode; @@ -43,6 +47,9 @@ public abstract class WebSocketEndpointBase implements WebSocketEndpoint { private final ContextSupport contextSupport; + private final InjectableBean<?> bean; + private final Object beanInstance; + public WebSocketEndpointBase(WebSocketConnection connection, Codecs codecs, WebSocketsRuntimeConfig config, ContextSupport contextSupport) { this.connection = connection; @@ -51,6 +58,16 @@ public WebSocketEndpointBase(WebSocketConnection connection, Codecs codecs, this.config = config; this.container = Arc.container(); this.contextSupport = contextSupport; + InjectableBean<?> bean = container.bean(beanIdentifier()); + if (bean.getScope().equals(ApplicationScoped.class) + || bean.getScope().equals(Singleton.class)) { + // For certain scopes, we can optimize and obtain the contextual reference immediately + this.bean = null; + this.beanInstance = container.instance(bean).get(); + } else { + this.bean = bean; + this.beanInstance = null; + } } @Override @@ -238,6 +255,10 @@ public void handle(Void event) { return UniHelper.toUni(promise.future()); } + public Object beanInstance() { + return beanInstance != null ? beanInstance : container.instance(bean).get(); + } + public Object beanInstance(String identifier) { return container.instance(container.bean(identifier)).get(); }