Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EIS integration #111154

Merged
merged 32 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9de0660
WIP
timgrein Jul 22, 2024
02fc3d2
Add ElasticInferenceServiceTests TODOs
timgrein Jul 22, 2024
4991125
Add ElasticInferenceServiceActionCreatorTests TODOs
timgrein Jul 22, 2024
ec43b4a
Add ElasticInferenceServiceResponseHandlerTests TODOs
timgrein Jul 22, 2024
ebb16d9
Add ElasticInferenceServiceSparseEmbeddingsRequestTests TODOs
timgrein Jul 22, 2024
f0d564c
Add ElasticInferenceServiceSparseEmbeddingsModelTests TODOs
timgrein Jul 23, 2024
d588bea
spotless apply
timgrein Jul 23, 2024
1a4ffc1
Fix conflicts
timgrein Jul 23, 2024
7bdda39
Add EmptySecretSettingsTests
timgrein Jul 23, 2024
ddadf1d
Add named writeables to InferenceNamedWriteablesProvider
timgrein Jul 23, 2024
a8058dd
Remove addressed todos
timgrein Jul 23, 2024
95de891
Translate model to correct endpoint
timgrein Jul 23, 2024
218fb62
Remove addressed TODO
timgrein Jul 23, 2024
7a5439b
Add docs to ElasticInferenceServiceFeature
timgrein Jul 23, 2024
806053a
Implement and test truncation/request
timgrein Jul 23, 2024
63c3edf
Add some EIS tests
demjened Jul 25, 2024
cf6eb5a
Support chunked inference
demjened Jul 26, 2024
a3e37fb
Check model config
demjened Jul 26, 2024
0294624
Add more tests
demjened Jul 26, 2024
edf560f
Add response handler
demjened Jul 29, 2024
7f490e6
Add more tests + HTTP 413 handling
demjened Jul 30, 2024
02d99d9
Fix some tests
demjened Jul 30, 2024
7e0bcd1
Spotless
demjened Jul 30, 2024
a17d725
Fixes
demjened Aug 1, 2024
418241b
Switch back to original response structure
demjened Aug 5, 2024
9cbe85e
Implement pass-through chunking
demjened Aug 6, 2024
b627569
Spotless
demjened Aug 6, 2024
8437cab
Fix after rebase
demjened Aug 6, 2024
1dc6c20
Spotless
demjened Aug 6, 2024
f0e703f
Log error upon failing to parse error response
demjened Aug 9, 2024
f968a33
Remove TODOs
demjened Aug 9, 2024
d3ef457
Update docs/changelog/111154.yaml
timgrein Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/111154.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 111154
summary: EIS integration
area: Inference
type: feature
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_NESTED_UNSUPPORTED = def(8_717_00_0);
public static final TransportVersion ESQL_SINGLE_VALUE_QUERY_SOURCE = def(8_718_00_0);
public static final TransportVersion ESQL_ORIGINAL_INDICES = def(8_719_00_0);
public static final TransportVersion ML_INFERENCE_EIS_INTEGRATION_ADDED = def(8_720_00_0);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.inference;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;

/**
* This class defines an empty secret settings object. This is useful for services that do not have any secret settings.
*/
public record EmptySecretSettings() implements SecretSettings {
public static final String NAME = "empty_secret_settings";

public static final EmptySecretSettings INSTANCE = new EmptySecretSettings();

public EmptySecretSettings(StreamInput in) {
this();
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.endObject();
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_EIS_INTEGRATION_ADDED;
}

@Override
public void writeTo(StreamOutput out) throws IOException {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.inference.EmptySecretSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
Expand Down Expand Up @@ -45,6 +46,7 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings;
Expand Down Expand Up @@ -95,6 +97,9 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
// Empty default task settings
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new));

// Empty default secret settings
namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, EmptySecretSettings.NAME, EmptySecretSettings::new));

// Default secret settings
namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, DefaultSecretSettings.NAME, DefaultSecretSettings::new));

Expand All @@ -111,6 +116,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addCustomElandWriteables(namedWriteables);
addAnthropicNamedWritables(namedWriteables);
addAmazonBedrockNamedWriteables(namedWriteables);
addEisNamedWriteables(namedWriteables);

return namedWriteables;
}
Expand Down Expand Up @@ -475,4 +481,14 @@ private static void addAnthropicNamedWritables(List<NamedWriteableRegistry.Entry
)
);
}

private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
ElasticInferenceServiceSparseEmbeddingsServiceSettings.NAME,
ElasticInferenceServiceSparseEmbeddingsServiceSettings::new
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
import org.elasticsearch.xpack.inference.services.elser.ElserInternalService;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService;
Expand Down Expand Up @@ -124,7 +128,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
private final SetOnce<HttpRequestSender.Factory> httpFactory = new SetOnce<>();
private final SetOnce<AmazonBedrockRequestSender.Factory> amazonBedrockFactory = new SetOnce<>();
private final SetOnce<ServiceComponents> serviceComponents = new SetOnce<>();

