From 1acdb64020e1908cadce99769c9de859588f3908 Mon Sep 17 00:00:00 2001
From: Martin Kouba <mkouba@redhat.com>
Date: Mon, 22 Apr 2024 12:04:05 +0200
Subject: [PATCH] WebSockets Next: avoid unnecessary bean lookups

- i.e. obtain the contextual reference when endpoint is created
---
 .../deployment/WebSocketServerProcessor.java  | 24 +++++++++++--------
 .../next/runtime/WebSocketEndpoint.java       |  7 ++++++
 .../next/runtime/WebSocketEndpointBase.java   | 21 ++++++++++++++++
 3 files changed, 42 insertions(+), 10 deletions(-)

diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java
index 18df7a083df1d..1ccd5ae685638 100644
--- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java
+++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java
@@ -398,7 +398,7 @@ private void validateOnClose(Callback callback) {
      *     }
      *
      *     public Uni doOnTextMessage(String message) {
-     *         Uni uni = ((Echo) super.beanInstance("MTd91f3oxHtG8gnznR7XcZBCLdE")).echo((String) message);
+     *         Uni uni = ((Echo) super.beanInstance().echo((String) message);
      *         if (uni != null) {
      *             // The lambda is implemented as a generated function: Echo_WebSocketEndpoint$$function$$1
      *             return uni.chain(m -> sendText(m, false));
@@ -408,7 +408,7 @@ private void validateOnClose(Callback callback) {
      *     }
      *
      *     public Uni doOnTextMessage(Object message) {
-     *         Object bean = super.beanInstance("egBJQ7_QAFkQlYXSTKE0XlN3wow");
+     *         Object bean = super.beanInstance();
      *         try {
      *             String ret = ((EchoEndpoint) bean).echo((String) message);
      *             return ret != null ? super.sendText(ret, false) : Uni.createFrom().voidItem();
@@ -430,6 +430,10 @@ private void validateOnClose(Callback callback) {
      *     public WebSocketEndpoint.ExecutionModel onTextMessageExecutionModel() {
      *         return ExecutionModel.EVENT_LOOP;
      *     }
+     *
+     *     public String beanIdentifier() {
+     *        return "egBJQ7_QAFkQlYXSTKE0XlN3wow";
+     *     }
      * }
      * </pre>
      *
@@ -470,13 +474,15 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint,
         MethodCreator executionMode = endpointCreator.getMethodCreator("executionMode", WebSocket.ExecutionMode.class);
         executionMode.returnValue(executionMode.load(endpoint.executionMode));
 
+        MethodCreator beanIdentifier = endpointCreator.getMethodCreator("beanIdentifier", String.class);
+        beanIdentifier.returnValue(beanIdentifier.load(endpoint.bean.getIdentifier()));
+
         if (endpoint.onOpen != null) {
             Callback callback = endpoint.onOpen;
             MethodCreator doOnOpen = endpointCreator.getMethodCreator("doOnOpen", Uni.class, Object.class);
-            // Foo foo = beanInstance("foo");
+            // Foo foo = beanInstance();
             ResultHandle beanInstance = doOnOpen.invokeVirtualMethod(
-                    MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class),
-                    doOnOpen.getThis(), doOnOpen.load(endpoint.bean.getIdentifier()));
+                    MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class), doOnOpen.getThis());
             // Call the business method
             TryBlock tryBlock = onErrorTryBlock(doOnOpen, doOnOpen.getThis());
             ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index);
@@ -500,8 +506,7 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint,
             MethodCreator doOnClose = endpointCreator.getMethodCreator("doOnClose", Uni.class, Object.class);
             // Foo foo = beanInstance("foo");
             ResultHandle beanInstance = doOnClose.invokeVirtualMethod(
-                    MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class),
-                    doOnClose.getThis(), doOnClose.load(endpoint.bean.getIdentifier()));
+                    MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class), doOnClose.getThis());
             // Call the business method
             TryBlock tryBlock = onErrorTryBlock(doOnClose, doOnClose.getThis());
             ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index);
@@ -648,10 +653,9 @@ private void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBu
                 methodParameterType);
 
         TryBlock tryBlock = onErrorTryBlock(doOnMessage, doOnMessage.getThis());
-        // Foo foo = beanInstance("foo");
+        // Foo foo = beanInstance();
         ResultHandle beanInstance = tryBlock.invokeVirtualMethod(
-                MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class),
-                tryBlock.getThis(), tryBlock.load(endpoint.bean.getIdentifier()));
+                MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class), tryBlock.getThis());
         ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index);
         // Call the business method
         ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance,
diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java
index 5ad60e04a69dd..a5dfb70a86076 100644
--- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java
+++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java
@@ -79,6 +79,13 @@ default ExecutionModel onCloseExecutionModel() {
 
     Uni<Void> doOnError(Throwable t);
 
+    /**
+     *
+     * @return the identifier of the bean with callbacks
+     * @see io.quarkus.arc.InjectableBean#getIdentifier()
+     */
+    String beanIdentifier();
+
     enum ExecutionModel {
         WORKER_THREAD,
         VIRTUAL_THREAD,
diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java
index 051362461babe..261de140f1683 100644
--- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java
+++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java
@@ -5,10 +5,14 @@
 import java.util.function.Consumer;
 import java.util.function.Function;
 
+import jakarta.enterprise.context.ApplicationScoped;
+import jakarta.inject.Singleton;
+
 import org.jboss.logging.Logger;
 
 import io.quarkus.arc.Arc;
 import io.quarkus.arc.ArcContainer;
+import io.quarkus.arc.InjectableBean;
 import io.quarkus.arc.InjectableContext.ContextState;
 import io.quarkus.virtual.threads.VirtualThreadsRecorder;
 import io.quarkus.websockets.next.WebSocket.ExecutionMode;
@@ -43,6 +47,9 @@ public abstract class WebSocketEndpointBase implements WebSocketEndpoint {
 
     private final ContextSupport contextSupport;
 
+    private final InjectableBean<?> bean;
+    private final Object beanInstance;
+
     public WebSocketEndpointBase(WebSocketConnection connection, Codecs codecs,
             WebSocketsRuntimeConfig config, ContextSupport contextSupport) {
         this.connection = connection;
@@ -51,6 +58,16 @@ public WebSocketEndpointBase(WebSocketConnection connection, Codecs codecs,
         this.config = config;
         this.container = Arc.container();
         this.contextSupport = contextSupport;
+        InjectableBean<?> bean = container.bean(beanIdentifier());
+        if (bean.getScope().equals(ApplicationScoped.class)
+                || bean.getScope().equals(Singleton.class)) {
+            // For certain scopes, we can optimize and obtain the contextual reference immediately
+            this.bean = null;
+            this.beanInstance = container.instance(bean).get();
+        } else {
+            this.bean = bean;
+            this.beanInstance = null;
+        }
     }
 
     @Override
@@ -238,6 +255,10 @@ public void handle(Void event) {
         return UniHelper.toUni(promise.future());
     }
 
+    public Object beanInstance() {
+        return beanInstance != null ? beanInstance : container.instance(bean).get();
+    }
+
     public Object beanInstance(String identifier) {
         return container.instance(container.bean(identifier)).get();
     }