From ba719d537661e03259a27d6e2a5507174b6ca360 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Wed, 17 Apr 2024 17:01:13 -0400 Subject: [PATCH] Adding in queuing logic --- .../AdjustableCapacityBlockingQueue.java | 4 + .../action/cohere/CohereActionCreator.java | 1 + .../action/cohere/CohereEmbeddingsAction.java | 1 + .../AzureOpenAiEmbeddingsRequestManager.java | 6 +- .../CohereEmbeddingsRequestManager.java | 6 +- .../sender/CohereRerankRequestManager.java | 6 +- .../sender/ExecutableInferenceRequest.java | 3 +- .../http/sender/HttpRequestSender.java | 20 +- .../sender/HuggingFaceRequestManager.java | 14 +- .../http/sender/InferenceRequest.java | 4 +- .../external/http/sender/NoopTask.java | 2 +- .../OpenAiCompletionRequestManager.java | 6 +- .../OpenAiEmbeddingsRequestManager.java | 6 +- .../http/sender/RequestExecutorService.java | 470 +++++++++------ .../http/sender/RequestExecutorService5.java | 419 -------------- .../sender/RequestExecutorServiceOld.java | 317 ++++++++++ .../RequestExecutorServiceSettings.java | 22 +- .../external/http/sender/RequestManager.java | 8 +- .../external/http/sender/RequestTask.java | 2 +- .../http/sender/SingleRequestManager.java | 11 +- .../inference/services/SenderService.java | 2 +- .../services/cohere/CohereService.java | 5 + .../AzureOpenAiActionCreatorTests.java | 13 +- .../AzureOpenAiEmbeddingsActionTests.java | 3 +- .../cohere/CohereActionCreatorTests.java | 3 +- .../cohere/CohereEmbeddingsActionTests.java | 4 +- .../HuggingFaceActionCreatorTests.java | 13 +- .../openai/OpenAiActionCreatorTests.java | 23 +- .../OpenAiChatCompletionActionTests.java | 5 +- .../openai/OpenAiEmbeddingsActionTests.java | 2 +- .../http/sender/HttpRequestSenderTests.java | 28 +- .../RequestExecutorServiceOldTests.java | 541 ++++++++++++++++++ .../sender/RequestExecutorServiceTests.java | 90 +-- ...torTests.java => RequestManagerTests.java} | 35 +- .../sender/SingleRequestManagerTests.java | 2 +- .../services/SenderServiceTests.java | 9 +- .../azureopenai/AzureOpenAiServiceTests.java | 5 +- .../services/cohere/CohereServiceTests.java | 5 +- .../HuggingFaceBaseServiceTests.java | 5 +- .../services/openai/OpenAiServiceTests.java | 5 +- 40 files changed, 1335 insertions(+), 791 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService5.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceOld.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceOldTests.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/{ExecutableRequestCreatorTests.java => RequestManagerTests.java} (57%) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/AdjustableCapacityBlockingQueue.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/AdjustableCapacityBlockingQueue.java index e73151b44a3e4..436af15158fe5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/AdjustableCapacityBlockingQueue.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/AdjustableCapacityBlockingQueue.java @@ -155,6 +155,10 @@ public int size() { return currentQueue.size() + prioritizedReadingQueue.size(); } + public E poll2() { + return null; + } + /** * The number of additional elements that his queue can accept without blocking. */ diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java index 9f54950dba2d3..3b582d950550b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreator.java @@ -25,6 +25,7 @@ public class CohereActionCreator implements CohereActionVisitor { private final ServiceComponents serviceComponents; public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) { + // TODO Batching - accept a class that can handle batching this.sender = Objects.requireNonNull(sender); this.serviceComponents = Objects.requireNonNull(serviceComponents); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java index 63e51d99a8cee..b4815f8f0d1bf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java @@ -36,6 +36,7 @@ public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model, Thread model.getServiceSettings().getCommonSettings().uri(), "Cohere embeddings" ); + // TODO - Batching pass the batching class on to the CohereEmbeddingsRequestManager requestCreator = CohereEmbeddingsRequestManager.of(model, threadPool); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java index 06152b50822aa..e0fcee30e5af3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -55,16 +54,15 @@ public AzureOpenAiEmbeddingsRequestManager(AzureOpenAiEmbeddingsModel model, Tru } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); AzureOpenAiEmbeddingsRequest request = new AzureOpenAiEmbeddingsRequest(truncator, truncatedInput, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java index 0bf1c11285adb..a51910f1d0a67 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -44,16 +43,15 @@ private CohereEmbeddingsRequestManager(CohereEmbeddingsModel model, ThreadPool t } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(input, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java index 1778663a194e8..1351eec406569 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -44,16 +43,15 @@ private CohereRerankRequestManager(CohereRerankModel model, ThreadPool threadPoo } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { CohereRerankRequest request = new CohereRerankRequest(query, input, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableInferenceRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableInferenceRequest.java index 53f30773cbfe3..214eba4ee3485 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableInferenceRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableInferenceRequest.java @@ -23,7 +23,6 @@ record ExecutableInferenceRequest( RequestSender requestSender, Logger logger, Request request, - HttpClientContext context, ResponseHandler responseHandler, Supplier hasFinished, ActionListener listener @@ -34,7 +33,7 @@ public void run() { var inferenceEntityId = request.createHttpRequest().inferenceEntityId(); try { - requestSender.send(logger, request, context, hasFinished, responseHandler, listener); + requestSender.send(logger, request, HttpClientContext.create(), hasFinished, responseHandler, listener); } catch (Exception e) { var errorMessage = Strings.format("Failed to send request from inference entity id [%s]", inferenceEntityId); logger.warn(errorMessage, e); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java index d337860848160..1301edbb3e019 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java @@ -15,6 +15,8 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.RequestExecutor; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -39,30 +41,28 @@ public static class Factory { private final ServiceComponents serviceComponents; private final HttpClientManager httpClientManager; private final ClusterService clusterService; - private final SingleRequestManager requestManager; + private final RequestSender requestSender; public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) { this.serviceComponents = Objects.requireNonNull(serviceComponents); this.httpClientManager = Objects.requireNonNull(httpClientManager); this.clusterService = Objects.requireNonNull(clusterService); - var requestSender = new RetryingHttpSender( + requestSender = new RetryingHttpSender( this.httpClientManager.getHttpClient(), serviceComponents.throttlerManager(), new RetrySettings(serviceComponents.settings(), clusterService), serviceComponents.threadPool() ); - requestManager = new SingleRequestManager(requestSender); } - public Sender createSender(String serviceName) { + public Sender createSender() { return new HttpRequestSender( - serviceName, serviceComponents.threadPool(), httpClientManager, clusterService, serviceComponents.settings(), - requestManager + requestSender ); } } @@ -71,26 +71,24 @@ public Sender createSender(String serviceName) { private final ThreadPool threadPool; private final HttpClientManager manager; - private final RequestExecutorService service; + private final RequestExecutor service; private final AtomicBoolean started = new AtomicBoolean(false); private final CountDownLatch startCompleted = new CountDownLatch(2); private HttpRequestSender( - String serviceName, ThreadPool threadPool, HttpClientManager httpClientManager, ClusterService clusterService, Settings settings, - SingleRequestManager requestManager + RequestSender requestSender ) { this.threadPool = Objects.requireNonNull(threadPool); this.manager = Objects.requireNonNull(httpClientManager); service = new RequestExecutorService( - serviceName, threadPool, startCompleted, new RequestExecutorServiceSettings(settings, clusterService), - requestManager + requestSender ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java index 7c09e0c67c1c6..6c8fc446d5243 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -55,26 +54,17 @@ private HuggingFaceRequestManager(HuggingFaceModel model, ResponseHandler respon } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { var truncatedInput = truncate(input, model.getTokenLimit()); var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model); - return new ExecutableInferenceRequest( - requestSender, - logger, - request, - context, - responseHandler, - hasRequestCompletedFunction, - listener - ); + execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); } record RateLimitGrouping(int accountHash) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java index 3c711bb79717c..6199a75a41a7d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java @@ -19,9 +19,9 @@ public interface InferenceRequest { /** - * Returns the creator that handles building an executable request based on the input provided. + * Returns the manager that handles building and executing an inference request. */ - RequestManager getRequestCreator(); + RequestManager getRequestManager(); /** * Returns the query associated with this request. Used for Rerank tasks. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/NoopTask.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/NoopTask.java index 0355880b3f714..3465ec18022bd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/NoopTask.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/NoopTask.java @@ -16,7 +16,7 @@ class NoopTask implements RejectableTask { @Override - public RequestManager getRequestCreator() { + public RequestManager getRequestManager() { return null; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java index 9c6c216c61272..7bc09fd76736b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -43,17 +42,16 @@ private OpenAiCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPo } @Override - public Runnable create( + public void execute( @Nullable String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(input, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } private static ResponseHandler createCompletionHandler() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java index 3a0a8fd64a656..41f91d2b89ee5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -55,17 +54,16 @@ private OpenAiEmbeddingsRequestManager(OpenAiEmbeddingsModel model, Truncator tr } @Override - public Runnable create( + public void execute( String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ) { var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); OpenAiEmbeddingsRequest request = new OpenAiEmbeddingsRequest(truncator, truncatedInput, model); - return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index d5a13c2e0771d..65969de73f973 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; @@ -17,34 +16,33 @@ import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueue; +import org.elasticsearch.xpack.inference.common.RateLimiter; import org.elasticsearch.xpack.inference.external.http.RequestExecutor; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; -/** - * A service for queuing and executing {@link RequestTask}. This class is useful because the - * {@link org.apache.http.impl.nio.conn.PoolingNHttpClientConnectionManager} will block when leasing a connection if no - * connections are available. To avoid blocking the inference transport threads, this executor will queue up the - * requests until connections are available. - * - * NOTE: It is the responsibility of the class constructing the - * {@link org.apache.http.client.methods.HttpUriRequest} to set a timeout for how long this executor will wait - * attempting to execute a task (aka waiting for the connection manager to lease a connection). See - * {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info. - */ class RequestExecutorService implements RequestExecutor { + private static final AdjustableCapacityBlockingQueue.QueueCreator QUEUE_CREATOR = new AdjustableCapacityBlockingQueue.QueueCreator<>() { @Override @@ -65,86 +63,96 @@ public BlockingQueue create() { } }; + private static final TimeValue DEFAULT_CLEANUP_INTERVAL = TimeValue.timeValueDays(10); + private static final Duration DEFAULT_STALE_DURATION = Duration.ofDays(10); + private static final Logger logger = LogManager.getLogger(RequestExecutorService.class); - private final String serviceName; - private final AdjustableCapacityBlockingQueue queue; - private final AtomicBoolean running = new AtomicBoolean(true); - private final CountDownLatch terminationLatch = new CountDownLatch(1); - private final HttpClientContext httpContext; + + private final ConcurrentMap inferenceEndpoints = new ConcurrentHashMap<>(); private final ThreadPool threadPool; private final CountDownLatch startupLatch; - private final BlockingQueue controlQueue = new LinkedBlockingQueue<>(); - private final SingleRequestManager requestManager; + private final CountDownLatch terminationLatch = new CountDownLatch(1); + private final RequestSender requestSender; + private final RequestExecutorServiceSettings settings; + private final TimeValue cleanUpInterval; + private final Duration staleEndpointDuration; + private final Clock clock; + private final AtomicBoolean shutdown = new AtomicBoolean(false); + private final AdjustableCapacityBlockingQueue.QueueCreator queueCreator; RequestExecutorService( - String serviceName, ThreadPool threadPool, @Nullable CountDownLatch startupLatch, RequestExecutorServiceSettings settings, - SingleRequestManager requestManager + RequestSender requestSender ) { - this(serviceName, threadPool, QUEUE_CREATOR, startupLatch, settings, requestManager); + this( + threadPool, + QUEUE_CREATOR, + startupLatch, + settings, + requestSender, + DEFAULT_CLEANUP_INTERVAL, + DEFAULT_STALE_DURATION, + Clock.systemUTC() + ); } - /** - * This constructor should only be used directly for testing. - */ RequestExecutorService( - String serviceName, ThreadPool threadPool, - AdjustableCapacityBlockingQueue.QueueCreator createQueue, + AdjustableCapacityBlockingQueue.QueueCreator queueCreator, @Nullable CountDownLatch startupLatch, RequestExecutorServiceSettings settings, - SingleRequestManager requestManager + RequestSender requestSender, + TimeValue cleanUpInterval, + Duration staleEndpointDuration, + Clock clock ) { - this.serviceName = Objects.requireNonNull(serviceName); this.threadPool = Objects.requireNonNull(threadPool); - this.httpContext = HttpClientContext.create(); - this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity()); + this.queueCreator = Objects.requireNonNull(queueCreator); this.startupLatch = startupLatch; - this.requestManager = Objects.requireNonNull(requestManager); + this.requestSender = Objects.requireNonNull(requestSender); + this.settings = Objects.requireNonNull(settings); + this.cleanUpInterval = Objects.requireNonNull(cleanUpInterval); + this.staleEndpointDuration = Objects.requireNonNull(staleEndpointDuration); + this.clock = Objects.requireNonNull(clock); + } - Objects.requireNonNull(settings); - settings.registerQueueCapacityCallback(this::onCapacityChange); + public void shutdown() { + if (shutdown.compareAndSet(false, true)) { + for (var endpoint : inferenceEndpoints.values()) { + endpoint.shutdown(); + } + } } - private void onCapacityChange(int capacity) { - logger.debug(() -> Strings.format("Setting queue capacity to [%s]", capacity)); + public boolean isShutdown() { + return shutdown.get(); + } - var enqueuedCapacityCommand = controlQueue.offer(() -> updateCapacity(capacity)); - if (enqueuedCapacityCommand == false) { - logger.warn("Failed to change request batching service queue capacity. Control queue was full, please try again later."); - } else { - // ensure that the task execution loop wakes up - queue.offer(new NoopTask()); - } + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return terminationLatch.await(timeout, unit); } - private void updateCapacity(int newCapacity) { - try { - queue.setCapacity(newCapacity); - } catch (Exception e) { - logger.warn( - format("Failed to set the capacity of the task queue to [%s] for request batching service [%s]", newCapacity, serviceName), - e - ); - } + public boolean isTerminated() { + return terminationLatch.getCount() == 0; + } + + public int queueSize() { + return inferenceEndpoints.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum(); } - /** - * Begin servicing tasks. - */ public void start() { try { signalStartInitiated(); - while (running.get()) { + while (isShutdown() == false) { handleTasks(); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); } finally { - running.set(false); + shutdown(); notifyRequestsOfShutdown(); terminationLatch.countDown(); } @@ -156,108 +164,33 @@ private void signalStartInitiated() { } } - /** - * Protects the task retrieval logic from an unexpected exception. - * - * @throws InterruptedException rethrows the exception if it occurred retrieving a task because the thread is likely attempting to - * shut down - */ private void handleTasks() throws InterruptedException { - try { - RejectableTask task = queue.take(); - - var command = controlQueue.poll(); - if (command != null) { - command.run(); - } - - // TODO add logic to complete pending items in the queue before shutting down - if (running.get() == false) { - logger.debug(() -> format("Http executor service [%s] exiting", serviceName)); - rejectTaskBecauseOfShutdown(task); - } else { - executeTask(task); - } - } catch (InterruptedException e) { - throw e; - } catch (Exception e) { - logger.warn(format("Http executor service [%s] failed while retrieving task for execution", serviceName), e); + boolean handledAtLeastOneTask = false; + for (var endpoint : inferenceEndpoints.values()) { + handledAtLeastOneTask |= endpoint.executeEnqueuedTask(); } - } - private void executeTask(RejectableTask task) { - try { - requestManager.execute(task, httpContext); - } catch (Exception e) { - logger.warn(format("Http executor service [%s] failed to execute request [%s]", serviceName, task), e); + if (handledAtLeastOneTask == false) { + sleep(settings.getTaskPollFrequency()); } } - private synchronized void notifyRequestsOfShutdown() { - assert isShutdown() : "Requests should only be notified if the executor is shutting down"; - - try { - List notExecuted = new ArrayList<>(); - queue.drainTo(notExecuted); - - rejectTasks(notExecuted, this::rejectTaskBecauseOfShutdown); - } catch (Exception e) { - logger.warn(format("Failed to notify tasks of queuing service [%s] shutdown", serviceName)); - } - } - - private void rejectTaskBecauseOfShutdown(RejectableTask task) { - try { - task.onRejection( - new EsRejectedExecutionException( - format("Failed to send request, queue service [%s] has shutdown prior to executing request", serviceName), - true - ) - ); - } catch (Exception e) { - logger.warn( - format("Failed to notify request [%s] for service [%s] of rejection after queuing service shutdown", task, serviceName) - ); - } + private void sleep(TimeValue sleepTime) throws InterruptedException { + sleepTime.timeUnit().sleep(sleepTime.duration()); } - private void rejectTasks(List tasks, Consumer rejectionFunction) { - for (var task : tasks) { - rejectionFunction.accept(task); - } - } - - public int queueSize() { - return queue.size(); - } + private void notifyRequestsOfShutdown() { + assert isShutdown() : "Requests should only be notified if the executor is shutting down"; - @Override - public void shutdown() { - if (running.compareAndSet(true, false)) { - // if this fails because the queue is full, that's ok, we just want to ensure that queue.take() returns - queue.offer(new NoopTask()); + for (var endpoint : inferenceEndpoints.values()) { + endpoint.notifyRequestsOfShutdown(); } } - @Override - public boolean isShutdown() { - return running.get() == false; - } - - @Override - public boolean isTerminated() { - return terminationLatch.getCount() == 0; - } - - @Override - public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { - return terminationLatch.await(timeout, unit); - } - /** * Execute the request at some point in the future. * - * @param requestCreator the http request to send + * @param requestManager the http request to send * @param inferenceInputs the inputs to send in the request * @param timeout the maximum time to wait for this request to complete (failing or succeeding). Once the time elapses, the * listener::onFailure is called with a {@link org.elasticsearch.ElasticsearchTimeoutException}. @@ -265,13 +198,13 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE * @param listener an {@link ActionListener} for the response or failure */ public void execute( - RequestManager requestCreator, + RequestManager requestManager, InferenceInputs inferenceInputs, @Nullable TimeValue timeout, ActionListener listener ) { var task = new RequestTask( - requestCreator, + requestManager, inferenceInputs, timeout, threadPool, @@ -280,38 +213,229 @@ public void execute( ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()) ); - completeExecution(task); + var endpoint = inferenceEndpoints.computeIfAbsent( + requestManager.rateLimitGrouping(), + key -> new RateLimitingEndpointHandler( + Integer.toString(requestManager.rateLimitGrouping().hashCode()), + queueCreator, + settings, + requestSender, + clock, + requestManager.rateLimitSettings() + ) + ); + + endpoint.enqueue(task); } - private void completeExecution(RequestTask task) { - if (isShutdown()) { - EsRejectedExecutionException rejected = new EsRejectedExecutionException( - format("Failed to enqueue task because the http executor service [%s] has already shutdown", serviceName), - true - ); + // TODO schedule a cleanup thread to run on an interval and remove entries from the map that are over a day old + private Scheduler.Cancellable startCleanUpThread(ThreadPool threadPool) { + logger.debug(() -> Strings.format("Clean up task scheduled with interval [%s]", cleanUpInterval)); - task.onRejection(rejected); - return; - } + return threadPool.scheduleWithFixedDelay(this::removeOldEndpoints, cleanUpInterval, threadPool.executor(UTILITY_THREAD_POOL_NAME)); + } - boolean added = queue.offer(task); - if (added == false) { - EsRejectedExecutionException rejected = new EsRejectedExecutionException( - format("Failed to execute task because the http executor service [%s] queue is full", serviceName), - false - ); - - task.onRejection(rejected); - } else if (isShutdown()) { - // It is possible that a shutdown and notification request occurred after we initially checked for shutdown above - // If the task was added after the queue was already drained it could sit there indefinitely. So let's check again if - // we shut down and if so we'll redo the notification - notifyRequestsOfShutdown(); - } + private void removeOldEndpoints() { + var now = Instant.now(clock); + // if the current time is after the last time the endpoint received a request + allowed stale period then we'll remove it + inferenceEndpoints.entrySet() + .removeIf(endpoint -> now.isAfter(endpoint.getValue().timeOfLastEnqueue().plus(staleEndpointDuration))); } // default for testing - int remainingQueueCapacity() { - return queue.remainingCapacity(); + Integer remainingQueueCapacity(RequestManager requestManager) { + var endpoint = inferenceEndpoints.get(requestManager.rateLimitGrouping()); + + if (endpoint == null) { + return null; + } + + return endpoint.remainingCapacity(); + } + + /** + * Provides a mechanism for ensuring that only a single thread is processing tasks from the queue at a time. + * As tasks are enqueued for execution, if a thread executing (or scheduled to execute a task in the future), + * a new one will not be started. + */ + private static class RateLimitingEndpointHandler { + + private static final Logger logger = LogManager.getLogger(RateLimitingEndpointHandler.class); + + private final AdjustableCapacityBlockingQueue queue; + private final AtomicBoolean shutdown = new AtomicBoolean(); + private final RequestSender requestSender; + private final String id; + private Instant timeOfLastEnqueue; + private final Clock clock; + private final RateLimiter rateLimiter; + + RateLimitingEndpointHandler( + String id, + AdjustableCapacityBlockingQueue.QueueCreator createQueue, + RequestExecutorServiceSettings settings, + RequestSender requestSender, + Clock clock, + RateLimitSettings rateLimitSettings + ) { + this.id = Objects.requireNonNull(id); + this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity()); + this.requestSender = Objects.requireNonNull(requestSender); + this.clock = Objects.requireNonNull(clock); + + Objects.requireNonNull(rateLimitSettings); + // TODO figure out a good accumulatedTokensLimit + rateLimiter = new RateLimiter(1, rateLimitSettings.requestsPerTimeUnit(), rateLimitSettings.timeUnit()); + + settings.registerQueueCapacityCallback(this::onCapacityChange); + } + + private void onCapacityChange(int capacity) { + logger.debug(() -> Strings.format("Executor service [%s] setting queue capacity to [%s]", id, capacity)); + + try { + queue.setCapacity(capacity); + } catch (Exception e) { + logger.warn(format("Executor service [%s] failed to set the capacity of the task queue to [%s]", id, capacity), e); + } + } + + public int queueSize() { + return queue.size(); + } + + public void shutdown() { + shutdown.set(true); + } + + public boolean isShutdown() { + return shutdown.get(); + } + + public Instant timeOfLastEnqueue() { + return timeOfLastEnqueue; + } + + public synchronized boolean executeEnqueuedTask() { + var timeBeforeAvailableToken = rateLimiter.timeToReserve2(1); + var task = queue.poll2(); + + // TODO Batching - in a situation where no new tasks are queued we'll want to execute any prepared tasks + // check for null and call a helper method executePreparedTasks() + + if (shouldExecuteImmediately(timeBeforeAvailableToken) == false || task == null) { + return false; + } + + executeTask(task); + return true; + } + + private static boolean shouldExecuteImmediately(TimeValue delay) { + return delay.duration() == 0; + } + + public void enqueue(RequestTask task) { + timeOfLastEnqueue = Instant.now(clock); + + if (isShutdown()) { + EsRejectedExecutionException rejected = new EsRejectedExecutionException( + format( + "Failed to enqueue task because the executor service [%s] has already shutdown", + task.getRequestManager().inferenceEntityId() + ), + true + ); + + task.onRejection(rejected); + return; + } + + var addedToQueue = queue.offer(task); + + if (addedToQueue == false) { + EsRejectedExecutionException rejected = new EsRejectedExecutionException( + format( + "Failed to execute task because the executor service [%s] queue is full", + task.getRequestManager().inferenceEntityId() + ), + false + ); + + task.onRejection(rejected); + } else if (isShutdown()) { + notifyRequestsOfShutdown(); + } + } + + private void executeTask(RejectableTask task) { + try { + if (isNoopRequest(task) || task.hasCompleted()) { + return; + } + + task.getRequestManager() + .execute(task.getQuery(), task.getInput(), requestSender, task.getRequestCompletedFunction(), task.getListener()); + } catch (Exception e) { + logger.warn( + format( + "Executor service [%s] failed to execute request for inference endpoint id [%s]", + id, + task.getRequestManager().inferenceEntityId() + ), + e + ); + } + } + + private static boolean isNoopRequest(InferenceRequest inferenceRequest) { + return inferenceRequest.getRequestManager() == null + || inferenceRequest.getInput() == null + || inferenceRequest.getListener() == null; + } + + public synchronized void notifyRequestsOfShutdown() { + assert isShutdown() : "Requests should only be notified if the executor is shutting down"; + + try { + List notExecuted = new ArrayList<>(); + queue.drainTo(notExecuted); + + rejectTasks(notExecuted); + } catch (Exception e) { + logger.warn(format("Failed to notify tasks of queuing service [%s] shutdown", id)); + } + } + + private void rejectTasks(List tasks) { + for (var task : tasks) { + rejectTask(task); + } + } + + private void rejectTask(RejectableTask task) { + try { + task.onRejection( + new EsRejectedExecutionException( + format( + "Failed to send request, queue service for inference entity [%s] has shutdown prior to executing request", + task.getRequestManager().inferenceEntityId() + ), + true + ) + ); + } catch (Exception e) { + logger.warn( + format( + "Failed to notify request for inference endpoint [%s] of rejection after queuing service shutdown", + task.getRequestManager().inferenceEntityId() + ) + ); + } + } + + public int remainingCapacity() { + return queue.remainingCapacity(); + } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService5.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService5.java deleted file mode 100644 index eacfe7b036e54..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService5.java +++ /dev/null @@ -1,419 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.http.sender; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.ContextPreservingActionListener; -import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.Strings; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.threadpool.Scheduler; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueue; -import org.elasticsearch.xpack.inference.common.RateLimiter; -import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; -import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; - -import java.time.Clock; -import java.time.Duration; -import java.time.Instant; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; - -import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; - -class RequestExecutorService5 { - - private static final AdjustableCapacityBlockingQueue.QueueCreator QUEUE_CREATOR = - new AdjustableCapacityBlockingQueue.QueueCreator<>() { - @Override - public BlockingQueue create(int capacity) { - BlockingQueue queue; - if (capacity <= 0) { - queue = create(); - } else { - queue = new LinkedBlockingQueue<>(capacity); - } - - return queue; - } - - @Override - public BlockingQueue create() { - return new LinkedBlockingQueue<>(); - } - }; - - private static final TimeValue DEFAULT_CLEANUP_INTERVAL = TimeValue.timeValueDays(10); - private static final Duration DEFAULT_STALE_DURATION = Duration.ofDays(10); - - private static final Logger logger = LogManager.getLogger(RequestExecutorService5.class); - - private final ConcurrentMap inferenceEndpoints = new ConcurrentHashMap<>(); - private final ThreadPool threadPool; - private final CountDownLatch startupLatch; - private final CountDownLatch terminationLatch = new CountDownLatch(1); - private final RequestSender requestSender; - private final RequestExecutorServiceSettings settings; - private final TimeValue cleanUpInterval; - private final Duration staleEndpointDuration; - private final Clock clock; - private final AtomicBoolean shutdown = new AtomicBoolean(false); - private final AdjustableCapacityBlockingQueue.QueueCreator queueCreator; - - RequestExecutorService5( - ThreadPool threadPool, - @Nullable CountDownLatch startupLatch, - RequestExecutorServiceSettings settings, - RequestSender requestSender - ) { - this( - threadPool, - QUEUE_CREATOR, - startupLatch, - settings, - requestSender, - DEFAULT_CLEANUP_INTERVAL, - DEFAULT_STALE_DURATION, - Clock.systemUTC() - ); - } - - RequestExecutorService5( - ThreadPool threadPool, - AdjustableCapacityBlockingQueue.QueueCreator queueCreator, - @Nullable CountDownLatch startupLatch, - RequestExecutorServiceSettings settings, - RequestSender requestSender, - TimeValue cleanUpInterval, - Duration staleEndpointDuration, - Clock clock - ) { - this.threadPool = Objects.requireNonNull(threadPool); - this.queueCreator = Objects.requireNonNull(queueCreator); - this.startupLatch = startupLatch; - this.requestSender = Objects.requireNonNull(requestSender); - this.settings = Objects.requireNonNull(settings); - this.cleanUpInterval = Objects.requireNonNull(cleanUpInterval); - this.staleEndpointDuration = Objects.requireNonNull(staleEndpointDuration); - this.clock = Objects.requireNonNull(clock); - } - - public void shutdown() { - if (shutdown.compareAndSet(false, true)) { - for (var endpoint : inferenceEndpoints.values()) { - endpoint.shutdown(); - } - } - } - - public boolean isShutdown() { - return shutdown.get(); - } - - public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { - return terminationLatch.await(timeout, unit); - } - - public boolean isTerminated() { - return terminationLatch.getCount() == 0; - } - - public int queueSize() { - return inferenceEndpoints.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum(); - } - - public void start() { - try { - signalStartInitiated(); - - while (isShutdown() == false) { - handleTasks(); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } finally { - shutdown(); - notifyRequestsOfShutdown(); - terminationLatch.countDown(); - } - } - - private void signalStartInitiated() { - if (startupLatch != null) { - startupLatch.countDown(); - } - } - - private void handleTasks() throws InterruptedException { - boolean handledAtLeastOneTask = false; - for (var endpoint : inferenceEndpoints.values()) { - handledAtLeastOneTask |= endpoint.executeEnqueuedTask(); - } - - if (handledAtLeastOneTask == false) { - // TODO make this configurable - Thread.sleep(50); - } - } - - private void notifyRequestsOfShutdown() { - assert isShutdown() : "Requests should only be notified if the executor is shutting down"; - - for (var endpoint : inferenceEndpoints.values()) { - endpoint.notifyRequestsOfShutdown(); - } - } - - /** - * Execute the request at some point in the future. - * - * @param requestCreator the http request to send - * @param inferenceInputs the inputs to send in the request - * @param timeout the maximum time to wait for this request to complete (failing or succeeding). Once the time elapses, the - * listener::onFailure is called with a {@link org.elasticsearch.ElasticsearchTimeoutException}. - * If null, then the request will wait forever - * @param listener an {@link ActionListener} for the response or failure - */ - public void execute( - RequestManager requestCreator, - InferenceInputs inferenceInputs, - @Nullable TimeValue timeout, - ActionListener listener - ) { - var task = new RequestTask( - requestCreator, - inferenceInputs, - timeout, - threadPool, - // TODO when multi-tenancy (as well as batching) is implemented we need to be very careful that we preserve - // the thread contexts correctly to avoid accidentally retrieving the credentials for the wrong user - ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()) - ); - - var endpoint = inferenceEndpoints.computeIfAbsent( - requestCreator.rateLimitGrouping(), - key -> new RateLimitingEndpointHandler( - Integer.toString(requestCreator.rateLimitGrouping().hashCode()), - queueCreator, - settings, - requestSender, - clock, - requestCreator.rateLimitSettings() - ) - ); - - endpoint.enqueue(task); - } - - // TODO schedule a cleanup thread to run on an interval and remove entries from the map that are over a day old - private Scheduler.Cancellable startCleanUpThread(ThreadPool threadPool) { - logger.debug(() -> Strings.format("Clean up task scheduled with interval [%s]", cleanUpInterval)); - - return threadPool.scheduleWithFixedDelay(this::removeOldEndpoints, cleanUpInterval, threadPool.executor(UTILITY_THREAD_POOL_NAME)); - } - - private void removeOldEndpoints() { - var now = Instant.now(clock); - // if the current time is after the last time the endpoint received a request + allowed stale period then we'll remove it - inferenceEndpoints.entrySet() - .removeIf(endpoint -> now.isAfter(endpoint.getValue().timeOfLastEnqueue().plus(staleEndpointDuration))); - } - - /** - * Provides a mechanism for ensuring that only a single thread is processing tasks from the queue at a time. - * As tasks are enqueued for execution, if a thread executing (or scheduled to execute a task in the future), - * a new one will not be started. - */ - private static class RateLimitingEndpointHandler { - - private static final Logger logger = LogManager.getLogger(RateLimitingEndpointHandler.class); - - private final AdjustableCapacityBlockingQueue queue; - private final AtomicBoolean shutdown = new AtomicBoolean(); - private final RequestSender requestSender; - private final String id; - private Instant timeOfLastEnqueue; - private final Clock clock; - private final RateLimiter rateLimiter; - - RateLimitingEndpointHandler( - String id, - AdjustableCapacityBlockingQueue.QueueCreator createQueue, - RequestExecutorServiceSettings settings, - RequestSender requestSender, - Clock clock, - RateLimitSettings rateLimitSettings - ) { - this.id = Objects.requireNonNull(id); - this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity()); - this.requestSender = Objects.requireNonNull(requestSender); - this.clock = Objects.requireNonNull(clock); - - Objects.requireNonNull(rateLimitSettings); - // TODO figure out a good limit - rateLimiter = new RateLimiter(1, rateLimitSettings.requestsPerTimeUnit(), rateLimitSettings.timeUnit()); - - settings.registerQueueCapacityCallback(this::onCapacityChange); - } - - private void onCapacityChange(int capacity) { - logger.debug(() -> Strings.format("Executor service [%s] setting queue capacity to [%s]", id, capacity)); - - try { - queue.setCapacity(capacity); - } catch (Exception e) { - logger.warn(format("Executor service [%s] failed to set the capacity of the task queue to [%s]", id, capacity), e); - } - } - - public int queueSize() { - return queue.size(); - } - - public void shutdown() { - shutdown.set(true); - } - - public boolean isShutdown() { - return shutdown.get(); - } - - public Instant timeOfLastEnqueue() { - return timeOfLastEnqueue; - } - - public synchronized boolean executeEnqueuedTask() { - var timeBeforeAvailableToken = rateLimiter.timeToReserve2(1); - var task = queue.poll(); - - if (shouldExecuteImmediately(timeBeforeAvailableToken) == false || task == null) { - return false; - } - - executeTask(task); - return true; - } - - private static boolean shouldExecuteImmediately(TimeValue delay) { - return delay.duration() == 0; - } - - public void enqueue(RequestTask task) { - timeOfLastEnqueue = Instant.now(clock); - - if (isShutdown()) { - EsRejectedExecutionException rejected = new EsRejectedExecutionException( - format( - "Failed to enqueue task because the executor service [%s] has already shutdown", - task.getRequestManager().inferenceEntityId() - ), - true - ); - - task.onRejection(rejected); - return; - } - - var addedToQueue = queue.offer(task); - - if (addedToQueue == false) { - EsRejectedExecutionException rejected = new EsRejectedExecutionException( - format( - "Failed to execute task because the executor service [%s] queue is full", - task.getRequestManager().inferenceEntityId() - ), - false - ); - - task.onRejection(rejected); - } else if (isShutdown()) { - notifyRequestsOfShutdown(); - } - } - - private void executeTask(RejectableTask task) { - try { - if (isNoopRequest(task) || task.hasCompleted()) { - return; - } - - task.getRequestManager().execute(task.getInput(), requestSender, task.getRequestCompletedFunction(), task.getListener()); - } catch (Exception e) { - logger.warn( - format( - "Executor service [%s] failed to execute request for inference endpoint id [%s]", - id, - task.getRequestManager().inferenceEntityId() - ), - e - ); - } - } - - private static boolean isNoopRequest(InferenceRequest inferenceRequest) { - return inferenceRequest.getRequestManager() == null - || inferenceRequest.getInput() == null - || inferenceRequest.getListener() == null; - } - - public synchronized void notifyRequestsOfShutdown() { - assert isShutdown() : "Requests should only be notified if the executor is shutting down"; - - try { - List notExecuted = new ArrayList<>(); - queue.drainTo(notExecuted); - - rejectTasks(notExecuted); - } catch (Exception e) { - logger.warn(format("Failed to notify tasks of queuing service [%s] shutdown", id)); - } - } - - private void rejectTasks(List tasks) { - for (var task : tasks) { - rejectTask(task); - } - } - - private void rejectTask(RejectableTask task) { - try { - task.onRejection( - new EsRejectedExecutionException( - format( - "Failed to send request, queue service for inference entity [%s] has shutdown prior to executing request", - task.getRequestManager().inferenceEntityId() - ), - true - ) - ); - } catch (Exception e) { - logger.warn( - format( - "Failed to notify request for inference endpoint [%s] of rejection after queuing service shutdown", - task.getRequestManager().inferenceEntityId() - ) - ); - } - } - - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceOld.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceOld.java new file mode 100644 index 0000000000000..9068247a054ff --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceOld.java @@ -0,0 +1,317 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ContextPreservingActionListener; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueue; +import org.elasticsearch.xpack.inference.external.http.RequestExecutor; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; + +import static org.elasticsearch.core.Strings.format; + +/** + * A service for queuing and executing {@link RequestTask}. This class is useful because the + * {@link org.apache.http.impl.nio.conn.PoolingNHttpClientConnectionManager} will block when leasing a connection if no + * connections are available. To avoid blocking the inference transport threads, this executor will queue up the + * requests until connections are available. + * + * NOTE: It is the responsibility of the class constructing the + * {@link org.apache.http.client.methods.HttpUriRequest} to set a timeout for how long this executor will wait + * attempting to execute a task (aka waiting for the connection manager to lease a connection). See + * {@link org.apache.http.client.config.RequestConfig.Builder#setConnectionRequestTimeout} for more info. + */ +class RequestExecutorServiceOld implements RequestExecutor { + private static final AdjustableCapacityBlockingQueue.QueueCreator QUEUE_CREATOR = + new AdjustableCapacityBlockingQueue.QueueCreator<>() { + @Override + public BlockingQueue create(int capacity) { + BlockingQueue queue; + if (capacity <= 0) { + queue = create(); + } else { + queue = new LinkedBlockingQueue<>(capacity); + } + + return queue; + } + + @Override + public BlockingQueue create() { + return new LinkedBlockingQueue<>(); + } + }; + + private static final Logger logger = LogManager.getLogger(RequestExecutorServiceOld.class); + private final String serviceName; + private final AdjustableCapacityBlockingQueue queue; + private final AtomicBoolean running = new AtomicBoolean(true); + private final CountDownLatch terminationLatch = new CountDownLatch(1); + private final HttpClientContext httpContext; + private final ThreadPool threadPool; + private final CountDownLatch startupLatch; + private final BlockingQueue controlQueue = new LinkedBlockingQueue<>(); + private final SingleRequestManager requestManager; + + RequestExecutorServiceOld( + String serviceName, + ThreadPool threadPool, + @Nullable CountDownLatch startupLatch, + RequestExecutorServiceSettings settings, + SingleRequestManager requestManager + ) { + this(serviceName, threadPool, QUEUE_CREATOR, startupLatch, settings, requestManager); + } + + /** + * This constructor should only be used directly for testing. + */ + RequestExecutorServiceOld( + String serviceName, + ThreadPool threadPool, + AdjustableCapacityBlockingQueue.QueueCreator createQueue, + @Nullable CountDownLatch startupLatch, + RequestExecutorServiceSettings settings, + SingleRequestManager requestManager + ) { + this.serviceName = Objects.requireNonNull(serviceName); + this.threadPool = Objects.requireNonNull(threadPool); + this.httpContext = HttpClientContext.create(); + this.queue = new AdjustableCapacityBlockingQueue<>(createQueue, settings.getQueueCapacity()); + this.startupLatch = startupLatch; + this.requestManager = Objects.requireNonNull(requestManager); + + Objects.requireNonNull(settings); + settings.registerQueueCapacityCallback(this::onCapacityChange); + } + + private void onCapacityChange(int capacity) { + logger.debug(() -> Strings.format("Setting queue capacity to [%s]", capacity)); + + var enqueuedCapacityCommand = controlQueue.offer(() -> updateCapacity(capacity)); + if (enqueuedCapacityCommand == false) { + logger.warn("Failed to change request batching service queue capacity. Control queue was full, please try again later."); + } else { + // ensure that the task execution loop wakes up + queue.offer(new NoopTask()); + } + } + + private void updateCapacity(int newCapacity) { + try { + queue.setCapacity(newCapacity); + } catch (Exception e) { + logger.warn( + format("Failed to set the capacity of the task queue to [%s] for request batching service [%s]", newCapacity, serviceName), + e + ); + } + } + + /** + * Begin servicing tasks. + */ + public void start() { + try { + signalStartInitiated(); + + while (running.get()) { + handleTasks(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + running.set(false); + notifyRequestsOfShutdown(); + terminationLatch.countDown(); + } + } + + private void signalStartInitiated() { + if (startupLatch != null) { + startupLatch.countDown(); + } + } + + /** + * Protects the task retrieval logic from an unexpected exception. + * + * @throws InterruptedException rethrows the exception if it occurred retrieving a task because the thread is likely attempting to + * shut down + */ + private void handleTasks() throws InterruptedException { + try { + RejectableTask task = queue.take(); + + var command = controlQueue.poll(); + if (command != null) { + command.run(); + } + + // TODO add logic to complete pending items in the queue before shutting down + if (running.get() == false) { + logger.debug(() -> format("Http executor service [%s] exiting", serviceName)); + rejectTaskBecauseOfShutdown(task); + } else { + executeTask(task); + } + } catch (InterruptedException e) { + throw e; + } catch (Exception e) { + logger.warn(format("Http executor service [%s] failed while retrieving task for execution", serviceName), e); + } + } + + private void executeTask(RejectableTask task) { + try { + requestManager.execute(task, httpContext); + } catch (Exception e) { + logger.warn(format("Http executor service [%s] failed to execute request [%s]", serviceName, task), e); + } + } + + private synchronized void notifyRequestsOfShutdown() { + assert isShutdown() : "Requests should only be notified if the executor is shutting down"; + + try { + List notExecuted = new ArrayList<>(); + queue.drainTo(notExecuted); + + rejectTasks(notExecuted, this::rejectTaskBecauseOfShutdown); + } catch (Exception e) { + logger.warn(format("Failed to notify tasks of queuing service [%s] shutdown", serviceName)); + } + } + + private void rejectTaskBecauseOfShutdown(RejectableTask task) { + try { + task.onRejection( + new EsRejectedExecutionException( + format("Failed to send request, queue service [%s] has shutdown prior to executing request", serviceName), + true + ) + ); + } catch (Exception e) { + logger.warn( + format("Failed to notify request [%s] for service [%s] of rejection after queuing service shutdown", task, serviceName) + ); + } + } + + private void rejectTasks(List tasks, Consumer rejectionFunction) { + for (var task : tasks) { + rejectionFunction.accept(task); + } + } + + public int queueSize() { + return queue.size(); + } + + @Override + public void shutdown() { + if (running.compareAndSet(true, false)) { + // if this fails because the queue is full, that's ok, we just want to ensure that queue.take() returns + queue.offer(new NoopTask()); + } + } + + @Override + public boolean isShutdown() { + return running.get() == false; + } + + @Override + public boolean isTerminated() { + return terminationLatch.getCount() == 0; + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return terminationLatch.await(timeout, unit); + } + + /** + * Execute the request at some point in the future. + * + * @param requestCreator the http request to send + * @param inferenceInputs the inputs to send in the request + * @param timeout the maximum time to wait for this request to complete (failing or succeeding). Once the time elapses, the + * listener::onFailure is called with a {@link org.elasticsearch.ElasticsearchTimeoutException}. + * If null, then the request will wait forever + * @param listener an {@link ActionListener} for the response or failure + */ + public void execute( + RequestManager requestCreator, + InferenceInputs inferenceInputs, + @Nullable TimeValue timeout, + ActionListener listener + ) { + var task = new RequestTask( + requestCreator, + inferenceInputs, + timeout, + threadPool, + // TODO when multi-tenancy (as well as batching) is implemented we need to be very careful that we preserve + // the thread contexts correctly to avoid accidentally retrieving the credentials for the wrong user + ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext()) + ); + + completeExecution(task); + } + + private void completeExecution(RequestTask task) { + if (isShutdown()) { + EsRejectedExecutionException rejected = new EsRejectedExecutionException( + format("Failed to enqueue task because the http executor service [%s] has already shutdown", serviceName), + true + ); + + task.onRejection(rejected); + return; + } + + boolean added = queue.offer(task); + if (added == false) { + EsRejectedExecutionException rejected = new EsRejectedExecutionException( + format("Failed to execute task because the http executor service [%s] queue is full", serviceName), + false + ); + + task.onRejection(rejected); + } else if (isShutdown()) { + // It is possible that a shutdown and notification request occurred after we initially checked for shutdown above + // If the task was added after the queue was already drained it could sit there indefinitely. So let's check again if + // we shut down and if so we'll redo the notification + notifyRequestsOfShutdown(); + } + } + + // default for testing + int remainingQueueCapacity() { + return queue.remainingCapacity(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettings.java index 86825035f2d05..febf35488efb4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceSettings.java @@ -10,6 +10,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; import java.util.ArrayList; import java.util.List; @@ -29,11 +30,20 @@ public class RequestExecutorServiceSettings { Setting.Property.Dynamic ); + private static final TimeValue DEFAULT_TASK_POLL_FREQUENCY_TIME = TimeValue.timeValueMillis(50); + static final Setting TASK_POLL_FREQUENCY_SETTING = Setting.timeSetting( + "xpack.inference.http.request_executor.task_poll_frequency", + DEFAULT_TASK_POLL_FREQUENCY_TIME, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + public static List> getSettingsDefinitions() { - return List.of(TASK_QUEUE_CAPACITY_SETTING); + return List.of(TASK_QUEUE_CAPACITY_SETTING, TASK_POLL_FREQUENCY_SETTING); } private volatile int queueCapacity; + private volatile TimeValue taskPollFrequency; private final List> queueCapacityCallbacks = new ArrayList>(); public RequestExecutorServiceSettings(Settings settings, ClusterService clusterService) { @@ -44,6 +54,7 @@ public RequestExecutorServiceSettings(Settings settings, ClusterService clusterS private void addSettingsUpdateConsumers(ClusterService clusterService) { clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_QUEUE_CAPACITY_SETTING, this::setQueueCapacity); + clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_POLL_FREQUENCY_SETTING, this::setTaskPollFrequency); } // default for testing @@ -55,6 +66,10 @@ void setQueueCapacity(int queueCapacity) { } } + private void setTaskPollFrequency(TimeValue taskPollFrequency) { + this.taskPollFrequency = taskPollFrequency; + } + void registerQueueCapacityCallback(Consumer onChangeCapacityCallback) { queueCapacityCallbacks.add(onChangeCapacityCallback); } @@ -62,4 +77,9 @@ void registerQueueCapacityCallback(Consumer onChangeCapacityCallback) { int getQueueCapacity() { return queueCapacity; } + + TimeValue getTaskPollFrequency() { + return taskPollFrequency; + } + } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java index 7d3cca596f1d0..79ef1b56ad231 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.http.sender; -import org.apache.http.client.protocol.HttpClientContext; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; @@ -21,14 +20,17 @@ * A contract for constructing a {@link Runnable} to handle sending an inference request to a 3rd party service. */ public interface RequestManager extends RateLimitable { - Runnable create( + void execute( @Nullable String query, List input, RequestSender requestSender, Supplier hasRequestCompletedFunction, - HttpClientContext context, ActionListener listener ); + // TODO For batching we'll add 2 new method: prepare(query, input, ...) which will allow the individual + // managers to implement their own batching + // executePreparedRequest() which will execute all prepared requests aka sends the batch + String inferenceEntityId(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java index 738592464232c..7a5f482412289 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java @@ -111,7 +111,7 @@ public void onRejection(Exception e) { } @Override - public RequestManager getRequestCreator() { + public RequestManager getRequestManager() { return requestCreator; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManager.java index 494c77964080f..e05a3248a13fe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManager.java @@ -12,6 +12,7 @@ import java.util.Objects; +// TODO remove /** * Handles executing a single inference request at a time. */ @@ -28,20 +29,18 @@ public void execute(InferenceRequest inferenceRequest, HttpClientContext context return; } - inferenceRequest.getRequestCreator() - .create( + inferenceRequest.getRequestManager() + .execute( inferenceRequest.getQuery(), inferenceRequest.getInput(), requestSender, inferenceRequest.getRequestCompletedFunction(), - context, inferenceRequest.getListener() - ) - .run(); + ); } private static boolean isNoopRequest(InferenceRequest inferenceRequest) { - return inferenceRequest.getRequestCreator() == null + return inferenceRequest.getRequestManager() == null || inferenceRequest.getInput() == null || inferenceRequest.getListener() == null; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 24c0ab2cd893e..1c64f505402d8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -31,7 +31,7 @@ public abstract class SenderService implements InferenceService { public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { Objects.requireNonNull(factory); - sender = factory.createSender(name()); + sender = factory.createSender(); this.serviceComponents = Objects.requireNonNull(serviceComponents); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index deb1cfb901602..96b5d2daee305 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -50,6 +50,11 @@ public class CohereService extends SenderService { public static final String NAME = "cohere"; + // TODO Batching - We'll instantiate a batching class within the services that want to support it and pass it through to + // the Cohere*RequestManager via the CohereActionCreator class + // The reason it needs to be done here is that the batching logic needs to hold state but the *RequestManagers are instantiated + // on every request + public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java index 4bdba67beec17..d8a7e78b94911 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java @@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSenderWithSingleRequestManager; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel; @@ -75,7 +76,7 @@ public void shutdown() throws IOException { public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -125,7 +126,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException { public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -181,7 +182,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat ); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -235,7 +236,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); // note - there is no complete documentation on Azure's error messages @@ -311,7 +312,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); // note - there is no complete documentation on Azure's error messages @@ -387,7 +388,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC public void testExecute_TruncatesInputBeforeSending() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java index e8eac1a13b180..de8eecd8186f6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java @@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSenderWithSingleRequestManager; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel; @@ -81,7 +82,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { mockClusterServiceEmpty() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java index 73b627742ab03..6500a7bf7f95f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java @@ -41,6 +41,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSenderWithSingleRequestManager; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.equalTo; @@ -71,7 +72,7 @@ public void shutdown() throws IOException { public void testCreate_CohereEmbeddingsModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java index 06cae11bc8d5d..ba04715f7c39f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java @@ -80,7 +80,7 @@ public void shutdown() throws IOException { public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { + try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -161,7 +161,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { + try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java index 3fc4e0ab390ae..bdd635467dc98 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java @@ -42,6 +42,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSenderWithSingleRequestManager; import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.contains; @@ -75,7 +76,7 @@ public void shutdown() throws IOException { public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -131,7 +132,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx ); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -187,7 +188,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -239,7 +240,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws ); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); // this will fail because the only valid formats are {"embeddings": [[...]]} or [[...]] @@ -292,7 +293,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJsonContentTooLarge = """ @@ -357,7 +358,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc public void testExecute_TruncatesInputBeforeSending() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java index 98eff32f72983..7bdc277790db6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java @@ -39,6 +39,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSenderWithSingleRequestManager; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -74,7 +75,7 @@ public void shutdown() throws IOException { public void testCreate_OpenAiEmbeddingsModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -127,7 +128,7 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException { public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -179,7 +180,7 @@ public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException { public void testCreate_OpenAiEmbeddingsModel_WithoutOrganization() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -238,7 +239,7 @@ public void testCreate_OpenAiEmbeddingsModel_FailsFromInvalidResponseFormat() th ); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -292,7 +293,7 @@ public void testCreate_OpenAiEmbeddingsModel_FailsFromInvalidResponseFormat() th public void testCreate_OpenAiChatCompletionModel() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -355,7 +356,7 @@ public void testCreate_OpenAiChatCompletionModel() throws IOException { public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -417,7 +418,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOExceptio public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -486,7 +487,7 @@ public void testCreate_OpenAiChatCompletionModel_FailsFromInvalidResponseFormat( ); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -552,7 +553,7 @@ public void testCreate_OpenAiChatCompletionModel_FailsFromInvalidResponseFormat( public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); var contentTooLargeErrorMessage = @@ -635,7 +636,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); var contentTooLargeErrorMessage = @@ -718,7 +719,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC public void testExecute_TruncatesInputBeforeSending() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java index b802403dcd28d..4ceac8762ed55 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java @@ -45,6 +45,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSenderWithSingleRequestManager; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; @@ -80,7 +81,7 @@ public void shutdown() throws IOException { public void testExecute_ReturnsSuccessfulResponse() throws IOException { var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -234,7 +235,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java index 45c1fa276c69a..9c73af1610c1b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java @@ -78,7 +78,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { mockClusterServiceEmpty() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = senderFactory.createSender()) { sender.start(); String responseJson = """ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 395c046413504..18b0d8b94ef41 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -79,7 +79,7 @@ public void shutdown() throws IOException, InterruptedException { public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception { var senderFactory = createSenderFactory(clientManager, threadRef); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = createSenderWithSingleRequestManager(senderFactory)) { sender.start(); String responseJson = """ @@ -135,11 +135,11 @@ public void testHttpRequestSender_Throws_WhenCallingSendBeforeStart() throws Exc mockClusterServiceEmpty() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = senderFactory.createSender()) { PlainActionFuture listener = new PlainActionFuture<>(); var thrownException = expectThrows( AssertionError.class, - () -> sender.send(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener) + () -> sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener) ); assertThat(thrownException.getMessage(), is("call start() before sending a request")); } @@ -155,17 +155,12 @@ public void testHttpRequestSender_Throws_WhenATimeoutOccurs() throws Exception { mockClusterServiceEmpty() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = senderFactory.createSender()) { assertThat(sender, instanceOf(HttpRequestSender.class)); sender.start(); PlainActionFuture listener = new PlainActionFuture<>(); - sender.send( - ExecutableRequestCreatorTests.createMock(), - new DocumentsOnlyInput(List.of()), - TimeValue.timeValueNanos(1), - listener - ); + sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener); var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT)); @@ -186,16 +181,11 @@ public void testHttpRequestSenderWithTimeout_Throws_WhenATimeoutOccurs() throws mockClusterServiceEmpty() ); - try (var sender = senderFactory.createSender("test_service")) { + try (var sender = senderFactory.createSender()) { sender.start(); PlainActionFuture listener = new PlainActionFuture<>(); - sender.send( - ExecutableRequestCreatorTests.createMock(), - new DocumentsOnlyInput(List.of()), - TimeValue.timeValueNanos(1), - listener - ); + sender.send(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener); var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT)); @@ -248,7 +238,7 @@ public static HttpRequestSender.Factory createSenderFactory( ); } - public static Sender createSenderWithSingleRequestManager(HttpRequestSender.Factory factory, String serviceName) { - return factory.createSender(serviceName); + public static Sender createSenderWithSingleRequestManager(HttpRequestSender.Factory factory) { + return factory.createSender(); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceOldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceOldTests.java new file mode 100644 index 0000000000000..72ae0120e5e00 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceOldTests.java @@ -0,0 +1,541 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchTimeoutException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender; +import org.junit.After; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueueTests.mockQueueCreator; +import static org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettingsTests.createRequestExecutorServiceSettings; +import static org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettingsTests.createRequestExecutorServiceSettingsEmpty; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class RequestExecutorServiceOldTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private ThreadPool threadPool; + + @Before + public void init() { + threadPool = createThreadPool(inferenceUtilityPool()); + } + + @After + public void shutdown() { + terminate(threadPool); + } + + public void testQueueSize_IsEmpty() { + var service = createRequestExecutorServiceWithMocks(); + + assertThat(service.queueSize(), is(0)); + } + + public void testQueueSize_IsOne() { + var service = createRequestExecutorServiceWithMocks(); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); + + assertThat(service.queueSize(), is(1)); + } + + public void testIsTerminated_IsFalse() { + var service = createRequestExecutorServiceWithMocks(); + + assertFalse(service.isTerminated()); + } + + public void testIsTerminated_IsTrue() throws InterruptedException { + var latch = new CountDownLatch(1); + var service = createRequestExecutorService(latch, mock(RetryingHttpSender.class)); + + service.shutdown(); + service.start(); + latch.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + + assertTrue(service.isTerminated()); + } + + public void testIsTerminated_AfterStopFromSeparateThread() throws Exception { + var waitToShutdown = new CountDownLatch(1); + var waitToReturnFromSend = new CountDownLatch(1); + + var requestSender = mock(RetryingHttpSender.class); + doAnswer(invocation -> { + waitToShutdown.countDown(); + waitToReturnFromSend.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + return Void.TYPE; + }).when(requestSender).send(any(), any(), any(), any(), any(), any()); + + var service = createRequestExecutorService(null, requestSender); + + Future executorTermination = submitShutdownRequest(waitToShutdown, waitToReturnFromSend, service); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.execute( + OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator("url", null, "key", "id", null, threadPool), + new DocumentsOnlyInput(List.of()), + null, + listener + ); + + service.start(); + + try { + executorTermination.get(1, TimeUnit.SECONDS); + } catch (Exception e) { + fail(Strings.format("Executor finished before it was signaled to shutdown: %s", e)); + } + + assertTrue(service.isShutdown()); + assertTrue(service.isTerminated()); + } + + public void testSend_AfterShutdown_Throws() { + var service = createRequestExecutorServiceWithMocks(); + + service.shutdown(); + + var listener = new PlainActionFuture(); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); + + var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + thrownException.getMessage(), + is("Failed to enqueue task because the http executor service [test_service] has already shutdown") + ); + assertTrue(thrownException.isExecutorShutdown()); + } + + public void testSend_Throws_WhenQueueIsFull() { + var service = new RequestExecutorServiceOld( + "test_service", + threadPool, + null, + createRequestExecutorServiceSettings(1), + new SingleRequestManager(mock(RetryingHttpSender.class)) + ); + + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); + var listener = new PlainActionFuture(); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); + + var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + thrownException.getMessage(), + is("Failed to execute task because the http executor service [test_service] queue is full") + ); + assertFalse(thrownException.isExecutorShutdown()); + } + + public void testTaskThrowsError_CallsOnFailure() { + var requestSender = mock(RetryingHttpSender.class); + + var service = createRequestExecutorService(null, requestSender); + + doAnswer(invocation -> { + service.shutdown(); + throw new IllegalArgumentException("failed"); + }).when(requestSender).send(any(), any(), any(), any(), any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + service.execute( + OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator("url", null, "key", "id", null, threadPool), + new DocumentsOnlyInput(List.of()), + null, + listener + ); + service.start(); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is(format("Failed to send request from inference entity id [%s]", "id"))); + assertThat(thrownException.getCause(), instanceOf(IllegalArgumentException.class)); + assertTrue(service.isTerminated()); + } + + public void testShutdown_AllowsMultipleCalls() { + var service = createRequestExecutorServiceWithMocks(); + + service.shutdown(); + service.shutdown(); + service.start(); + + assertTrue(service.isTerminated()); + assertTrue(service.isShutdown()); + } + + public void testSend_CallsOnFailure_WhenRequestTimesOut() { + var service = createRequestExecutorServiceWithMocks(); + + var listener = new PlainActionFuture(); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener); + + var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + thrownException.getMessage(), + is(format("Request timed out waiting to be sent after [%s]", TimeValue.timeValueNanos(1))) + ); + } + + public void testSend_PreservesThreadContext() throws InterruptedException, ExecutionException, TimeoutException { + var headerKey = "not empty"; + var headerValue = "value"; + + var service = createRequestExecutorServiceWithMocks(); + + // starting this on a separate thread to ensure we aren't using the same thread context that the rest of the test will execute with + threadPool.generic().execute(service::start); + + ThreadContext threadContext = threadPool.getThreadContext(); + threadContext.putHeader(headerKey, headerValue); + + var requestSender = mock(RetryingHttpSender.class); + + var waitToShutdown = new CountDownLatch(1); + var waitToReturnFromSend = new CountDownLatch(1); + + // this code will be executed by the queue's thread + doAnswer(invocation -> { + var serviceThreadContext = threadPool.getThreadContext(); + // ensure that the spawned thread didn't pick up the header that was set initially on a separate thread + assertNull(serviceThreadContext.getHeader(headerKey)); + + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[5]; + listener.onResponse(null); + + waitToShutdown.countDown(); + waitToReturnFromSend.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + return Void.TYPE; + }).when(requestSender).send(any(), any(), any(), any(), any(), any()); + + var finishedOnResponse = new CountDownLatch(1); + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(InferenceServiceResults ignore) { + // if we've preserved the thread context correctly then the header should still exist + ThreadContext listenerContext = threadPool.getThreadContext(); + assertThat(listenerContext.getHeader(headerKey), is(headerValue)); + finishedOnResponse.countDown(); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("onFailure shouldn't be called", e); + } + }; + + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); + + Future executorTermination = submitShutdownRequest(waitToShutdown, waitToReturnFromSend, service); + + executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); + assertTrue(service.isTerminated()); + + finishedOnResponse.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + } + + public void testSend_NotifiesTasksOfShutdown() { + var service = createRequestExecutorServiceWithMocks(); + + var listener = new PlainActionFuture(); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); + + service.shutdown(); + service.start(); + + var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + thrownException.getMessage(), + is("Failed to send request, queue service [test_service] has shutdown prior to executing request") + ); + assertTrue(thrownException.isExecutorShutdown()); + assertTrue(service.isTerminated()); + } + + public void testQueueTake_DoesNotCauseServiceToTerminate_WhenItThrows() throws InterruptedException { + @SuppressWarnings("unchecked") + BlockingQueue queue = mock(LinkedBlockingQueue.class); + + var service = new RequestExecutorServiceOld( + getTestName(), + threadPool, + mockQueueCreator(queue), + null, + createRequestExecutorServiceSettingsEmpty(), + new SingleRequestManager(mock(RetryingHttpSender.class)) + ); + + when(queue.take()).thenThrow(new ElasticsearchException("failed")).thenAnswer(invocation -> { + service.shutdown(); + return null; + }); + service.start(); + + assertTrue(service.isTerminated()); + verify(queue, times(2)).take(); + } + + public void testQueueTake_ThrowingInterruptedException_TerminatesService() throws Exception { + @SuppressWarnings("unchecked") + BlockingQueue queue = mock(LinkedBlockingQueue.class); + when(queue.take()).thenThrow(new InterruptedException("failed")); + + var service = new RequestExecutorServiceOld( + getTestName(), + threadPool, + mockQueueCreator(queue), + null, + createRequestExecutorServiceSettingsEmpty(), + new SingleRequestManager(mock(RetryingHttpSender.class)) + ); + + Future executorTermination = threadPool.generic().submit(() -> { + try { + service.start(); + } catch (Exception e) { + fail(Strings.format("Failed to shutdown executor: %s", e)); + } + }); + + executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); + + assertTrue(service.isTerminated()); + verify(queue, times(1)).take(); + } + + public void testQueueTake_RejectsTask_WhenServiceShutsDown() throws Exception { + var mockTask = mock(RejectableTask.class); + @SuppressWarnings("unchecked") + BlockingQueue queue = mock(LinkedBlockingQueue.class); + + var service = new RequestExecutorServiceOld( + "test_service", + threadPool, + mockQueueCreator(queue), + null, + createRequestExecutorServiceSettingsEmpty(), + new SingleRequestManager(mock(RetryingHttpSender.class)) + ); + + doAnswer(invocation -> { + service.shutdown(); + return mockTask; + }).doReturn(new NoopTask()).when(queue).take(); + + service.start(); + + assertTrue(service.isTerminated()); + verify(queue, times(1)).take(); + + ArgumentCaptor argument = ArgumentCaptor.forClass(Exception.class); + verify(mockTask, times(1)).onRejection(argument.capture()); + assertThat(argument.getValue(), instanceOf(EsRejectedExecutionException.class)); + assertThat( + argument.getValue().getMessage(), + is("Failed to send request, queue service [test_service] has shutdown prior to executing request") + ); + + var rejectionException = (EsRejectedExecutionException) argument.getValue(); + assertTrue(rejectionException.isExecutorShutdown()); + } + + public void testChangingCapacity_SetsCapacityToTwo() throws ExecutionException, InterruptedException, TimeoutException { + var requestSender = mock(RetryingHttpSender.class); + + var settings = createRequestExecutorServiceSettings(1); + var service = new RequestExecutorServiceOld("test_service", threadPool, null, settings, new SingleRequestManager(requestSender)); + + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); + assertThat(service.queueSize(), is(1)); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); + + var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to execute task because the http executor service [test_service] queue is full") + ); + + settings.setQueueCapacity(2); + + var waitToShutdown = new CountDownLatch(1); + var waitToReturnFromSend = new CountDownLatch(1); + // There is a request already queued, and its execution path will initiate shutting down the service + doAnswer(invocation -> { + waitToShutdown.countDown(); + waitToReturnFromSend.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + return Void.TYPE; + }).when(requestSender).send(any(), any(), any(), any(), any(), any()); + + Future executorTermination = submitShutdownRequest(waitToShutdown, waitToReturnFromSend, service); + + service.start(); + + executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); + assertTrue(service.isTerminated()); + assertThat(service.remainingQueueCapacity(), is(2)); + } + + public void testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull() throws ExecutionException, InterruptedException, + TimeoutException { + var requestSender = mock(RetryingHttpSender.class); + + var settings = createRequestExecutorServiceSettings(3); + var service = new RequestExecutorServiceOld("test_service", threadPool, null, settings, new SingleRequestManager(requestSender)); + + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); + assertThat(service.queueSize(), is(3)); + + settings.setQueueCapacity(1); + + var waitToShutdown = new CountDownLatch(1); + var waitToReturnFromSend = new CountDownLatch(1); + // There is a request already queued, and its execution path will initiate shutting down the service + doAnswer(invocation -> { + waitToShutdown.countDown(); + waitToReturnFromSend.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + return Void.TYPE; + }).when(requestSender).send(any(), any(), any(), any(), any(), any()); + + Future executorTermination = submitShutdownRequest(waitToShutdown, waitToReturnFromSend, service); + + service.start(); + + executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); + assertTrue(service.isTerminated()); + assertThat(service.remainingQueueCapacity(), is(1)); + assertThat(service.queueSize(), is(0)); + + var thrownException = expectThrows( + EsRejectedExecutionException.class, + () -> listener.actionGet(TIMEOUT.getSeconds(), TimeUnit.SECONDS) + ); + assertThat( + thrownException.getMessage(), + is("Failed to send request, queue service [test_service] has shutdown prior to executing request") + ); + assertTrue(thrownException.isExecutorShutdown()); + } + + public void testChangingCapacity_ToZero_SetsQueueCapacityToUnbounded() throws IOException, ExecutionException, InterruptedException, + TimeoutException { + var requestSender = mock(RetryingHttpSender.class); + + var settings = createRequestExecutorServiceSettings(1); + var service = new RequestExecutorServiceOld("test_service", threadPool, null, settings, new SingleRequestManager(requestSender)); + + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); + assertThat(service.queueSize(), is(1)); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); + + var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to execute task because the http executor service [test_service] queue is full") + ); + + settings.setQueueCapacity(0); + + var waitToShutdown = new CountDownLatch(1); + var waitToReturnFromSend = new CountDownLatch(1); + // There is a request already queued, and its execution path will initiate shutting down the service + doAnswer(invocation -> { + waitToShutdown.countDown(); + waitToReturnFromSend.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + return Void.TYPE; + }).when(requestSender).send(any(), any(), any(), any(), any(), any()); + + Future executorTermination = submitShutdownRequest(waitToShutdown, waitToReturnFromSend, service); + + service.start(); + + executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); + assertTrue(service.isTerminated()); + assertThat(service.remainingQueueCapacity(), is(Integer.MAX_VALUE)); + } + + private Future submitShutdownRequest( + CountDownLatch waitToShutdown, + CountDownLatch waitToReturnFromSend, + RequestExecutorServiceOld service + ) { + return threadPool.generic().submit(() -> { + try { + // wait for a task to be added to be executed before beginning shutdown + waitToShutdown.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + service.shutdown(); + // tells send to return + waitToReturnFromSend.countDown(); + service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + } catch (Exception e) { + fail(Strings.format("Failed to shutdown executor: %s", e)); + } + }); + } + + private RequestExecutorServiceOld createRequestExecutorServiceWithMocks() { + return createRequestExecutorService(null, mock(RetryingHttpSender.class)); + } + + private RequestExecutorServiceOld createRequestExecutorService( + @Nullable CountDownLatch startupLatch, + RetryingHttpSender requestSender + ) { + return new RequestExecutorServiceOld( + "test_service", + threadPool, + startupLatch, + createRequestExecutorServiceSettingsEmpty(), + new SingleRequestManager(requestSender) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java index ff88ba221d985..4d6068b0bc9be 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java @@ -70,7 +70,7 @@ public void testQueueSize_IsEmpty() { public void testQueueSize_IsOne() { var service = createRequestExecutorServiceWithMocks(); - service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); assertThat(service.queueSize(), is(1)); } @@ -133,7 +133,7 @@ public void testSend_AfterShutdown_Throws() { service.shutdown(); var listener = new PlainActionFuture(); - service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); @@ -145,7 +145,7 @@ public void testSend_AfterShutdown_Throws() { } public void testSend_Throws_WhenQueueIsFull() { - var service = new RequestExecutorService( + var service = new RequestExecutorServiceOld( "test_service", threadPool, null, @@ -153,9 +153,9 @@ public void testSend_Throws_WhenQueueIsFull() { new SingleRequestManager(mock(RetryingHttpSender.class)) ); - service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); var listener = new PlainActionFuture(); - service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); @@ -207,12 +207,7 @@ public void testSend_CallsOnFailure_WhenRequestTimesOut() { var service = createRequestExecutorServiceWithMocks(); var listener = new PlainActionFuture(); - service.execute( - ExecutableRequestCreatorTests.createMock(), - new DocumentsOnlyInput(List.of()), - TimeValue.timeValueNanos(1), - listener - ); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), TimeValue.timeValueNanos(1), listener); var thrownException = expectThrows(ElasticsearchTimeoutException.class, () -> listener.actionGet(TIMEOUT)); @@ -270,7 +265,7 @@ public void onFailure(Exception e) { } }; - service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); Future executorTermination = submitShutdownRequest(waitToShutdown, waitToReturnFromSend, service); @@ -284,7 +279,7 @@ public void testSend_NotifiesTasksOfShutdown() { var service = createRequestExecutorServiceWithMocks(); var listener = new PlainActionFuture(); - service.execute(ExecutableRequestCreatorTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); + service.execute(RequestManagerTests.createMock(), new DocumentsOnlyInput(List.of()), null, listener); service.shutdown(); service.start(); @@ -303,7 +298,7 @@ public void testQueueTake_DoesNotCauseServiceToTerminate_WhenItThrows() throws I @SuppressWarnings("unchecked") BlockingQueue queue = mock(LinkedBlockingQueue.class); - var service = new RequestExecutorService( + var service = new RequestExecutorServiceOld( getTestName(), threadPool, mockQueueCreator(queue), @@ -327,7 +322,7 @@ public void testQueueTake_ThrowingInterruptedException_TerminatesService() throw BlockingQueue queue = mock(LinkedBlockingQueue.class); when(queue.take()).thenThrow(new InterruptedException("failed")); - var service = new RequestExecutorService( + var service = new RequestExecutorServiceOld( getTestName(), threadPool, mockQueueCreator(queue), @@ -355,7 +350,7 @@ public void testQueueTake_RejectsTask_WhenServiceShutsDown() throws Exception { @SuppressWarnings("unchecked") BlockingQueue queue = mock(LinkedBlockingQueue.class); - var service = new RequestExecutorService( + var service = new RequestExecutorServiceOld( "test_service", threadPool, mockQueueCreator(queue), @@ -390,30 +385,24 @@ public void testChangingCapacity_SetsCapacityToTwo() throws ExecutionException, var requestSender = mock(RetryingHttpSender.class); var settings = createRequestExecutorServiceSettings(1); - var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender)); + var service = new RequestExecutorService(threadPool, null, settings, requestSender); - service.execute( - ExecutableRequestCreatorTests.createMock(requestSender), - new DocumentsOnlyInput(List.of()), - null, - new PlainActionFuture<>() - ); + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); assertThat(service.queueSize(), is(1)); PlainActionFuture listener = new PlainActionFuture<>(); - service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); + var requestManager = RequestManagerTests.createMock(requestSender, "id"); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); - assertThat( - thrownException.getMessage(), - is("Failed to execute task because the http executor service [test_service] queue is full") - ); + assertThat(thrownException.getMessage(), is("Failed to execute task because the executor service [id] queue is full")); settings.setQueueCapacity(2); var waitToShutdown = new CountDownLatch(1); var waitToReturnFromSend = new CountDownLatch(1); // There is a request already queued, and its execution path will initiate shutting down the service + // TODO I think we need this to do the listener.onResponse? doAnswer(invocation -> { waitToShutdown.countDown(); waitToReturnFromSend.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); @@ -426,7 +415,7 @@ public void testChangingCapacity_SetsCapacityToTwo() throws ExecutionException, executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); assertTrue(service.isTerminated()); - assertThat(service.remainingQueueCapacity(), is(2)); + assertThat(service.remainingQueueCapacity(requestManager), is(2)); } public void testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull() throws ExecutionException, InterruptedException, @@ -434,23 +423,14 @@ public void testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull( var requestSender = mock(RetryingHttpSender.class); var settings = createRequestExecutorServiceSettings(3); - var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender)); + var service = new RequestExecutorService(threadPool, null, settings, requestSender); - service.execute( - ExecutableRequestCreatorTests.createMock(requestSender), - new DocumentsOnlyInput(List.of()), - null, - new PlainActionFuture<>() - ); - service.execute( - ExecutableRequestCreatorTests.createMock(requestSender), - new DocumentsOnlyInput(List.of()), - null, - new PlainActionFuture<>() - ); + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); PlainActionFuture listener = new PlainActionFuture<>(); - service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); + var requestManager = RequestManagerTests.createMock(requestSender); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, listener); assertThat(service.queueSize(), is(3)); settings.setQueueCapacity(1); @@ -470,7 +450,7 @@ public void testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull( executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); assertTrue(service.isTerminated()); - assertThat(service.remainingQueueCapacity(), is(1)); + assertThat(service.remainingQueueCapacity(requestManager), is(1)); assertThat(service.queueSize(), is(0)); var thrownException = expectThrows( @@ -489,18 +469,14 @@ public void testChangingCapacity_ToZero_SetsQueueCapacityToUnbounded() throws IO var requestSender = mock(RetryingHttpSender.class); var settings = createRequestExecutorServiceSettings(1); - var service = new RequestExecutorService("test_service", threadPool, null, settings, new SingleRequestManager(requestSender)); + var service = new RequestExecutorService(threadPool, null, settings, requestSender); + var requestManager = RequestManagerTests.createMock(requestSender); - service.execute( - ExecutableRequestCreatorTests.createMock(requestSender), - new DocumentsOnlyInput(List.of()), - null, - new PlainActionFuture<>() - ); + service.execute(requestManager, new DocumentsOnlyInput(List.of()), null, new PlainActionFuture<>()); assertThat(service.queueSize(), is(1)); PlainActionFuture listener = new PlainActionFuture<>(); - service.execute(ExecutableRequestCreatorTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); + service.execute(RequestManagerTests.createMock(requestSender), new DocumentsOnlyInput(List.of()), null, listener); var thrownException = expectThrows(EsRejectedExecutionException.class, () -> listener.actionGet(TIMEOUT)); assertThat( @@ -525,7 +501,7 @@ public void testChangingCapacity_ToZero_SetsQueueCapacityToUnbounded() throws IO executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); assertTrue(service.isTerminated()); - assertThat(service.remainingQueueCapacity(), is(Integer.MAX_VALUE)); + assertThat(service.remainingQueueCapacity(requestManager), is(Integer.MAX_VALUE)); } private Future submitShutdownRequest( @@ -552,12 +528,6 @@ private RequestExecutorService createRequestExecutorServiceWithMocks() { } private RequestExecutorService createRequestExecutorService(@Nullable CountDownLatch startupLatch, RetryingHttpSender requestSender) { - return new RequestExecutorService( - "test_service", - threadPool, - startupLatch, - createRequestExecutorServiceSettingsEmpty(), - new SingleRequestManager(requestSender) - ); + return new RequestExecutorService(threadPool, startupLatch, createRequestExecutorServiceSettingsEmpty(), requestSender); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableRequestCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java similarity index 57% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableRequestCreatorTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java index 31297ed432ef5..70504d2f51b29 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/ExecutableRequestCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java @@ -10,10 +10,12 @@ import org.apache.http.client.protocol.HttpClientContext; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.request.RequestTests; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyList; @@ -21,34 +23,43 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class ExecutableRequestCreatorTests { +public class RequestManagerTests { public static RequestManager createMock() { - var mockCreator = mock(RequestManager.class); - when(mockCreator.create(any(), anyList(), any(), any(), any(), any())).thenReturn(() -> {}); - - return mockCreator; + return mock(RequestManager.class); } public static RequestManager createMock(RequestSender requestSender) { - return createMock(requestSender, "id"); + return createMock(requestSender, "id", new RateLimitSettings(TimeValue.timeValueSeconds(1))); + } + + public static RequestManager createMock(RequestSender requestSender, String inferenceEntityId) { + return createMock(requestSender, inferenceEntityId, new RateLimitSettings(TimeValue.timeValueSeconds(1))); } - public static RequestManager createMock(RequestSender requestSender, String modelId) { - var mockCreator = mock(RequestManager.class); + public static RequestManager createMock(RequestSender requestSender, String inferenceEntityId, RateLimitSettings settings) { + var mockManager = mock(RequestManager.class); doAnswer(invocation -> { @SuppressWarnings("unchecked") ActionListener listener = (ActionListener) invocation.getArguments()[5]; - return (Runnable) () -> requestSender.send( + requestSender.send( mock(Logger.class), - RequestTests.mockRequest(modelId), + RequestTests.mockRequest(inferenceEntityId), HttpClientContext.create(), () -> false, mock(ResponseHandler.class), listener ); - }).when(mockCreator).create(any(), anyList(), any(), any(), any(), any()); - return mockCreator; + return Void.TYPE; + }).when(mockManager).execute(any(), anyList(), any(), any(), any()); + + // just return something consistent so the hashing works + when(mockManager.rateLimitGrouping()).thenReturn(inferenceEntityId); + + when(mockManager.rateLimitSettings()).thenReturn(settings); + when(mockManager.inferenceEntityId()).thenReturn(inferenceEntityId); + + return mockManager; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManagerTests.java index 55965bc2354d3..1fab17ea86465 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/SingleRequestManagerTests.java @@ -19,7 +19,7 @@ public class SingleRequestManagerTests extends ESTestCase { public void testExecute_DoesNotCallRequestCreatorCreate_WhenInputIsNull() { var requestCreator = mock(RequestManager.class); var request = mock(InferenceRequest.class); - when(request.getRequestCreator()).thenReturn(requestCreator); + when(request.getRequestManager()).thenReturn(requestCreator); new SingleRequestManager(mock(RetryingHttpSender.class)).execute(mock(InferenceRequest.class), HttpClientContext.create()); verifyNoInteractions(requestCreator); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 672f186b37ceb..26d6803715fbc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -33,7 +33,6 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -59,7 +58,7 @@ public void testStart_InitializesTheSender() throws IOException { var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { PlainActionFuture listener = new PlainActionFuture<>(); @@ -67,7 +66,7 @@ public void testStart_InitializesTheSender() throws IOException { listener.actionGet(TIMEOUT); verify(sender, times(1)).start(); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); } verify(sender, times(1)).close(); @@ -79,7 +78,7 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); try (var service = new TestSenderService(factory, createWithEmptySettings(threadPool))) { PlainActionFuture listener = new PlainActionFuture<>(); @@ -89,7 +88,7 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep service.start(mock(Model.class), listener); listener.actionGet(TIMEOUT); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(2)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 4e65d987a26ad..b3a18742201c8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -73,7 +73,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -594,7 +593,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -616,7 +615,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index e75dfc4ec798e..27af0ce6cb06c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -72,7 +72,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -612,7 +611,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -634,7 +633,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java index 73c013af7b117..ab53573a1c729 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java @@ -33,7 +33,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel; import static org.hamcrest.CoreMatchers.is; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -59,7 +58,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotHuggingFaceModel() throws IOExcep var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -81,7 +80,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotHuggingFaceModel() throws IOExcep is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 70d7181106810..3f840a504c689 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -71,7 +71,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -674,7 +673,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender(anyString())).thenReturn(sender); + when(factory.createSender()).thenReturn(sender); var mockModel = getInvalidModel("model_id", "service_name"); @@ -696,7 +695,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") ); - verify(factory, times(1)).createSender(anyString()); + verify(factory, times(1)).createSender(); verify(sender, times(1)).start(); }