private final SetOnce<ElasticInferenceServiceComponents> eisComponents = new SetOnce<>();
private final SetOnce<InferenceServiceRegistry> inferenceServiceRegistry = new SetOnce<>();
private final SetOnce<ShardBulkInferenceActionFilter> shardBulkInferenceActionFilter = new SetOnce<>();
private List<InferenceServiceExtension> inferenceServiceExtensions;
Expand Down Expand Up @@ -187,6 +191,15 @@ public Collection<?> createComponents(PluginServices services) {
var inferenceServices = new ArrayList<>(inferenceServiceExtensions);
inferenceServices.add(this::getInferenceServiceFactories);

if (ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) {
ElasticInferenceServiceSettings eisSettings = new ElasticInferenceServiceSettings(settings);
eisComponents.set(new ElasticInferenceServiceComponents(eisSettings.getEisGatewayUrl()));

inferenceServices.add(
() -> List.of(context -> new ElasticInferenceService(httpFactory.get(), serviceComponents.get(), eisComponents.get()))
);
}

var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client());
// This must be done after the HttpRequestSenderFactory is created so that the services can get the
// reference correctly
Expand Down Expand Up @@ -281,6 +294,7 @@ public List<Setting<?>> getSettings() {
HttpClientManager.getSettingsDefinitions(),
ThrottlerManager.getSettingsDefinitions(),
RetrySettings.getSettingsDefinitions(),
ElasticInferenceServiceSettings.getSettingsDefinitions(),
Truncator.getSettingsDefinitions(),
RequestExecutorServiceSettings.getSettingsDefinitions(),
List.of(SKIP_VALIDATE_AND_START)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.elastic;

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.ElasticInferenceServiceSparseEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;

import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;

public class ElasticInferenceServiceActionCreator implements ElasticInferenceServiceActionVisitor {

private final Sender sender;

private final ServiceComponents serviceComponents;

public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents serviceComponents) {
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}

@Override
public ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model) {
var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents);
var errorMessage = constructFailedToSendRequestMessage(model.uri(), "Elastic Inference Service sparse embeddings");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.elastic;

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;

public interface ElasticInferenceServiceActionVisitor {

ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.elastic;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ContentTooLargeException;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceErrorResponseEntity;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;

public class ElasticInferenceServiceResponseHandler extends BaseResponseHandler {

public ElasticInferenceServiceResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, ElasticInferenceServiceErrorResponseEntity::fromResponse);
}

@Override
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
checkForFailureStatusCode(request, result);
checkForEmptyBody(throttlerManager, logger, request, result);
}

void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
int statusCode = result.response().getStatusLine().getStatusCode();
if (statusCode >= 200 && statusCode < 300) {
return;
}

if (statusCode == 500) {
throw new RetryException(true, buildError(SERVER_ERROR, request, result));
} else if (statusCode == 400) {
throw new RetryException(false, buildError(BAD_REQUEST, request, result));
} else if (statusCode == 405) {
throw new RetryException(false, buildError(METHOD_NOT_ALLOWED, request, result));
} else if (statusCode == 413) {
throw new ContentTooLargeException(buildError(CONTENT_TOO_LARGE, request, result));
}

throw new RetryException(false, buildError(UNSUCCESSFUL, request, result));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ public abstract class BaseResponseHandler implements ResponseHandler {
public static final String REDIRECTION = "Unhandled redirection";
public static final String CONTENT_TOO_LARGE = "Received a content too large status code";
public static final String UNSUCCESSFUL = "Received an unsuccessful status code";
public static final String BAD_REQUEST = "Received a bad request status code";
public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code";

protected final String requestType;
private final ResponseParser parseFunction;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel;

import java.util.Objects;

public abstract class ElasticInferenceServiceRequestManager extends BaseRequestManager {

protected ElasticInferenceServiceRequestManager(ThreadPool threadPool, ElasticInferenceServiceModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
}

record RateLimitGrouping(int modelIdHash) {
public static RateLimitGrouping of(ElasticInferenceServiceModel model) {
Objects.requireNonNull(model);

return new RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceSparseEmbeddingsRequest;
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceSparseEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;

import java.util.List;
import java.util.function.Supplier;

import static org.elasticsearch.xpack.inference.common.Truncator.truncate;

public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends ElasticInferenceServiceRequestManager {

private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceSparseEmbeddingsRequestManager.class);

private static final ResponseHandler HANDLER = createSparseEmbeddingsHandler();

private final ElasticInferenceServiceSparseEmbeddingsModel model;

private final Truncator truncator;

private static ResponseHandler createSparseEmbeddingsHandler() {
return new ElasticInferenceServiceResponseHandler(
"Elastic Inference Service sparse embeddings",
ElasticInferenceServiceSparseEmbeddingsResponseEntity::fromResponse
);
}

public ElasticInferenceServiceSparseEmbeddingsRequestManager(
ElasticInferenceServiceSparseEmbeddingsModel model,
ServiceComponents serviceComponents
) {
super(serviceComponents.threadPool(), model);
this.model = model;
this.truncator = serviceComponents.truncator();
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());

ElasticInferenceServiceSparseEmbeddingsRequest request = new ElasticInferenceServiceSparseEmbeddingsRequest(
truncator,
truncatedInput,
model
);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic;

import org.elasticsearch.xpack.inference.external.request.Request;

public interface ElasticInferenceServiceRequest extends Request {}
Loading