From 70091526d1daf1ccc868b920171020e4087d59be Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Fri, 9 Aug 2024 17:22:59 +0200 Subject: [PATCH] EIS integration (#111154) * WIP * Add ElasticInferenceServiceTests TODOs * Add ElasticInferenceServiceActionCreatorTests TODOs * Add ElasticInferenceServiceResponseHandlerTests TODOs * Add ElasticInferenceServiceSparseEmbeddingsRequestTests TODOs * Add ElasticInferenceServiceSparseEmbeddingsModelTests TODOs * spotless apply * Fix conflicts * Add EmptySecretSettingsTests * Add named writeables to InferenceNamedWriteablesProvider * Remove addressed todos * Translate model to correct endpoint * Remove addressed TODO * Add docs to ElasticInferenceServiceFeature * Implement and test truncation/request * Add some EIS tests * Support chunked inference * Check model config * Add more tests * Add response handler * Add more tests + HTTP 413 handling * Fix some tests * Spotless * Fixes * Switch back to original response structure * Implement pass-through chunking * Spotless * Fix after rebase * Spotless * Log error upon failing to parse error response * Remove TODOs * Update docs/changelog/111154.yaml --------- Co-authored-by: Adam Demjen --- docs/changelog/111154.yaml | 5 + .../org/elasticsearch/TransportVersions.java | 1 + .../inference/EmptySecretSettings.java | 50 ++ .../InferenceNamedWriteablesProvider.java | 16 + .../xpack/inference/InferencePlugin.java | 16 +- .../ElasticInferenceServiceActionCreator.java | 38 ++ .../ElasticInferenceServiceActionVisitor.java | 17 + ...lasticInferenceServiceResponseHandler.java | 54 ++ .../http/retry/BaseResponseHandler.java | 2 + ...ElasticInferenceServiceRequestManager.java | 28 + ...ServiceSparseEmbeddingsRequestManager.java | 71 +++ .../ElasticInferenceServiceRequest.java | 12 + ...ferenceServiceSparseEmbeddingsRequest.java | 80 +++ ...eServiceSparseEmbeddingsRequestEntity.java | 41 ++ ...icInferenceServiceErrorResponseEntity.java | 64 +++ ...ServiceSparseEmbeddingsResponseEntity.java | 121 ++++ .../elastic/ElasticInferenceService.java | 274 +++++++++ .../ElasticInferenceServiceComponents.java | 10 + .../ElasticInferenceServiceFeature.java | 20 + .../elastic/ElasticInferenceServiceModel.java | 55 ++ ...erenceServiceRateLimitServiceSettings.java | 18 + .../ElasticInferenceServiceSettings.java | 33 ++ ...InferenceServiceSparseEmbeddingsModel.java | 113 ++++ ...erviceSparseEmbeddingsServiceSettings.java | 162 ++++++ .../services/elser/ElserInternalService.java | 13 +- .../elser/ElserInternalServiceSettings.java | 6 +- .../inference/services/elser/ElserModels.java | 33 ++ .../inference/EmptySecretSettingsTests.java | 35 ++ ...ticInferenceServiceActionCreatorTests.java | 289 ++++++++++ ...cInferenceServiceResponseHandlerTests.java | 116 ++++ ...iceSparseEmbeddingsRequestEntityTests.java | 50 ++ ...ceServiceSparseEmbeddingsRequestTests.java | 81 +++ ...erenceServiceErrorResponseEntityTests.java | 61 ++ ...ceSparseEmbeddingsResponseEntityTests.java | 241 ++++++++ ...enceServiceSparseEmbeddingsModelTests.java | 33 ++ ...eSparseEmbeddingsServiceSettingsTests.java | 90 +++ .../elastic/ElasticInferenceServiceTests.java | 523 ++++++++++++++++++ .../ElserInternalServiceSettingsTests.java | 8 +- .../elser/ElserInternalServiceTests.java | 2 +- .../services/elser/ElserModelsTests.java | 33 ++ 40 files changed, 2894 insertions(+), 21 deletions(-) create mode 100644 docs/changelog/111154.yaml create mode 100644 server/src/main/java/org/elasticsearch/inference/EmptySecretSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreator.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionVisitor.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandler.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceErrorResponseEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceSparseEmbeddingsResponseEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceFeature.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRateLimitServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserModels.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/EmptySecretSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreatorTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandlerTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceErrorResponseEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceSparseEmbeddingsResponseEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserModelsTests.java diff --git a/docs/changelog/111154.yaml b/docs/changelog/111154.yaml new file mode 100644 index 0000000000000..3297f5005a811 --- /dev/null +++ b/docs/changelog/111154.yaml @@ -0,0 +1,5 @@ +pr: 111154 +summary: EIS integration +area: Inference +type: feature +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 2427a2fe72ac6..2b579894ca521 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -187,6 +187,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_NESTED_UNSUPPORTED = def(8_717_00_0); public static final TransportVersion ESQL_SINGLE_VALUE_QUERY_SOURCE = def(8_718_00_0); public static final TransportVersion ESQL_ORIGINAL_INDICES = def(8_719_00_0); + public static final TransportVersion ML_INFERENCE_EIS_INTEGRATION_ADDED = def(8_720_00_0); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/inference/EmptySecretSettings.java b/server/src/main/java/org/elasticsearch/inference/EmptySecretSettings.java new file mode 100644 index 0000000000000..5c6acb78a91e3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/EmptySecretSettings.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; + +/** + * This class defines an empty secret settings object. This is useful for services that do not have any secret settings. + */ +public record EmptySecretSettings() implements SecretSettings { + public static final String NAME = "empty_secret_settings"; + + public static final EmptySecretSettings INSTANCE = new EmptySecretSettings(); + + public EmptySecretSettings(StreamInput in) { + this(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_EIS_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 476ab3355a0b8..489a81b642492 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; @@ -45,6 +46,7 @@ import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings; @@ -95,6 +97,9 @@ public static List getNamedWriteables() { // Empty default task settings namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new)); + // Empty default secret settings + namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, EmptySecretSettings.NAME, EmptySecretSettings::new)); + // Default secret settings namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, DefaultSecretSettings.NAME, DefaultSecretSettings::new)); @@ -111,6 +116,7 @@ public static List getNamedWriteables() { addCustomElandWriteables(namedWriteables); addAnthropicNamedWritables(namedWriteables); addAmazonBedrockNamedWriteables(namedWriteables); + addEisNamedWriteables(namedWriteables); return namedWriteables; } @@ -475,4 +481,14 @@ private static void addAnthropicNamedWritables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + ElasticInferenceServiceSparseEmbeddingsServiceSettings.NAME, + ElasticInferenceServiceSparseEmbeddingsServiceSettings::new + ) + ); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index ec9398358d180..f6d4a9f774a91 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -77,6 +77,10 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService; import org.elasticsearch.xpack.inference.services.cohere.CohereService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; @@ -124,7 +128,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP private final SetOnce httpFactory = new SetOnce<>(); private final SetOnce amazonBedrockFactory = new SetOnce<>(); private final SetOnce serviceComponents = new SetOnce<>(); - + private final SetOnce eisComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); private final SetOnce shardBulkInferenceActionFilter = new SetOnce<>(); private List inferenceServiceExtensions; @@ -187,6 +191,15 @@ public Collection createComponents(PluginServices services) { var inferenceServices = new ArrayList<>(inferenceServiceExtensions); inferenceServices.add(this::getInferenceServiceFactories); + if (ElasticInferenceServiceFeature.ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG.isEnabled()) { + ElasticInferenceServiceSettings eisSettings = new ElasticInferenceServiceSettings(settings); + eisComponents.set(new ElasticInferenceServiceComponents(eisSettings.getEisGatewayUrl())); + + inferenceServices.add( + () -> List.of(context -> new ElasticInferenceService(httpFactory.get(), serviceComponents.get(), eisComponents.get())) + ); + } + var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client()); // This must be done after the HttpRequestSenderFactory is created so that the services can get the // reference correctly @@ -281,6 +294,7 @@ public List> getSettings() { HttpClientManager.getSettingsDefinitions(), ThrottlerManager.getSettingsDefinitions(), RetrySettings.getSettingsDefinitions(), + ElasticInferenceServiceSettings.getSettingsDefinitions(), Truncator.getSettingsDefinitions(), RequestExecutorServiceSettings.getSettingsDefinitions(), List.of(SKIP_VALIDATE_AND_START) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreator.java new file mode 100644 index 0000000000000..ea2295979c480 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreator.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.elastic; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.ElasticInferenceServiceSparseEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel; + +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; + +public class ElasticInferenceServiceActionCreator implements ElasticInferenceServiceActionVisitor { + + private final Sender sender; + + private final ServiceComponents serviceComponents; + + public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model) { + var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents); + var errorMessage = constructFailedToSendRequestMessage(model.uri(), "Elastic Inference Service sparse embeddings"); + return new SenderExecutableAction(sender, requestManager, errorMessage); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionVisitor.java new file mode 100644 index 0000000000000..99985e50b2538 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionVisitor.java @@ -0,0 +1,17 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.elastic; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel; + +public interface ElasticInferenceServiceActionVisitor { + + ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandler.java new file mode 100644 index 0000000000000..15e543fadad71 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandler.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.elastic; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ContentTooLargeException; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceErrorResponseEntity; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; + +import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody; + +public class ElasticInferenceServiceResponseHandler extends BaseResponseHandler { + + public ElasticInferenceServiceResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, ElasticInferenceServiceErrorResponseEntity::fromResponse); + } + + @Override + public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) + throws RetryException { + checkForFailureStatusCode(request, result); + checkForEmptyBody(throttlerManager, logger, request, result); + } + + void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { + int statusCode = result.response().getStatusLine().getStatusCode(); + if (statusCode >= 200 && statusCode < 300) { + return; + } + + if (statusCode == 500) { + throw new RetryException(true, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 400) { + throw new RetryException(false, buildError(BAD_REQUEST, request, result)); + } else if (statusCode == 405) { + throw new RetryException(false, buildError(METHOD_NOT_ALLOWED, request, result)); + } else if (statusCode == 413) { + throw new ContentTooLargeException(buildError(CONTENT_TOO_LARGE, request, result)); + } + + throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index f793cb3586924..c9cbe169ec03d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -27,6 +27,8 @@ public abstract class BaseResponseHandler implements ResponseHandler { public static final String REDIRECTION = "Unhandled redirection"; public static final String CONTENT_TOO_LARGE = "Received a content too large status code"; public static final String UNSUCCESSFUL = "Received an unsuccessful status code"; + public static final String BAD_REQUEST = "Received a bad request status code"; + public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code"; protected final String requestType; private final ResponseParser parseFunction; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceRequestManager.java new file mode 100644 index 0000000000000..c857a481f8f04 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceRequestManager.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; + +import java.util.Objects; + +public abstract class ElasticInferenceServiceRequestManager extends BaseRequestManager { + + protected ElasticInferenceServiceRequestManager(ThreadPool threadPool, ElasticInferenceServiceModel model) { + super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + } + + record RateLimitGrouping(int modelIdHash) { + public static RateLimitGrouping of(ElasticInferenceServiceModel model) { + Objects.requireNonNull(model); + + return new RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode()); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java new file mode 100644 index 0000000000000..b59ac54d5cbb6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceSparseEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceSparseEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel; + +import java.util.List; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends ElasticInferenceServiceRequestManager { + + private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceSparseEmbeddingsRequestManager.class); + + private static final ResponseHandler HANDLER = createSparseEmbeddingsHandler(); + + private final ElasticInferenceServiceSparseEmbeddingsModel model; + + private final Truncator truncator; + + private static ResponseHandler createSparseEmbeddingsHandler() { + return new ElasticInferenceServiceResponseHandler( + "Elastic Inference Service sparse embeddings", + ElasticInferenceServiceSparseEmbeddingsResponseEntity::fromResponse + ); + } + + public ElasticInferenceServiceSparseEmbeddingsRequestManager( + ElasticInferenceServiceSparseEmbeddingsModel model, + ServiceComponents serviceComponents + ) { + super(serviceComponents.threadPool(), model); + this.model = model; + this.truncator = serviceComponents.truncator(); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); + + ElasticInferenceServiceSparseEmbeddingsRequest request = new ElasticInferenceServiceSparseEmbeddingsRequest( + truncator, + truncatedInput, + model + ); + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequest.java new file mode 100644 index 0000000000000..03eec913a265f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequest.java @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.elasticsearch.xpack.inference.external.request.Request; + +public interface ElasticInferenceServiceRequest extends Request {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java new file mode 100644 index 0000000000000..41a2ef1c3ccda --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticInferenceServiceRequest { + + private final URI uri; + + private final ElasticInferenceServiceSparseEmbeddingsModel model; + + private final Truncator.TruncationResult truncationResult; + private final Truncator truncator; + + public ElasticInferenceServiceSparseEmbeddingsRequest( + Truncator truncator, + Truncator.TruncationResult truncationResult, + ElasticInferenceServiceSparseEmbeddingsModel model + ) { + this.truncator = truncator; + this.truncationResult = truncationResult; + this.model = Objects.requireNonNull(model); + this.uri = model.uri(); + } + + @Override + public HttpRequest createHttpRequest() { + var httpPost = new HttpPost(uri); + var requestEntity = Strings.toString(new ElasticInferenceServiceSparseEmbeddingsRequestEntity(truncationResult.input())); + + ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public Request truncate() { + var truncatedInput = truncator.truncate(truncationResult.input()); + + return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model); + } + + @Override + public boolean[] getTruncationInfo() { + return truncationResult.truncated().clone(); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..301bbf0146c14 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestEntity.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record ElasticInferenceServiceSparseEmbeddingsRequestEntity(List inputs) implements ToXContentObject { + + private static final String INPUT_FIELD = "input"; + + public ElasticInferenceServiceSparseEmbeddingsRequestEntity { + Objects.requireNonNull(inputs); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startArray(INPUT_FIELD); + + { + for (String input : inputs) { + builder.value(input); + } + } + + builder.endArray(); + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceErrorResponseEntity.java new file mode 100644 index 0000000000000..c860821c81bbf --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceErrorResponseEntity.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.elastic; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorMessage; + +public class ElasticInferenceServiceErrorResponseEntity implements ErrorMessage { + + private final String errorMessage; + + private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceErrorResponseEntity.class); + + private ElasticInferenceServiceErrorResponseEntity(String errorMessage) { + this.errorMessage = errorMessage; + } + + @Override + public String getErrorMessage() { + return errorMessage; + } + + /** + * An example error response would look like + * + * + * { + * "error": "some error" + * } + * + * + * @param response The error response + * @return An error entity if the response is JSON with the above structure + * or null if the response does not contain the error field + */ + public static @Nullable ElasticInferenceServiceErrorResponseEntity fromResponse(HttpResult response) { + try ( + XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response.body()) + ) { + var responseMap = jsonParser.map(); + var error = (String) responseMap.get("error"); + if (error != null) { + return new ElasticInferenceServiceErrorResponseEntity(error); + } + } catch (Exception e) { + logger.debug("Failed to parse error response", e); + } + + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceSparseEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceSparseEmbeddingsResponseEntity.java new file mode 100644 index 0000000000000..2b36cc5d22cd4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceSparseEmbeddingsResponseEntity.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.elastic; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class ElasticInferenceServiceSparseEmbeddingsResponseEntity { + + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = + "Failed to find required field [%s] in Elastic Inference Service embeddings response"; + + /** + * Parses the EIS json response. + * + * For a request like: + * + *
+     *     
+     *         {
+     *             "inputs": ["Embed this text", "Embed this text, too"]
+     *         }
+     *     
+     * 
+ * + * The response would look like: + * + *
+     *     
+     *         {
+     *           "data": [
+     *                     {
+     *                       "Embed": 2.1259406,
+     *                       "this": 1.7073475,
+     *                       "text": 0.9020516
+     *                     },
+     *                    (...)
+     *                  ],
+     *           "meta": {
+     *               "processing_latency": ...,
+     *               "request_time": ...
+     *           }
+     *     
+     * 
+ */ + + public static SparseEmbeddingResults fromResponse(Request request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); + + var truncationResults = request.getTruncationInfo(); + List parsedEmbeddings = parseList( + jsonParser, + (parser, index) -> ElasticInferenceServiceSparseEmbeddingsResponseEntity.parseExpansionResult( + truncationResults, + parser, + index + ) + ); + + if (parsedEmbeddings.isEmpty()) { + return new SparseEmbeddingResults(Collections.emptyList()); + } + + return new SparseEmbeddingResults(parsedEmbeddings); + } + } + + private static SparseEmbeddingResults.Embedding parseExpansionResult(boolean[] truncationResults, XContentParser parser, int index) + throws IOException { + XContentParser.Token token = parser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser); + + List weightedTokens = new ArrayList<>(); + token = parser.nextToken(); + while (token != null && token != XContentParser.Token.END_OBJECT) { + ensureExpectedToken(XContentParser.Token.FIELD_NAME, token, parser); + var floatToken = parser.nextToken(); + ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, floatToken, parser); + + weightedTokens.add(new WeightedToken(parser.currentName(), parser.floatValue())); + + token = parser.nextToken(); + } + + // prevent an out of bounds if for some reason the truncation list is smaller than the results + var isTruncated = truncationResults != null && index < truncationResults.length && truncationResults[index]; + return new SparseEmbeddingResults.Embedding(weightedTokens, isTruncated); + } + + private ElasticInferenceServiceSparseEmbeddingsResponseEntity() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java new file mode 100644 index 0000000000000..f77217f9c02f9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -0,0 +1,274 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionCreator; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +public class ElasticInferenceService extends SenderService { + + public static final String NAME = "elastic"; + + private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; + + public ElasticInferenceService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + ElasticInferenceServiceComponents eisComponents + ) { + super(factory, serviceComponents); + this.elasticInferenceServiceComponents = eisComponents; + } + + @Override + protected void doInfer( + Model model, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof ElasticInferenceServiceModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + ElasticInferenceServiceModel elasticInferenceServiceModel = (ElasticInferenceServiceModel) model; + var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents()); + + var action = elasticInferenceServiceModel.accept(actionCreator, taskSettings); + action.execute(new DocumentsOnlyInput(input), timeout, listener); + } + + @Override + protected void doInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + throw new UnsupportedOperationException("Query input not supported for Elastic Inference Service"); + } + + @Override + protected void doChunkedInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + ChunkingOptions chunkingOptions, + TimeValue timeout, + ActionListener> listener + ) { + // Pass-through without actually performing chunking (result will have a single chunk per input) + ActionListener inferListener = listener.delegateFailureAndWrap( + (delegate, response) -> delegate.onResponse(translateToChunkedResults(input, response)) + ); + + doInfer(model, input, taskSettings, inputType, timeout, inferListener); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + Set platformArchitectures, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ElasticInferenceServiceModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + serviceSettingsMap, + elasticInferenceServiceComponents, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + private static ElasticInferenceServiceModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + ElasticInferenceServiceComponents eisServiceComponents, + String failureMessage, + ConfigurationParseContext context + ) { + return switch (taskType) { + case SPARSE_EMBEDDING -> new ElasticInferenceServiceSparseEmbeddingsModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + eisServiceComponents, + context + ); + default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + }; + } + + @Override + public Model parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_EIS_INTEGRATION_ADDED; + } + + private ElasticInferenceServiceModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + secretSettings, + elasticInferenceServiceComponents, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + @Override + public void checkModelConfig(Model model, ActionListener listener) { + if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel embeddingsModel) { + listener.onResponse(updateModelWithEmbeddingDetails(embeddingsModel)); + } else { + listener.onResponse(model); + } + } + + private static List translateToChunkedResults( + List inputs, + InferenceServiceResults inferenceResults + ) { + if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { + return InferenceChunkedSparseEmbeddingResults.listOf(inputs, sparseEmbeddingResults); + } else if (inferenceResults instanceof ErrorInferenceResults error) { + return List.of(new ErrorChunkedInferenceResults(error.getException())); + } else { + String expectedClass = Strings.format("%s", SparseEmbeddingResults.class.getSimpleName()); + throw createInvalidChunkedResultException(expectedClass, inferenceResults.getWriteableName()); + } + } + + private ElasticInferenceServiceSparseEmbeddingsModel updateModelWithEmbeddingDetails( + ElasticInferenceServiceSparseEmbeddingsModel model + ) { + ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings = new ElasticInferenceServiceSparseEmbeddingsServiceSettings( + model.getServiceSettings().modelId(), + model.getServiceSettings().maxInputTokens(), + model.getServiceSettings().rateLimitSettings() + ); + + return new ElasticInferenceServiceSparseEmbeddingsModel(model, serviceSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java new file mode 100644 index 0000000000000..4386964e927d2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceComponents.java @@ -0,0 +1,10 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +public record ElasticInferenceServiceComponents(String eisGatewayUrl) {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceFeature.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceFeature.java new file mode 100644 index 0000000000000..b0fb6d14ee6f7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceFeature.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.elasticsearch.common.util.FeatureFlag; + +/** + * Elastic Inference Service (EIS) feature flag. When the feature is complete, this flag will be removed. + * Enable feature via JVM option: `-Des.eis_feature_flag_enabled=true`. + */ +public class ElasticInferenceServiceFeature { + + public static final FeatureFlag ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG = new FeatureFlag("eis"); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java new file mode 100644 index 0000000000000..e7809d869fec4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionVisitor; + +import java.util.Map; +import java.util.Objects; + +public abstract class ElasticInferenceServiceModel extends Model { + + private final ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings; + + private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; + + public ElasticInferenceServiceModel( + ModelConfigurations configurations, + ModelSecrets secrets, + ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings, + ElasticInferenceServiceComponents elasticInferenceServiceComponents + ) { + super(configurations, secrets); + + this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings); + this.elasticInferenceServiceComponents = Objects.requireNonNull(elasticInferenceServiceComponents); + } + + public ElasticInferenceServiceModel(ElasticInferenceServiceModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + + this.rateLimitServiceSettings = model.rateLimitServiceSettings(); + this.elasticInferenceServiceComponents = model.elasticInferenceServiceComponents(); + } + + public ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings() { + return rateLimitServiceSettings; + } + + public ElasticInferenceServiceComponents elasticInferenceServiceComponents() { + return elasticInferenceServiceComponents; + } + + public abstract ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map taskSettings); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRateLimitServiceSettings.java new file mode 100644 index 0000000000000..2ec562b61fa01 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRateLimitServiceSettings.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +public interface ElasticInferenceServiceRateLimitServiceSettings { + + String modelId(); + + RateLimitSettings rateLimitSettings(); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java new file mode 100644 index 0000000000000..8525710c6cf23 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; + +import java.util.List; + +public class ElasticInferenceServiceSettings { + + static final Setting EIS_GATEWAY_URL = Setting.simpleString("xpack.inference.eis.gateway.url", Setting.Property.NodeScope); + + // Adjust this variable to be volatile, if the setting can be updated at some point in time + private final String eisGatewayUrl; + + public ElasticInferenceServiceSettings(Settings settings) { + eisGatewayUrl = EIS_GATEWAY_URL.get(settings); + } + + public static List> getSettingsDefinitions() { + return List.of(EIS_GATEWAY_URL); + } + + public String getEisGatewayUrl() { + return eisGatewayUrl; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java new file mode 100644 index 0000000000000..163e3dd654150 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.elser.ElserModels; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; + +public class ElasticInferenceServiceSparseEmbeddingsModel extends ElasticInferenceServiceModel { + + private final URI uri; + + public ElasticInferenceServiceSparseEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + Map secrets, + ElasticInferenceServiceComponents elasticInferenceServiceComponents, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + ElasticInferenceServiceSparseEmbeddingsServiceSettings.fromMap(serviceSettings, context), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents + ); + } + + public ElasticInferenceServiceSparseEmbeddingsModel( + ElasticInferenceServiceSparseEmbeddingsModel model, + ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings + ) { + super(model, serviceSettings); + + try { + this.uri = createUri(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + ElasticInferenceServiceSparseEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings, + @Nullable TaskSettings taskSettings, + @Nullable SecretSettings secretSettings, + ElasticInferenceServiceComponents elasticInferenceServiceComponents + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + serviceSettings, + elasticInferenceServiceComponents + ); + + try { + this.uri = createUri(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + @Override + public ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map taskSettings) { + return visitor.create(this); + } + + @Override + public ElasticInferenceServiceSparseEmbeddingsServiceSettings getServiceSettings() { + return (ElasticInferenceServiceSparseEmbeddingsServiceSettings) super.getServiceSettings(); + } + + public URI uri() { + return uri; + } + + private URI createUri() throws URISyntaxException { + String modelId = getServiceSettings().modelId(); + String modelIdUriPath; + + switch (modelId) { + case ElserModels.ELSER_V2_MODEL -> modelIdUriPath = "ELSERv2"; + default -> throw new IllegalArgumentException("Unsupported model for EIS [" + modelId + "]"); + } + + return new URI(elasticInferenceServiceComponents().eisGatewayUrl() + "/sparse-text-embedding/" + modelIdUriPath); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..15b89525f7915 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.elser.ElserModels; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; + +public class ElasticInferenceServiceSparseEmbeddingsServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + ElasticInferenceServiceRateLimitServiceSettings { + + public static final String NAME = "elastic_inference_service_sparse_embeddings_service_settings"; + + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(1_000); + + public static ElasticInferenceServiceSparseEmbeddingsServiceSettings fromMap( + Map map, + ConfigurationParseContext context + ) { + ValidationException validationException = new ValidationException(); + + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer maxInputTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + ElasticInferenceService.NAME, + context + ); + + if (modelId != null && ElserModels.isValidEisModel(modelId) == false) { + validationException.addValidationError("unknown ELSER model id [" + modelId + "]"); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, rateLimitSettings); + } + + private final String modelId; + + private final Integer maxInputTokens; + private final RateLimitSettings rateLimitSettings; + + public ElasticInferenceServiceSparseEmbeddingsServiceSettings( + String modelId, + @Nullable Integer maxInputTokens, + RateLimitSettings rateLimitSettings + ) { + this.modelId = Objects.requireNonNull(modelId); + this.maxInputTokens = maxInputTokens; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public ElasticInferenceServiceSparseEmbeddingsServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.maxInputTokens = in.readOptionalVInt(); + this.rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public String getWriteableName() { + return NAME; + } + + public String modelId() { + return modelId; + } + + public Integer maxInputTokens() { + return maxInputTokens; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_EIS_INTEGRATION_ADDED; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragmentOfExposedFields(builder, params); + + builder.endObject(); + + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeOptionalVInt(maxInputTokens); + rateLimitSettings.writeTo(out); + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + ElasticInferenceServiceSparseEmbeddingsServiceSettings that = (ElasticInferenceServiceSparseEmbeddingsServiceSettings) object; + return Objects.equals(modelId, that.modelId) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, maxInputTokens, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java index 03d7682600e7c..775ddca160463 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java @@ -47,22 +47,13 @@ import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL; +import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL_LINUX_X86; public class ElserInternalService extends BaseElasticsearchInternalService { public static final String NAME = "elser"; - static final String ELSER_V1_MODEL = ".elser_model_1"; - // Default non platform specific v2 model - static final String ELSER_V2_MODEL = ".elser_model_2"; - static final String ELSER_V2_MODEL_LINUX_X86 = ".elser_model_2_linux-x86_64"; - - public static Set VALID_ELSER_MODEL_IDS = Set.of( - ElserInternalService.ELSER_V1_MODEL, - ElserInternalService.ELSER_V2_MODEL, - ElserInternalService.ELSER_V2_MODEL_LINUX_X86 - ); - private static final String OLD_MODEL_ID_FIELD_NAME = "model_version"; public ElserInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java index 75797919b3616..fcbabd5a88fc6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java @@ -18,8 +18,6 @@ import java.util.Arrays; import java.util.Map; -import static org.elasticsearch.xpack.inference.services.elser.ElserInternalService.VALID_ELSER_MODEL_IDS; - public class ElserInternalServiceSettings extends ElasticsearchInternalServiceSettings { public static final String NAME = "elser_mlnode_service_settings"; @@ -29,10 +27,10 @@ public static ElasticsearchInternalServiceSettings.Builder fromRequestMap(Map VALID_ELSER_MODEL_IDS = Set.of( + ElserModels.ELSER_V1_MODEL, + ElserModels.ELSER_V2_MODEL, + ElserModels.ELSER_V2_MODEL_LINUX_X86 + ); + + public static boolean isValidModel(String model) { + return VALID_ELSER_MODEL_IDS.contains(model); + } + + public static boolean isValidEisModel(String model) { + return ELSER_V2_MODEL.equals(model); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/EmptySecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/EmptySecretSettingsTests.java new file mode 100644 index 0000000000000..b50ea9e5ee224 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/EmptySecretSettingsTests.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +public class EmptySecretSettingsTests extends AbstractWireSerializingTestCase { + + public static EmptySecretSettings createRandom() { + return EmptySecretSettings.INSTANCE; // no options to randomise + } + + @Override + protected Writeable.Reader instanceReader() { + return EmptySecretSettings::new; + } + + @Override + protected EmptySecretSettings createTestInstance() { + return createRandom(); + } + + @Override + protected EmptySecretSettings mutateInstance(EmptySecretSettings instance) { + // All instances are the same and have no fields, nothing to mutate + return null; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreatorTests.java new file mode 100644 index 0000000000000..1081a60ba6866 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/elastic/ElasticInferenceServiceActionCreatorTests.java @@ -0,0 +1,289 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.elastic; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class ElasticInferenceServiceActionCreatorTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + @SuppressWarnings("unchecked") + public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "data": [ + { + "hello": 2.1259406, + "greet": 1.7073475 + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer)); + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool)); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("hello world")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat( + result.asMap(), + is( + SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings( + List.of( + new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("hello", 2.1259406f, "greet", 1.7073475f), false) + ) + ) + ) + ); + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(1)); + assertThat(requestMap.get("input"), instanceOf(List.class)); + var inputList = (List) requestMap.get("input"); + assertThat(inputList, contains("hello world")); + } + } + + @SuppressWarnings("unchecked") + public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOException { + // timeout as zero for no retries + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + // This will fail because the expected output is {"data": [{...}]} + String responseJson = """ + { + "data": { + "hello": 2.1259406 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer)); + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool)); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("hello world")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]") + ); + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(1)); + assertThat(requestMap.get("input"), instanceOf(List.class)); + var inputList = (List) requestMap.get("input"); + assertThat(inputList, contains("hello world")); + } + } + + public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJsonContentTooLarge = """ + { + "error": "Input validation error: `input` must have less than 512 tokens. Given: 571", + "error_type": "Validation" + } + """; + + String responseJson = """ + { + "data": [ + { + "hello": 2.1259406, + "greet": 1.7073475 + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(413).setBody(responseJsonContentTooLarge)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer)); + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool)); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("hello world")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat( + result.asMap(), + is( + SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings( + List.of( + new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("hello", 2.1259406f, "greet", 1.7073475f), true) + ) + ) + ) + ); + + assertThat(webServer.requests(), hasSize(2)); + { + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + var initialRequestAsMap = entityAsMap(webServer.requests().get(0).getBody()); + var initialInputs = initialRequestAsMap.get("input"); + assertThat(initialInputs, is(List.of("hello world"))); + } + { + assertNull(webServer.requests().get(1).getUri().getQuery()); + assertThat(webServer.requests().get(1).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + var truncatedRequest = entityAsMap(webServer.requests().get(1).getBody()); + var truncatedInputs = truncatedRequest.get("input"); + assertThat(truncatedInputs, is(List.of("hello"))); + } + } + } + + public void testExecute_TruncatesInputBeforeSending() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "data": [ + { + "hello": 2.1259406, + "greet": 1.7073475 + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + // truncated to 1 token = 3 characters + var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), 1); + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool)); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("hello world")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat( + result.asMap(), + is( + SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings( + List.of( + new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("hello", 2.1259406f, "greet", 1.7073475f), true) + ) + ) + ) + ); + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + var initialRequestAsMap = entityAsMap(webServer.requests().get(0).getBody()); + var initialInputs = initialRequestAsMap.get("input"); + assertThat(initialInputs, is(List.of("hel"))); + } + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandlerTests.java new file mode 100644 index 0000000000000..ea30ee29ff5a8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceResponseHandlerTests.java @@ -0,0 +1,116 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.elastic; + +import org.apache.http.Header; +import org.apache.http.HeaderElement; +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ContentTooLargeException; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.hamcrest.MatcherAssert; + +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.core.Is.is; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ElasticInferenceServiceResponseHandlerTests extends ESTestCase { + + public void testCheckForFailureStatusCode_DoesNotThrowFor200() { + callCheckForFailureStatusCode(200, "id"); + } + + public void testCheckForFailureStatusCode_ThrowsFor400() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(400, "id")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a bad request status code for request from inference entity id [id] status [400]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor405() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(405, "id")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a method not allowed status code for request from inference entity id [id] status [405]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.METHOD_NOT_ALLOWED)); + } + + public void testCheckForFailureStatusCode_ThrowsFor413() { + var exception = expectThrows(ContentTooLargeException.class, () -> callCheckForFailureStatusCode(413, "id")); + assertTrue(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a content too large status code for request from inference entity id [id] status [413]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.REQUEST_ENTITY_TOO_LARGE)); + } + + public void testCheckForFailureStatusCode_ThrowsFor500_WithShouldRetryTrue() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(500, "id")); + assertTrue(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [500]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor402() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(402, "id")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received an unsuccessful status code for request from inference entity id [id] status [402]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.PAYMENT_REQUIRED)); + } + + private static void callCheckForFailureStatusCode(int statusCode, String modelId) { + callCheckForFailureStatusCode(statusCode, null, modelId); + } + + private static void callCheckForFailureStatusCode(int statusCode, @Nullable String errorMessage, String modelId) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var httpResponse = mock(HttpResponse.class); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + var header = mock(Header.class); + when(header.getElements()).thenReturn(new HeaderElement[] {}); + when(httpResponse.getFirstHeader(anyString())).thenReturn(header); + + String responseJson = Strings.format(""" + { + "message": "%s" + } + """, errorMessage); + + var mockRequest = mock(Request.class); + when(mockRequest.getInferenceEntityId()).thenReturn(modelId); + var httpResult = new HttpResult(httpResponse, errorMessage == null ? new byte[] {} : responseJson.getBytes(StandardCharsets.UTF_8)); + var handler = new ElasticInferenceServiceResponseHandler("", (request, result) -> null); + + handler.checkForFailureStatusCode(mockRequest, httpResult); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..7b10cf600275c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestEntityTests.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class ElasticInferenceServiceSparseEmbeddingsRequestEntityTests extends ESTestCase { + + public void testToXContent_SingleInput() throws IOException { + var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(List.of("abc")); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "input": ["abc"] + }""")); + } + + public void testToXContent_MultipleInputs() throws IOException { + var entity = new ElasticInferenceServiceSparseEmbeddingsRequestEntity(List.of("abc", "def")); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "input": [ + "abc", + "def" + ] + } + """)); + } + + private String xContentEntityToString(ElasticInferenceServiceSparseEmbeddingsRequestEntity entity) throws IOException { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + return Strings.toString(builder); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..0f2c859fb62d5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequestTests.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class ElasticInferenceServiceSparseEmbeddingsRequestTests extends ESTestCase { + + public void testCreateHttpRequest() throws IOException { + var url = "http://eis-gateway.com"; + var input = "input"; + + var request = createRequest(url, input); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.size(), equalTo(1)); + assertThat(requestMap.get("input"), is(List.of(input))); + } + + public void testTruncate_ReducesInputTextSizeByHalf() throws IOException { + var url = "http://eis-gateway.com"; + var input = "abcd"; + + var request = createRequest(url, input); + var truncatedRequest = request.truncate(); + + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(1)); + assertThat(requestMap.get("input"), is(List.of("ab"))); + } + + public void testIsTruncated_ReturnsTrue() { + var url = "http://eis-gateway.com"; + var input = "abcd"; + + var request = createRequest(url, input); + assertFalse(request.getTruncationInfo()[0]); + + var truncatedRequest = request.truncate(); + assertTrue(truncatedRequest.getTruncationInfo()[0]); + } + + public ElasticInferenceServiceSparseEmbeddingsRequest createRequest(String url, String input) { + var embeddingsModel = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(url); + + return new ElasticInferenceServiceSparseEmbeddingsRequest( + TruncatorTests.createTruncator(), + new Truncator.TruncationResult(List.of(input), new boolean[] { false }), + embeddingsModel + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceErrorResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceErrorResponseEntityTests.java new file mode 100644 index 0000000000000..4da0518084828 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceErrorResponseEntityTests.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.elastic; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class ElasticInferenceServiceErrorResponseEntityTests extends ESTestCase { + + public void testFromResponse() { + String responseJson = """ + { + "error": "error" + } + """; + + ElasticInferenceServiceErrorResponseEntity errorResponseEntity = ElasticInferenceServiceErrorResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertNotNull(errorResponseEntity); + assertThat(errorResponseEntity.getErrorMessage(), is("error")); + } + + public void testFromResponse_NoErrorMessagePresent() { + String responseJson = """ + { + "not_error": "error" + } + """; + + ElasticInferenceServiceErrorResponseEntity errorResponseEntity = ElasticInferenceServiceErrorResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertNull(errorResponseEntity); + } + + public void testFromResponse_InvalidJson() { + String invalidResponseJson = """ + { + """; + + ElasticInferenceServiceErrorResponseEntity errorResponseEntity = ElasticInferenceServiceErrorResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), invalidResponseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertNull(errorResponseEntity); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceSparseEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceSparseEmbeddingsResponseEntityTests.java new file mode 100644 index 0000000000000..6e1994260ca0c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceSparseEmbeddingsResponseEntityTests.java @@ -0,0 +1,241 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.elastic; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.nio.charset.StandardCharsets; +import java.util.List; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ElasticInferenceServiceSparseEmbeddingsResponseEntityTests extends ESTestCase { + + public void testSparseEmbeddingsResponse_SingleEmbeddingInData_NoMeta_NoTruncation() throws Exception { + String responseJson = """ + { + "data": [ + { + "a": 1.23, + "is": 4.56, + "it": 7.89 + } + ] + } + """; + + SparseEmbeddingResults parsedResults = ElasticInferenceServiceSparseEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.embeddings(), + is( + List.of( + SparseEmbeddingResults.Embedding.create( + List.of(new WeightedToken("a", 1.23F), new WeightedToken("is", 4.56F), new WeightedToken("it", 7.89F)), + false + ) + ) + ) + ); + } + + public void testSparseEmbeddingsResponse_MultipleEmbeddingsInData_NoMeta_NoTruncation() throws Exception { + String responseJson = """ + { + "data": [ + { + "a": 1.23, + "is": 4.56, + "it": 7.89 + }, + { + "b": 1.23, + "it": 4.56, + "is": 7.89 + } + ] + } + """; + + SparseEmbeddingResults parsedResults = ElasticInferenceServiceSparseEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.embeddings(), + is( + List.of( + SparseEmbeddingResults.Embedding.create( + List.of(new WeightedToken("a", 1.23F), new WeightedToken("is", 4.56F), new WeightedToken("it", 7.89F)), + false + ), + SparseEmbeddingResults.Embedding.create( + List.of(new WeightedToken("b", 1.23F), new WeightedToken("it", 4.56F), new WeightedToken("is", 7.89F)), + false + ) + ) + ) + ); + } + + public void testSparseEmbeddingsResponse_SingleEmbeddingInData_NoMeta_Truncated() throws Exception { + String responseJson = """ + { + "data": [ + { + "a": 1.23, + "is": 4.56, + "it": 7.89 + } + ] + } + """; + + var request = mock(Request.class); + when(request.getTruncationInfo()).thenReturn(new boolean[] { true }); + + SparseEmbeddingResults parsedResults = ElasticInferenceServiceSparseEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.embeddings(), + is( + List.of( + SparseEmbeddingResults.Embedding.create( + List.of(new WeightedToken("a", 1.23F), new WeightedToken("is", 4.56F), new WeightedToken("it", 7.89F)), + true + ) + ) + ) + ); + } + + public void testSparseEmbeddingsResponse_MultipleEmbeddingsInData_NoMeta_Truncated() throws Exception { + String responseJson = """ + { + "data": [ + { + "a": 1.23, + "is": 4.56, + "it": 7.89 + }, + { + "b": 1.23, + "it": 4.56, + "is": 7.89 + } + ] + } + """; + + var request = mock(Request.class); + when(request.getTruncationInfo()).thenReturn(new boolean[] { true, false }); + + SparseEmbeddingResults parsedResults = ElasticInferenceServiceSparseEmbeddingsResponseEntity.fromResponse( + request, + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.embeddings(), + is( + List.of( + SparseEmbeddingResults.Embedding.create( + List.of(new WeightedToken("a", 1.23F), new WeightedToken("is", 4.56F), new WeightedToken("it", 7.89F)), + true + ), + SparseEmbeddingResults.Embedding.create( + List.of(new WeightedToken("b", 1.23F), new WeightedToken("it", 4.56F), new WeightedToken("is", 7.89F)), + false + ) + ) + ) + ); + } + + public void testSparseEmbeddingsResponse_SingleEmbeddingInData_IgnoresMetaBeforeData_NoTruncation() throws Exception { + String responseJson = """ + { + "meta": { + "processing_latency": 1.23 + }, + "data": [ + { + "a": 1.23, + "is": 4.56, + "it": 7.89 + } + ] + } + """; + + SparseEmbeddingResults parsedResults = ElasticInferenceServiceSparseEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.embeddings(), + is( + List.of( + SparseEmbeddingResults.Embedding.create( + List.of(new WeightedToken("a", 1.23F), new WeightedToken("is", 4.56F), new WeightedToken("it", 7.89F)), + false + ) + ) + ) + ); + } + + public void testSparseEmbeddingsResponse_SingleEmbeddingInData_IgnoresMetaAfterData_NoTruncation() throws Exception { + String responseJson = """ + { + "data": [ + { + "a": 1.23, + "is": 4.56, + "it": 7.89 + } + ], + "meta": { + "processing_latency": 1.23 + } + } + """; + + SparseEmbeddingResults parsedResults = ElasticInferenceServiceSparseEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.embeddings(), + is( + List.of( + SparseEmbeddingResults.Embedding.create( + List.of(new WeightedToken("a", 1.23F), new WeightedToken("is", 4.56F), new WeightedToken("it", 7.89F)), + false + ) + ) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java new file mode 100644 index 0000000000000..af13ce7944685 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.elser.ElserModels; + +public class ElasticInferenceServiceSparseEmbeddingsModelTests extends ESTestCase { + + public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String url) { + return createModel(url, null); + } + + public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String url, Integer maxInputTokens) { + return new ElasticInferenceServiceSparseEmbeddingsModel( + "id", + TaskType.SPARSE_EMBEDDING, + "service", + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(ElserModels.ELSER_V2_MODEL, maxInputTokens, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents(url) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..a2b36cf9abdd5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.elser.ElserModels; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.elser.ElserModelsTests.randomElserModel; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase< + ElasticInferenceServiceSparseEmbeddingsServiceSettings> { + + @Override + protected Writeable.Reader instanceReader() { + return ElasticInferenceServiceSparseEmbeddingsServiceSettings::new; + } + + @Override + protected ElasticInferenceServiceSparseEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected ElasticInferenceServiceSparseEmbeddingsServiceSettings mutateInstance( + ElasticInferenceServiceSparseEmbeddingsServiceSettings instance + ) throws IOException { + return randomValueOtherThan(instance, ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests::createRandom); + } + + public void testFromMap() { + var modelId = ElserModels.ELSER_V2_MODEL; + + var serviceSettings = ElasticInferenceServiceSparseEmbeddingsServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + ConfigurationParseContext.REQUEST + ); + + assertThat(serviceSettings, is(new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, null, null))); + } + + public void testFromMap_InvalidElserModelId() { + var invalidModelId = "invalid"; + + ValidationException validationException = expectThrows( + ValidationException.class, + () -> ElasticInferenceServiceSparseEmbeddingsServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, invalidModelId)), + ConfigurationParseContext.REQUEST + ) + ); + + assertThat(validationException.getMessage(), containsString(Strings.format("unknown ELSER model id [%s]", invalidModelId))); + } + + public void testToXContent_WritesAlLFields() throws IOException { + var modelId = ElserModels.ELSER_V1_MODEL; + var maxInputTokens = 10; + var serviceSettings = new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(Strings.format(""" + {"model_id":"%s","max_input_tokens":%d,"rate_limit":{"requests_per_minute":1000}}""", modelId, maxInputTokens))); + } + + public static ElasticInferenceServiceSparseEmbeddingsServiceSettings createRandom() { + return new ElasticInferenceServiceSparseEmbeddingsServiceSettings(randomElserModel(), randomNonNegativeInt(), null); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java new file mode 100644 index 0000000000000..62416f05800c6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -0,0 +1,523 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.elser.ElserModels; +import org.elasticsearch.xpack.inference.services.openai.OpenAiService; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; +import static org.elasticsearch.xpack.inference.Utils.getModelListenerForException; +import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; +import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class ElasticInferenceServiceTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesASparseEmbeddingsModel() throws IOException { + try (var service = createServiceWithMockSender()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(ElasticInferenceServiceSparseEmbeddingsModel.class)); + + var completionModel = (ElasticInferenceServiceSparseEmbeddingsModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(ElserModels.ELSER_V2_MODEL)); + + }, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.SPARSE_EMBEDDING, + getRequestConfigMap(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL), Map.of(), Map.of()), + Set.of(), + modelListener + ); + } + } + + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { + try (var service = createServiceWithMockSender()) { + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "The [elastic] service does not support task type [completion]" + ); + + service.parseRequestConfig( + "id", + TaskType.COMPLETION, + getRequestConfigMap(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL), Map.of(), Map.of()), + Set.of(), + failureListener + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createServiceWithMockSender()) { + var config = getRequestConfigMap(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL), Map.of(), Map.of()); + config.put("extra_key", "value"); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [elastic] service" + ); + service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, Set.of(), failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + try (var service = createServiceWithMockSender()) { + Map serviceSettings = new HashMap<>(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL)); + serviceSettings.put("extra_key", "value"); + + var config = getRequestConfigMap(serviceSettings, Map.of(), Map.of()); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [elastic] service" + ); + service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, Set.of(), failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + try (var service = createServiceWithMockSender()) { + var taskSettings = Map.of("extra_key", (Object) "value"); + + var config = getRequestConfigMap(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL), taskSettings, Map.of()); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [elastic] service" + ); + service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, Set.of(), failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + try (var service = createServiceWithMockSender()) { + var secretSettings = Map.of("extra_key", (Object) "value"); + + var config = getRequestConfigMap(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL), Map.of(), secretSettings); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [elastic] service" + ); + service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, Set.of(), failureListener); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesASparseEmbeddingModel() throws IOException { + try (var service = createServiceWithMockSender()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL)), + Map.of(), + Map.of() + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.SPARSE_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(ElasticInferenceServiceSparseEmbeddingsModel.class)); + + var sparseEmbeddingsModel = (ElasticInferenceServiceSparseEmbeddingsModel) model; + assertThat(sparseEmbeddingsModel.getServiceSettings().modelId(), is(ElserModels.ELSER_V2_MODEL)); + assertThat(sparseEmbeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(sparseEmbeddingsModel.getSecretSettings(), is(EmptySecretSettings.INSTANCE)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createServiceWithMockSender()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL)), + Map.of(), + Map.of() + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.SPARSE_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(ElasticInferenceServiceSparseEmbeddingsModel.class)); + + var completionModel = (ElasticInferenceServiceSparseEmbeddingsModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(ElserModels.ELSER_V2_MODEL)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(completionModel.getSecretSettings(), is(EmptySecretSettings.INSTANCE)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try (var service = createServiceWithMockSender()) { + Map serviceSettingsMap = new HashMap<>(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL)); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(serviceSettingsMap, Map.of(), Map.of()); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.SPARSE_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(ElasticInferenceServiceSparseEmbeddingsModel.class)); + + var completionModel = (ElasticInferenceServiceSparseEmbeddingsModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(ElserModels.ELSER_V2_MODEL)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(completionModel.getSecretSettings(), is(EmptySecretSettings.INSTANCE)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try (var service = createServiceWithMockSender()) { + var taskSettings = Map.of("extra_key", (Object) "value"); + + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL)), + taskSettings, + Map.of() + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.SPARSE_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(ElasticInferenceServiceSparseEmbeddingsModel.class)); + + var completionModel = (ElasticInferenceServiceSparseEmbeddingsModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(ElserModels.ELSER_V2_MODEL)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(completionModel.getSecretSettings(), is(EmptySecretSettings.INSTANCE)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + try (var service = createServiceWithMockSender()) { + var secretSettingsMap = Map.of("extra_key", (Object) "value"); + + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL)), + Map.of(), + secretSettingsMap + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.SPARSE_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(ElasticInferenceServiceSparseEmbeddingsModel.class)); + + var completionModel = (ElasticInferenceServiceSparseEmbeddingsModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(ElserModels.ELSER_V2_MODEL)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(completionModel.getSecretSettings(), is(EmptySecretSettings.INSTANCE)); + } + } + + public void testCheckModelConfig_ReturnsNewModelReference() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer)); + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + + var returnedModel = listener.actionGet(TIMEOUT); + assertThat(returnedModel, is(ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer)))); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name"); + + try ( + var service = new ElasticInferenceService( + factory, + createWithEmptySettings(threadPool), + new ElasticInferenceServiceComponents(null) + ) + ) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + List.of(""), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testInfer_ThrowsWhenQueryIsPresent() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try ( + var service = new ElasticInferenceService( + senderFactory, + createWithEmptySettings(threadPool), + new ElasticInferenceServiceComponents(getUrl(webServer)) + ) + ) { + var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer)); + + PlainActionFuture listener = new PlainActionFuture<>(); + UnsupportedOperationException exception = expectThrows( + UnsupportedOperationException.class, + () -> service.infer( + model, + "should throw", + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ) + ); + + assertThat(exception.getMessage(), is("Query input not supported for Elastic Inference Service")); + } + } + + public void testInfer_SendsEmbeddingsRequest() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var eisGatewayUrl = getUrl(webServer); + + try ( + var service = new ElasticInferenceService( + senderFactory, + createWithEmptySettings(threadPool), + new ElasticInferenceServiceComponents(eisGatewayUrl) + ) + ) { + String responseJson = """ + { + "data": [ + { + "hello": 2.1259406, + "greet": 1.7073475 + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(eisGatewayUrl); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("input text"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + var result = listener.actionGet(TIMEOUT); + + assertThat( + result.asMap(), + Matchers.is( + SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings( + List.of( + new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("hello", 2.1259406f, "greet", 1.7073475f), false) + ) + ) + ) + ); + var request = webServer.requests().get(0); + assertNull(request.getUri().getQuery()); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), Matchers.equalTo(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(request.getBody()); + assertThat(requestMap, is(Map.of("input", List.of("input text")))); + } + } + + public void testChunkedInfer_PassesThrough() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var eisGatewayUrl = getUrl(webServer); + + try ( + var service = new ElasticInferenceService( + senderFactory, + createWithEmptySettings(threadPool), + new ElasticInferenceServiceComponents(eisGatewayUrl) + ) + ) { + String responseJson = """ + { + "data": [ + { + "hello": 2.1259406, + "greet": 1.7073475 + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(eisGatewayUrl); + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + List.of("input text"), + new HashMap<>(), + InputType.INGEST, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + MatcherAssert.assertThat( + results.get(0).asMap(), + Matchers.is( + Map.of( + InferenceChunkedSparseEmbeddingResults.FIELD_NAME, + List.of( + Map.of( + ChunkedNlpInferenceResults.TEXT, + "input text", + ChunkedNlpInferenceResults.INFERENCE, + Map.of("hello", 2.1259406f, "greet", 1.7073475f) + ) + ) + ) + ) + ); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap, is(Map.of("input", List.of("input text")))); + } + } + + private ElasticInferenceService createServiceWithMockSender() { + return new ElasticInferenceService( + mock(HttpRequestSender.Factory.class), + createWithEmptySettings(threadPool), + new ElasticInferenceServiceComponents(null) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java index ec753b9bec887..ffbdf1a5a6178 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java @@ -16,12 +16,12 @@ import java.io.IOException; import java.util.HashSet; +import static org.elasticsearch.xpack.inference.services.elser.ElserModelsTests.randomElserModel; + public class ElserInternalServiceSettingsTests extends AbstractWireSerializingTestCase { public static ElserInternalServiceSettings createRandom() { - return new ElserInternalServiceSettings( - ElasticsearchInternalServiceSettingsTests.validInstance(randomFrom(ElserInternalService.VALID_ELSER_MODEL_IDS)) - ); + return new ElserInternalServiceSettings(ElasticsearchInternalServiceSettingsTests.validInstance(randomElserModel())); } public void testBwcWrite() throws IOException { @@ -67,7 +67,7 @@ protected ElserInternalServiceSettings mutateInstance(ElserInternalServiceSettin ) ); case 2 -> { - var versions = new HashSet<>(ElserInternalService.VALID_ELSER_MODEL_IDS); + var versions = new HashSet<>(ElserModels.VALID_ELSER_MODEL_IDS); versions.remove(instance.modelId()); yield new ElserInternalServiceSettings( new ElasticsearchInternalServiceSettings( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java index f950e515a5336..85add1a0090c8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java @@ -171,7 +171,7 @@ public void testParseConfigStrictWithNoTaskSettings() { "foo", TaskType.SPARSE_EMBEDDING, ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ElserInternalService.ELSER_V2_MODEL, null), + new ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null), ElserMlNodeTaskSettings.DEFAULT ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserModelsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserModelsTests.java new file mode 100644 index 0000000000000..f56e941dcc8c0 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserModelsTests.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elser; + +import org.elasticsearch.test.ESTestCase; + +public class ElserModelsTests extends ESTestCase { + + public static String randomElserModel() { + return randomFrom(ElserModels.VALID_ELSER_MODEL_IDS); + } + + public void testIsValidModel() { + assertTrue(ElserModels.isValidModel(randomElserModel())); + } + + public void testIsValidEisModel() { + assertTrue(ElserModels.isValidEisModel(ElserModels.ELSER_V2_MODEL)); + } + + public void testIsInvalidModel() { + assertFalse(ElserModels.isValidModel("invalid")); + } + + public void testIsInvalidEisModel() { + assertFalse(ElserModels.isValidEisModel(ElserModels.ELSER_V2_MODEL_LINUX_X86)); + } +}