Skip to content

Commit

Permalink
EIS integration (elastic#111154)
Browse files Browse the repository at this point in the history
* WIP

* Add ElasticInferenceServiceTests TODOs

* Add ElasticInferenceServiceActionCreatorTests TODOs

* Add ElasticInferenceServiceResponseHandlerTests TODOs

* Add ElasticInferenceServiceSparseEmbeddingsRequestTests TODOs

* Add ElasticInferenceServiceSparseEmbeddingsModelTests TODOs

* spotless apply

* Fix conflicts

* Add EmptySecretSettingsTests

* Add named writeables to InferenceNamedWriteablesProvider

* Remove addressed todos

* Translate model to correct endpoint

* Remove addressed TODO

* Add docs to ElasticInferenceServiceFeature

* Implement and test truncation/request

* Add some EIS tests

* Support chunked inference

* Check model config

* Add more tests

* Add response handler

* Add more tests + HTTP 413 handling

* Fix some tests

* Spotless

* Fixes

* Switch back to original response structure

* Implement pass-through chunking

* Spotless

* Fix after rebase

* Spotless

* Log error upon failing to parse error response

* Remove TODOs

* Update docs/changelog/111154.yaml

---------

Co-authored-by: Adam Demjen <[email protected]>
  • Loading branch information
2 people authored and cbuescher committed Sep 4, 2024
1 parent 598ae63 commit 7009152
Show file tree
Hide file tree
Showing 40 changed files with 2,894 additions and 21 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/111154.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 111154
summary: EIS integration
area: Inference
type: feature
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_NESTED_UNSUPPORTED = def(8_717_00_0);
public static final TransportVersion ESQL_SINGLE_VALUE_QUERY_SOURCE = def(8_718_00_0);
public static final TransportVersion ESQL_ORIGINAL_INDICES = def(8_719_00_0);
public static final TransportVersion ML_INFERENCE_EIS_INTEGRATION_ADDED = def(8_720_00_0);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.inference;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;

/**
* This class defines an empty secret settings object. This is useful for services that do not have any secret settings.
*/
public record EmptySecretSettings() implements SecretSettings {
public static final String NAME = "empty_secret_settings";

public static final EmptySecretSettings INSTANCE = new EmptySecretSettings();

public EmptySecretSettings(StreamInput in) {
this();
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.endObject();
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_EIS_INTEGRATION_ADDED;
}

@Override
public void writeTo(StreamOutput out) throws IOException {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.inference.EmptySecretSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
Expand Down Expand Up @@ -45,6 +46,7 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings;
Expand Down Expand Up @@ -95,6 +97,9 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
// Empty default task settings
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new));

// Empty default secret settings
namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, EmptySecretSettings.NAME, EmptySecretSettings::new));

// Default secret settings
namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, DefaultSecretSettings.NAME, DefaultSecretSettings::new));

Expand All @@ -111,6 +116,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addCustomElandWriteables(namedWriteables);
addAnthropicNamedWritables(namedWriteables);
addAmazonBedrockNamedWriteables(namedWriteables);
addEisNamedWriteables(namedWriteables);

return namedWriteables;
}
Expand Down Expand Up @@ -475,4 +481,14 @@ private static void addAnthropicNamedWritables(List<NamedWriteableRegistry.Entry
)
);
}

private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticInferenceServiceSparseEmbeddingsServiceSettings.NAME,
ElasticInferenceServiceSparseEmbeddingsServiceSettings::new
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
import org.elasticsearch.xpack.inference.services.elser.ElserInternalService;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService;
Expand Down Expand Up @@ -124,7 +128,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
private final SetOnce<HttpRequestSender.Factory> httpFactory = new SetOnce<>();
private final SetOnce<AmazonBedrockRequestSender.Factory> amazonBedrockFactory = new SetOnce<>();
private final SetOnce<ServiceComponents> serviceComponents = new SetOnce<>();

