Skip to content

Commit

Permalink
Adding in queuing logic
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-buttner committed Apr 17, 2024
1 parent 2291514 commit ba719d5
Show file tree
Hide file tree
Showing 40 changed files with 1,335 additions and 791 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ public int size() {
return currentQueue.size() + prioritizedReadingQueue.size();
}

public E poll2() {
return null;
}

/**
* The number of additional elements that his queue can accept without blocking.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public class CohereActionCreator implements CohereActionVisitor {
private final ServiceComponents serviceComponents;

public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) {
// TODO Batching - accept a class that can handle batching
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model, Thread
model.getServiceSettings().getCommonSettings().uri(),
"Cohere embeddings"
);
// TODO - Batching pass the batching class on to the CohereEmbeddingsRequestManager
requestCreator = CohereEmbeddingsRequestManager.of(model, threadPool);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -55,16 +54,15 @@ public AzureOpenAiEmbeddingsRequestManager(AzureOpenAiEmbeddingsModel model, Tru
}

@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
AzureOpenAiEmbeddingsRequest request = new AzureOpenAiEmbeddingsRequest(truncator, truncatedInput, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -44,16 +43,15 @@ private CohereEmbeddingsRequestManager(CohereEmbeddingsModel model, ThreadPool t
}

@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(input, model);

return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -44,16 +43,15 @@ private CohereRerankRequestManager(CohereRerankModel model, ThreadPool threadPoo
}

@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
CohereRerankRequest request = new CohereRerankRequest(query, input, model);

return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ record ExecutableInferenceRequest(
RequestSender requestSender,
Logger logger,
Request request,
HttpClientContext context,
ResponseHandler responseHandler,
Supplier<Boolean> hasFinished,
ActionListener<InferenceServiceResults> listener
Expand All @@ -34,7 +33,7 @@ public void run() {
var inferenceEntityId = request.createHttpRequest().inferenceEntityId();

try {
requestSender.send(logger, request, context, hasFinished, responseHandler, listener);
requestSender.send(logger, request, HttpClientContext.create(), hasFinished, responseHandler, listener);
} catch (Exception e) {
var errorMessage = Strings.format("Failed to send request from inference entity id [%s]", inferenceEntityId);
logger.warn(errorMessage, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.RequestExecutor;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
Expand All @@ -39,30 +41,28 @@ public static class Factory {
private final ServiceComponents serviceComponents;
private final HttpClientManager httpClientManager;
private final ClusterService clusterService;
private final SingleRequestManager requestManager;
private final RequestSender requestSender;

public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) {
this.serviceComponents = Objects.requireNonNull(serviceComponents);
this.httpClientManager = Objects.requireNonNull(httpClientManager);
this.clusterService = Objects.requireNonNull(clusterService);

var requestSender = new RetryingHttpSender(
requestSender = new RetryingHttpSender(
this.httpClientManager.getHttpClient(),
serviceComponents.throttlerManager(),
new RetrySettings(serviceComponents.settings(), clusterService),
serviceComponents.threadPool()
);
requestManager = new SingleRequestManager(requestSender);
}

public Sender createSender(String serviceName) {
public Sender createSender() {
return new HttpRequestSender(
serviceName,
serviceComponents.threadPool(),
httpClientManager,
clusterService,
serviceComponents.settings(),
requestManager
requestSender
);
}
}
Expand All @@ -71,26 +71,24 @@ public Sender createSender(String serviceName) {

private final ThreadPool threadPool;
private final HttpClientManager manager;
private final RequestExecutorService service;
private final RequestExecutor service;
private final AtomicBoolean started = new AtomicBoolean(false);
private final CountDownLatch startCompleted = new CountDownLatch(2);

private HttpRequestSender(
String serviceName,
ThreadPool threadPool,
HttpClientManager httpClientManager,
ClusterService clusterService,
Settings settings,
SingleRequestManager requestManager
RequestSender requestSender
) {
this.threadPool = Objects.requireNonNull(threadPool);
this.manager = Objects.requireNonNull(httpClientManager);
service = new RequestExecutorService(
serviceName,
threadPool,
startCompleted,
new RequestExecutorServiceSettings(settings, clusterService),
requestManager
requestSender
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -55,26 +54,17 @@ private HuggingFaceRequestManager(HuggingFaceModel model, ResponseHandler respon
}

@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
var truncatedInput = truncate(input, model.getTokenLimit());
var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model);

return new ExecutableInferenceRequest(
requestSender,
logger,
request,
context,
responseHandler,
hasRequestCompletedFunction,
listener
);
execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
}

record RateLimitGrouping(int accountHash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
public interface InferenceRequest {

/**
* Returns the creator that handles building an executable request based on the input provided.
* Returns the manager that handles building and executing an inference request.
*/
RequestManager getRequestCreator();
RequestManager getRequestManager();

/**
* Returns the query associated with this request. Used for Rerank tasks.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class NoopTask implements RejectableTask {

@Override
public RequestManager getRequestCreator() {
public RequestManager getRequestManager() {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -43,17 +42,16 @@ private OpenAiCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPo
}

@Override
public Runnable create(
public void execute(
@Nullable String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(input, model);

return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}

private static ResponseHandler createCompletionHandler() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
Expand Down Expand Up @@ -55,17 +54,16 @@ private OpenAiEmbeddingsRequestManager(OpenAiEmbeddingsModel model, Truncator tr
}

@Override
public Runnable create(
public void execute(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
OpenAiEmbeddingsRequest request = new OpenAiEmbeddingsRequest(truncator, truncatedInput, model);

return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Loading

0 comments on commit ba719d5

Please sign in to comment.