private final SetOnce<ElasticInferenceServiceComponents> eisComponents = new SetOnce<>();
private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
private List<InferenceServiceExtension> inferenceServiceExtensions;
Expand Down Expand Up @@ -187,6 +191,15 @@ public Collection<?> createComponents(PluginServices services) {
var inferenceServices = new ArrayList<>(inferenceServiceExtensions);
inferenceServices.add(this::getInferenceServiceFactories);

if (ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) {
ElasticInferenceServiceSettings eisSettings = new ElasticInferenceServiceSettings(settings);
eisComponents.set(new ElasticInferenceServiceComponents(eisSettings.getEisGatewayUrl()));

inferenceServices.add(
() -> List.of(context -> new ElasticInferenceService(httpFactory.get(), serviceComponents.get(), eisComponents.get()))
);
}

var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client());
// This must be done after the HttpRequestSenderFactory is created so that the services can get the
// reference correctly
Expand Down Expand Up @@ -281,6 +294,7 @@ public List<Setting<?>> getSettings() {
HttpClientManager.getSettingsDefinitions(),
ThrottlerManager.getSettingsDefinitions(),
RetrySettings.getSettingsDefinitions(),
ElasticInferenceServiceSettings.getSettingsDefinitions(),
Truncator.getSettingsDefinitions(),
RequestExecutorServiceSettings.getSettingsDefinitions(),
List.of(SKIP_VALIDATE_AND_START)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.elastic;

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.ElasticInferenceServiceSparseEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;

import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;

public class ElasticInferenceServiceActionCreator implements ElasticInferenceServiceActionVisitor {

private final Sender sender;

private final ServiceComponents serviceComponents;

public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents serviceComponents) {
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}

@Override
public ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model) {
var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents);
var errorMessage = constructFailedToSendRequestMessage(model.uri(), "Elastic Inference Service sparse embeddings");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.elastic;

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;

public interface ElasticInferenceServiceActionVisitor {

ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.elastic;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ContentTooLargeException;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceErrorResponseEntity;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;

public class ElasticInferenceServiceResponseHandler extends BaseResponseHandler {

public ElasticInferenceServiceResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, ElasticInferenceServiceErrorResponseEntity::fromResponse);
}

@Override
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
checkForFailureStatusCode(request, result);
checkForEmptyBody(throttlerManager, logger, request, result);
}

void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
int statusCode = result.response().getStatusLine().getStatusCode();
if (statusCode >= 200 && statusCode < 300) {
return;
}

if (statusCode == 500) {
throw new RetryException(true, buildError(SERVER_ERROR, request, result));
} else if (statusCode == 400) {
throw new RetryException(false, buildError(BAD_REQUEST, request, result));
} else if (statusCode == 405) {
throw new RetryException(false, buildError(METHOD_NOT_ALLOWED, request, result));
} else if (statusCode == 413) {
throw new ContentTooLargeException(buildError(CONTENT_TOO_LARGE, request, result));
}

throw new RetryException(false, buildError(UNSUCCESSFUL, request, result));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ public abstract class BaseResponseHandler implements ResponseHandler {
public static final String REDIRECTION = "Unhandled redirection";
public static final String CONTENT_TOO_LARGE = "Received a content too large status code";
public static final String UNSUCCESSFUL = "Received an unsuccessful status code";
public static final String BAD_REQUEST = "Received a bad request status code";
public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code";

protected final String requestType;
private final ResponseParser parseFunction;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel;

import java.util.Objects;

public abstract class ElasticInferenceServiceRequestManager extends BaseRequestManager {

protected ElasticInferenceServiceRequestManager(ThreadPool threadPool, ElasticInferenceServiceModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
}

record RateLimitGrouping(int modelIdHash) {
public static RateLimitGrouping of(ElasticInferenceServiceModel model) {
Objects.requireNonNull(model);

return new RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceSparseEmbeddingsRequest;
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceSparseEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;

import java.util.List;
import java.util.function.Supplier;

import static org.elasticsearch.xpack.inference.common.Truncator.truncate;

public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends ElasticInferenceServiceRequestManager {

private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceSparseEmbeddingsRequestManager.class);

private static final ResponseHandler HANDLER = createSparseEmbeddingsHandler();

private final ElasticInferenceServiceSparseEmbeddingsModel model;

private final Truncator truncator;

private static ResponseHandler createSparseEmbeddingsHandler() {
return new ElasticInferenceServiceResponseHandler(
"Elastic Inference Service sparse embeddings",
ElasticInferenceServiceSparseEmbeddingsResponseEntity::fromResponse
);
}

public ElasticInferenceServiceSparseEmbeddingsRequestManager(
ElasticInferenceServiceSparseEmbeddingsModel model,
ServiceComponents serviceComponents
) {
super(serviceComponents.threadPool(), model);
this.model = model;
this.truncator = serviceComponents.truncator();
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());

ElasticInferenceServiceSparseEmbeddingsRequest request = new ElasticInferenceServiceSparseEmbeddingsRequest(
truncator,
truncatedInput,
model
);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic;

import org.elasticsearch.xpack.inference.external.request.Request;

public interface ElasticInferenceServiceRequest extends Request {}
Loading

0 comments on commit 7009152

Please sign in to comment.