From 61bd21c0c12e04d235f064e22efdc7e0566413c0 Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 27 May 2024 08:08:10 +0100 Subject: [PATCH 01/29] Fix shutdown-metadata-related flakiness in `SnapshotStressTestsIT` (#109049) We must not mark the master for shutdown since this triggers a master failover, and we must keep all the blocks in place until the mark/unmark sequence is complete. Also adds some logging around shutdown metadata manipulation. Found while investigating #108907 although this doesn't fix that issue. --- .../snapshots/SnapshotStressTestsIT.java | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotStressTestsIT.java b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotStressTestsIT.java index f70b86fd4fba2..3f43da20fec3e 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotStressTestsIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotStressTestsIT.java @@ -1189,6 +1189,12 @@ private void startNodeShutdownMarker() { final var clusterService = cluster.getCurrentMasterNodeInstance(ClusterService.class); + if (node.nodeName.equals(clusterService.localNode().getName())) { + return; + } + + logger.info("--> marking [{}] for removal", node); + SubscribableListener .newForked( @@ -1252,12 +1258,15 @@ public void onFailure(Exception e) { @Override public void clusterStateProcessed(ClusterState initialState, ClusterState newState) { l.onResponse(null); + logger.info("--> unmarked [{}] for removal", node); } } ) ) - .addListener(mustSucceed(ignored -> startNodeShutdownMarker())); + .addListener( + ActionListener.releaseAfter(mustSucceed(ignored -> startNodeShutdownMarker()), localReleasables.transfer()) + ); rerun = false; } finally { From 1b4a057bf752fd01d3e824912b413900dd34561b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Fred=C3=A9n?= <109296772+jfreden@users.noreply.github.com> Date: Mon, 27 May 2024 09:33:03 +0200 Subject: [PATCH 02/29] Add await security migration to SecuritySingleNodeTestCase (#109024) This addresses: https://github.com/elastic/elasticsearch/issues/109010 https://github.com/elastic/elasticsearch/issues/109023 https://github.com/elastic/elasticsearch/issues/109011 Where the `teardown` in `ESSingleNodeTestCase` is failing intermittently because of `assertThat(searchService.getActiveContexts(), equalTo(0));`. This is because the migration hasn't closed its search context yet. --- .../test/SecuritySingleNodeTestCase.java | 24 +++++++++++++++++++ ...ervedRealmElasticAutoconfigIntegTests.java | 5 +++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/test/SecuritySingleNodeTestCase.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/test/SecuritySingleNodeTestCase.java index 77ae4ab838585..16a3ea53eeeac 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/test/SecuritySingleNodeTestCase.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/test/SecuritySingleNodeTestCase.java @@ -14,6 +14,8 @@ import org.elasticsearch.client.RestClient; import org.elasticsearch.client.RestClientBuilder; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.network.NetworkAddress; import org.elasticsearch.common.settings.MockSecureSettings; import org.elasticsearch.common.settings.SecureString; @@ -36,10 +38,13 @@ import java.util.Collection; import java.util.Collections; import java.util.Map; +import java.util.concurrent.CountDownLatch; import java.util.stream.Collectors; +import static org.elasticsearch.persistent.PersistentTasksCustomMetadata.getTaskWithId; import static org.elasticsearch.test.SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING; import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; +import static org.elasticsearch.xpack.core.security.support.SecurityMigrationTaskParams.TASK_NAME; import static org.hamcrest.Matchers.hasItem; /** @@ -77,12 +82,31 @@ public static void destroyDefaultSettings() { @Override public void tearDown() throws Exception { + awaitSecurityMigration(); super.tearDown(); if (resetNodeAfterTest()) { tearDownRestClient(); } } + private boolean isMigrationComplete(ClusterState state) { + return getTaskWithId(state, TASK_NAME) == null; + } + + protected void awaitSecurityMigration() { + final var latch = new CountDownLatch(1); + ClusterService clusterService = getInstanceFromNode(ClusterService.class); + clusterService.addListener((event) -> { + if (isMigrationComplete(event.state())) { + latch.countDown(); + } + }); + if (isMigrationComplete(clusterService.state())) { + latch.countDown(); + } + safeAwait(latch); + } + private static void tearDownRestClient() { if (restClient != null) { IOUtils.closeWhileHandlingException(restClient); diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/esnative/ReservedRealmElasticAutoconfigIntegTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/esnative/ReservedRealmElasticAutoconfigIntegTests.java index 1cd3cfa3a5870..c04630d457959 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/esnative/ReservedRealmElasticAutoconfigIntegTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/esnative/ReservedRealmElasticAutoconfigIntegTests.java @@ -79,6 +79,8 @@ public void testAutoconfigFailedPasswordPromotion() { if (getIndexResponse.getIndices().length > 0) { assertThat(getIndexResponse.getIndices().length, is(1)); assertThat(getIndexResponse.getIndices()[0], is(TestRestrictedIndices.INTERNAL_SECURITY_MAIN_INDEX_7)); + // Security migration needs to finish before deleting the index + awaitSecurityMigration(); DeleteIndexRequest deleteIndexRequest = new DeleteIndexRequest(getIndexResponse.getIndices()); assertAcked(client().admin().indices().delete(deleteIndexRequest).actionGet()); } @@ -137,6 +139,8 @@ public void testAutoconfigSucceedsAfterPromotionFailure() throws Exception { putUserRequest.passwordHash(Hasher.PBKDF2.hash(password)); putUserRequest.roles(Strings.EMPTY_ARRAY); client().execute(PutUserAction.INSTANCE, putUserRequest).get(); + // Security migration needs to finish before making the cluster read only + awaitSecurityMigration(); // but then make the cluster read-only ClusterUpdateSettingsRequest updateSettingsRequest = new ClusterUpdateSettingsRequest(); @@ -160,7 +164,6 @@ public void testAutoconfigSucceedsAfterPromotionFailure() throws Exception { restRequest.setOptions(options); ResponseException exception = expectThrows(ResponseException.class, () -> getRestClient().performRequest(restRequest)); assertThat(exception.getResponse().getStatusLine().getStatusCode(), is(RestStatus.SERVICE_UNAVAILABLE.getStatus())); - // clear cluster-wide write block updateSettingsRequest = new ClusterUpdateSettingsRequest(); updateSettingsRequest.transientSettings( From b1d54e624576bc34c70a0484772a88a576c05b0e Mon Sep 17 00:00:00 2001 From: Kostas Krikellas <131142368+kkrik-es@users.noreply.github.com> Date: Mon, 27 May 2024 10:47:22 +0300 Subject: [PATCH 03/29] Store source for fields in objects with `dynamic` override (#108911) Covers setting `dynamic` to `false` or `runtime`. Related to https://github.com/elastic/elasticsearch/issues/106825, https://github.com/elastic/elasticsearch/pull/108417 --- docs/changelog/108911.yaml | 5 + .../indices.create/20_synthetic_source.yml | 183 +++++++++++++++ .../index/mapper/DocumentParser.java | 141 +++++++++--- .../index/mapper/DocumentParserContext.java | 8 +- .../mapper/IgnoredSourceFieldMapperTests.java | 216 ++++++++++++++++++ 5 files changed, 514 insertions(+), 39 deletions(-) create mode 100644 docs/changelog/108911.yaml diff --git a/docs/changelog/108911.yaml b/docs/changelog/108911.yaml new file mode 100644 index 0000000000000..8832e01f7426e --- /dev/null +++ b/docs/changelog/108911.yaml @@ -0,0 +1,5 @@ +pr: 108911 +summary: Store source for fields in objects with `dynamic` override +area: Mapping +type: enhancement +issues: [] diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.create/20_synthetic_source.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.create/20_synthetic_source.yml index 4b4900dce6504..a763d6e457490 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.create/20_synthetic_source.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.create/20_synthetic_source.yml @@ -653,6 +653,189 @@ nested object array next to other fields: - match: { hits.hits.0._source.f: "2000" } +--- +object with dynamic override: + - requires: + cluster_features: ["mapper.track_ignored_source"] + reason: requires tracking ignored source + + - do: + indices.create: + index: test + body: + mappings: + _source: + mode: synthetic + properties: + path_no: + dynamic: false + properties: + name: + type: keyword + path_runtime: + dynamic: runtime + properties: + name: + type: keyword + + - do: + bulk: + index: test + refresh: true + body: + - '{ "create": { } }' + - '{ "name": "a", "path_no": { "some_int": 10, "to.a.very.deeply.nested.field": "A", "name": "foo" }, "path_runtime": { "some_int": 20, "to.a.very.deeply.nested.field": "B", "name": "bar" } }' + + - do: + search: + index: test + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._source.name: a } + - match: { hits.hits.0._source.path_no.name: foo } + - match: { hits.hits.0._source.path_no.some_int: 10 } + - match: { hits.hits.0._source.path_no.to.a.very.deeply.nested.field: A } + - match: { hits.hits.0._source.path_runtime.name: bar } + - match: { hits.hits.0._source.path_runtime.some_int: 20 } + - match: { hits.hits.0._source.path_runtime.to.a.very.deeply.nested.field: B } + + +--- +subobject with dynamic override: + - requires: + cluster_features: ["mapper.track_ignored_source"] + reason: requires tracking ignored source + + - do: + indices.create: + index: test + body: + mappings: + _source: + mode: synthetic + properties: + path: + properties: + to_no: + dynamic: false + properties: + name: + type: keyword + to_runtime: + dynamic: runtime + properties: + name: + type: keyword + + - do: + bulk: + index: test + refresh: true + body: + - '{ "create": { } }' + - '{ "name": "a", "path": { "some_int": 10, "to_no": { "some_text": "A", "name": "foo" }, "to_runtime": { "some_text": "B", "name": "bar" } } }' + + - do: + search: + index: test + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._source.name: a } + - match: { hits.hits.0._source.path.some_int: 10 } + - match: { hits.hits.0._source.path.to_no.name: foo } + - match: { hits.hits.0._source.path.to_no.some_text: A } + - match: { hits.hits.0._source.path.to_runtime.name: bar } + - match: { hits.hits.0._source.path.to_runtime.some_text: B } + + +--- +object array in object with dynamic override: + - requires: + cluster_features: ["mapper.track_ignored_source"] + reason: requires tracking ignored source + + - do: + indices.create: + index: test + body: + mappings: + _source: + mode: synthetic + properties: + path_no: + dynamic: false + properties: + name: + type: keyword + path_runtime: + dynamic: runtime + properties: + name: + type: keyword + + - do: + bulk: + index: test + refresh: true + body: + - '{ "create": { } }' + - '{ "path_no": [ { "some_int": 10 }, {"name": "foo"} ], "path_runtime": [ { "some_int": 20 }, {"name": "bar"} ], "name": "baz" }' + + - do: + search: + index: test + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._source.name: baz } + - match: { hits.hits.0._source.path_no.0.some_int: 10 } + - match: { hits.hits.0._source.path_no.1.name: foo } + - match: { hits.hits.0._source.path_runtime.0.some_int: 20 } + - match: { hits.hits.0._source.path_runtime.1.name: bar } + + +--- +value array in object with dynamic override: + - requires: + cluster_features: ["mapper.track_ignored_source"] + reason: requires tracking ignored source + + - do: + indices.create: + index: test + body: + mappings: + _source: + mode: synthetic + properties: + path_no: + dynamic: false + properties: + name: + type: keyword + path_runtime: + dynamic: runtime + properties: + name: + type: keyword + + - do: + bulk: + index: test + refresh: true + body: + - '{ "create": { } }' + - '{ "path_no": { "values": [ "A", "B" ] }, "path_runtime": { "values": [ "C", "D" ] }, "name": "foo" }' + + - do: + search: + index: test + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._source.name: foo } + - match: { hits.hits.0._source.path_no.values: [ A, B] } + - match: { hits.hits.0._source.path_runtime.values: [ C, D] } + + --- nested object: - requires: diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java index dbc098b6ce2ae..a89a89472a678 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java @@ -127,7 +127,7 @@ private static void internalParseDocument(MetadataFieldMapper[] metadataFieldsMa if (context.root().isEnabled() == false) { // entire type is disabled - if (context.mappingLookup().isSourceSynthetic()) { + if (context.canAddIgnoredField()) { context.addIgnoredField( new IgnoredSourceFieldMapper.NameValue( MapperService.SINGLE_MAPPING_NAME, @@ -263,7 +263,7 @@ static void parseObjectOrNested(DocumentParserContext context) throws IOExceptio String currentFieldName = parser.currentName(); if (context.parent().isEnabled() == false) { // entire type is disabled - if (context.mappingLookup().isSourceSynthetic()) { + if (context.canAddIgnoredField()) { context.addIgnoredField( new IgnoredSourceFieldMapper.NameValue( context.parent().fullPath(), @@ -431,8 +431,7 @@ static void parseObjectOrField(DocumentParserContext context, Mapper mapper) thr parseObjectOrNested(context.createFlattenContext(currentFieldName)); context.path().add(currentFieldName); } else { - if (context.mappingLookup().isSourceSynthetic() - && fieldMapper.syntheticSourceMode() == FieldMapper.SyntheticSourceMode.FALLBACK) { + if (context.canAddIgnoredField() && fieldMapper.syntheticSourceMode() == FieldMapper.SyntheticSourceMode.FALLBACK) { Tuple contextWithSourceToStore = XContentDataHelper.cloneSubContext(context); context.addIgnoredField( @@ -521,14 +520,37 @@ private static void parseObjectDynamic(DocumentParserContext context, String cur ensureNotStrict(context, currentFieldName); if (context.dynamic() == ObjectMapper.Dynamic.FALSE) { failIfMatchesRoutingPath(context, currentFieldName); - // not dynamic, read everything up to end object - context.parser().skipChildren(); + if (context.canAddIgnoredField()) { + // read everything up to end object and store it + context.addIgnoredField( + IgnoredSourceFieldMapper.NameValue.fromContext( + context, + context.path().pathAsText(currentFieldName), + XContentDataHelper.encodeToken(context.parser()) + ) + ); + } else { + // not dynamic, read everything up to end object + context.parser().skipChildren(); + } } else { Mapper dynamicObjectMapper; if (context.dynamic() == ObjectMapper.Dynamic.RUNTIME) { // with dynamic:runtime all leaf fields will be runtime fields unless explicitly mapped, // hence we don't dynamically create empty objects under properties, but rather carry around an artificial object mapper dynamicObjectMapper = new NoOpObjectMapper(currentFieldName, context.path().pathAsText(currentFieldName)); + if (context.canAddIgnoredField()) { + // Clone the DocumentParserContext to parse its subtree twice. + Tuple tuple = XContentDataHelper.cloneSubContext(context); + context.addIgnoredField( + IgnoredSourceFieldMapper.NameValue.fromContext( + context, + context.path().pathAsText(currentFieldName), + XContentDataHelper.encodeXContentBuilder(tuple.v2()) + ) + ); + context = tuple.v1(); + } } else { dynamicObjectMapper = DynamicFieldsBuilder.createDynamicObjectMapper(context, currentFieldName); } @@ -592,23 +614,33 @@ private static void parseArray(DocumentParserContext context, String lastFieldNa private static void parseArrayDynamic(DocumentParserContext context, String currentFieldName) throws IOException { ensureNotStrict(context, currentFieldName); if (context.dynamic() == ObjectMapper.Dynamic.FALSE) { - context.parser().skipChildren(); - } else { - Mapper objectMapperFromTemplate = DynamicFieldsBuilder.createObjectMapperFromTemplate(context, currentFieldName); - if (objectMapperFromTemplate == null) { - parseNonDynamicArray(context, objectMapperFromTemplate, currentFieldName, currentFieldName); + if (context.canAddIgnoredField()) { + context.addIgnoredField( + IgnoredSourceFieldMapper.NameValue.fromContext( + context, + context.path().pathAsText(currentFieldName), + XContentDataHelper.encodeToken(context.parser()) + ) + ); } else { - if (parsesArrayValue(objectMapperFromTemplate)) { - if (context.addDynamicMapper(objectMapperFromTemplate) == false) { - context.parser().skipChildren(); - return; - } - context.path().add(currentFieldName); - parseObjectOrField(context, objectMapperFromTemplate); - context.path().remove(); - } else { - parseNonDynamicArray(context, objectMapperFromTemplate, currentFieldName, currentFieldName); + context.parser().skipChildren(); + } + return; + } + Mapper objectMapperFromTemplate = DynamicFieldsBuilder.createObjectMapperFromTemplate(context, currentFieldName); + if (objectMapperFromTemplate == null) { + parseNonDynamicArray(context, objectMapperFromTemplate, currentFieldName, currentFieldName); + } else { + if (parsesArrayValue(objectMapperFromTemplate)) { + if (context.addDynamicMapper(objectMapperFromTemplate) == false) { + context.parser().skipChildren(); + return; } + context.path().add(currentFieldName); + parseObjectOrField(context, objectMapperFromTemplate); + context.path().remove(); + } else { + parseNonDynamicArray(context, objectMapperFromTemplate, currentFieldName, currentFieldName); } } } @@ -624,12 +656,14 @@ private static void parseNonDynamicArray( String arrayFieldName ) throws IOException { // Check if we need to record the array source. This only applies to synthetic source. - if (context.mappingLookup().isSourceSynthetic() && context.getClonedSource() == false) { - boolean storeArraySourceEnabled = mapper instanceof ObjectMapper objectMapper && objectMapper.storeArraySource(); + if (context.canAddIgnoredField()) { + boolean objectRequiresStoringSource = mapper instanceof ObjectMapper objectMapper + && (objectMapper.storeArraySource() || objectMapper.dynamic == ObjectMapper.Dynamic.RUNTIME); boolean fieldWithFallbackSyntheticSource = mapper instanceof FieldMapper fieldMapper && fieldMapper.syntheticSourceMode() == FieldMapper.SyntheticSourceMode.FALLBACK; - if (storeArraySourceEnabled || fieldWithFallbackSyntheticSource || mapper instanceof NestedObjectMapper) { - // Clone the DocumentParserContext to parse its subtree twice. + boolean nestedObject = mapper instanceof NestedObjectMapper; + boolean dynamicRuntimeContext = context.dynamic() == ObjectMapper.Dynamic.RUNTIME; + if (objectRequiresStoringSource || fieldWithFallbackSyntheticSource || nestedObject || dynamicRuntimeContext) { Tuple tuple = XContentDataHelper.cloneSubContext(context); context.addIgnoredField( IgnoredSourceFieldMapper.NameValue.fromContext( @@ -639,16 +673,17 @@ private static void parseNonDynamicArray( ) ); context = tuple.v1(); - } else if (mapper instanceof ObjectMapper objectMapper && objectMapper.isEnabled() == false) { - context.addIgnoredField( - IgnoredSourceFieldMapper.NameValue.fromContext( - context, - context.path().pathAsText(arrayFieldName), - XContentDataHelper.encodeToken(context.parser()) - ) - ); - return; - } + } else if (mapper instanceof ObjectMapper objectMapper + && (objectMapper.isEnabled() == false || objectMapper.dynamic == ObjectMapper.Dynamic.FALSE)) { + context.addIgnoredField( + IgnoredSourceFieldMapper.NameValue.fromContext( + context, + context.path().pathAsText(arrayFieldName), + XContentDataHelper.encodeToken(context.parser()) + ) + ); + return; + } } XContentParser parser = context.parser(); @@ -746,12 +781,30 @@ private static void parseNullValue(DocumentParserContext context, String lastFie } } - private static void parseDynamicValue(final DocumentParserContext context, String currentFieldName) throws IOException { + private static void parseDynamicValue(DocumentParserContext context, String currentFieldName) throws IOException { ensureNotStrict(context, currentFieldName); if (context.dynamic() == ObjectMapper.Dynamic.FALSE) { failIfMatchesRoutingPath(context, currentFieldName); + if (context.canAddIgnoredField()) { + context.addIgnoredField( + IgnoredSourceFieldMapper.NameValue.fromContext( + context, + context.path().pathAsText(currentFieldName), + XContentDataHelper.encodeToken(context.parser()) + ) + ); + } return; } + if (context.dynamic() == ObjectMapper.Dynamic.RUNTIME && context.canAddIgnoredField()) { + context.addIgnoredField( + IgnoredSourceFieldMapper.NameValue.fromContext( + context, + context.path().pathAsText(currentFieldName), + XContentDataHelper.encodeToken(context.parser()) + ) + ); + } if (context.dynamic().getDynamicFieldsBuilder().createDynamicFieldFromValue(context, currentFieldName) == false) { failIfMatchesRoutingPath(context, currentFieldName); } @@ -846,7 +899,21 @@ public Query termQuery(Object value, SearchExecutionContext context) { @Override protected void parseCreateField(DocumentParserContext context) { - // field defined as runtime field, don't index anything + // Run-time fields are mapped to this mapper, so it needs to handle storing values for use in synthetic source. + // #parseValue calls this method once the run-time field is created. + if (context.dynamic() == ObjectMapper.Dynamic.RUNTIME && context.canAddIgnoredField()) { + try { + context.addIgnoredField( + IgnoredSourceFieldMapper.NameValue.fromContext( + context, + context.path().pathAsText(context.parser().currentName()), + XContentDataHelper.encodeToken(context.parser()) + ) + ); + } catch (IOException e) { + throw new IllegalArgumentException("failed to parse run-time field under [" + context.path().pathAsText("") + " ]", e); + } + } } @Override diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java index 628610474eeb8..fe1ad85d6a7c1 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java @@ -269,7 +269,7 @@ public final Collection getIgnoredFields() { * Add the given ignored values to the corresponding list. */ public final void addIgnoredField(IgnoredSourceFieldMapper.NameValue values) { - if (clonedSource == false) { + if (canAddIgnoredField()) { // Skip tracking the source for this field twice, it's already tracked for the entire parsing subcontext. ignoredFieldValues.add(values); } @@ -327,6 +327,10 @@ final boolean getClonedSource() { return clonedSource; } + final boolean canAddIgnoredField() { + return mappingLookup.isSourceSynthetic() && clonedSource == false; + } + /** * Description on the document being parsed used in error messages. Not * called unless there is an error. @@ -384,7 +388,7 @@ public final boolean addDynamicMapper(Mapper mapper) { int additionalFieldsToAdd = getNewFieldsSize() + mapperSize; if (indexSettings().isIgnoreDynamicFieldsBeyondLimit()) { if (mappingLookup.exceedsLimit(indexSettings().getMappingTotalFieldsLimit(), additionalFieldsToAdd)) { - if (mappingLookup.isSourceSynthetic()) { + if (canAddIgnoredField()) { try { addIgnoredField( IgnoredSourceFieldMapper.NameValue.fromContext( diff --git a/server/src/test/java/org/elasticsearch/index/mapper/IgnoredSourceFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/IgnoredSourceFieldMapperTests.java index e9e07350a129a..71a0e001dc72a 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/IgnoredSourceFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/IgnoredSourceFieldMapperTests.java @@ -776,4 +776,220 @@ public void testNestedObjectIncludeInRoot() throws IOException { assertEquals(""" {"path":{"foo":"A","bar":"B"}}""", syntheticSource); } + + public void testNoDynamicObjectSingleField() throws IOException { + String name = randomAlphaOfLength(20); + DocumentMapper documentMapper = createMapperService(syntheticSourceMapping(b -> { + b.startObject("path").field("type", "object").field("dynamic", "false").endObject(); + })).documentMapper(); + var syntheticSource = syntheticSource(documentMapper, b -> { + b.startObject("path"); + { + b.field("name", name); + } + b.endObject(); + }); + assertEquals(String.format(Locale.ROOT, """ + {"path":{"name":"%s"}}""", name), syntheticSource); + } + + public void testNoDynamicObjectManyFields() throws IOException { + boolean booleanValue = randomBoolean(); + int intValue = randomInt(); + String stringValue = randomAlphaOfLength(20); + + DocumentMapper documentMapper = createMapperService(syntheticSourceMapping(b -> { + b.startObject("boolean_value").field("type", "boolean").endObject(); + b.startObject("path").field("type", "object").field("dynamic", "false"); + { + b.startObject("properties"); + { + b.startObject("string_value").field("type", "keyword").endObject(); + } + b.endObject(); + } + b.endObject(); + })).documentMapper(); + + var syntheticSource = syntheticSource(documentMapper, b -> { + b.field("boolean_value", booleanValue); + b.startObject("path"); + { + b.field("int_value", intValue); + b.startObject("to"); + { + b.startObject("some"); + { + b.startObject("deeply"); + { + b.startObject("nested"); + b.field("string_value", stringValue); + b.endObject(); + } + b.endObject(); + } + b.endObject(); + } + b.field("string_value", stringValue); + b.endObject(); + } + b.endObject(); + }); + + assertEquals(String.format(Locale.ROOT, """ + {"boolean_value":%s,"path":{"int_value":%s,"to":{"some":{"deeply":{"nested":{"string_value":"%s"}}},\ + "string_value":"%s"}}}""", booleanValue, intValue, stringValue, stringValue), syntheticSource); + } + + public void testNoDynamicObjectSimpleArray() throws IOException { + DocumentMapper documentMapper = createMapperService(syntheticSourceMapping(b -> { + b.startObject("path").field("type", "object").field("dynamic", "false").endObject(); + })).documentMapper(); + var syntheticSource = syntheticSource(documentMapper, b -> { + b.startArray("path"); + { + b.startObject().field("name", "foo").endObject(); + b.startObject().field("name", "bar").endObject(); + } + b.endArray(); + }); + assertEquals(""" + {"path":[{"name":"foo"},{"name":"bar"}]}""", syntheticSource); + } + + public void testNoDynamicObjectSimpleValueArray() throws IOException { + DocumentMapper documentMapper = createMapperService(syntheticSourceMapping(b -> { + b.startObject("path").field("type", "object").field("dynamic", "false").endObject(); + })).documentMapper(); + var syntheticSource = syntheticSource( + documentMapper, + b -> { b.startObject("path").array("name", "A", "B", "C", "D").endObject(); } + ); + assertEquals(""" + {"path":{"name":["A","B","C","D"]}}""", syntheticSource); + } + + public void testNoDynamicObjectNestedArray() throws IOException { + DocumentMapper documentMapper = createMapperService(syntheticSourceMapping(b -> { + b.startObject("path").field("type", "object").field("dynamic", "false").endObject(); + })).documentMapper(); + var syntheticSource = syntheticSource(documentMapper, b -> { + b.startArray("path"); + { + b.startObject().startObject("to").field("foo", "A").field("bar", "B").endObject().endObject(); + b.startObject().startObject("to").field("foo", "C").field("bar", "D").endObject().endObject(); + } + b.endArray(); + }); + assertEquals(""" + {"path":[{"to":{"foo":"A","bar":"B"}},{"to":{"foo":"C","bar":"D"}}]}""", syntheticSource); + } + + public void testRuntimeDynamicObjectSingleField() throws IOException { + String name = randomAlphaOfLength(20); + DocumentMapper documentMapper = createMapperService(syntheticSourceMapping(b -> { + b.startObject("path").field("type", "object").field("dynamic", "runtime").endObject(); + })).documentMapper(); + var syntheticSource = syntheticSource(documentMapper, b -> { + b.startObject("path"); + { + b.field("name", name); + } + b.endObject(); + }); + assertEquals(String.format(Locale.ROOT, """ + {"path":{"name":"%s"}}""", name), syntheticSource); + } + + public void testRuntimeDynamicObjectManyFields() throws IOException { + boolean booleanValue = randomBoolean(); + int intValue = randomInt(); + String stringValue = randomAlphaOfLength(20); + + DocumentMapper documentMapper = createMapperService(syntheticSourceMapping(b -> { + b.startObject("boolean_value").field("type", "boolean").endObject(); + b.startObject("path").field("type", "object").field("dynamic", "runtime"); + { + b.startObject("properties"); + { + b.startObject("string_value").field("type", "keyword").endObject(); + } + b.endObject(); + } + b.endObject(); + })).documentMapper(); + + var syntheticSource = syntheticSource(documentMapper, b -> { + b.field("boolean_value", booleanValue); + b.startObject("path"); + { + b.field("int_value", intValue); + b.startObject("to"); + { + b.startObject("some"); + { + b.startObject("deeply"); + { + b.startObject("nested"); + b.field("string_value", stringValue); + b.endObject(); + } + b.endObject(); + } + b.endObject(); + } + b.field("string_value", stringValue); + b.endObject(); + } + b.endObject(); + }); + + assertEquals(String.format(Locale.ROOT, """ + {"boolean_value":%s,"path":{"int_value":%s,"to":{"some":{"deeply":{"nested":{"string_value":"%s"}}},\ + "string_value":"%s"}}}""", booleanValue, intValue, stringValue, stringValue), syntheticSource); + } + + public void testRuntimeDynamicObjectSimpleArray() throws IOException { + DocumentMapper documentMapper = createMapperService(syntheticSourceMapping(b -> { + b.startObject("path").field("type", "object").field("dynamic", "runtime").endObject(); + })).documentMapper(); + var syntheticSource = syntheticSource(documentMapper, b -> { + b.startArray("path"); + { + b.startObject().field("name", "foo").endObject(); + b.startObject().field("name", "bar").endObject(); + } + b.endArray(); + }); + assertEquals(""" + {"path":[{"name":"foo"},{"name":"bar"}]}""", syntheticSource); + } + + public void testRuntimeDynamicObjectSimpleValueArray() throws IOException { + DocumentMapper documentMapper = createMapperService(syntheticSourceMapping(b -> { + b.startObject("path").field("type", "object").field("dynamic", "runtime").endObject(); + })).documentMapper(); + var syntheticSource = syntheticSource( + documentMapper, + b -> { b.startObject("path").array("name", "A", "B", "C", "D").endObject(); } + ); + assertEquals(""" + {"path":{"name":["A","B","C","D"]}}""", syntheticSource); + } + + public void testRuntimeDynamicObjectNestedArray() throws IOException { + DocumentMapper documentMapper = createMapperService(syntheticSourceMapping(b -> { + b.startObject("path").field("type", "object").field("dynamic", "runtime").endObject(); + })).documentMapper(); + var syntheticSource = syntheticSource(documentMapper, b -> { + b.startArray("path"); + { + b.startObject().startObject("to").field("foo", "A").field("bar", "B").endObject().endObject(); + b.startObject().startObject("to").field("foo", "C").field("bar", "D").endObject().endObject(); + } + b.endArray(); + }); + assertEquals(""" + {"path":[{"to":{"foo":"A","bar":"B"}},{"to":{"foo":"C","bar":"D"}}]}""", syntheticSource); + } } From 843579652c4f1b1842ef44ffbb496995f00b5b71 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 27 May 2024 08:51:21 +0100 Subject: [PATCH 04/29] [ML] Refactor the Embedding classes (#107516) Extract and generic-ised the Embedding classes for reuse and so the EmbeddingRequestChunker can handle different types of embedding. --- .../core/inference/results/ByteEmbedding.java | 98 ++++++++++ .../ChunkedSparseEmbeddingResults.java | 18 +- .../ChunkedTextEmbeddingByteResults.java | 80 +------- .../ChunkedTextEmbeddingFloatResults.java | 76 +------- .../results/ChunkedTextEmbeddingResults.java | 2 +- .../core/inference/results/Embedding.java | 72 +++++++ .../inference/results/EmbeddingChunk.java | 79 ++++++++ .../core/inference/results/EmbeddingInt.java | 12 -- .../inference/results/EmbeddingResults.java | 39 ++++ .../inference/results/FloatEmbedding.java | 98 ++++++++++ .../inference/results/SparseEmbedding.java | 170 +++++++++++++++++ .../results/SparseEmbeddingResults.java | 68 +------ .../core/inference/results/TextEmbedding.java | 2 +- .../results/TextEmbeddingByteResults.java | 104 ++--------- .../results/TextEmbeddingResults.java | 101 ++-------- .../inference/results/TextEmbeddingUtils.java | 2 +- .../MlInferenceNamedXContentProvider.java | 4 + .../results/TextEmbeddingByteResults.java | 100 ++++++++++ .../TextEmbeddingByteResultsTests.java | 61 ++++++ .../TestDenseInferenceServiceExtension.java | 5 +- .../TestSparseInferenceServiceExtension.java | 9 +- .../common/EmbeddingRequestChunker.java | 141 ++++++++++++-- .../CohereEmbeddingsResponseEntity.java | 13 +- .../HuggingFaceElserResponseEntity.java | 14 +- .../HuggingFaceEmbeddingsResponseEntity.java | 10 +- .../OpenAiEmbeddingsResponseEntity.java | 7 +- .../common/EmbeddingRequestChunkerTests.java | 175 +++++++++++++++--- ...AiStudioEmbeddingsResponseEntityTests.java | 3 +- .../CohereEmbeddingsResponseEntityTests.java | 38 +--- ...gingFaceEmbeddingsResponseEntityTests.java | 33 +--- .../OpenAiEmbeddingsResponseEntityTests.java | 23 +-- .../rest/RestInferenceActionTests.java | 5 +- .../ChunkedSparseEmbeddingResultsTests.java | 5 +- .../ChunkedTextEmbeddingByteResultsTests.java | 21 +-- ...ChunkedTextEmbeddingFloatResultsTests.java | 8 +- .../ChunkedTextEmbeddingResultsTests.java | 5 +- .../results/SparseEmbeddingResultsTests.java | 46 ++--- .../TextEmbeddingByteResultsTests.java | 32 ++-- .../results/TextEmbeddingResultsTests.java | 32 ++-- .../inference/services/ServiceUtilsTests.java | 4 +- .../services/cohere/CohereServiceTests.java | 6 +- .../services/openai/OpenAiServiceTests.java | 6 +- 42 files changed, 1213 insertions(+), 614 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteEmbedding.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/Embedding.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingChunk.java delete mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingInt.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingResults.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatEmbedding.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbedding.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingByteResults.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingByteResultsTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteEmbedding.java new file mode 100644 index 0000000000000..9fb017b75abb5 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ByteEmbedding.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +public class ByteEmbedding extends Embedding { + + public static ByteEmbedding of(List embedding) { + byte[] embeddingBytes = new byte[embedding.size()]; + for (int i = 0; i < embedding.size(); i++) { + embeddingBytes[i] = embedding.get(i); + } + return new ByteEmbedding(embeddingBytes); + } + + /** + * Wrapper so around a primitive byte array so that it can be + * treated as a generic + */ + public static class ByteArrayWrapper implements EmbeddingValues { + + final byte[] bytes; + + public ByteArrayWrapper(byte[] bytes) { + this.bytes = bytes; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return valuesToXContent(EMBEDDING, builder, params); + } + + @Override + public int size() { + return bytes.length; + } + + @Override + public XContentBuilder valuesToXContent(String fieldName, XContentBuilder builder, Params params) throws IOException { + builder.startArray(fieldName); + for (var value : bytes) { + builder.value(value); + } + builder.endArray(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ByteArrayWrapper that = (ByteArrayWrapper) o; + return Arrays.equals(bytes, that.bytes); + } + + @Override + public int hashCode() { + return Arrays.hashCode(bytes); + } + } + + public ByteEmbedding(StreamInput in) throws IOException { + this(in.readByteArray()); + } + + public ByteEmbedding(byte[] embedding) { + super(new ByteArrayWrapper(embedding)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeByteArray(embedding.bytes); + } + + public byte[] bytes() { + return embedding.bytes; + } + + public float[] toFloatArray() { + float[] floatArray = new float[embedding.bytes.length]; + for (int i = 0; i < embedding.bytes.length; i++) { + floatArray[i] = ((Byte) embedding.bytes[i]).floatValue(); + } + return floatArray; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedSparseEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedSparseEmbeddingResults.java index c91d0dc6fd538..5b57a8da9d37c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedSparseEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedSparseEmbeddingResults.java @@ -49,7 +49,7 @@ public static List of(List inputs, Spars return results; } - public static ChunkedSparseEmbeddingResults of(String input, SparseEmbeddingResults.Embedding embedding) { + public static ChunkedSparseEmbeddingResults of(String input, SparseEmbedding embedding) { var weightedTokens = embedding.tokens() .stream() .map(weightedToken -> new WeightedToken(weightedToken.token(), weightedToken.weight())) @@ -58,6 +58,22 @@ public static ChunkedSparseEmbeddingResults of(String input, SparseEmbeddingResu return new ChunkedSparseEmbeddingResults(List.of(new ChunkedTextExpansionResults.ChunkedResult(input, weightedTokens))); } + public static ChunkedSparseEmbeddingResults of(List> embeddingChunks) { + var ch = embeddingChunks.stream() + .map( + chunk -> new ChunkedTextExpansionResults.ChunkedResult( + chunk.matchedText(), + chunk.embedding().embedding.tokens() + .stream() + .map(weightedToken -> new WeightedToken(weightedToken.token(), weightedToken.weight())) + .toList() + ) + ) + .toList(); + + return new ChunkedSparseEmbeddingResults(ch); + } + private final List chunkedResults; public ChunkedSparseEmbeddingResults(List chunks) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedTextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedTextEmbeddingByteResults.java index 86ea70ddd62dd..b88b502a76195 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedTextEmbeddingByteResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedTextEmbeddingByteResults.java @@ -7,26 +7,22 @@ package org.elasticsearch.xpack.core.inference.results; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Map; -import java.util.Objects; import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; -public record ChunkedTextEmbeddingByteResults(List chunks, boolean isTruncated) implements ChunkedInferenceServiceResults { +public record ChunkedTextEmbeddingByteResults(List> chunks, boolean isTruncated) + implements + ChunkedInferenceServiceResults { public static final String NAME = "chunked_text_embedding_service_byte_results"; public static final String FIELD_NAME = "text_embedding_byte_chunk"; @@ -41,23 +37,22 @@ public static List of(List inputs, TextE var results = new ArrayList(inputs.size()); for (int i = 0; i < inputs.size(); i++) { - results.add(of(inputs.get(i), textEmbeddings.embeddings().get(i).values())); + results.add(of(inputs.get(i), textEmbeddings.embeddings().get(i).getEmbedding().bytes)); } return results; } public static ChunkedTextEmbeddingByteResults of(String input, byte[] byteEmbeddings) { - return new ChunkedTextEmbeddingByteResults(List.of(new EmbeddingChunk(input, byteEmbeddings)), false); + return new ChunkedTextEmbeddingByteResults(List.of(new EmbeddingChunk<>(input, new ByteEmbedding(byteEmbeddings))), false); } public ChunkedTextEmbeddingByteResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(EmbeddingChunk::new), in.readBoolean()); + this(in.readCollectionAsList(in1 -> new EmbeddingChunk<>(in1.readString(), new ByteEmbedding(in1))), in.readBoolean()); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - // TODO add isTruncated flag builder.startArray(FIELD_NAME); for (var embedding : chunks) { embedding.toXContent(builder, params); @@ -92,68 +87,7 @@ public String getWriteableName() { return NAME; } - public List getChunks() { + public List> getChunks() { return chunks; } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - ChunkedTextEmbeddingByteResults that = (ChunkedTextEmbeddingByteResults) o; - return isTruncated == that.isTruncated && Objects.equals(chunks, that.chunks); - } - - @Override - public int hashCode() { - return Objects.hash(chunks, isTruncated); - } - - public record EmbeddingChunk(String matchedText, byte[] embedding) implements Writeable, ToXContentObject { - - public EmbeddingChunk(StreamInput in) throws IOException { - this(in.readString(), in.readByteArray()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(matchedText); - out.writeByteArray(embedding); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(ChunkedNlpInferenceResults.TEXT, matchedText); - - builder.startArray(ChunkedNlpInferenceResults.INFERENCE); - for (byte value : embedding) { - builder.value(value); - } - builder.endArray(); - - builder.endObject(); - return builder; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - EmbeddingChunk that = (EmbeddingChunk) o; - return Objects.equals(matchedText, that.matchedText) && Arrays.equals(embedding, that.embedding); - } - - @Override - public int hashCode() { - int result = Objects.hash(matchedText); - result = 31 * result + Arrays.hashCode(embedding); - return result; - } - } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedTextEmbeddingFloatResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedTextEmbeddingFloatResults.java index 4fcd5a53fc287..65370277aba34 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedTextEmbeddingFloatResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedTextEmbeddingFloatResults.java @@ -7,35 +7,30 @@ package org.elasticsearch.xpack.core.inference.results; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; import java.io.IOException; -import java.util.Arrays; import java.util.List; import java.util.Map; -import java.util.Objects; -public record ChunkedTextEmbeddingFloatResults(List chunks) implements ChunkedInferenceServiceResults { +public record ChunkedTextEmbeddingFloatResults(List> chunks) + implements + ChunkedInferenceServiceResults { public static final String NAME = "chunked_text_embedding_service_float_results"; public static final String FIELD_NAME = "text_embedding_float_chunk"; public ChunkedTextEmbeddingFloatResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(EmbeddingChunk::new)); + this(in.readCollectionAsList(in1 -> new EmbeddingChunk<>(in1.readString(), new FloatEmbedding(in1)))); } @Override public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { - // TODO add isTruncated flag builder.startArray(FIELD_NAME); for (var embedding : chunks) { embedding.toXContent(builder, params); @@ -69,69 +64,8 @@ public String getWriteableName() { return NAME; } - public List getChunks() { + public List> getChunks() { return chunks; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - ChunkedTextEmbeddingFloatResults that = (ChunkedTextEmbeddingFloatResults) o; - return Objects.equals(chunks, that.chunks); - } - - @Override - public int hashCode() { - return Objects.hash(chunks); - } - - public record EmbeddingChunk(String matchedText, float[] embedding) implements Writeable, ToXContentObject { - - public EmbeddingChunk(StreamInput in) throws IOException { - this(in.readString(), in.readFloatArray()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(matchedText); - out.writeFloatArray(embedding); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(ChunkedNlpInferenceResults.TEXT, matchedText); - - builder.startArray(ChunkedNlpInferenceResults.INFERENCE); - for (float value : embedding) { - builder.value(value); - } - builder.endArray(); - - builder.endObject(); - return builder; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - EmbeddingChunk that = (EmbeddingChunk) o; - return Objects.equals(matchedText, that.matchedText) && Arrays.equals(embedding, that.embedding); - } - - @Override - public int hashCode() { - int result = Objects.hash(matchedText); - result = 31 * result + Arrays.hashCode(embedding); - return result; - } - } - } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedTextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedTextEmbeddingResults.java index f09eafc1591dd..5553230cb7e9f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedTextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedTextEmbeddingResults.java @@ -44,7 +44,7 @@ public static List of(List inputs, TextE var results = new ArrayList(inputs.size()); for (int i = 0; i < inputs.size(); i++) { - results.add(ChunkedTextEmbeddingResults.of(inputs.get(i), textEmbeddings.embeddings().get(i).values())); + results.add(ChunkedTextEmbeddingResults.of(inputs.get(i), textEmbeddings.embeddings().get(i).getEmbedding().floats)); } return results; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/Embedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/Embedding.java new file mode 100644 index 0000000000000..75eeb6ff01f73 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/Embedding.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public abstract class Embedding implements Writeable, ToXContentObject { + public static final String EMBEDDING = "embedding"; + + public interface EmbeddingValues extends ToXContentFragment { + int size(); + + XContentBuilder valuesToXContent(String fieldName, XContentBuilder builder, Params params) throws IOException; + } + + protected final T embedding; + + protected Embedding(T embedding) { + this.embedding = embedding; + } + + public T getEmbedding() { + return embedding; + } + + public Map asMap() { + return Map.of(EMBEDDING, embedding); + } + + public int getSize() { + return embedding.size(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + embedding.valuesToXContent(EMBEDDING, builder, params); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Embedding embedding1 = (Embedding) o; + return Objects.equals(embedding, embedding1.embedding); + } + + @Override + public int hashCode() { + return Objects.hash(embedding); + } + + @Override + public String toString() { + return Strings.toString(this); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingChunk.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingChunk.java new file mode 100644 index 0000000000000..9f0f3bc7fe952 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingChunk.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +public class EmbeddingChunk implements Writeable, ToXContentObject { + + private final String matchedText; + private final Embedding embedding; + + public EmbeddingChunk(String matchedText, Embedding embedding) { + this.matchedText = matchedText; + this.embedding = embedding; + } + + public String matchedText() { + return matchedText; + } + + public Embedding embedding() { + return embedding; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(matchedText); + embedding.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ChunkedNlpInferenceResults.TEXT, matchedText); + embedding.embedding.valuesToXContent(ChunkedNlpInferenceResults.INFERENCE, builder, params); + builder.endObject(); + return builder; + } + + public Map asMap() { + var map = new HashMap(); + map.put(ChunkedNlpInferenceResults.TEXT, matchedText); + map.put(ChunkedNlpInferenceResults.INFERENCE, embedding.getEmbedding()); + return map; + } + + @Override + public String toString() { + return Strings.toString(this); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + EmbeddingChunk that = (EmbeddingChunk) o; + return Objects.equals(matchedText, that.matchedText) && Objects.equals(embedding, that.embedding); + } + + @Override + public int hashCode() { + return Objects.hash(matchedText, embedding); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingInt.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingInt.java deleted file mode 100644 index 05fc8a3cef1b6..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingInt.java +++ /dev/null @@ -1,12 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -public interface EmbeddingInt { - int getSize(); -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingResults.java new file mode 100644 index 0000000000000..851a320c262a5 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/EmbeddingResults.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.inference.InferenceServiceResults; + +import java.util.List; + +public interface EmbeddingResults { + List> embeddings(); + + EmbeddingType embeddingType(); + + enum EmbeddingType { + SPARSE { + public Class matchedClass() { + return SparseEmbeddingResults.class; + }; + }, + FLOAT { + public Class matchedClass() { + return TextEmbeddingResults.class; + }; + }, + + BYTE { + public Class matchedClass() { + return TextEmbeddingByteResults.class; + }; + }; + + public abstract Class matchedClass(); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatEmbedding.java new file mode 100644 index 0000000000000..18bdff1a6f47d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/FloatEmbedding.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +public class FloatEmbedding extends Embedding { + + public static FloatEmbedding of(List embedding) { + float[] embeddingFloats = new float[embedding.size()]; + for (int i = 0; i < embedding.size(); i++) { + embeddingFloats[i] = embedding.get(i); + } + return new FloatEmbedding(embeddingFloats); + } + + public static class FloatArrayWrapper implements EmbeddingValues { + + final float[] floats; + + public FloatArrayWrapper(float[] floats) { + this.floats = floats; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return valuesToXContent(EMBEDDING, builder, params); + } + + @Override + public int size() { + return floats.length; + } + + @Override + public XContentBuilder valuesToXContent(String fieldName, XContentBuilder builder, Params params) throws IOException { + builder.startArray(fieldName); + for (var value : floats) { + builder.value(value); + } + builder.endArray(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FloatArrayWrapper that = (FloatArrayWrapper) o; + return Arrays.equals(floats, that.floats); + } + + @Override + public int hashCode() { + return Arrays.hashCode(floats); + } + } + + public FloatEmbedding(StreamInput in) throws IOException { + this(in.readFloatArray()); + } + + public FloatEmbedding(float[] embedding) { + super(new FloatArrayWrapper(embedding)); + } + + public float[] asFloatArray() { + return embedding.floats; + } + + public double[] asDoubleArray() { + var result = new double[embedding.floats.length]; + for (int i = 0; i < embedding.floats.length; i++) { + result[i] = embedding.floats[i]; + } + return result; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeFloatArray(embedding.floats); + } + + public static FloatEmbedding of(org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults embeddingResult) { + return new FloatEmbedding(embeddingResult.getInferenceAsFloat()); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbedding.java new file mode 100644 index 0000000000000..18b93dd1fef8d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbedding.java @@ -0,0 +1,170 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +public class SparseEmbedding extends Embedding { + + public static final String IS_TRUNCATED = "is_truncated"; + + public static SparseEmbedding fromMlResults( + List weightedTokens, + boolean isTruncated + ) { + return new SparseEmbedding( + new WeightedTokens(weightedTokens.stream().map(token -> new WeightedToken(token.token(), token.weight())).toList()), + isTruncated + ); + } + + private final boolean isTruncated; + + public SparseEmbedding(StreamInput in) throws IOException { + this(new WeightedTokens(in.readCollectionAsImmutableList(SparseEmbedding.WeightedToken::new)), in.readBoolean()); + } + + public SparseEmbedding(WeightedTokens embedding, boolean isTruncated) { + super(embedding); + this.isTruncated = isTruncated; + } + + public SparseEmbedding(List tokens, boolean isTruncated) { + super(new WeightedTokens(tokens)); + this.isTruncated = isTruncated; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(embedding.tokens); + out.writeBoolean(isTruncated); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(IS_TRUNCATED, isTruncated); + embedding.toXContent(builder, params); + builder.endObject(); + return builder; + } + + public boolean isTruncated() { + return isTruncated; + } + + public Map asMap() { + var embeddingMap = new LinkedHashMap( + embedding.tokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight)) + ); + + return new LinkedHashMap<>(Map.of(IS_TRUNCATED, isTruncated, EMBEDDING, embeddingMap)); + } + + public List tokens() { + return embedding.tokens; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + SparseEmbedding that = (SparseEmbedding) o; + return isTruncated == that.isTruncated; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), isTruncated); + } + + public static class WeightedTokens implements Embedding.EmbeddingValues { + private final List tokens; + + public WeightedTokens(List tokens) { + this.tokens = tokens; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return valuesToXContent(EMBEDDING, builder, params); + } + + @Override + public int size() { + return tokens.size(); + } + + @Override + public XContentBuilder valuesToXContent(String fieldName, XContentBuilder builder, Params params) throws IOException { + builder.startObject(fieldName); + for (var weightedToken : tokens) { + weightedToken.toXContent(builder, params); + } + builder.endObject(); + return builder; + } + + public List tokens() { + return tokens; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WeightedTokens that = (WeightedTokens) o; + return Objects.equals(tokens, that.tokens); + } + + @Override + public int hashCode() { + return Objects.hash(tokens); + } + } + + public record WeightedToken(String token, float weight) implements Writeable, ToXContentFragment { + public WeightedToken(StreamInput in) throws IOException { + this(in.readString(), in.readFloat()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(token); + out.writeFloat(weight); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(token, weight); + return builder; + } + + public Map asMap() { + return Map.of(token, weight); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java index 1db6dcc802d00..f108f9e86416c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java @@ -8,15 +8,12 @@ package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; @@ -26,25 +23,24 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; -public record SparseEmbeddingResults(List embeddings) implements InferenceServiceResults { +public record SparseEmbeddingResults(List embeddings) implements InferenceServiceResults, EmbeddingResults { public static final String NAME = "sparse_embedding_results"; public static final String SPARSE_EMBEDDING = TaskType.SPARSE_EMBEDDING.toString(); public SparseEmbeddingResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(Embedding::new)); + this(in.readCollectionAsList(SparseEmbedding::new)); } public static SparseEmbeddingResults of(List results) { - List embeddings = new ArrayList<>(results.size()); + List embeddings = new ArrayList<>(results.size()); for (InferenceResults result : results) { if (result instanceof TextExpansionResults expansionResults) { - embeddings.add(Embedding.create(expansionResults.getWeightedTokens(), expansionResults.isTruncated())); + embeddings.add(SparseEmbedding.fromMlResults(expansionResults.getWeightedTokens(), expansionResults.isTruncated())); } else if (result instanceof org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults errorResult) { if (errorResult.getException() instanceof ElasticsearchStatusException statusException) { throw statusException; @@ -71,7 +67,7 @@ public static SparseEmbeddingResults of(List results public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startArray(SPARSE_EMBEDDING); - for (Embedding embedding : embeddings) { + for (var embedding : embeddings) { embedding.toXContent(builder, params); } @@ -112,60 +108,14 @@ public List transformToLegacyFormat() { .stream() .map(weightedToken -> new WeightedToken(weightedToken.token(), weightedToken.weight())) .toList(), - embedding.isTruncated + embedding.isTruncated() ) ) .toList(); } - public record Embedding(List tokens, boolean isTruncated) implements Writeable, ToXContentObject { - - public static final String EMBEDDING = "embedding"; - public static final String IS_TRUNCATED = "is_truncated"; - - public Embedding(StreamInput in) throws IOException { - this(in.readCollectionAsList(WeightedToken::new), in.readBoolean()); - } - - public static Embedding create(List weightedTokens, boolean isTruncated) { - return new Embedding( - weightedTokens.stream().map(token -> new WeightedToken(token.token(), token.weight())).toList(), - isTruncated - ); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(tokens); - out.writeBoolean(isTruncated); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(IS_TRUNCATED, isTruncated); - builder.startObject(EMBEDDING); - - for (var weightedToken : tokens) { - weightedToken.toXContent(builder, params); - } - - builder.endObject(); - builder.endObject(); - return builder; - } - - public Map asMap() { - var embeddingMap = new LinkedHashMap( - tokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight)) - ); - - return new LinkedHashMap<>(Map.of(IS_TRUNCATED, isTruncated, EMBEDDING, embeddingMap)); - } - - @Override - public String toString() { - return Strings.toString(this); - } + @Override + public EmbeddingType embeddingType() { + return EmbeddingType.SPARSE; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbedding.java index a185c2938223e..ef406f5c956ad 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbedding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbedding.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.core.inference.results; -public interface TextEmbedding { +public interface TextEmbedding extends EmbeddingResults { /** * Returns the first text embedding entry in the result list's array size. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java index 04986b2d957d7..743795d0530f5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java @@ -9,18 +9,14 @@ package org.elasticsearch.xpack.core.inference.results; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -43,12 +39,27 @@ * ] * } */ -public record TextEmbeddingByteResults(List embeddings) implements InferenceServiceResults, TextEmbedding { +public class TextEmbeddingByteResults implements InferenceServiceResults, TextEmbedding { public static final String NAME = "text_embedding_service_byte_results"; public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes"; + private final List embeddings; + + public TextEmbeddingByteResults(List embeddings) { + this.embeddings = embeddings; + } + public TextEmbeddingByteResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(Embedding::new)); + this(in.readCollectionAsList(ByteEmbedding::new)); + } + + public List embeddings() { + return embeddings; + } + + @Override + public EmbeddingType embeddingType() { + return EmbeddingType.BYTE; } @Override @@ -59,7 +70,7 @@ public int getFirstEmbeddingSize() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startArray(TEXT_EMBEDDING_BYTES); - for (Embedding embedding : embeddings) { + for (var embedding : embeddings) { embedding.toXContent(builder, params); } builder.endArray(); @@ -80,9 +91,9 @@ public String getWriteableName() { public List transformToCoordinationFormat() { return embeddings.stream() .map( - embedding -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( + embedding -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingByteResults( TEXT_EMBEDDING_BYTES, - embedding.toDoubleArray(), + embedding.bytes(), false ) ) @@ -114,82 +125,7 @@ public boolean equals(Object o) { return Objects.equals(embeddings, that.embeddings); } - @Override public int hashCode() { return Objects.hash(embeddings); } - - public record Embedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingInt { - public static final String EMBEDDING = "embedding"; - - public Embedding(StreamInput in) throws IOException { - this(in.readByteArray()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeByteArray(values); - } - - public static Embedding of(List embeddingValuesList) { - byte[] embeddingValues = new byte[embeddingValuesList.size()]; - for (int i = 0; i < embeddingValuesList.size(); i++) { - embeddingValues[i] = embeddingValuesList.get(i); - } - return new Embedding(embeddingValues); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - - builder.startArray(EMBEDDING); - for (byte value : values) { - builder.value(value); - } - builder.endArray(); - - builder.endObject(); - return builder; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - private float[] toFloatArray() { - float[] floatArray = new float[values.length]; - for (int i = 0; i < values.length; i++) { - floatArray[i] = ((Byte) values[i]).floatValue(); - } - return floatArray; - } - - private double[] toDoubleArray() { - double[] doubleArray = new double[values.length]; - for (int i = 0; i < values.length; i++) { - doubleArray[i] = ((Byte) values[i]).floatValue(); - } - return doubleArray; - } - - @Override - public int getSize() { - return values().length; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Embedding embedding = (Embedding) o; - return Arrays.equals(values, embedding.values); - } - - @Override - public int hashCode() { - return Arrays.hashCode(values); - } - } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java index 152e10e82d5ba..7652551d9ca1d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java @@ -10,24 +10,19 @@ package org.elasticsearch.xpack.core.inference.results; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.stream.Collectors; /** @@ -47,12 +42,12 @@ * ] * } */ -public record TextEmbeddingResults(List embeddings) implements InferenceServiceResults, TextEmbedding { +public record TextEmbeddingResults(List embeddings) implements InferenceServiceResults, TextEmbedding, EmbeddingResults { public static final String NAME = "text_embedding_service_results"; public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString(); public TextEmbeddingResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(Embedding::new)); + this(in.readCollectionAsList(FloatEmbedding::new)); } @SuppressWarnings("deprecation") @@ -60,16 +55,16 @@ public TextEmbeddingResults(StreamInput in) throws IOException { this( legacyTextEmbeddingResults.embeddings() .stream() - .map(embedding -> new Embedding(embedding.values())) + .map(embedding -> new FloatEmbedding(embedding.values())) .collect(Collectors.toList()) ); } public static TextEmbeddingResults of(List results) { - List embeddings = new ArrayList<>(results.size()); + List embeddings = new ArrayList<>(results.size()); for (InferenceResults result : results) { if (result instanceof org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults embeddingResult) { - embeddings.add(Embedding.of(embeddingResult)); + embeddings.add(FloatEmbedding.of(embeddingResult)); } else if (result instanceof org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults errorResult) { if (errorResult.getException() instanceof ElasticsearchStatusException statusException) { throw statusException; @@ -97,7 +92,7 @@ public int getFirstEmbeddingSize() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startArray(TEXT_EMBEDDING); - for (Embedding embedding : embeddings) { + for (Embedding embedding : embeddings) { embedding.toXContent(builder, params); } builder.endArray(); @@ -131,7 +126,7 @@ public List transformToCoordinationFormat() { @SuppressWarnings("deprecation") public List transformToLegacyFormat() { var legacyEmbedding = new LegacyTextEmbeddingResults( - embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.values)).toList() + embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.asFloatArray())).toList() ); return List.of(legacyEmbedding); @@ -145,86 +140,12 @@ public Map asMap() { } @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - TextEmbeddingResults that = (TextEmbeddingResults) o; - return Objects.equals(embeddings, that.embeddings); + public List embeddings() { + return embeddings; } @Override - public int hashCode() { - return Objects.hash(embeddings); - } - - public record Embedding(float[] values) implements Writeable, ToXContentObject, EmbeddingInt { - public static final String EMBEDDING = "embedding"; - - public Embedding(StreamInput in) throws IOException { - this(in.readFloatArray()); - } - - public static Embedding of(org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults embeddingResult) { - float[] embeddingAsArray = embeddingResult.getInferenceAsFloat(); - return new Embedding(embeddingAsArray); - } - - public static Embedding of(List embeddingValuesList) { - float[] embeddingValues = new float[embeddingValuesList.size()]; - for (int i = 0; i < embeddingValuesList.size(); i++) { - embeddingValues[i] = embeddingValuesList.get(i); - } - return new Embedding(embeddingValues); - } - - @Override - public int getSize() { - return values.length; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeFloatArray(values); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - - builder.startArray(EMBEDDING); - for (float value : values) { - builder.value(value); - } - builder.endArray(); - - builder.endObject(); - return builder; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - private double[] asDoubleArray() { - double[] doubles = new double[values.length]; - for (int i = 0; i < values.length; i++) { - doubles[i] = values[i]; - } - return doubles; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Embedding embedding = (Embedding) o; - return Arrays.equals(values, embedding.values); - } - - @Override - public int hashCode() { - return Arrays.hashCode(values); - } + public EmbeddingType embeddingType() { + return EmbeddingType.FLOAT; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java index 4c68d02264457..d45a0d9731fc9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java @@ -19,7 +19,7 @@ public class TextEmbeddingUtils { * @return the size of the text embedding * @throws IllegalStateException if the list of embeddings is empty */ - public static int getFirstEmbeddingSize(List embeddings) throws IllegalStateException { + public static int getFirstEmbeddingSize(List> embeddings) throws IllegalStateException { if (embeddings.isEmpty()) { throw new IllegalStateException("Embeddings list is empty"); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index a3fb956c3252d..f7fe2d0f6491a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -30,6 +30,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults; import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults; @@ -653,6 +654,9 @@ public List getNamedWriteables() { ); namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, TextExpansionResults.NAME, TextExpansionResults::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new)); + namedWriteables.add( + new NamedWriteableRegistry.Entry(InferenceResults.class, TextEmbeddingByteResults.NAME, TextEmbeddingByteResults::new) + ); namedWriteables.add( new NamedWriteableRegistry.Entry( InferenceResults.class, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingByteResults.java new file mode 100644 index 0000000000000..4871ba208aa6a --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingByteResults.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; + +public class TextEmbeddingByteResults extends NlpInferenceResults { + + public static final String NAME = "text_embedding_byte_result"; + + private final String resultsField; + private final byte[] inference; + + public TextEmbeddingByteResults(String resultsField, byte[] inference, boolean isTruncated) { + super(isTruncated); + this.inference = inference; + this.resultsField = resultsField; + } + + public TextEmbeddingByteResults(StreamInput in) throws IOException { + super(in); + inference = in.readByteArray(); + resultsField = in.readString(); + } + + public String getResultsField() { + return resultsField; + } + + public byte[] getInference() { + return inference; + } + + public float[] getInferenceAsFloat() { + float[] floatArray = new float[inference.length]; + for (int i = 0; i < inference.length; i++) { + floatArray[i] = inference[i]; + } + return floatArray; + } + + @Override + void doXContentBody(XContentBuilder builder, Params params) throws IOException { + builder.field(resultsField, inference); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + void doWriteTo(StreamOutput out) throws IOException { + out.writeByteArray(inference); + out.writeString(resultsField); + } + + @Override + void addMapFields(Map map) { + map.put(resultsField, inference); + } + + @Override + public Map asMap(String outputField) { + var map = super.asMap(outputField); + map.put(outputField, inference); + return map; + } + + @Override + public Object predictedValue() { + throw new UnsupportedOperationException("[" + NAME + "] does not support a single predicted value"); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + TextEmbeddingByteResults that = (TextEmbeddingByteResults) o; + return Objects.equals(resultsField, that.resultsField) && Arrays.equals(inference, that.inference); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), resultsField, Arrays.hashCode(inference)); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingByteResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingByteResultsTests.java new file mode 100644 index 0000000000000..7ea6e531fcbd2 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingByteResultsTests.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.IngestDocument; + +import java.util.Map; + +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; + +public class TextEmbeddingByteResultsTests extends InferenceResultsTestCase { + + public static TextEmbeddingByteResults createRandomResults() { + int columns = randomIntBetween(1, 10); + var arr = new byte[columns]; + for (int i = 0; i < columns; i++) { + arr[i] = randomByte(); + } + + return new TextEmbeddingByteResults(DEFAULT_RESULTS_FIELD, arr, randomBoolean()); + } + + @Override + protected Writeable.Reader instanceReader() { + return TextEmbeddingByteResults::new; + } + + @Override + protected TextEmbeddingByteResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected TextEmbeddingByteResults mutateInstance(TextEmbeddingByteResults instance) { + return null;// TODO implement https://github.com/elastic/elasticsearch/issues/25929 + } + + public void testAsMap() { + TextEmbeddingByteResults testInstance = createTestInstance(); + Map asMap = testInstance.asMap(); + int size = testInstance.isTruncated ? 2 : 1; + assertThat(asMap.keySet(), hasSize(size)); + assertArrayEquals(testInstance.getInference(), (byte[]) asMap.get(DEFAULT_RESULTS_FIELD)); + if (testInstance.isTruncated) { + assertThat(asMap.get("is_truncated"), is(true)); + } + } + + @Override + void assertFieldValues(TextEmbeddingByteResults createdInstance, IngestDocument document, String parentField, String resultsField) { + assertArrayEquals(document.getFieldValue(parentField + resultsField, byte[].class), createdInstance.getInference()); + } +} diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index bb18b71eb3fea..9093598a8f6ff 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -29,6 +29,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.inference.results.FloatEmbedding; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults; @@ -137,14 +138,14 @@ public void chunkedInfer( } private TextEmbeddingResults makeResults(List input, int dimensions) { - List embeddings = new ArrayList<>(); + List embeddings = new ArrayList<>(); for (int i = 0; i < input.size(); i++) { double[] doubleEmbeddings = generateEmbedding(input.get(i), dimensions); List floatEmbeddings = new ArrayList<>(dimensions); for (int j = 0; j < dimensions; j++) { floatEmbeddings.add((float) doubleEmbeddings[j]); } - embeddings.add(TextEmbeddingResults.Embedding.of(floatEmbeddings)); + embeddings.add(FloatEmbedding.of(floatEmbeddings)); } return new TextEmbeddingResults(embeddings); } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 05e85334cff5a..c39999377ae54 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -29,6 +29,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbedding; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; @@ -127,13 +128,13 @@ public void chunkedInfer( } private SparseEmbeddingResults makeResults(List input) { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < input.size(); i++) { - var tokens = new ArrayList(); + var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { - tokens.add(new WeightedToken("feature_" + j, generateEmbedding(input.get(i), j))); + tokens.add(new SparseEmbedding.WeightedToken("feature_" + j, generateEmbedding(input.get(i), j))); } - embeddings.add(new SparseEmbeddingResults.Embedding(tokens, false)); + embeddings.add(new SparseEmbedding(tokens, false)); } return new SparseEmbeddingResults(embeddings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunker.java index 77d03ac660952..6866b04f2347e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunker.java @@ -7,19 +7,29 @@ package org.elasticsearch.xpack.inference.common; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.ByteEmbedding; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.Embedding; +import org.elasticsearch.xpack.core.inference.results.EmbeddingChunk; +import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.FloatEmbedding; +import org.elasticsearch.xpack.core.inference.results.SparseEmbedding; import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; /** @@ -35,6 +45,8 @@ */ public class EmbeddingRequestChunker { + private static final Logger logger = LogManager.getLogger(EmbeddingRequestChunker.class); + public static final int DEFAULT_WORDS_PER_CHUNK = 250; public static final int DEFAULT_CHUNK_OVERLAP = 100; @@ -45,10 +57,12 @@ public class EmbeddingRequestChunker { private final int chunkOverlap; private List> chunkedInputs; - private List>> results; + private List>>> results; private AtomicArray errors; private ActionListener> finalListener; + private AtomicReference firstResultType = new AtomicReference<>(); + public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch) { this.maxNumberOfInputsPerBatch = maxNumberOfInputsPerBatch; this.wordsPerChunk = DEFAULT_WORDS_PER_CHUNK; @@ -160,14 +174,14 @@ private class DebatchingListener implements ActionListener chunks, + AtomicArray>> embeddings + ) { + return switch (embeddingType) { + case FLOAT -> mergeFloatResults(chunks, embeddings); + case BYTE -> mergeByteResults(chunks, embeddings); + case SPARSE -> mergeSparseResults(chunks, embeddings); + }; + } + + private ChunkedTextEmbeddingFloatResults mergeFloatResults( List chunks, - AtomicArray> debatchedResults + AtomicArray>> debatchedResults ) { - var all = new ArrayList(); + var all = new ArrayList(); for (int i = 0; i < debatchedResults.length(); i++) { var subBatch = debatchedResults.get(i); - all.addAll(subBatch); + for (var result : subBatch) { + if (result instanceof FloatEmbedding fe) { + all.add(fe); + } else { + var message = "Unexpected embedding result type [" + + result.getClass().getSimpleName() + + "], expected a float embedding"; + logger.error(message); + throw new IllegalStateException(message); + } + } } assert chunks.size() == all.size(); - var embeddingChunks = new ArrayList(); + var embeddingChunks = new ArrayList>(); for (int i = 0; i < chunks.size(); i++) { - embeddingChunks.add(new ChunkedTextEmbeddingFloatResults.EmbeddingChunk(chunks.get(i), all.get(i).values())); + embeddingChunks.add(new EmbeddingChunk<>(chunks.get(i), all.get(i))); } return new ChunkedTextEmbeddingFloatResults(embeddingChunks); } + + private ChunkedTextEmbeddingByteResults mergeByteResults( + List chunks, + AtomicArray>> debatchedResults + ) { + var all = new ArrayList(); + for (int i = 0; i < debatchedResults.length(); i++) { + var subBatch = debatchedResults.get(i); + for (var result : subBatch) { + if (result instanceof ByteEmbedding be) { + all.add(be); + } else { + var message = "Unexpected embedding result type [" + + result.getClass().getSimpleName() + + "], expected a byte embedding"; + logger.error(message); + throw new IllegalStateException(message); + } + } + } + + assert chunks.size() == all.size(); + + var embeddingChunks = new ArrayList>(); + for (int i = 0; i < chunks.size(); i++) { + embeddingChunks.add(new EmbeddingChunk<>(chunks.get(i), all.get(i))); + } + + return new ChunkedTextEmbeddingByteResults(embeddingChunks, false); + } + + private ChunkedSparseEmbeddingResults mergeSparseResults( + List chunks, + AtomicArray>> debatchedResults + ) { + var all = new ArrayList(); + for (int i = 0; i < debatchedResults.length(); i++) { + var subBatch = debatchedResults.get(i); + for (var result : subBatch) { + if (result instanceof SparseEmbedding se) { + all.add(se); + } else { + var message = "Unexpected embedding result type [" + + result.getClass().getSimpleName() + + "], expected a byte embedding"; + logger.error(message); + throw new IllegalStateException(message); + } + } + } + + assert chunks.size() == all.size(); + + var embeddingChunks = new ArrayList>(); + for (int i = 0; i < chunks.size(); i++) { + embeddingChunks.add(new EmbeddingChunk<>(chunks.get(i), all.get(i))); + } + + return ChunkedSparseEmbeddingResults.of(embeddingChunks); + } } public record BatchRequest(List subBatches) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java index f787c6337d646..7353f9cf1f94e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java @@ -18,6 +18,8 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.ByteEmbedding; +import org.elasticsearch.xpack.core.inference.results.FloatEmbedding; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; @@ -188,11 +190,10 @@ private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser p return new TextEmbeddingByteResults(embeddingList); } - private static TextEmbeddingByteResults.Embedding parseByteArrayEntry(XContentParser parser) throws IOException { + private static ByteEmbedding parseByteArrayEntry(XContentParser parser) throws IOException { XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); List embeddingValuesList = XContentParserUtils.parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry); - - return TextEmbeddingByteResults.Embedding.of(embeddingValuesList); + return ByteEmbedding.of(embeddingValuesList); } private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException { @@ -216,10 +217,10 @@ private static InferenceServiceResults parseFloatEmbeddingsArray(XContentParser return new TextEmbeddingResults(embeddingList); } - private static TextEmbeddingResults.Embedding parseFloatArrayEntry(XContentParser parser) throws IOException { + private static FloatEmbedding parseFloatArrayEntry(XContentParser parser) throws IOException { XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - List embeddingValuesList = XContentParserUtils.parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingFloatEntry); - return TextEmbeddingResults.Embedding.of(embeddingValuesList); + List embeddingValues = XContentParserUtils.parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingFloatEntry); + return FloatEmbedding.of(embeddingValues); } private static Float parseEmbeddingFloatEntry(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java index 270a981a6998d..a18ac3e8e466b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java @@ -13,8 +13,8 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.SparseEmbedding; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -65,7 +65,7 @@ public static SparseEmbeddingResults fromResponse(Request request, HttpResult re moveToFirstToken(jsonParser); var truncationResults = request.getTruncationInfo(); - List parsedEmbeddings = XContentParserUtils.parseList( + List parsedEmbeddings = XContentParserUtils.parseList( jsonParser, (parser, index) -> HuggingFaceElserResponseEntity.parseExpansionResult(truncationResults, parser, index) ); @@ -78,26 +78,24 @@ public static SparseEmbeddingResults fromResponse(Request request, HttpResult re } } - private static SparseEmbeddingResults.Embedding parseExpansionResult(boolean[] truncationResults, XContentParser parser, int index) - throws IOException { + private static SparseEmbedding parseExpansionResult(boolean[] truncationResults, XContentParser parser, int index) throws IOException { XContentParser.Token token = parser.currentToken(); XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser); - List weightedTokens = new ArrayList<>(); + List weightedTokens = new ArrayList<>(); token = parser.nextToken(); while (token != null && token != XContentParser.Token.END_OBJECT) { XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, token, parser); var floatToken = parser.nextToken(); XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, floatToken, parser); - weightedTokens.add(new WeightedToken(parser.currentName(), parser.floatValue())); - + weightedTokens.add(new SparseEmbedding.WeightedToken(parser.currentName(), parser.floatValue())); token = parser.nextToken(); } // prevent an out of bounds if for some reason the truncation list is smaller than the results var isTruncated = truncationResults != null && index < truncationResults.length && truncationResults[index]; - return new SparseEmbeddingResults.Embedding(weightedTokens, isTruncated); + return new SparseEmbedding(weightedTokens, isTruncated); } private HuggingFaceElserResponseEntity() {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java index a3e06b3c2075a..53a2c3e9c7f00 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntity.java @@ -15,6 +15,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.FloatEmbedding; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -92,7 +93,7 @@ public static TextEmbeddingResults fromResponse(Request request, HttpResult resp * sentence-transformers/all-MiniLM-L12-v2 */ private static TextEmbeddingResults parseArrayFormat(XContentParser parser) throws IOException { - List embeddingList = XContentParserUtils.parseList( + List embeddingList = XContentParserUtils.parseList( parser, HuggingFaceEmbeddingsResponseEntity::parseEmbeddingEntry ); @@ -139,7 +140,7 @@ private static TextEmbeddingResults parseArrayFormat(XContentParser parser) thro private static TextEmbeddingResults parseObjectFormat(XContentParser parser) throws IOException { positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingList = XContentParserUtils.parseList( + List embeddingList = XContentParserUtils.parseList( parser, HuggingFaceEmbeddingsResponseEntity::parseEmbeddingEntry ); @@ -147,11 +148,10 @@ private static TextEmbeddingResults parseObjectFormat(XContentParser parser) thr return new TextEmbeddingResults(embeddingList); } - private static TextEmbeddingResults.Embedding parseEmbeddingEntry(XContentParser parser) throws IOException { + private static FloatEmbedding parseEmbeddingEntry(XContentParser parser) throws IOException { XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - List embeddingValuesList = XContentParserUtils.parseList(parser, HuggingFaceEmbeddingsResponseEntity::parseEmbeddingList); - return TextEmbeddingResults.Embedding.of(embeddingValuesList); + return FloatEmbedding.of(embeddingValuesList); } private static float parseEmbeddingList(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java index 39b97014c3619..18d4e78cf0bfe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java @@ -15,6 +15,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.FloatEmbedding; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -83,7 +84,7 @@ public static TextEmbeddingResults fromResponse(Request request, HttpResult resp positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingList = XContentParserUtils.parseList( + List embeddingList = XContentParserUtils.parseList( jsonParser, OpenAiEmbeddingsResponseEntity::parseEmbeddingObject ); @@ -92,7 +93,7 @@ public static TextEmbeddingResults fromResponse(Request request, HttpResult resp } } - private static TextEmbeddingResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { + private static FloatEmbedding parseEmbeddingObject(XContentParser parser) throws IOException { XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); @@ -101,7 +102,7 @@ private static TextEmbeddingResults.Embedding parseEmbeddingObject(XContentParse // parse and discard the rest of the object consumeUntilObjectEnd(parser); - return TextEmbeddingResults.Embedding.of(embeddingValuesList); + return FloatEmbedding.of(embeddingValuesList); } private static float parseEmbeddingList(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunkerTests.java index 164f975cc464f..2dd9a3fde30d1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunkerTests.java @@ -10,8 +10,15 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.ByteEmbedding; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.core.inference.results.FloatEmbedding; +import org.elasticsearch.xpack.core.inference.results.SparseEmbedding; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import java.util.ArrayList; @@ -19,6 +26,7 @@ import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.startsWith; @@ -157,7 +165,7 @@ public void testLongInputChunkedOverMultipleBatches() { } } - public void testMergingListener() { + public void testMergingListenerFloatEmbedding() { int batchSize = 5; int chunkSize = 20; int overlap = 0; @@ -169,7 +177,7 @@ public void testMergingListener() { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); + List inputs = List.of("1st small", passageBuilder.toString()); var finalListener = testListener(); var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); @@ -177,22 +185,22 @@ public void testMergingListener() { // 4 inputs in 2 batches { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < batchSize; i++) { - embeddings.add(new TextEmbeddingResults.Embedding(new float[] { randomFloat() })); + embeddings.add(FloatEmbedding.of(List.of(randomFloat()))); } batches.get(0).listener().onResponse(new TextEmbeddingResults(embeddings)); } { - var embeddings = new ArrayList(); - for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch - embeddings.add(new TextEmbeddingResults.Embedding(new float[] { randomFloat() })); + var embeddings = new ArrayList(); + for (int i = 0; i < 2; i++) { // 2 requests in the 2nd batch + embeddings.add(FloatEmbedding.of(List.of(randomFloat()))); } batches.get(1).listener().onResponse(new TextEmbeddingResults(embeddings)); } assertNotNull(finalListener.results); - assertThat(finalListener.results, hasSize(4)); + assertThat(finalListener.results, hasSize(2)); { var chunkedResult = finalListener.results.get(0); assertThat(chunkedResult, instanceOf(ChunkedTextEmbeddingFloatResults.class)); @@ -213,22 +221,147 @@ public void testMergingListener() { assertThat(chunkedFloatResult.chunks().get(4).matchedText(), startsWith(" passage_input80 ")); assertThat(chunkedFloatResult.chunks().get(5).matchedText(), startsWith(" passage_input100 ")); } + } + + public void testMergingListenerSparseEmbedding() { + int batchSize = 5; + int chunkSize = 20; + int overlap = 0; + // passage will be chunked into batchSize + 1 parts + // and spread over 2 batch requests + int numberOfWordsInPassage = (chunkSize * batchSize) + 5; + + var passageBuilder = new StringBuilder(); + for (int i = 0; i < numberOfWordsInPassage; i++) { + passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace + } + List inputs = List.of("1st small", passageBuilder.toString()); + + var finalListener = testListener(); + var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); + assertThat(batches, hasSize(2)); + + // 4 inputs in 2 batches { - var chunkedResult = finalListener.results.get(2); - assertThat(chunkedResult, instanceOf(ChunkedTextEmbeddingFloatResults.class)); - var chunkedFloatResult = (ChunkedTextEmbeddingFloatResults) chunkedResult; - assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertEquals("2nd small", chunkedFloatResult.chunks().get(0).matchedText()); + var embeddings = new ArrayList(); + for (int i = 0; i < batchSize; i++) { + embeddings.add(new SparseEmbedding(List.of(new SparseEmbedding.WeightedToken("a", 1.0f)), false)); + } + batches.get(0).listener().onResponse(new SparseEmbeddingResults(embeddings)); } { - var chunkedResult = finalListener.results.get(3); - assertThat(chunkedResult, instanceOf(ChunkedTextEmbeddingFloatResults.class)); - var chunkedFloatResult = (ChunkedTextEmbeddingFloatResults) chunkedResult; - assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertEquals("3rd small", chunkedFloatResult.chunks().get(0).matchedText()); + var embeddings = new ArrayList(); + for (int i = 0; i < 2; i++) { // 2 requests in the 2nd batch + embeddings.add(new SparseEmbedding(List.of(new SparseEmbedding.WeightedToken("b", 1.0f)), false)); + } + batches.get(1).listener().onResponse(new SparseEmbeddingResults(embeddings)); + } + + assertNotNull(finalListener.results); + assertThat(finalListener.results, hasSize(2)); + { + var chunkedResult = finalListener.results.get(0); + assertThat(chunkedResult, instanceOf(ChunkedSparseEmbeddingResults.class)); + var chunkedSparseResult = (ChunkedSparseEmbeddingResults) chunkedResult; + assertThat(chunkedSparseResult.getChunkedResults(), hasSize(1)); + assertEquals("1st small", chunkedSparseResult.getChunkedResults().get(0).matchedText()); + } + { + // this is the large input split in multiple chunks + var chunkedResult = finalListener.results.get(1); + assertThat(chunkedResult, instanceOf(ChunkedSparseEmbeddingResults.class)); + var chunkedSparseResult = (ChunkedSparseEmbeddingResults) chunkedResult; + assertThat(chunkedSparseResult.getChunkedResults(), hasSize(6)); + assertThat(chunkedSparseResult.getChunkedResults().get(0).matchedText(), startsWith("passage_input0 ")); + assertThat(chunkedSparseResult.getChunkedResults().get(1).matchedText(), startsWith(" passage_input20 ")); + assertThat(chunkedSparseResult.getChunkedResults().get(2).matchedText(), startsWith(" passage_input40 ")); + assertThat(chunkedSparseResult.getChunkedResults().get(3).matchedText(), startsWith(" passage_input60 ")); + assertThat(chunkedSparseResult.getChunkedResults().get(4).matchedText(), startsWith(" passage_input80 ")); + assertThat(chunkedSparseResult.getChunkedResults().get(5).matchedText(), startsWith(" passage_input100 ")); } } + public void testMergingListenerByteEmbedding() { + int batchSize = 5; + int chunkSize = 20; + int overlap = 0; + // passage will be chunked into batchSize + 1 parts + // and spread over 2 batch requests + int numberOfWordsInPassage = (chunkSize * batchSize) + 5; + + var passageBuilder = new StringBuilder(); + for (int i = 0; i < numberOfWordsInPassage; i++) { + passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace + } + List inputs = List.of("1st small", passageBuilder.toString()); + + var finalListener = testListener(); + var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); + assertThat(batches, hasSize(2)); + + // 2 inputs in 2 batches + { + var embeddings = new ArrayList(); + for (int i = 0; i < batchSize; i++) { + embeddings.add(ByteEmbedding.of(List.of(randomByte()))); + } + batches.get(0).listener().onResponse(new TextEmbeddingByteResults(embeddings)); + } + { + var embeddings = new ArrayList(); + for (int i = 0; i < 2; i++) { // 2 requests in the 2nd batch + embeddings.add(ByteEmbedding.of(List.of(randomByte()))); + } + batches.get(1).listener().onResponse(new TextEmbeddingByteResults(embeddings)); + } + + assertNotNull(finalListener.results); + assertThat(finalListener.results, hasSize(2)); + { + var chunkedResult = finalListener.results.get(0); + assertThat(chunkedResult, instanceOf(ChunkedTextEmbeddingByteResults.class)); + var chunkedByteResult = (ChunkedTextEmbeddingByteResults) chunkedResult; + assertThat(chunkedByteResult.chunks(), hasSize(1)); + assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText()); + } + { + // this is the large input split in multiple chunks + var chunkedResult = finalListener.results.get(1); + assertThat(chunkedResult, instanceOf(ChunkedTextEmbeddingByteResults.class)); + var chunkedByteResult = (ChunkedTextEmbeddingByteResults) chunkedResult; + assertThat(chunkedByteResult.chunks(), hasSize(6)); + assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 ")); + assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 ")); + assertThat(chunkedByteResult.chunks().get(2).matchedText(), startsWith(" passage_input40 ")); + assertThat(chunkedByteResult.chunks().get(3).matchedText(), startsWith(" passage_input60 ")); + assertThat(chunkedByteResult.chunks().get(4).matchedText(), startsWith(" passage_input80 ")); + assertThat(chunkedByteResult.chunks().get(5).matchedText(), startsWith(" passage_input100 ")); + } + } + + public void testDifferentResponseTypes() { + List inputs = List.of("one", "two"); + var result = new AtomicReference>(); + + var listener = ActionListener.>wrap(result::set, e -> fail(e.getMessage())); + + var batches = new EmbeddingRequestChunker(inputs, 1, 100, 0).batchRequestsWithListeners(listener); + assertThat(batches, hasSize(2)); + + batches.get(0).listener().onResponse(new TextEmbeddingByteResults(List.of(ByteEmbedding.of(List.of(randomByte()))))); + batches.get(1).listener().onResponse(new TextEmbeddingResults(List.of(FloatEmbedding.of(List.of(randomFloat()))))); + + assertThat(result.get().get(0), instanceOf(ChunkedTextEmbeddingByteResults.class)); + assertThat(result.get().get(1), instanceOf(ErrorChunkedInferenceResults.class)); + assertThat( + ((ErrorChunkedInferenceResults) result.get().get(1)).getException().getMessage(), + containsString( + "The embedding response types are different. [TextEmbeddingResults] does not match the first response type " + + "[TextEmbeddingByteResults]" + ) + ); + } + public void testListenerErrorsWithWrongNumberOfResponses() { List inputs = List.of("1st small", "2nd small", "3rd small"); @@ -251,9 +384,9 @@ public void onFailure(Exception e) { var batches = new EmbeddingRequestChunker(inputs, 10, 100, 0).batchRequestsWithListeners(listener); assertThat(batches, hasSize(1)); - var embeddings = new ArrayList(); - embeddings.add(new TextEmbeddingResults.Embedding(new float[] { randomFloat() })); - embeddings.add(new TextEmbeddingResults.Embedding(new float[] { randomFloat() })); + var embeddings = new ArrayList(); + embeddings.add(FloatEmbedding.of(List.of(randomFloat()))); + embeddings.add(FloatEmbedding.of(List.of(randomFloat()))); batches.get(0).listener().onResponse(new TextEmbeddingResults(embeddings)); assertEquals("Error the number of embedding responses [2] does not equal the number of requests [3]", failureMessage.get()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioEmbeddingsResponseEntityTests.java index 41768a6814f36..e744b579924c8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioEmbeddingsResponseEntityTests.java @@ -9,6 +9,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.FloatEmbedding; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -55,6 +56,6 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.of(List.of(0.014539449F, -0.015288644F))))); + assertThat(parsedResults.embeddings(), is(List.of(FloatEmbedding.of(List.of(0.014539449F, -0.015288644F))))); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java index d809635aa4f38..d0f6f45974d1e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java @@ -10,6 +10,8 @@ import org.apache.http.HttpResponse; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.ByteEmbedding; +import org.elasticsearch.xpack.core.inference.results.FloatEmbedding; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; @@ -58,7 +60,7 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { MatcherAssert.assertThat(parsedResults, instanceOf(TextEmbeddingResults.class)); MatcherAssert.assertThat( ((TextEmbeddingResults) parsedResults).embeddings(), - is(List.of(new TextEmbeddingResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }))) + is(List.of(FloatEmbedding.of(List.of(-0.0018434525F, 0.01777649F)))) ); } @@ -94,10 +96,7 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - MatcherAssert.assertThat( - parsedResults.embeddings(), - is(List.of(new TextEmbeddingResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }))) - ); + MatcherAssert.assertThat(parsedResults.embeddings(), is(List.of(FloatEmbedding.of(List.of(-0.0018434525F, 0.01777649F))))); } public void testFromResponse_UsesTheFirstValidEmbeddingsEntry() throws IOException { @@ -138,10 +137,7 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntry() throws IOExcepti new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - MatcherAssert.assertThat( - parsedResults.embeddings(), - is(List.of(new TextEmbeddingResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }))) - ); + MatcherAssert.assertThat(parsedResults.embeddings(), is(List.of(FloatEmbedding.of(List.of(-0.0018434525F, 0.01777649F))))); } public void testFromResponse_UsesTheFirstValidEmbeddingsEntryInt8_WithInvalidFirst() throws IOException { @@ -182,10 +178,7 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntryInt8_WithInvalidFir new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - MatcherAssert.assertThat( - parsedResults.embeddings(), - is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 }))) - ); + MatcherAssert.assertThat(parsedResults.embeddings(), is(List.of(ByteEmbedding.of(List.of((byte) -1, (byte) 0))))); } public void testFromResponse_ParsesBytes() throws IOException { @@ -220,10 +213,7 @@ public void testFromResponse_ParsesBytes() throws IOException { new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - MatcherAssert.assertThat( - parsedResults.embeddings(), - is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -1, (byte) 0 }))) - ); + MatcherAssert.assertThat(parsedResults.embeddings(), is(List.of(ByteEmbedding.of(List.of((byte) -1, (byte) 0))))); } public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { @@ -262,12 +252,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException MatcherAssert.assertThat( parsedResults.embeddings(), - is( - List.of( - new TextEmbeddingResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }), - new TextEmbeddingResults.Embedding(new float[] { -0.123F, 0.123F }) - ) - ) + is(List.of(FloatEmbedding.of(List.of(-0.0018434525F, 0.01777649F)), FloatEmbedding.of(List.of(-0.123F, 0.123F)))) ); } @@ -309,12 +294,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw MatcherAssert.assertThat( parsedResults.embeddings(), - is( - List.of( - new TextEmbeddingResults.Embedding(new float[] { -0.0018434525F, 0.01777649F }), - new TextEmbeddingResults.Embedding(new float[] { -0.123F, 0.123F }) - ) - ) + is(List.of(FloatEmbedding.of(List.of(-0.0018434525F, 0.01777649F)), FloatEmbedding.of(List.of(-0.123F, 0.123F)))) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java index 238dab5929139..b69bb37d35c53 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceEmbeddingsResponseEntityTests.java @@ -10,6 +10,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.common.ParsingException; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.FloatEmbedding; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -37,10 +38,7 @@ public void testFromResponse_CreatesResultsForASingleItem_ArrayFormat() throws I new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat( - parsedResults.embeddings(), - is(List.of(new TextEmbeddingResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) - ); + assertThat(parsedResults.embeddings(), is(List.of(FloatEmbedding.of(List.of(0.014539449F, -0.015288644F))))); } public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws IOException { @@ -60,10 +58,7 @@ public void testFromResponse_CreatesResultsForASingleItem_ObjectFormat() throws new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat( - parsedResults.embeddings(), - is(List.of(new TextEmbeddingResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) - ); + assertThat(parsedResults.embeddings(), is(List.of(FloatEmbedding.of(List.of(0.014539449F, -0.015288644F))))); } public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws IOException { @@ -87,12 +82,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ArrayFormat() throws assertThat( parsedResults.embeddings(), - is( - List.of( - new TextEmbeddingResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new TextEmbeddingResults.Embedding(new float[] { 0.0123F, -0.0123F }) - ) - ) + is(List.of(FloatEmbedding.of(List.of(0.014539449F, -0.015288644F)), FloatEmbedding.of(List.of(0.0123F, -0.0123F)))) ); } @@ -119,12 +109,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw assertThat( parsedResults.embeddings(), - is( - List.of( - new TextEmbeddingResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new TextEmbeddingResults.Embedding(new float[] { 0.0123F, -0.0123F }) - ) - ) + is(List.of(FloatEmbedding.of(List.of(0.014539449F, -0.015288644F)), FloatEmbedding.of(List.of(0.0123F, -0.0123F)))) ); } @@ -260,7 +245,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ArrayFormat() throw new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(new float[] { 1.0F })))); + assertThat(parsedResults.embeddings(), is(List.of(FloatEmbedding.of(List.of(1.0F))))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() throws IOException { @@ -279,7 +264,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt_ObjectFormat() thro new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(new float[] { 1.0F })))); + assertThat(parsedResults.embeddings(), is(List.of(FloatEmbedding.of(List.of(1.0F))))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() throws IOException { @@ -296,7 +281,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ArrayFormat() thro new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(new float[] { 4.0294965E10F })))); + assertThat(parsedResults.embeddings(), is(List.of(FloatEmbedding.of(List.of(4.0294965E10F))))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() throws IOException { @@ -315,7 +300,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong_ObjectFormat() thr new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(new float[] { 4.0294965E10F })))); + assertThat(parsedResults.embeddings(), is(List.of(FloatEmbedding.of(List.of(4.0294965E10F))))); } public void testFromResponse_FailsWhenEmbeddingValueIsAnObject_ObjectFormat() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java index 6c38092f509a7..7299f7ab6ef7b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java @@ -10,6 +10,7 @@ import org.apache.http.HttpResponse; import org.elasticsearch.common.ParsingException; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.FloatEmbedding; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -49,10 +50,7 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat( - parsedResults.embeddings(), - is(List.of(new TextEmbeddingResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) - ); + assertThat(parsedResults.embeddings(), is(List.of(FloatEmbedding.of(List.of(0.014539449F, -0.015288644F))))); } public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { @@ -92,12 +90,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException assertThat( parsedResults.embeddings(), - is( - List.of( - new TextEmbeddingResults.Embedding(new float[] { 0.014539449F, -0.015288644F }), - new TextEmbeddingResults.Embedding(new float[] { 0.0123F, -0.0123F }) - ) - ) + is(List.of(FloatEmbedding.of(List.of(0.014539449F, -0.015288644F)), FloatEmbedding.of(List.of(0.0123F, -0.0123F)))) ); } @@ -264,7 +257,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOExceptio new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(new float[] { 1.0F })))); + assertThat(parsedResults.embeddings(), is(List.of(FloatEmbedding.of(List.of(1.0F))))); } public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOException { @@ -293,7 +286,7 @@ public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOExcepti new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(new float[] { 4.0294965E10F })))); + assertThat(parsedResults.embeddings(), is(List.of(FloatEmbedding.of(List.of(4.0294965E10F))))); } public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { @@ -382,9 +375,9 @@ public void testFieldsInDifferentOrderServer() throws IOException { parsedResults.embeddings(), is( List.of( - new TextEmbeddingResults.Embedding(new float[] { -0.9F, 0.5F, 0.3F }), - new TextEmbeddingResults.Embedding(new float[] { 0.1F, 0.5F }), - new TextEmbeddingResults.Embedding(new float[] { 0.5F, 0.5F }) + FloatEmbedding.of(List.of(-0.9F, 0.5F, 0.3F)), + FloatEmbedding.of(List.of(0.1F, 0.5F)), + FloatEmbedding.of(List.of(0.5F, 0.5F)) ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java index 8365ebdfad786..eb9cf8324683f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestInferenceActionTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.test.rest.RestActionTestCase; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ByteEmbedding; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.junit.Before; @@ -75,8 +76,6 @@ public void testUses3SecondTimeoutFromParams() { } private static InferenceAction.Response createResponse() { - return new InferenceAction.Response( - new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -1 }))) - ); + return new InferenceAction.Response(new TextEmbeddingByteResults(List.of(ByteEmbedding.of(List.of((byte) -1))))); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedSparseEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedSparseEmbeddingResultsTests.java index 073a662c1e8f2..392a49be7cd42 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedSparseEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedSparseEmbeddingResultsTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbedding; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults; @@ -73,7 +74,7 @@ public void testToXContent_CreatesTheRightJsonForASingleChunk() { public void testToXContent_CreatesTheRightJsonForASingleChunk_FromSparseEmbeddingResults() { var entity = ChunkedSparseEmbeddingResults.of( List.of("text"), - new SparseEmbeddingResults(List.of(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("token", 0.1f)), false))) + new SparseEmbeddingResults(List.of(new SparseEmbedding(List.of(new SparseEmbedding.WeightedToken("token", 0.1f)), false))) ); assertThat(entity.size(), is(1)); @@ -109,7 +110,7 @@ public void testToXContent_ThrowsWhenInputSizeIsDifferentThanEmbeddings() { IllegalArgumentException.class, () -> ChunkedSparseEmbeddingResults.of( List.of("text", "text2"), - new SparseEmbeddingResults(List.of(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("token", 0.1f)), false))) + new SparseEmbeddingResults(List.of(new SparseEmbedding(List.of(new SparseEmbedding.WeightedToken("token", 0.1f)), false))) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedTextEmbeddingByteResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedTextEmbeddingByteResultsTests.java index 6d6fbe956280a..cc1b000644dad 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedTextEmbeddingByteResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedTextEmbeddingByteResultsTests.java @@ -10,7 +10,9 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.results.ByteEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingByteResults; +import org.elasticsearch.xpack.core.inference.results.EmbeddingChunk; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import java.io.IOException; @@ -24,7 +26,7 @@ public class ChunkedTextEmbeddingByteResultsTests extends AbstractWireSerializin public static ChunkedTextEmbeddingByteResults createRandomResults() { int numChunks = randomIntBetween(1, 5); - var chunks = new ArrayList(numChunks); + var chunks = new ArrayList>(numChunks); for (int i = 0; i < numChunks; i++) { chunks.add(createRandomChunk()); @@ -33,28 +35,25 @@ public static ChunkedTextEmbeddingByteResults createRandomResults() { return new ChunkedTextEmbeddingByteResults(chunks, randomBoolean()); } - private static ChunkedTextEmbeddingByteResults.EmbeddingChunk createRandomChunk() { + private static EmbeddingChunk createRandomChunk() { int columns = randomIntBetween(1, 10); byte[] bytes = new byte[columns]; for (int i = 0; i < columns; i++) { bytes[i] = randomByte(); } - return new ChunkedTextEmbeddingByteResults.EmbeddingChunk(randomAlphaOfLength(6), bytes); + return new EmbeddingChunk<>(randomAlphaOfLength(6), new ByteEmbedding(bytes)); } public void testToXContent_CreatesTheRightJsonForASingleChunk() { - var entity = new ChunkedTextEmbeddingByteResults( - List.of(new ChunkedTextEmbeddingByteResults.EmbeddingChunk("text", new byte[] { (byte) 1 })), - false - ); + var entity = new ChunkedTextEmbeddingByteResults(List.of(new EmbeddingChunk<>("text", ByteEmbedding.of(List.of((byte) 1)))), false); assertThat( entity.asMap(), is( Map.of( ChunkedTextEmbeddingByteResults.FIELD_NAME, - List.of(new ChunkedTextEmbeddingByteResults.EmbeddingChunk("text", new byte[] { (byte) 1 })) + List.of(new EmbeddingChunk<>("text", new ByteEmbedding(new byte[] { (byte) 1 }))) ) ) ); @@ -75,7 +74,7 @@ public void testToXContent_CreatesTheRightJsonForASingleChunk() { public void testToXContent_CreatesTheRightJsonForASingleChunk_ForTextEmbeddingByteResults() { var entity = ChunkedTextEmbeddingByteResults.of( List.of("text"), - new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 1 }))) + new TextEmbeddingByteResults(List.of(ByteEmbedding.of(List.of((byte) 1)))) ); assertThat(entity.size(), is(1)); @@ -87,7 +86,7 @@ public void testToXContent_CreatesTheRightJsonForASingleChunk_ForTextEmbeddingBy is( Map.of( ChunkedTextEmbeddingByteResults.FIELD_NAME, - List.of(new ChunkedTextEmbeddingByteResults.EmbeddingChunk("text", new byte[] { (byte) 1 })) + List.of(new EmbeddingChunk<>("text", new ByteEmbedding(new byte[] { (byte) 1 }))) ) ) ); @@ -110,7 +109,7 @@ public void testToXContent_ThrowsWhenInputSizeIsDifferentThanEmbeddings() { IllegalArgumentException.class, () -> ChunkedTextEmbeddingByteResults.of( List.of("text", "text2"), - new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 1 }))) + new TextEmbeddingByteResults(List.of(ByteEmbedding.of(List.of((byte) 1)))) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedTextEmbeddingFloatResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedTextEmbeddingFloatResultsTests.java index beb75fbfa36a6..461bf793996af 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedTextEmbeddingFloatResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedTextEmbeddingFloatResultsTests.java @@ -10,6 +10,8 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.EmbeddingChunk; +import org.elasticsearch.xpack.core.inference.results.FloatEmbedding; import java.io.IOException; import java.util.ArrayList; @@ -18,7 +20,7 @@ public class ChunkedTextEmbeddingFloatResultsTests extends AbstractWireSerializi public static ChunkedTextEmbeddingFloatResults createRandomResults() { int numChunks = randomIntBetween(1, 5); - var chunks = new ArrayList(numChunks); + var chunks = new ArrayList>(numChunks); for (int i = 0; i < numChunks; i++) { chunks.add(createRandomChunk()); @@ -27,14 +29,14 @@ public static ChunkedTextEmbeddingFloatResults createRandomResults() { return new ChunkedTextEmbeddingFloatResults(chunks); } - private static ChunkedTextEmbeddingFloatResults.EmbeddingChunk createRandomChunk() { + private static EmbeddingChunk createRandomChunk() { int columns = randomIntBetween(1, 10); float[] floats = new float[columns]; for (int i = 0; i < columns; i++) { floats[i] = randomFloat(); } - return new ChunkedTextEmbeddingFloatResults.EmbeddingChunk(randomAlphaOfLength(6), floats); + return new EmbeddingChunk<>(randomAlphaOfLength(6), new FloatEmbedding(floats)); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedTextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedTextEmbeddingResultsTests.java index 1fc0282b5d96d..935180cce5d95 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedTextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedTextEmbeddingResultsTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.FloatEmbedding; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; @@ -97,7 +98,7 @@ public void testToXContent_CreatesTheRightJsonForASingleChunk() { public void testToXContent_CreatesTheRightJsonForASingleChunk_FromTextEmbeddingResults() { var entity = ChunkedTextEmbeddingResults.of( List.of("text"), - new TextEmbeddingResults(List.of(new TextEmbeddingResults.Embedding(new float[] { 0.1f, 0.2f }))) + new TextEmbeddingResults(List.of(FloatEmbedding.of(List.of(0.1f, 0.2f)))) ); assertThat(entity.size(), is(1)); @@ -140,7 +141,7 @@ public void testToXContent_ThrowsWhenInputSizeIsDifferentThanEmbeddings() { IllegalArgumentException.class, () -> ChunkedTextEmbeddingResults.of( List.of("text", "text2"), - new TextEmbeddingResults(List.of(new TextEmbeddingResults.Embedding(new float[] { 0.1f, 0.2f }))) + new TextEmbeddingResults(List.of(FloatEmbedding.of(List.of(0.1f, 0.2f)))) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java index 17950ea8056c2..727df98d27bbb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.results.SparseEmbedding; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; @@ -31,7 +32,7 @@ public static SparseEmbeddingResults createRandomResults() { } public static SparseEmbeddingResults createRandomResults(int numEmbeddings, int numTokens) { - List embeddings = new ArrayList<>(numEmbeddings); + List embeddings = new ArrayList<>(numEmbeddings); for (int i = 0; i < numEmbeddings; i++) { embeddings.add(createRandomEmbedding(numTokens)); @@ -41,7 +42,7 @@ public static SparseEmbeddingResults createRandomResults(int numEmbeddings, int } public static SparseEmbeddingResults createRandomResults(List input) { - List embeddings = new ArrayList<>(input.size()); + List embeddings = new ArrayList<>(input.size()); for (String s : input) { int numTokens = Strings.tokenizeToStringArray(s, " ").length; @@ -51,13 +52,13 @@ public static SparseEmbeddingResults createRandomResults(List input) { return new SparseEmbeddingResults(embeddings); } - private static SparseEmbeddingResults.Embedding createRandomEmbedding(int numTokens) { - List tokenList = new ArrayList<>(numTokens); + private static SparseEmbedding createRandomEmbedding(int numTokens) { + List tokenList = new ArrayList<>(numTokens); for (int i = 0; i < numTokens; i++) { - tokenList.add(new WeightedToken(Integer.toString(i), (float) randomDoubleBetween(0.0, 5.0, false))); + tokenList.add(new SparseEmbedding.WeightedToken(Integer.toString(i), (float) randomDoubleBetween(0.0, 5.0, false))); } - return new SparseEmbeddingResults.Embedding(tokenList, randomBoolean()); + return new SparseEmbedding(tokenList, randomBoolean()); } @Override @@ -78,14 +79,14 @@ protected SparseEmbeddingResults mutateInstance(SparseEmbeddingResults instance) int end = randomInt(instance.embeddings().size() - 1); return new SparseEmbeddingResults(instance.embeddings().subList(0, end)); } else { - List embeddings = new ArrayList<>(instance.embeddings()); + List embeddings = new ArrayList<>(instance.embeddings()); embeddings.add(createRandomEmbedding(randomIntBetween(0, 20))); return new SparseEmbeddingResults(embeddings); } } public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException { - var entity = createSparseResult(List.of(createEmbedding(List.of(new WeightedToken("token", 0.1F)), false))); + var entity = createSparseResult(List.of(createEmbedding(List.of(new SparseEmbedding.WeightedToken("token", 0.1F)), false))); assertThat(entity.asMap(), is(buildExpectation(List.of(new EmbeddingExpectation(Map.of("token", 0.1F), false))))); String xContentResult = Strings.toString(entity, true, true); assertThat(xContentResult, is(""" @@ -104,8 +105,14 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException { var entity = createSparseResult( List.of( - new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("token", 0.1F), new WeightedToken("token2", 0.2F)), false), - new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("token3", 0.3F), new WeightedToken("token4", 0.4F)), false) + new SparseEmbedding( + List.of(new SparseEmbedding.WeightedToken("token", 0.1F), new SparseEmbedding.WeightedToken("token2", 0.2F)), + false + ), + new SparseEmbedding( + List.of(new SparseEmbedding.WeightedToken("token3", 0.3F), new SparseEmbedding.WeightedToken("token4", 0.4F)), + false + ) ) ); assertThat( @@ -145,8 +152,8 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I public void testTransformToCoordinationFormat() { var results = createSparseResult( List.of( - createEmbedding(List.of(new WeightedToken("token", 0.1F)), false), - createEmbedding(List.of(new WeightedToken("token2", 0.2F)), true) + createEmbedding(List.of(new SparseEmbedding.WeightedToken("token", 0.1F)), false), + createEmbedding(List.of(new SparseEmbedding.WeightedToken("token2", 0.2F)), true) ) ).transformToCoordinationFormat(); @@ -167,23 +174,16 @@ public static Map buildExpectation(List em return Map.of( SparseEmbeddingResults.SPARSE_EMBEDDING, embeddings.stream() - .map( - embedding -> Map.of( - SparseEmbeddingResults.Embedding.EMBEDDING, - embedding.tokens, - SparseEmbeddingResults.Embedding.IS_TRUNCATED, - embedding.isTruncated - ) - ) + .map(embedding -> Map.of(SparseEmbedding.EMBEDDING, embedding.tokens, SparseEmbedding.IS_TRUNCATED, embedding.isTruncated)) .toList() ); } - public static SparseEmbeddingResults createSparseResult(List embeddings) { + public static SparseEmbeddingResults createSparseResult(List embeddings) { return new SparseEmbeddingResults(embeddings); } - public static SparseEmbeddingResults.Embedding createEmbedding(List tokensList, boolean isTruncated) { - return new SparseEmbeddingResults.Embedding(tokensList, isTruncated); + public static SparseEmbedding createEmbedding(List tokensList, boolean isTruncated) { + return new SparseEmbedding(tokensList, isTruncated); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java index baa46ff299bcc..48784b9bd8652 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.results.ByteEmbedding; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import java.io.IOException; @@ -22,7 +23,7 @@ public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase { public static TextEmbeddingByteResults createRandomResults() { int embeddings = randomIntBetween(1, 10); - List embeddingResults = new ArrayList<>(embeddings); + List embeddingResults = new ArrayList<>(embeddings); for (int i = 0; i < embeddings; i++) { embeddingResults.add(createRandomEmbedding()); @@ -31,7 +32,7 @@ public static TextEmbeddingByteResults createRandomResults() { return new TextEmbeddingByteResults(embeddingResults); } - private static TextEmbeddingByteResults.Embedding createRandomEmbedding() { + private static ByteEmbedding createRandomEmbedding() { int columns = randomIntBetween(1, 10); byte[] bytes = new byte[columns]; @@ -39,11 +40,11 @@ private static TextEmbeddingByteResults.Embedding createRandomEmbedding() { bytes[i] = randomByte(); } - return new TextEmbeddingByteResults.Embedding(bytes); + return new ByteEmbedding(bytes); } public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException { - var entity = new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }))); + var entity = new TextEmbeddingByteResults(List.of(ByteEmbedding.of(List.of((byte) 23)))); String xContentResult = Strings.toString(entity, true, true); assertThat(xContentResult, is(""" @@ -60,10 +61,8 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException { var entity = new TextEmbeddingByteResults( - List.of( - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23 }), - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 24 }) - ) + List.of(ByteEmbedding.of(List.of((byte) 23)), ByteEmbedding.of(List.of((byte) 24))) + ); String xContentResult = Strings.toString(entity, true, true); @@ -86,24 +85,21 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I public void testTransformToCoordinationFormat() { var results = new TextEmbeddingByteResults( - List.of( - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }), - new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 }) - ) + List.of(ByteEmbedding.of(List.of((byte) 23, (byte) 24)), ByteEmbedding.of(List.of((byte) 25, (byte) 26))) ).transformToCoordinationFormat(); assertThat( results, is( List.of( - new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( + new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingByteResults( TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, - new double[] { 23F, 24F }, + new byte[] { 23, 24 }, false ), - new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( + new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingByteResults( TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, - new double[] { 25F, 26F }, + new byte[] { 25, 26 }, false ) ) @@ -129,7 +125,7 @@ protected TextEmbeddingByteResults mutateInstance(TextEmbeddingByteResults insta int end = randomInt(instance.embeddings().size() - 1); return new TextEmbeddingByteResults(instance.embeddings().subList(0, end)); } else { - List embeddings = new ArrayList<>(instance.embeddings()); + List embeddings = new ArrayList<>(instance.embeddings()); embeddings.add(createRandomEmbedding()); return new TextEmbeddingByteResults(embeddings); } @@ -138,7 +134,7 @@ protected TextEmbeddingByteResults mutateInstance(TextEmbeddingByteResults insta public static Map buildExpectation(List> embeddings) { return Map.of( TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, - embeddings.stream().map(embedding -> Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList() + embeddings.stream().map(embedding -> Map.of(ByteEmbedding.EMBEDDING, embedding)).toList() ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index 716568fdb5645..bac25c9181dee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -10,6 +10,8 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.results.ByteEmbedding; +import org.elasticsearch.xpack.core.inference.results.FloatEmbedding; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; @@ -23,7 +25,7 @@ public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase { public static TextEmbeddingResults createRandomResults() { int embeddings = randomIntBetween(1, 10); - List embeddingResults = new ArrayList<>(embeddings); + List embeddingResults = new ArrayList<>(embeddings); for (int i = 0; i < embeddings; i++) { embeddingResults.add(createRandomEmbedding()); @@ -32,18 +34,18 @@ public static TextEmbeddingResults createRandomResults() { return new TextEmbeddingResults(embeddingResults); } - private static TextEmbeddingResults.Embedding createRandomEmbedding() { + private static FloatEmbedding createRandomEmbedding() { int columns = randomIntBetween(1, 10); float[] floats = new float[columns]; for (int i = 0; i < columns; i++) { floats[i] = randomFloat(); } - return new TextEmbeddingResults.Embedding(floats); + return new FloatEmbedding(floats); } public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException { - var entity = new TextEmbeddingResults(List.of(new TextEmbeddingResults.Embedding(new float[] { 0.1F }))); + var entity = new TextEmbeddingResults(List.of(FloatEmbedding.of(List.of(0.1F)))); String xContentResult = Strings.toString(entity, true, true); assertThat(xContentResult, is(""" @@ -59,10 +61,7 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE } public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException { - var entity = new TextEmbeddingResults( - List.of(new TextEmbeddingResults.Embedding(new float[] { 0.1F }), new TextEmbeddingResults.Embedding(new float[] { 0.2F })) - - ); + var entity = new TextEmbeddingResults(List.of(FloatEmbedding.of(List.of(0.1F)), FloatEmbedding.of(List.of(0.2F)))); String xContentResult = Strings.toString(entity, true, true); assertThat(xContentResult, is(""" @@ -83,12 +82,8 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I } public void testTransformToCoordinationFormat() { - var results = new TextEmbeddingResults( - List.of( - new TextEmbeddingResults.Embedding(new float[] { 0.1F, 0.2F }), - new TextEmbeddingResults.Embedding(new float[] { 0.3F, 0.4F }) - ) - ).transformToCoordinationFormat(); + var results = new TextEmbeddingResults(List.of(FloatEmbedding.of(List.of(0.1F, 0.2F)), FloatEmbedding.of(List.of(0.3F, 0.4F)))) + .transformToCoordinationFormat(); assertThat( results, @@ -127,21 +122,18 @@ protected TextEmbeddingResults mutateInstance(TextEmbeddingResults instance) thr int end = randomInt(instance.embeddings().size() - 1); return new TextEmbeddingResults(instance.embeddings().subList(0, end)); } else { - List embeddings = new ArrayList<>(instance.embeddings()); + List embeddings = new ArrayList<>(instance.embeddings()); embeddings.add(createRandomEmbedding()); return new TextEmbeddingResults(embeddings); } } public static Map buildExpectationFloat(List embeddings) { - return Map.of(TextEmbeddingResults.TEXT_EMBEDDING, embeddings.stream().map(TextEmbeddingResults.Embedding::new).toList()); + return Map.of(TextEmbeddingResults.TEXT_EMBEDDING, embeddings.stream().map(FloatEmbedding::new).toList()); } public static Map buildExpectationByte(List embeddings) { - return Map.of( - TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, - embeddings.stream().map(TextEmbeddingByteResults.Embedding::new).toList() - ); + return Map.of(TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, embeddings.stream().map(ByteEmbedding::new).toList()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index 828e333d4fd27..0a34de7b342ee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -765,7 +765,7 @@ public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingResults() { var size = listener.actionGet(TIMEOUT); - assertThat(size, is(textEmbedding.embeddings().get(0).getSize())); + assertThat(size, is(textEmbedding.embeddings().get(0).getEmbedding().size())); } public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingByteResults() { @@ -789,7 +789,7 @@ public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingByteResults() { var size = listener.actionGet(TIMEOUT); - assertThat(size, is(textEmbedding.embeddings().get(0).getSize())); + assertThat(size, is(textEmbedding.embeddings().get(0).getEmbedding().size())); } private static Map modifiableMap(Map aMap) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index f06fee4b0b9c4..66097e62c3ed8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -33,6 +33,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.FloatEmbedding; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -51,7 +52,6 @@ import org.junit.Before; import java.io.IOException; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1226,14 +1226,14 @@ public void testChunkedInfer_BatchesCalls() throws IOException { var floatResult = (ChunkedTextEmbeddingFloatResults) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); - assertTrue(Arrays.equals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding())); + assertEquals(FloatEmbedding.of(List.of(0.123f, -0.123f)), floatResult.chunks().get(0).embedding()); } { assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedTextEmbeddingFloatResults.class)); var floatResult = (ChunkedTextEmbeddingFloatResults) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("bar", floatResult.chunks().get(0).matchedText()); - assertTrue(Arrays.equals(new float[] { 0.223f, -0.223f }, floatResult.chunks().get(0).embedding())); + assertEquals(FloatEmbedding.of(List.of(0.223f, -0.223f)), floatResult.chunks().get(0).embedding()); } MatcherAssert.assertThat(webServer.requests(), hasSize(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index cbac29c452772..124b90e10dc88 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -32,6 +32,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.FloatEmbedding; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -48,7 +49,6 @@ import org.junit.Before; import java.io.IOException; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1267,14 +1267,14 @@ public void testChunkedInfer_Batches() throws IOException { var floatResult = (ChunkedTextEmbeddingFloatResults) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); - assertTrue(Arrays.equals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding())); + assertEquals(FloatEmbedding.of(List.of(0.123f, -0.123f)), floatResult.chunks().get(0).embedding()); } { assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedTextEmbeddingFloatResults.class)); var floatResult = (ChunkedTextEmbeddingFloatResults) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("bar", floatResult.chunks().get(0).matchedText()); - assertTrue(Arrays.equals(new float[] { 0.223f, -0.223f }, floatResult.chunks().get(0).embedding())); + assertEquals(FloatEmbedding.of(List.of(0.223f, -0.223f)), floatResult.chunks().get(0).embedding()); } assertThat(webServer.requests(), hasSize(1)); From a7bb2ff34e127f15a82ae244b021748632c74f2b Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Mon, 27 May 2024 10:15:32 +0200 Subject: [PATCH 05/29] Make TTestStats#EMPTY final (#109052) --- .../org/elasticsearch/xpack/analytics/ttest/TTestStats.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/TTestStats.java b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/TTestStats.java index 832952e957e7a..6ead5cbd9ee59 100644 --- a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/TTestStats.java +++ b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/ttest/TTestStats.java @@ -20,7 +20,7 @@ * Collects basic stats that are needed to perform t-test */ public class TTestStats implements Writeable { - static TTestStats EMPTY = new TTestStats(0, 0, 0); + static final TTestStats EMPTY = new TTestStats(0, 0, 0); public final long count; public final double sum; From 6312d49c0556b4841797f4eb2bbe14289b4c8231 Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 27 May 2024 09:26:44 +0100 Subject: [PATCH 06/29] Pausable chunked HTTP responses (#104851) Today we must collect together in memory everything needed to send a HTTP response before starting to send it to the client, even if we're using the `chunked` transfer encoding to bound the memory needed for the final serialization step. To properly bound the memory usage on the coordinating node we must instead be able to start sending the response to the client before we have collected everything needed to finish it. If we do this then we must be able to handle the case where we run out of data to send by pausing the transmission and resuming it once there's more data to send. This commit extends the `ChunkedRestResponseBody` interface to allow it to express that it has run out of chunks for immediate transmission, but that the body will continue at a later point. --- .../netty4/Netty4ChunkedContinuationsIT.java | 713 ++++++++++++++++++ .../http/netty4/Netty4ChunkedEncodingIT.java | 11 + .../http/netty4/Netty4PipeliningIT.java | 10 + .../netty4/Netty4ChunkedHttpContinuation.java | 38 + .../netty4/Netty4ChunkedHttpResponse.java | 2 +- .../netty4/Netty4HttpPipeliningHandler.java | 73 +- .../http/netty4/Netty4HttpResponse.java | 2 +- .../Netty4HttpPipeliningHandlerTests.java | 11 + .../elasticsearch/http/HttpBodyTracer.java | 2 +- .../rest/ChunkedRestResponseBody.java | 65 +- .../rest/LoggingChunkedRestResponseBody.java | 11 + .../elasticsearch/rest/RestController.java | 14 +- .../org/elasticsearch/rest/RestResponse.java | 1 + .../http/DefaultRestChannelTests.java | 108 ++- .../elasticsearch/rest/RestResponseUtils.java | 1 + 15 files changed, 1026 insertions(+), 36 deletions(-) create mode 100644 modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedContinuationsIT.java create mode 100644 modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpContinuation.java diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedContinuationsIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedContinuationsIT.java new file mode 100644 index 0000000000000..25195a1176fb8 --- /dev/null +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedContinuationsIT.java @@ -0,0 +1,713 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.http.netty4; + +import org.apache.logging.log4j.Level; +import org.apache.logging.log4j.core.LogEvent; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.ESNetty4IntegTestCase; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionRunnable; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.ActionTestUtils; +import org.elasticsearch.action.support.CountDownActionListener; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.action.support.TransportAction; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.client.ResponseListener; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.common.ReferenceDocs; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.ReleasableBytesReference; +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.io.Streams; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.logging.ChunkedLoggingStreamTestUtils; +import org.elasticsearch.common.recycler.Recycler; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.IndexScopedSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.settings.SettingsFilter; +import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.AbstractRefCounted; +import org.elasticsearch.core.RefCounted; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.http.HttpBodyTracer; +import org.elasticsearch.http.HttpRouteStats; +import org.elasticsearch.http.HttpRouteStatsTracker; +import org.elasticsearch.http.HttpServerTransport; +import org.elasticsearch.plugins.ActionPlugin; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.rest.action.RestActionListener; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.test.MockLog; +import org.elasticsearch.test.junit.annotations.TestLogging; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.transport.netty4.Netty4Utils; +import org.elasticsearch.xcontent.ToXContentObject; + +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; +import java.util.function.Supplier; +import java.util.regex.Pattern; + +import static org.elasticsearch.rest.RestRequest.Method.GET; +import static org.elasticsearch.rest.RestResponse.TEXT_CONTENT_TYPE; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; + +public class Netty4ChunkedContinuationsIT extends ESNetty4IntegTestCase { + + @Override + protected Collection> nodePlugins() { + return CollectionUtils.concatLists( + List.of(YieldsContinuationsPlugin.class, InfiniteContinuationsPlugin.class, CountDown3Plugin.class), + super.nodePlugins() + ); + } + + @Override + protected boolean addMockHttpTransport() { + return false; // enable http + } + + private static final String expectedBody = """ + batch-0-chunk-0 + batch-0-chunk-1 + batch-0-chunk-2 + batch-1-chunk-0 + batch-1-chunk-1 + batch-1-chunk-2 + batch-2-chunk-0 + batch-2-chunk-1 + batch-2-chunk-2 + """; + + public void testBasic() throws IOException { + try (var ignored = withResourceTracker()) { + final var response = getRestClient().performRequest(new Request("GET", YieldsContinuationsPlugin.ROUTE)); + assertEquals(200, response.getStatusLine().getStatusCode()); + assertThat(response.getEntity().getContentType().toString(), containsString(TEXT_CONTENT_TYPE)); + assertTrue(response.getEntity().isChunked()); + final String body; + try (var reader = new InputStreamReader(response.getEntity().getContent(), StandardCharsets.UTF_8)) { + body = Streams.copyToString(reader); + } + assertEquals(expectedBody, body); + } + } + + @TestLogging( + reason = "testing TRACE logging", + value = "org.elasticsearch.http.HttpTracer:TRACE,org.elasticsearch.http.HttpBodyTracer:TRACE" + ) + public void testTraceLogging() { + + // slightly awkward test, we can't use ChunkedLoggingStreamTestUtils.getDecodedLoggedBody directly because it asserts that we _only_ + // log one thing and we can't easily separate the request body from the response body logging, so instead we capture the body log + // message and then log it again with a different logger. + final var resources = new ArrayList(); + try (var ignored = Releasables.wrap(resources)) { + resources.add(withResourceTracker()); + final var executor = EsExecutors.newFixed( + "test", + 1, + -1, + EsExecutors.daemonThreadFactory(Settings.EMPTY, "test"), + new ThreadContext(Settings.EMPTY), + EsExecutors.TaskTrackingConfig.DO_NOT_TRACK + ); + resources.add(() -> assertTrue(ThreadPool.terminate(executor, 10, TimeUnit.SECONDS))); + var loggingFinishedLatch = new CountDownLatch(1); + MockLog.assertThatLogger( + () -> assertEquals( + expectedBody, + ChunkedLoggingStreamTestUtils.getDecodedLoggedBody( + logger, + Level.INFO, + "response body", + ReferenceDocs.HTTP_TRACER, + () -> { + final var request = new Request("GET", YieldsContinuationsPlugin.ROUTE); + request.addParameter("error_trace", "true"); + getRestClient().performRequest(request); + safeAwait(loggingFinishedLatch); + } + ).utf8ToString() + ), + HttpBodyTracer.class, + new MockLog.LoggingExpectation() { + final Pattern messagePattern = Pattern.compile("^\\[[1-9][0-9]*] (response body.*)"); + + @Override + public void match(LogEvent event) { + final var formattedMessage = event.getMessage().getFormattedMessage(); + final var matcher = messagePattern.matcher(formattedMessage); + if (matcher.matches()) { + executor.execute(() -> { + logger.info("{}", matcher.group(1)); + if (formattedMessage.contains(ReferenceDocs.HTTP_TRACER.toString())) { + loggingFinishedLatch.countDown(); + } + }); + } + } + + @Override + public void assertMatched() {} + } + ); + } + } + + public void testResponseBodySizeStats() throws IOException { + try (var ignored = withResourceTracker()) { + final var totalResponseSizeBefore = getTotalResponseSize(); + getRestClient().performRequest(new Request("GET", YieldsContinuationsPlugin.ROUTE)); + final var totalResponseSizeAfter = getTotalResponseSize(); + assertEquals(expectedBody.length(), totalResponseSizeAfter - totalResponseSizeBefore); + } + } + + private static final HttpRouteStats EMPTY_ROUTE_STATS = new HttpRouteStatsTracker().getStats(); + + private long getTotalResponseSize() { + return client().admin() + .cluster() + .prepareNodesStats() + .clear() + .setHttp(true) + .get() + .getNodes() + .stream() + .mapToLong( + ns -> ns.getHttp().httpRouteStats().getOrDefault(YieldsContinuationsPlugin.ROUTE, EMPTY_ROUTE_STATS).totalResponseSize() + ) + .sum(); + } + + public void testPipelining() throws Exception { + try (var ignored = withResourceTracker(); var nettyClient = new Netty4HttpClient()) { + final var responses = nettyClient.get( + randomFrom(internalCluster().getInstance(HttpServerTransport.class).boundAddress().boundAddresses()).address(), + CountDown3Plugin.ROUTE, + YieldsContinuationsPlugin.ROUTE, + CountDown3Plugin.ROUTE, + YieldsContinuationsPlugin.ROUTE, + CountDown3Plugin.ROUTE + ); + + assertEquals("{}", Netty4Utils.toBytesReference(responses.get(0).content()).utf8ToString()); + assertEquals(expectedBody, Netty4Utils.toBytesReference(responses.get(1).content()).utf8ToString()); + assertEquals("{}", Netty4Utils.toBytesReference(responses.get(2).content()).utf8ToString()); + assertEquals(expectedBody, Netty4Utils.toBytesReference(responses.get(3).content()).utf8ToString()); + assertEquals("{}", Netty4Utils.toBytesReference(responses.get(4).content()).utf8ToString()); + } finally { + internalCluster().fullRestart(); // reset countdown listener + } + } + + public void testContinuationFailure() throws Exception { + try (var ignored = withResourceTracker(); var nettyClient = new Netty4HttpClient()) { + final var failIndex = between(0, 2); + final var responses = nettyClient.get( + randomFrom(internalCluster().getInstance(HttpServerTransport.class).boundAddress().boundAddresses()).address(), + YieldsContinuationsPlugin.ROUTE, + YieldsContinuationsPlugin.ROUTE + "?" + YieldsContinuationsPlugin.FAIL_INDEX_PARAM + "=" + failIndex + ); + + if (failIndex == 0) { + assertThat( + responses, + anyOf( + // might get a 500 response if the failure is early enough + hasSize(2), + // might get no response before channel closed + hasSize(1), + // might even close the channel before flushing the previous response + hasSize(0) + ) + ); + + if (responses.size() == 2) { + assertEquals(expectedBody, Netty4Utils.toBytesReference(responses.get(0).content()).utf8ToString()); + assertEquals(500, responses.get(1).status().code()); + } + } else { + assertThat(responses, hasSize(1)); + } + + if (responses.size() > 0) { + assertEquals(expectedBody, Netty4Utils.toBytesReference(responses.get(0).content()).utf8ToString()); + assertEquals(200, responses.get(0).status().code()); + } + } + } + + public void testClientCancellation() { + try (var ignored = withResourceTracker()) { + final var cancellable = getRestClient().performRequestAsync( + new Request("GET", InfiniteContinuationsPlugin.ROUTE), + new ResponseListener() { + @Override + public void onSuccess(Response response) { + fail("should not succeed"); + } + + @Override + public void onFailure(Exception exception) { + assertThat(exception, instanceOf(CancellationException.class)); + } + } + ); + if (randomBoolean()) { + safeSleep(scaledRandomIntBetween(10, 500)); // make it more likely the request started executing + } + cancellable.cancel(); + } // closing the request tracker ensures that everything is released, including all response chunks and the overall response + } + + private static Releasable withResourceTracker() { + assertNull(refs); + final var latch = new CountDownLatch(1); + refs = AbstractRefCounted.of(latch::countDown); + return () -> { + refs.decRef(); + try { + safeAwait(latch); + } finally { + refs = null; + } + }; + } + + private static volatile RefCounted refs = null; + + /** + * Adds a REST route which yields a sequence of continuations which are computed asynchronously, effectively pausing after each one.. + */ + public static class YieldsContinuationsPlugin extends Plugin implements ActionPlugin { + static final String ROUTE = "/_test/yields_continuations"; + static final String FAIL_INDEX_PARAM = "fail_index"; + + private static final ActionType TYPE = new ActionType<>("test:yields_continuations"); + + @Override + public Collection> getActions() { + return List.of(new ActionHandler<>(TYPE, TransportYieldsContinuationsAction.class)); + } + + public static class Request extends ActionRequest { + final int failIndex; + + public Request(int failIndex) { + this.failIndex = failIndex; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + } + + public static class Response extends ActionResponse { + private final int failIndex; + private final Executor executor; + + public Response(int failIndex, Executor executor) { + this.failIndex = failIndex; + this.executor = executor; + } + + @Override + public void writeTo(StreamOutput out) { + TransportAction.localOnly(); + } + + public ChunkedRestResponseBody getChunkedBody() { + return getChunkBatch(0); + } + + private ChunkedRestResponseBody getChunkBatch(int batchIndex) { + if (batchIndex == failIndex && randomBoolean()) { + throw new ElasticsearchException("simulated failure creating next batch"); + } + return new ChunkedRestResponseBody() { + + private final Iterator lines = Iterators.forRange(0, 3, i -> "batch-" + batchIndex + "-chunk-" + i + "\n"); + + @Override + public boolean isDone() { + return lines.hasNext() == false; + } + + @Override + public boolean isEndOfResponse() { + return batchIndex == 2; + } + + @Override + public void getContinuation(ActionListener listener) { + executor.execute(ActionRunnable.supply(listener, () -> getChunkBatch(batchIndex + 1))); + } + + @Override + public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) throws IOException { + assertTrue(lines.hasNext()); + refs.mustIncRef(); + final var output = new RecyclerBytesStreamOutput(recycler); + boolean success = false; + try { + try (var writer = new OutputStreamWriter(Streams.flushOnCloseStream(output), StandardCharsets.UTF_8)) { + writer.write(lines.next()); + } + final var result = new ReleasableBytesReference(output.bytes(), Releasables.wrap(output, refs::decRef)); + if (batchIndex == failIndex) { + throw new ElasticsearchException("simulated failure encoding chunk"); + } + success = true; + return result; + } finally { + if (success == false) { + refs.decRef(); + output.close(); + } + } + } + + @Override + public String getResponseContentTypeString() { + assertEquals(0, batchIndex); + return TEXT_CONTENT_TYPE; + } + }; + } + } + + public static class TransportYieldsContinuationsAction extends TransportAction { + private final ExecutorService executor; + + @Inject + public TransportYieldsContinuationsAction(ActionFilters actionFilters, TransportService transportService) { + super(TYPE.name(), actionFilters, transportService.getTaskManager()); + executor = transportService.getThreadPool().executor(ThreadPool.Names.GENERIC); + } + + @Override + protected void doExecute(Task task, Request request, ActionListener listener) { + executor.execute(ActionRunnable.supply(listener, () -> new Response(request.failIndex, executor))); + } + } + + @Override + public Collection getRestHandlers( + Settings settings, + NamedWriteableRegistry namedWriteableRegistry, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster, + Predicate clusterSupportsFeature + ) { + return List.of(new BaseRestHandler() { + @Override + public String getName() { + return ROUTE; + } + + @Override + public List routes() { + return List.of(new Route(GET, ROUTE)); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + final var failIndex = request.paramAsInt(FAIL_INDEX_PARAM, Integer.MAX_VALUE); + refs.mustIncRef(); + return new RestChannelConsumer() { + + @Override + public void close() { + refs.decRef(); + } + + @Override + public void accept(RestChannel channel) { + refs.mustIncRef(); + client.execute(TYPE, new Request(failIndex), new RestActionListener<>(channel) { + @Override + protected void processResponse(Response response) { + try { + final var responseBody = response.getChunkedBody(); // might fail, so do this before acquiring ref + refs.mustIncRef(); + channel.sendResponse(RestResponse.chunked(RestStatus.OK, responseBody, refs::decRef)); + } finally { + refs.decRef(); + } + } + }); + } + }; + } + }); + } + } + + /** + * Adds a REST route which yields an infinite sequence of continuations which can only be stopped by the client closing the connection. + */ + public static class InfiniteContinuationsPlugin extends Plugin implements ActionPlugin { + static final String ROUTE = "/_test/infinite_continuations"; + + private static final ActionType TYPE = new ActionType<>("test:infinite_continuations"); + + @Override + public Collection> getActions() { + return List.of(new ActionHandler<>(TYPE, TransportInfiniteContinuationsAction.class)); + } + + public static class Request extends ActionRequest { + @Override + public ActionRequestValidationException validate() { + return null; + } + } + + public static class Response extends ActionResponse { + private final Executor executor; + volatile boolean computingContinuation; + + public Response(Executor executor) { + this.executor = executor; + } + + @Override + public void writeTo(StreamOutput out) { + TransportAction.localOnly(); + } + + public ChunkedRestResponseBody getChunkedBody() { + return new ChunkedRestResponseBody() { + private final Iterator lines = Iterators.single("infinite response\n"); + + @Override + public boolean isDone() { + return lines.hasNext() == false; + } + + @Override + public boolean isEndOfResponse() { + return false; + } + + @Override + public void getContinuation(ActionListener listener) { + computingContinuation = true; + executor.execute(ActionRunnable.supply(listener, () -> { + computingContinuation = false; + return getChunkedBody(); + })); + } + + @Override + public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) { + assertTrue(lines.hasNext()); + refs.mustIncRef(); + return new ReleasableBytesReference(new BytesArray(lines.next()), refs::decRef); + } + + @Override + public String getResponseContentTypeString() { + return TEXT_CONTENT_TYPE; + } + }; + } + } + + public static class TransportInfiniteContinuationsAction extends TransportAction { + private final ExecutorService executor; + + @Inject + public TransportInfiniteContinuationsAction(ActionFilters actionFilters, TransportService transportService) { + super(TYPE.name(), actionFilters, transportService.getTaskManager()); + this.executor = transportService.getThreadPool().executor(ThreadPool.Names.GENERIC); + } + + @Override + protected void doExecute(Task task, Request request, ActionListener listener) { + executor.execute( + ActionRunnable.supply(ActionTestUtils.assertNoFailureListener(listener::onResponse), () -> new Response(executor)) + ); + } + } + + @Override + public Collection getRestHandlers( + Settings settings, + NamedWriteableRegistry namedWriteableRegistry, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster, + Predicate clusterSupportsFeature + ) { + return List.of(new BaseRestHandler() { + @Override + public String getName() { + return ROUTE; + } + + @Override + public List routes() { + return List.of(new Route(GET, ROUTE)); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + final var localRefs = refs; // single volatile read + if (localRefs != null && localRefs.tryIncRef()) { + return new RestChannelConsumer() { + @Override + public void close() { + localRefs.decRef(); + } + + @Override + public void accept(RestChannel channel) { + localRefs.mustIncRef(); + client.execute(TYPE, new Request(), new RestActionListener<>(channel) { + @Override + protected void processResponse(Response response) { + channel.sendResponse(RestResponse.chunked(RestStatus.OK, response.getChunkedBody(), () -> { + // cancellation notification only happens while processing a continuation, not while computing + // the next one; prompt cancellation requires use of something like RestCancellableNodeClient + assertFalse(response.computingContinuation); + assertSame(localRefs, refs); + localRefs.decRef(); + })); + } + }); + } + }; + } else { + throw new TaskCancelledException("request cancelled"); + } + } + }); + } + } + + /** + * Adds an HTTP route that waits for 3 concurrent executions before returning any of them + */ + public static class CountDown3Plugin extends Plugin implements ActionPlugin { + + static final String ROUTE = "/_test/countdown_3"; + + @Override + public Collection getRestHandlers( + Settings settings, + NamedWriteableRegistry namedWriteableRegistry, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster, + Predicate clusterSupportsFeature + ) { + return List.of(new BaseRestHandler() { + private final SubscribableListener subscribableListener = new SubscribableListener<>(); + private final CountDownActionListener countDownActionListener = new CountDownActionListener( + 3, + subscribableListener.map(v -> EMPTY_RESPONSE) + ); + + private void addListener(ActionListener listener) { + subscribableListener.addListener(listener); + countDownActionListener.onResponse(null); + } + + @Override + public String getName() { + return ROUTE; + } + + @Override + public List routes() { + return List.of(new Route(GET, ROUTE)); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + refs.mustIncRef(); + return new RestChannelConsumer() { + + @Override + public void close() { + refs.decRef(); + } + + @Override + public void accept(RestChannel channel) { + refs.mustIncRef(); + addListener(ActionListener.releaseAfter(new RestToXContentListener<>(channel), refs::decRef)); + } + }; + } + }); + } + } + + private static final ToXContentObject EMPTY_RESPONSE = (builder, params) -> builder.startObject().endObject(); +} diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedEncodingIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedEncodingIT.java index 2f472dab23afa..b2a54e2027308 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedEncodingIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedEncodingIT.java @@ -10,6 +10,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.ESNetty4IntegTestCase; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseListener; @@ -250,6 +251,16 @@ public boolean isDone() { return chunkIterator.hasNext() == false; } + @Override + public boolean isEndOfResponse() { + return true; + } + + @Override + public void getContinuation(ActionListener listener) { + assert false : "no continuations"; + } + @Override public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) { localRefs.mustIncRef(); diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4PipeliningIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4PipeliningIT.java index 130a1168d455c..ce8da0c08af54 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4PipeliningIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4PipeliningIT.java @@ -251,6 +251,16 @@ public boolean isDone() { return false; } + @Override + public boolean isEndOfResponse() { + return true; + } + + @Override + public void getContinuation(ActionListener listener) { + fail("no continuations here"); + } + @Override public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) throws IOException { assert bytesRemaining >= 0 : "already failed"; diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpContinuation.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpContinuation.java new file mode 100644 index 0000000000000..156f1c27aa67c --- /dev/null +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpContinuation.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.http.netty4; + +import io.netty.util.concurrent.PromiseCombiner; + +import org.elasticsearch.rest.ChunkedRestResponseBody; + +final class Netty4ChunkedHttpContinuation implements Netty4HttpResponse { + private final int sequence; + private final ChunkedRestResponseBody body; + private final PromiseCombiner combiner; + + Netty4ChunkedHttpContinuation(int sequence, ChunkedRestResponseBody body, PromiseCombiner combiner) { + this.sequence = sequence; + this.body = body; + this.combiner = combiner; + } + + @Override + public int getSequence() { + return sequence; + } + + public ChunkedRestResponseBody body() { + return body; + } + + public PromiseCombiner combiner() { + return combiner; + } +} diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpResponse.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpResponse.java index f5f32bf333779..783c02da0bbcc 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpResponse.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpResponse.java @@ -19,7 +19,7 @@ /** * A http response that will be transferred via chunked encoding when handled by {@link Netty4HttpPipeliningHandler}. */ -public final class Netty4ChunkedHttpResponse extends DefaultHttpResponse implements Netty4HttpResponse, HttpResponse { +final class Netty4ChunkedHttpResponse extends DefaultHttpResponse implements Netty4HttpResponse, HttpResponse { private final int sequence; diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java index b86e168e2e620..8280c438613a2 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java @@ -28,6 +28,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.core.Booleans; @@ -148,6 +149,8 @@ public void write(final ChannelHandlerContext ctx, final Object msg, final Chann } private void enqueuePipelinedResponse(ChannelHandlerContext ctx, Netty4HttpResponse restResponse, ChannelPromise promise) { + assert restResponse instanceof Netty4ChunkedHttpContinuation == false + : "received out-of-order continuation at [" + restResponse.getSequence() + "], expecting [" + writeSequence + "]"; assert restResponse.getSequence() > writeSequence : "response sequence [" + restResponse.getSequence() + "] we below write sequence [" + writeSequence + "]"; if (outboundHoldingQueue.size() >= maxEventsHeld) { @@ -187,6 +190,8 @@ private void doWrite(ChannelHandlerContext ctx, Netty4HttpResponse readyResponse doWriteFullResponse(ctx, fullResponse, promise); } else if (readyResponse instanceof Netty4ChunkedHttpResponse chunkedResponse) { doWriteChunkedResponse(ctx, chunkedResponse, promise); + } else if (readyResponse instanceof Netty4ChunkedHttpContinuation chunkedContinuation) { + doWriteChunkedContinuation(ctx, chunkedContinuation, promise); } else { assert false : readyResponse.getClass().getCanonicalName(); throw new IllegalStateException("illegal message type: " + readyResponse.getClass().getCanonicalName()); @@ -224,16 +229,75 @@ private void doWriteChunkedResponse(ChannelHandlerContext ctx, Netty4ChunkedHttp } } + private void doWriteChunkedContinuation(ChannelHandlerContext ctx, Netty4ChunkedHttpContinuation continuation, ChannelPromise promise) { + final PromiseCombiner combiner = continuation.combiner(); + assert currentChunkedWrite == null; + final var responseBody = continuation.body(); + assert responseBody.isDone() == false : "response with continuations must have at least one (possibly-empty) chunk in each part"; + currentChunkedWrite = new ChunkedWrite(combiner, promise, responseBody); + // NB "writable" means there's space in the downstream ChannelOutboundBuffer, we aren't trying to saturate the physical channel. + while (ctx.channel().isWritable()) { + if (writeChunk(ctx, currentChunkedWrite)) { + finishChunkedWrite(); + return; + } + } + } + private void finishChunkedWrite() { if (currentChunkedWrite == null) { // failure during chunked response serialization, we're closing the channel return; } - assert currentChunkedWrite.responseBody().isDone(); final var finishingWrite = currentChunkedWrite; currentChunkedWrite = null; - writeSequence++; - finishingWrite.combiner().finish(finishingWrite.onDone()); + final var finishingWriteBody = finishingWrite.responseBody(); + assert finishingWriteBody.isDone(); + final var endOfResponse = finishingWriteBody.isEndOfResponse(); + if (endOfResponse) { + writeSequence++; + finishingWrite.combiner().finish(finishingWrite.onDone()); + } else { + final var channel = finishingWrite.onDone().channel(); + ActionListener.run(ActionListener.assertOnce(new ActionListener<>() { + @Override + public void onResponse(ChunkedRestResponseBody continuation) { + channel.writeAndFlush( + new Netty4ChunkedHttpContinuation(writeSequence, continuation, finishingWrite.combiner()), + finishingWrite.onDone() // pass the terminal listener/promise along the line + ); + checkShutdown(); + } + + @Override + public void onFailure(Exception e) { + logger.error( + Strings.format("failed to get continuation of HTTP response body for [%s], closing connection", channel), + e + ); + channel.close().addListener(ignored -> { + finishingWrite.combiner().add(channel.newFailedFuture(e)); + finishingWrite.combiner().finish(finishingWrite.onDone()); + }); + checkShutdown(); + } + + private void checkShutdown() { + if (channel.eventLoop().isShuttingDown()) { + // The event loop is shutting down, and https://github.com/netty/netty/issues/8007 means that we cannot know if the + // preceding activity made it onto its queue before shutdown or whether it will just vanish without a trace, so + // to avoid a leak we must double-check that the final listener is completed once the event loop is terminated. + // Note that the final listener came from Netty4Utils#safeWriteAndFlush so its executor is an ImmediateEventExecutor + // which means this completion is not subject to the same issue, it still works even if the event loop has already + // terminated. + channel.eventLoop() + .terminationFuture() + .addListener(ignored -> finishingWrite.onDone().tryFailure(new ClosedChannelException())); + } + } + + }), finishingWriteBody::getContinuation); + } } private void splitAndWrite(ChannelHandlerContext ctx, Netty4FullHttpResponse msg, ChannelPromise promise) { @@ -321,7 +385,8 @@ private boolean writeChunk(ChannelHandlerContext ctx, ChunkedWrite chunkedWrite) } final ByteBuf content = Netty4Utils.toByteBuf(bytes); final boolean done = body.isDone(); - final ChannelFuture f = ctx.write(done ? new DefaultLastHttpContent(content) : new DefaultHttpContent(content)); + final boolean lastChunk = done && body.isEndOfResponse(); + final ChannelFuture f = ctx.write(lastChunk ? new DefaultLastHttpContent(content) : new DefaultHttpContent(content)); f.addListener(ignored -> bytes.close()); combiner.add(f); return done; diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpResponse.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpResponse.java index 3396b13cdab0f..80cf3469c00ca 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpResponse.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpResponse.java @@ -11,7 +11,7 @@ /** * Super-interface for responses handled by the Netty4 HTTP transport. */ -public sealed interface Netty4HttpResponse permits Netty4FullHttpResponse, Netty4ChunkedHttpResponse { +sealed interface Netty4HttpResponse permits Netty4FullHttpResponse, Netty4ChunkedHttpResponse, Netty4ChunkedHttpContinuation { /** * @return The sequence number for the request which corresponds with this response, for making sure that we send responses to pipelined * requests in the correct order. diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java index 9e0f30caec755..bb4a0939c98f0 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java @@ -28,6 +28,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.Randomness; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -511,6 +512,16 @@ public boolean isDone() { return remaining == 0; } + @Override + public boolean isEndOfResponse() { + return true; + } + + @Override + public void getContinuation(ActionListener listener) { + fail("no continuations here"); + } + @Override public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) { assertThat(remaining, greaterThan(0)); diff --git a/server/src/main/java/org/elasticsearch/http/HttpBodyTracer.java b/server/src/main/java/org/elasticsearch/http/HttpBodyTracer.java index 1773a4803f62a..1dd2868f7bfa6 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpBodyTracer.java +++ b/server/src/main/java/org/elasticsearch/http/HttpBodyTracer.java @@ -17,7 +17,7 @@ import java.io.OutputStream; -class HttpBodyTracer { +public class HttpBodyTracer { private static final Logger logger = LogManager.getLogger(HttpBodyTracer.class); public static boolean isEnabled() { diff --git a/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBody.java b/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBody.java index 5c41be0fc9f9f..2f7fc458ca020 100644 --- a/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBody.java +++ b/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBody.java @@ -8,6 +8,8 @@ package org.elasticsearch.rest; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.io.stream.BytesStream; import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput; @@ -20,6 +22,8 @@ import org.elasticsearch.core.Streams; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; @@ -31,18 +35,51 @@ import java.util.Iterator; /** - * The body of a rest response that uses chunked HTTP encoding. Implementations are used to avoid materializing full responses on heap and - * instead serialize only as much of the response as can be flushed to the network right away. + *

A body (or a part thereof) of an HTTP response that uses the {@code chunked} transfer-encoding. This allows Elasticsearch to avoid + * materializing the full response into on-heap buffers up front, instead serializing only as much of the response as can be flushed to the + * network right away.

+ * + *

Each {@link ChunkedRestResponseBody} represents a sequence of chunks that are ready for immediate transmission: if {@link #isDone} + * returns {@code false} then {@link #encodeChunk} can be called at any time and must synchronously return the next chunk to be sent. + * Many HTTP responses will be a single such sequence. However, if an implementation's {@link #isEndOfResponse} returns {@code false} at the + * end of the sequence then the transmission is paused and {@link #getContinuation} is called to compute the next sequence of chunks + * asynchronously.

*/ public interface ChunkedRestResponseBody { Logger logger = LogManager.getLogger(ChunkedRestResponseBody.class); /** - * @return true once this response has been written fully. + * @return {@code true} if this body contains no more chunks and the REST layer should check for a possible continuation by calling + * {@link #isEndOfResponse}, or {@code false} if the REST layer should request another chunk from this body using {@link #encodeChunk}. */ boolean isDone(); + /** + * @return {@code true} if this is the last chunked body in the response, or {@code false} if the REST layer should request further + * chunked bodies by calling {@link #getContinuation}. + */ + boolean isEndOfResponse(); + + /** + *

Asynchronously retrieves the next part of the body. Called if {@link #isEndOfResponse} returns {@code false}.

+ * + *

Note that this is called on a transport thread, so implementations must take care to dispatch any nontrivial work elsewhere.

+ + *

Note that the {@link Task} corresponding to any invocation of {@link Client#execute} completes as soon as the client action + * returns its response, so it no longer exists when this method is called and cannot be used to receive cancellation notifications. + * Instead, if the HTTP channel is closed while sending a response then the REST layer will invoke {@link RestResponse#close}. If the + * HTTP channel is closed while the REST layer is waiting for a continuation then the {@link RestResponse} will not be closed until the + * continuation listener is completed. Implementations will typically explicitly create a {@link CancellableTask} to represent the + * computation and transmission of the entire {@link RestResponse}, and will cancel this task if the {@link RestResponse} is closed + * prematurely.

+ * + * @param listener Listener to complete with the next part of the body. By the point this is called we have already started to send + * the body of the response, so there's no good ways to handle an exception here. Completing the listener exceptionally + * will log an error, abort sending the response, and close the HTTP connection. + */ + void getContinuation(ActionListener listener); + /** * Serializes approximately as many bytes of the response as request by {@code sizeHint} to a {@link ReleasableBytesReference} that * is created from buffers backed by the given {@code recycler}. @@ -102,6 +139,17 @@ public boolean isDone() { return serialization.hasNext() == false; } + @Override + public boolean isEndOfResponse() { + return true; + } + + @Override + public void getContinuation(ActionListener listener) { + assert false : "no continuations"; + listener.onFailure(new IllegalStateException("no continuations available")); + } + @Override public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) throws IOException { try { @@ -180,6 +228,17 @@ public boolean isDone() { return chunkIterator.hasNext() == false; } + @Override + public boolean isEndOfResponse() { + return true; + } + + @Override + public void getContinuation(ActionListener listener) { + assert false : "no continuations"; + listener.onFailure(new IllegalStateException("no continuations available")); + } + @Override public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) throws IOException { try { diff --git a/server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBody.java b/server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBody.java index 0508828c70da1..865f433e25aa4 100644 --- a/server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBody.java +++ b/server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBody.java @@ -9,6 +9,7 @@ package org.elasticsearch.rest; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.recycler.Recycler; @@ -30,6 +31,16 @@ public boolean isDone() { return inner.isDone(); } + @Override + public boolean isEndOfResponse() { + return inner.isEndOfResponse(); + } + + @Override + public void getContinuation(ActionListener listener) { + inner.getContinuation(listener.map(continuation -> new LoggingChunkedRestResponseBody(continuation, loggerStream))); + } + @Override public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) throws IOException { var chunk = inner.encodeChunk(sizeHint, recycler); diff --git a/server/src/main/java/org/elasticsearch/rest/RestController.java b/server/src/main/java/org/elasticsearch/rest/RestController.java index 16813f1141e12..0c08520a5dd0b 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestController.java +++ b/server/src/main/java/org/elasticsearch/rest/RestController.java @@ -934,11 +934,23 @@ public boolean isDone() { return delegate.isDone(); } + @Override + public boolean isEndOfResponse() { + return delegate.isEndOfResponse(); + } + + @Override + public void getContinuation(ActionListener listener) { + delegate.getContinuation( + listener.map(continuation -> new EncodedLengthTrackingChunkedRestResponseBody(continuation, responseLengthRecorder)) + ); + } + @Override public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) throws IOException { final ReleasableBytesReference bytesReference = delegate.encodeChunk(sizeHint, recycler); responseLengthRecorder.addChunkLength(bytesReference.length()); - if (isDone()) { + if (isDone() && isEndOfResponse()) { responseLengthRecorder.close(); } return bytesReference; diff --git a/server/src/main/java/org/elasticsearch/rest/RestResponse.java b/server/src/main/java/org/elasticsearch/rest/RestResponse.java index a4a44a5a65561..9862ab31bd53f 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestResponse.java +++ b/server/src/main/java/org/elasticsearch/rest/RestResponse.java @@ -86,6 +86,7 @@ private RestResponse(RestStatus status, String responseMediaType, BytesReference public static RestResponse chunked(RestStatus restStatus, ChunkedRestResponseBody content, @Nullable Releasable releasable) { if (content.isDone()) { + assert content.isEndOfResponse() : "response with continuations must have at least one (possibly-empty) chunk in each part"; return new RestResponse(restStatus, content.getResponseContentTypeString(), BytesArray.EMPTY, releasable); } else { return new RestResponse(restStatus, content.getResponseContentTypeString(), null, content, releasable); diff --git a/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java index 82eb88a90873f..f12d8ea5c631a 100644 --- a/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java +++ b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.lucene.util.BytesRef; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.common.ReferenceDocs; import org.elasticsearch.common.bytes.BytesArray; @@ -51,9 +52,11 @@ import org.mockito.ArgumentCaptor; import java.io.IOException; +import java.io.OutputStream; import java.net.InetSocketAddress; import java.nio.channels.ClosedChannelException; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -534,6 +537,16 @@ public boolean isDone() { return false; } + @Override + public boolean isEndOfResponse() { + throw new AssertionError("should not check for end-of-response for HEAD request"); + } + + @Override + public void getContinuation(ActionListener listener) { + throw new AssertionError("should not get any continuations for HEAD request"); + } + @Override public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) { throw new AssertionError("should not try to serialize response body for HEAD request"); @@ -677,16 +690,24 @@ public void testResponseBodyTracing() { @Override public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content) { try (var bso = new BytesStreamOutput()) { - while (content.isDone() == false) { - try (var bytes = content.encodeChunk(1 << 14, BytesRefRecycler.NON_RECYCLING_INSTANCE)) { - bytes.writeTo(bso); - } - } + writeContent(bso, content); return new TestHttpResponse(status, bso.bytes()); } catch (IOException e) { return fail(e); } } + + private static void writeContent(OutputStream bso, ChunkedRestResponseBody content) throws IOException { + while (content.isDone() == false) { + try (var bytes = content.encodeChunk(1 << 14, BytesRefRecycler.NON_RECYCLING_INSTANCE)) { + bytes.writeTo(bso); + } + } + if (content.isEndOfResponse()) { + return; + } + writeContent(bso, PlainActionFuture.get(content::getContinuation)); + } }; final RestRequest request = RestRequest.request(parserConfig(), httpRequest, httpChannel); @@ -714,7 +735,58 @@ public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody co ) ); + final var parts = new ArrayList(); + class TestBody implements ChunkedRestResponseBody { + boolean isDone; + final BytesReference thisChunk; + final BytesReference remainingChunks; + final int remainingContinuations; + + TestBody(BytesReference content, int remainingContinuations) { + if (remainingContinuations == 0) { + thisChunk = content; + remainingChunks = BytesArray.EMPTY; + } else { + var splitAt = between(0, content.length()); + thisChunk = content.slice(0, splitAt); + remainingChunks = content.slice(splitAt, content.length() - splitAt); + } + this.remainingContinuations = remainingContinuations; + } + + @Override + public boolean isDone() { + return isDone; + } + + @Override + public boolean isEndOfResponse() { + return remainingContinuations == 0; + } + + @Override + public void getContinuation(ActionListener listener) { + final var continuation = new TestBody(remainingChunks, remainingContinuations - 1); + parts.add(continuation); + listener.onResponse(continuation); + } + + @Override + public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) { + assertFalse(isDone); + isDone = true; + return ReleasableBytesReference.wrap(thisChunk); + } + + @Override + public String getResponseContentTypeString() { + return RestResponse.TEXT_CONTENT_TYPE; + } + } + final var isClosed = new AtomicBoolean(); + final var firstPart = new TestBody(responseBody, between(0, 3)); + parts.add(firstPart); assertEquals( responseBody, ChunkedLoggingStreamTestUtils.getDecodedLoggedBody( @@ -722,27 +794,13 @@ public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody co Level.TRACE, "[" + request.getRequestId() + "] response body", ReferenceDocs.HTTP_TRACER, - () -> channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBody() { - - boolean isDone; - - @Override - public boolean isDone() { - return isDone; - } - - @Override - public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) { - assertFalse(isDone); - isDone = true; - return ReleasableBytesReference.wrap(responseBody); - } - - @Override - public String getResponseContentTypeString() { - return RestResponse.TEXT_CONTENT_TYPE; + () -> channel.sendResponse(RestResponse.chunked(RestStatus.OK, firstPart, () -> { + assertTrue(isClosed.compareAndSet(false, true)); + for (int i = 0; i < parts.size(); i++) { + assertTrue("isDone " + i, parts.get(i).isDone()); + assertEquals("isEndOfResponse " + i, i == parts.size() - 1, parts.get(i).isEndOfResponse()); } - }, () -> assertTrue(isClosed.compareAndSet(false, true)))) + })) ) ); diff --git a/test/framework/src/main/java/org/elasticsearch/rest/RestResponseUtils.java b/test/framework/src/main/java/org/elasticsearch/rest/RestResponseUtils.java index dfbd7266cc4a2..1b1331fe25bbf 100644 --- a/test/framework/src/main/java/org/elasticsearch/rest/RestResponseUtils.java +++ b/test/framework/src/main/java/org/elasticsearch/rest/RestResponseUtils.java @@ -42,6 +42,7 @@ public static BytesReference getBodyContent(RestResponse restResponse) { chunk.writeTo(out); } } + assert chunkedRestResponseBody.isEndOfResponse() : "RestResponseUtils#getBodyContent does not support continuations (yet)"; out.flush(); return out.bytes(); From c6760ddcf3a568618be9fdf6f1f95b7038da5721 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Fred=C3=A9n?= <109296772+jfreden@users.noreply.github.com> Date: Mon, 27 May 2024 10:27:50 +0200 Subject: [PATCH 07/29] Remove awaits fix from dls tests (#109055) --- .../org/elasticsearch/integration/DlsFlsRequestCacheTests.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DlsFlsRequestCacheTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DlsFlsRequestCacheTests.java index f83aeb117b7b8..3fbcd00690e82 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DlsFlsRequestCacheTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DlsFlsRequestCacheTests.java @@ -252,7 +252,6 @@ public void testRequestCacheForFLS() { assertCacheState(FLS_INDEX, 2, 4); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/109010") public void testRequestCacheForBothDLSandFLS() throws ExecutionException, InterruptedException { final Client powerClient = client(); final Client limitedClient = limitedClient(); @@ -316,7 +315,6 @@ public void testRequestCacheForBothDLSandFLS() throws ExecutionException, Interr assertCacheState(INDEX, 2, 5); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/109011") public void testRequestCacheWithTemplateRoleQuery() { final Client client1 = client().filterWithHeader( Map.of("Authorization", basicAuthHeaderValue(DLS_TEMPLATE_ROLE_QUERY_USER_1, new SecureString(TEST_PASSWORD.toCharArray()))) From a5b1848c14ab387a0f222ecbf59908a06afcecd0 Mon Sep 17 00:00:00 2001 From: Luigi Dell'Aquila Date: Mon, 27 May 2024 10:36:06 +0200 Subject: [PATCH 08/29] ES|QL: more tests for coalesce() function (#109032) Adding more unit tests for `coalesce()` function, in particular adding tests for `ip`, `date` and spatial data types. This also generates the right signatures for Kibana. Related to https://github.com/elastic/elasticsearch/issues/108982 --- .../functions/kibana/definition/coalesce.json | 108 ++++++++++++++++++ .../esql/functions/types/coalesce.asciidoc | 6 + .../src/main/resources/meta.csv-spec | 6 +- .../function/scalar/nulls/Coalesce.java | 39 ++++++- .../function/scalar/nulls/CoalesceTests.java | 60 +++++++++- .../SpatialRelatesFunctionTestCase.java | 4 +- 6 files changed, 214 insertions(+), 9 deletions(-) diff --git a/docs/reference/esql/functions/kibana/definition/coalesce.json b/docs/reference/esql/functions/kibana/definition/coalesce.json index 1081b42839577..d9659fa03e809 100644 --- a/docs/reference/esql/functions/kibana/definition/coalesce.json +++ b/docs/reference/esql/functions/kibana/definition/coalesce.json @@ -34,6 +34,96 @@ "variadic" : true, "returnType" : "boolean" }, + { + "params" : [ + { + "name" : "first", + "type" : "cartesian_point", + "optional" : false, + "description" : "Expression to evaluate." + }, + { + "name" : "rest", + "type" : "cartesian_point", + "optional" : true, + "description" : "Other expression to evaluate." + } + ], + "variadic" : true, + "returnType" : "cartesian_point" + }, + { + "params" : [ + { + "name" : "first", + "type" : "cartesian_shape", + "optional" : false, + "description" : "Expression to evaluate." + }, + { + "name" : "rest", + "type" : "cartesian_shape", + "optional" : true, + "description" : "Other expression to evaluate." + } + ], + "variadic" : true, + "returnType" : "cartesian_shape" + }, + { + "params" : [ + { + "name" : "first", + "type" : "datetime", + "optional" : false, + "description" : "Expression to evaluate." + }, + { + "name" : "rest", + "type" : "datetime", + "optional" : true, + "description" : "Other expression to evaluate." + } + ], + "variadic" : true, + "returnType" : "datetime" + }, + { + "params" : [ + { + "name" : "first", + "type" : "geo_point", + "optional" : false, + "description" : "Expression to evaluate." + }, + { + "name" : "rest", + "type" : "geo_point", + "optional" : true, + "description" : "Other expression to evaluate." + } + ], + "variadic" : true, + "returnType" : "geo_point" + }, + { + "params" : [ + { + "name" : "first", + "type" : "geo_shape", + "optional" : false, + "description" : "Expression to evaluate." + }, + { + "name" : "rest", + "type" : "geo_shape", + "optional" : true, + "description" : "Other expression to evaluate." + } + ], + "variadic" : true, + "returnType" : "geo_shape" + }, { "params" : [ { @@ -64,6 +154,24 @@ "variadic" : true, "returnType" : "integer" }, + { + "params" : [ + { + "name" : "first", + "type" : "ip", + "optional" : false, + "description" : "Expression to evaluate." + }, + { + "name" : "rest", + "type" : "ip", + "optional" : true, + "description" : "Other expression to evaluate." + } + ], + "variadic" : true, + "returnType" : "ip" + }, { "params" : [ { diff --git a/docs/reference/esql/functions/types/coalesce.asciidoc b/docs/reference/esql/functions/types/coalesce.asciidoc index e7d513f2aad86..a5d8f85aa564e 100644 --- a/docs/reference/esql/functions/types/coalesce.asciidoc +++ b/docs/reference/esql/functions/types/coalesce.asciidoc @@ -7,8 +7,14 @@ first | rest | result boolean | boolean | boolean boolean | | boolean +cartesian_point | cartesian_point | cartesian_point +cartesian_shape | cartesian_shape | cartesian_shape +datetime | datetime | datetime +geo_point | geo_point | geo_point +geo_shape | geo_shape | geo_shape integer | integer | integer integer | | integer +ip | ip | ip keyword | keyword | keyword keyword | | keyword long | long | long diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec index c824f837bf249..f68dd15e9c516 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec @@ -14,7 +14,7 @@ synopsis:keyword "double cbrt(number:double|integer|long|unsigned_long)" "double|integer|long|unsigned_long ceil(number:double|integer|long|unsigned_long)" "boolean cidr_match(ip:ip, blockX...:keyword|text)" -"boolean|text|integer|keyword|long coalesce(first:boolean|text|integer|keyword|long, ?rest...:boolean|text|integer|keyword|long)" +"boolean|cartesian_point|cartesian_shape|date|geo_point|geo_shape|integer|ip|keyword|long|text coalesce(first:boolean|cartesian_point|cartesian_shape|date|geo_point|geo_shape|integer|ip|keyword|long|text, ?rest...:boolean|cartesian_point|cartesian_shape|date|geo_point|geo_shape|integer|ip|keyword|long|text)" "keyword concat(string1:keyword|text, string2...:keyword|text)" "double cos(angle:double|integer|long|unsigned_long)" "double cosh(angle:double|integer|long|unsigned_long)" @@ -128,7 +128,7 @@ case |[condition, trueValue] |[boolean, "boolean|cartesian cbrt |number |"double|integer|long|unsigned_long" |"Numeric expression. If `null`, the function returns `null`." ceil |number |"double|integer|long|unsigned_long" |Numeric expression. If `null`, the function returns `null`. cidr_match |[ip, blockX] |[ip, "keyword|text"] |[IP address of type `ip` (both IPv4 and IPv6 are supported)., CIDR block to test the IP against.] -coalesce |first |"boolean|text|integer|keyword|long" |Expression to evaluate. +coalesce |first |"boolean|cartesian_point|cartesian_shape|date|geo_point|geo_shape|integer|ip|keyword|long|text" |Expression to evaluate. concat |[string1, string2] |["keyword|text", "keyword|text"] |[Strings to concatenate., Strings to concatenate.] cos |angle |"double|integer|long|unsigned_long" |An angle, in radians. If `null`, the function returns `null`. cosh |angle |"double|integer|long|unsigned_long" |An angle, in radians. If `null`, the function returns `null`. @@ -359,7 +359,7 @@ case |"boolean|cartesian_point|date|double|geo_point|integer|ip|keyword cbrt |double |false |false |false ceil |"double|integer|long|unsigned_long" |false |false |false cidr_match |boolean |[false, false] |true |false -coalesce |"boolean|text|integer|keyword|long" |false |true |false +coalesce |"boolean|cartesian_point|cartesian_shape|date|geo_point|geo_shape|integer|ip|keyword|long|text" |false |true |false concat |keyword |[false, false] |true |false cos |double |false |false |false cosh |double |false |false |false diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/Coalesce.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/Coalesce.java index 97d558ce53b68..5c823f47f794f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/Coalesce.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/Coalesce.java @@ -43,7 +43,18 @@ public class Coalesce extends EsqlScalarFunction implements OptionalArgument { private DataType dataType; @FunctionInfo( - returnType = { "boolean", "text", "integer", "keyword", "long" }, + returnType = { + "boolean", + "cartesian_point", + "cartesian_shape", + "date", + "geo_point", + "geo_shape", + "integer", + "ip", + "keyword", + "long", + "text" }, description = "Returns the first of its arguments that is not null. If all arguments are null, it returns `null`.", examples = { @Example(file = "null", tag = "coalesce") } ) @@ -51,12 +62,34 @@ public Coalesce( Source source, @Param( name = "first", - type = { "boolean", "text", "integer", "keyword", "long" }, + type = { + "boolean", + "cartesian_point", + "cartesian_shape", + "date", + "geo_point", + "geo_shape", + "integer", + "ip", + "keyword", + "long", + "text" }, description = "Expression to evaluate." ) Expression first, @Param( name = "rest", - type = { "boolean", "text", "integer", "keyword", "long" }, + type = { + "boolean", + "cartesian_point", + "cartesian_shape", + "date", + "geo_point", + "geo_shape", + "integer", + "ip", + "keyword", + "long", + "text" }, description = "Other expression to evaluate.", optional = true ) List rest diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceTests.java index b3b0fcae2f3d5..7e925a1cb7f25 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceTests.java @@ -10,6 +10,7 @@ import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.elasticsearch.common.network.NetworkAddress; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.EvalOperator; @@ -18,13 +19,19 @@ import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.DataTypes; import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.evaluator.EvalMapper; import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.scalar.VaragsTestCaseBuilder; +import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesFunctionTestCase; import org.elasticsearch.xpack.esql.planner.Layout; +import org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter; +import org.elasticsearch.xpack.esql.type.EsqlDataTypes; +import java.time.ZonedDateTime; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -45,12 +52,63 @@ public CoalesceTests(@Name("TestCase") Supplier testC */ @ParametersFactory public static Iterable parameters() { + List suppliers = new ArrayList<>(); VaragsTestCaseBuilder builder = new VaragsTestCaseBuilder(type -> "Coalesce"); builder.expectString(strings -> strings.filter(v -> v != null).findFirst()); builder.expectLong(longs -> longs.filter(v -> v != null).findFirst()); builder.expectInt(ints -> ints.filter(v -> v != null).findFirst()); builder.expectBoolean(booleans -> booleans.filter(v -> v != null).findFirst()); - return parameterSuppliersFromTypedData(builder.suppliers()); + suppliers.addAll(builder.suppliers()); + addSpatialCombinations(suppliers); + suppliers.add(new TestCaseSupplier(List.of(DataTypes.IP, DataTypes.IP), () -> { + var first = randomBoolean() ? null : EsqlDataTypeConverter.stringToIP(NetworkAddress.format(randomIp(true))); + var second = EsqlDataTypeConverter.stringToIP(NetworkAddress.format(randomIp(true))); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(first, DataTypes.IP, "first"), + new TestCaseSupplier.TypedData(second, DataTypes.IP, "second") + ), + "CoalesceEvaluator[values=[Attribute[channel=0], Attribute[channel=1]]]", + DataTypes.IP, + equalTo(first == null ? second : first) + ); + })); + suppliers.add(new TestCaseSupplier(List.of(DataTypes.DATETIME, DataTypes.DATETIME), () -> { + Long firstDate = randomBoolean() ? null : ZonedDateTime.parse("2023-12-04T10:15:30Z").toInstant().toEpochMilli(); + Long secondDate = ZonedDateTime.parse("2023-12-05T10:45:00Z").toInstant().toEpochMilli(); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(firstDate, DataTypes.DATETIME, "first"), + new TestCaseSupplier.TypedData(secondDate, DataTypes.DATETIME, "second") + ), + "CoalesceEvaluator[values=[Attribute[channel=0], Attribute[channel=1]]]", + DataTypes.DATETIME, + equalTo(firstDate == null ? secondDate : firstDate) + ); + })); + + return parameterSuppliersFromTypedData(suppliers); + } + + protected static void addSpatialCombinations(List suppliers) { + for (DataType dataType : List.of( + EsqlDataTypes.GEO_POINT, + EsqlDataTypes.GEO_SHAPE, + EsqlDataTypes.CARTESIAN_POINT, + EsqlDataTypes.CARTESIAN_SHAPE + )) { + TestCaseSupplier.TypedDataSupplier leftDataSupplier = SpatialRelatesFunctionTestCase.testCaseSupplier(dataType); + TestCaseSupplier.TypedDataSupplier rightDataSupplier = SpatialRelatesFunctionTestCase.testCaseSupplier(dataType); + suppliers.add( + TestCaseSupplier.testCaseSupplier( + leftDataSupplier, + rightDataSupplier, + (l, r) -> equalTo("CoalesceEvaluator[values=[Attribute[channel=0], Attribute[channel=1]]]"), + dataType, + (l, r) -> l + ) + ); + } } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesFunctionTestCase.java index a38ad43c00f71..c689adfe50b29 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesFunctionTestCase.java @@ -112,7 +112,7 @@ private static String compatibleTypes(DataType spatialDataType) { return Strings.join(compatibleTypeNames(spatialDataType), " or "); } - private static TestCaseSupplier.TypedDataSupplier testCaseSupplier(DataType dataType) { + public static TestCaseSupplier.TypedDataSupplier testCaseSupplier(DataType dataType) { return switch (dataType.esType()) { case "geo_point" -> TestCaseSupplier.geoPointCases(() -> false).get(0); case "geo_shape" -> TestCaseSupplier.geoShapeCases(() -> false).get(0); @@ -190,7 +190,7 @@ private static DataType pickSpatialType(DataType leftType, DataType rightType) { } } - private static Matcher spatialEvaluatorString(DataType leftType, DataType rightType) { + public static Matcher spatialEvaluatorString(DataType leftType, DataType rightType) { String crsType = isSpatialGeo(pickSpatialType(leftType, rightType)) ? "Geo" : "Cartesian"; return equalTo( getFunctionClassName() + crsType + "SourceAndSourceEvaluator[leftValue=Attribute[channel=0], rightValue=Attribute[channel=1]]" From d770ce28d158ee6bb5d7052fe41b1550b6bab272 Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 27 May 2024 09:48:47 +0100 Subject: [PATCH 09/29] Remove deprecated ctors in `GetShutdownStatusAction$Request` (#108962) These things are now unused. Relates #107984 --- .../xpack/shutdown/GetShutdownStatusAction.java | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/GetShutdownStatusAction.java b/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/GetShutdownStatusAction.java index d88d3c35bf3ac..7e5498a7676ba 100644 --- a/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/GetShutdownStatusAction.java +++ b/x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/GetShutdownStatusAction.java @@ -45,18 +45,6 @@ public static class Request extends MasterNodeRequest { private final String[] nodeIds; - @Deprecated(forRemoval = true) // temporary compatibility shim - public Request() { - super(MasterNodeRequest.TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT); - nodeIds = Strings.EMPTY_ARRAY; - } - - @Deprecated(forRemoval = true) // temporary compatibility shim - public Request(String nodeId) { - super(MasterNodeRequest.TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT); - nodeIds = new String[] { nodeId }; - } - public Request(TimeValue masterNodeTimeout, String... nodeIds) { super(masterNodeTimeout); this.nodeIds = nodeIds; From 88bd9c39263adc73cef89d99aadf18b89616f92b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Fred=C3=A9n?= <109296772+jfreden@users.noreply.github.com> Date: Mon, 27 May 2024 10:53:26 +0200 Subject: [PATCH 10/29] AwaitsFix: https://github.com/elastic/elasticsearch/issues/109058 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index 4f64492466375..be0bc4a1b8b1c 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -35,6 +35,9 @@ tests: - class: "org.elasticsearch.upgrades.MlTrainedModelsUpgradeIT" issue: "https://github.com/elastic/elasticsearch/issues/108993" method: "testTrainedModelInference" +- class: "org.elasticsearch.xpack.security.authc.esnative.ReservedRealmElasticAutoconfigIntegTests" + issue: "https://github.com/elastic/elasticsearch/issues/109058" + method: "testAutoconfigSucceedsAfterPromotionFailure" # Examples: # # Mute a single test case in a YAML test suite: From 9eb8accebf1df5e68dbda0df49490407aa6e4ce2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Fred=C3=A9n?= <109296772+jfreden@users.noreply.github.com> Date: Mon, 27 May 2024 10:54:28 +0200 Subject: [PATCH 11/29] AwaitsFix: https://github.com/elastic/elasticsearch/issues/109059 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index be0bc4a1b8b1c..8a2c56d94c5b6 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -38,6 +38,9 @@ tests: - class: "org.elasticsearch.xpack.security.authc.esnative.ReservedRealmElasticAutoconfigIntegTests" issue: "https://github.com/elastic/elasticsearch/issues/109058" method: "testAutoconfigSucceedsAfterPromotionFailure" +- class: "org.elasticsearch.xpack.security.authc.esnative.ReservedRealmElasticAutoconfigIntegTests" + issue: "https://github.com/elastic/elasticsearch/issues/109059" + method: "testAutoconfigFailedPasswordPromotion" # Examples: # # Mute a single test case in a YAML test suite: From 520a1599a65301c0cac44afe1ea306d3f718416f Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 27 May 2024 10:46:03 +0100 Subject: [PATCH 12/29] AwaitsFix for #101608 --- .../repositories/s3/S3BlobStoreRepositoryTests.java | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java b/modules/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java index 94cfce5357857..ebc60e8027d81 100644 --- a/modules/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java +++ b/modules/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java @@ -225,6 +225,7 @@ public void testAbortRequestStats() throws Exception { } @TestIssueLogging(issueUrl = "https://github.com/elastic/elasticsearch/issues/101608", value = "com.amazonaws.request:DEBUG") + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/101608") public void testMetrics() throws Exception { // Create the repository and perform some activities final String repository = createRepository(randomRepositoryName()); From 8b64ee3970b35231d7b3db9b854a04d3068552ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Fred=C3=A9n?= <109296772+jfreden@users.noreply.github.com> Date: Mon, 27 May 2024 12:27:42 +0200 Subject: [PATCH 13/29] Await at least one security migration in AutoconfigIntegTests (#109062) Relates: https://github.com/elastic/elasticsearch/issues/109058 https://github.com/elastic/elasticsearch/issues/109059 `ReservedRealmElasticAutoconfigIntegTests` deletes the `.security-7` and makes the index metadata read only, this causes the persistent task that run the migration job to not be able to update its status to completed even though it is successful. When the `teardown` happens at the end of the test it then fails because it looks like the migration task is not finished (still not marked as completed or failed). This adds a check to these specific tests that makes sure the migration finishes before the index is removed/its metadata is made read only. --- muted-tests.yml | 6 ---- .../test/SecuritySingleNodeTestCase.java | 2 +- ...ervedRealmElasticAutoconfigIntegTests.java | 29 +++++++++++++++++-- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index 8a2c56d94c5b6..4f64492466375 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -35,12 +35,6 @@ tests: - class: "org.elasticsearch.upgrades.MlTrainedModelsUpgradeIT" issue: "https://github.com/elastic/elasticsearch/issues/108993" method: "testTrainedModelInference" -- class: "org.elasticsearch.xpack.security.authc.esnative.ReservedRealmElasticAutoconfigIntegTests" - issue: "https://github.com/elastic/elasticsearch/issues/109058" - method: "testAutoconfigSucceedsAfterPromotionFailure" -- class: "org.elasticsearch.xpack.security.authc.esnative.ReservedRealmElasticAutoconfigIntegTests" - issue: "https://github.com/elastic/elasticsearch/issues/109059" - method: "testAutoconfigFailedPasswordPromotion" # Examples: # # Mute a single test case in a YAML test suite: diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/test/SecuritySingleNodeTestCase.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/test/SecuritySingleNodeTestCase.java index 16a3ea53eeeac..2eb45021a5bfe 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/test/SecuritySingleNodeTestCase.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/test/SecuritySingleNodeTestCase.java @@ -93,7 +93,7 @@ private boolean isMigrationComplete(ClusterState state) { return getTaskWithId(state, TASK_NAME) == null; } - protected void awaitSecurityMigration() { + private void awaitSecurityMigration() { final var latch = new CountDownLatch(1); ClusterService clusterService = getInstanceFromNode(ClusterService.class); clusterService.addListener((event) -> { diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/esnative/ReservedRealmElasticAutoconfigIntegTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/esnative/ReservedRealmElasticAutoconfigIntegTests.java index c04630d457959..ae48d7563494f 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/esnative/ReservedRealmElasticAutoconfigIntegTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/esnative/ReservedRealmElasticAutoconfigIntegTests.java @@ -15,7 +15,10 @@ import org.elasticsearch.client.Request; import org.elasticsearch.client.RequestOptions; import org.elasticsearch.client.ResponseException; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.MockSecureSettings; import org.elasticsearch.common.settings.SecureString; @@ -29,7 +32,10 @@ import org.elasticsearch.xpack.core.security.test.TestRestrictedIndices; import org.junit.BeforeClass; +import java.util.concurrent.CountDownLatch; + import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.xpack.core.security.action.UpdateIndexMigrationVersionAction.MIGRATION_VERSION_CUSTOM_KEY; import static org.elasticsearch.xpack.security.support.SecuritySystemIndices.SECURITY_MAIN_ALIAS; import static org.hamcrest.Matchers.is; @@ -64,6 +70,25 @@ protected SecureString getBootstrapPassword() { return null; // no bootstrap password for this test } + private boolean isMigrationComplete(ClusterState state) { + IndexMetadata indexMetadata = state.metadata().getIndices().get(TestRestrictedIndices.INTERNAL_SECURITY_MAIN_INDEX_7); + return indexMetadata.getCustomData(MIGRATION_VERSION_CUSTOM_KEY) != null; + } + + private void awaitSecurityMigrationRanOnce() { + final var latch = new CountDownLatch(1); + ClusterService clusterService = getInstanceFromNode(ClusterService.class); + clusterService.addListener((event) -> { + if (isMigrationComplete(event.state())) { + latch.countDown(); + } + }); + if (isMigrationComplete(clusterService.state())) { + latch.countDown(); + } + safeAwait(latch); + } + public void testAutoconfigFailedPasswordPromotion() { try { // prevents the .security index from being created automatically (after elastic user authentication) @@ -80,7 +105,7 @@ public void testAutoconfigFailedPasswordPromotion() { assertThat(getIndexResponse.getIndices().length, is(1)); assertThat(getIndexResponse.getIndices()[0], is(TestRestrictedIndices.INTERNAL_SECURITY_MAIN_INDEX_7)); // Security migration needs to finish before deleting the index - awaitSecurityMigration(); + awaitSecurityMigrationRanOnce(); DeleteIndexRequest deleteIndexRequest = new DeleteIndexRequest(getIndexResponse.getIndices()); assertAcked(client().admin().indices().delete(deleteIndexRequest).actionGet()); } @@ -140,7 +165,7 @@ public void testAutoconfigSucceedsAfterPromotionFailure() throws Exception { putUserRequest.roles(Strings.EMPTY_ARRAY); client().execute(PutUserAction.INSTANCE, putUserRequest).get(); // Security migration needs to finish before making the cluster read only - awaitSecurityMigration(); + awaitSecurityMigrationRanOnce(); // but then make the cluster read-only ClusterUpdateSettingsRequest updateSettingsRequest = new ClusterUpdateSettingsRequest(); From f5e6f2ab401a8a96f89ec962563c6abc07f976f8 Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 27 May 2024 12:26:16 +0100 Subject: [PATCH 14/29] Rename to `ChunkedRestResponseBodyPart` (#109057) Follow-up to #104851 to rename some symbols to reflect that the class formerly known as a `ChunkedRestResponseBody` may now only be _part_ of the whole response body. --- .../netty4/Netty4ChunkedContinuationsIT.java | 37 ++++++------ .../http/netty4/Netty4ChunkedEncodingIT.java | 10 ++-- .../http/netty4/Netty4PipeliningIT.java | 10 ++-- .../netty4/Netty4ChunkedHttpContinuation.java | 12 ++-- .../netty4/Netty4ChunkedHttpResponse.java | 12 ++-- .../netty4/Netty4HttpPipeliningHandler.java | 39 ++++++------- .../http/netty4/Netty4HttpRequest.java | 6 +- .../Netty4HttpPipeliningHandlerTests.java | 12 ++-- .../Netty4HttpServerTransportTests.java | 6 +- .../elasticsearch/rest/RestControllerIT.java | 2 +- .../http/DefaultRestChannel.java | 8 +-- .../org/elasticsearch/http/HttpRequest.java | 4 +- ....java => ChunkedRestResponseBodyPart.java} | 56 +++++++++---------- ...> LoggingChunkedRestResponseBodyPart.java} | 18 +++--- .../elasticsearch/rest/RestController.java | 26 ++++----- .../org/elasticsearch/rest/RestResponse.java | 12 ++-- .../action/RestChunkedToXContentListener.java | 4 +- .../cluster/RestNodesHotThreadsAction.java | 2 +- .../rest/action/cat/RestTable.java | 6 +- .../action/info/RestClusterInfoAction.java | 4 +- .../http/DefaultRestChannelTests.java | 42 +++++++------- .../elasticsearch/http/TestHttpRequest.java | 4 +- ... => ChunkedRestResponseBodyPartTests.java} | 19 ++++--- .../rest/RestControllerTests.java | 2 +- .../elasticsearch/rest/RestResponseTests.java | 2 +- .../rest/action/cat/RestTableTests.java | 7 ++- .../elasticsearch/rest/RestResponseUtils.java | 10 ++-- .../test/rest/FakeRestRequest.java | 4 +- .../esql/action/EsqlResponseListener.java | 6 +- 29 files changed, 195 insertions(+), 187 deletions(-) rename server/src/main/java/org/elasticsearch/rest/{ChunkedRestResponseBody.java => ChunkedRestResponseBodyPart.java} (79%) rename server/src/main/java/org/elasticsearch/rest/{LoggingChunkedRestResponseBody.java => LoggingChunkedRestResponseBodyPart.java} (68%) rename server/src/test/java/org/elasticsearch/rest/{ChunkedRestResponseBodyTests.java => ChunkedRestResponseBodyPartTests.java} (81%) diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedContinuationsIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedContinuationsIT.java index 25195a1176fb8..77333677120a9 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedContinuationsIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedContinuationsIT.java @@ -60,7 +60,7 @@ import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.BaseRestHandler; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; @@ -370,31 +370,31 @@ public void writeTo(StreamOutput out) { TransportAction.localOnly(); } - public ChunkedRestResponseBody getChunkedBody() { - return getChunkBatch(0); + public ChunkedRestResponseBodyPart getFirstResponseBodyPart() { + return getResponseBodyPart(0); } - private ChunkedRestResponseBody getChunkBatch(int batchIndex) { + private ChunkedRestResponseBodyPart getResponseBodyPart(int batchIndex) { if (batchIndex == failIndex && randomBoolean()) { throw new ElasticsearchException("simulated failure creating next batch"); } - return new ChunkedRestResponseBody() { + return new ChunkedRestResponseBodyPart() { private final Iterator lines = Iterators.forRange(0, 3, i -> "batch-" + batchIndex + "-chunk-" + i + "\n"); @Override - public boolean isDone() { + public boolean isPartComplete() { return lines.hasNext() == false; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { return batchIndex == 2; } @Override - public void getContinuation(ActionListener listener) { - executor.execute(ActionRunnable.supply(listener, () -> getChunkBatch(batchIndex + 1))); + public void getNextPart(ActionListener listener) { + executor.execute(ActionRunnable.supply(listener, () -> getResponseBodyPart(batchIndex + 1))); } @Override @@ -486,11 +486,12 @@ public void accept(RestChannel channel) { @Override protected void processResponse(Response response) { try { - final var responseBody = response.getChunkedBody(); // might fail, so do this before acquiring ref + final var responseBody = response.getFirstResponseBodyPart(); + // preceding line might fail, so needs to be done before acquiring the sendResponse ref refs.mustIncRef(); channel.sendResponse(RestResponse.chunked(RestStatus.OK, responseBody, refs::decRef)); } finally { - refs.decRef(); + refs.decRef(); // release the ref acquired at the top of accept() } } }); @@ -534,26 +535,26 @@ public void writeTo(StreamOutput out) { TransportAction.localOnly(); } - public ChunkedRestResponseBody getChunkedBody() { - return new ChunkedRestResponseBody() { + public ChunkedRestResponseBodyPart getResponseBodyPart() { + return new ChunkedRestResponseBodyPart() { private final Iterator lines = Iterators.single("infinite response\n"); @Override - public boolean isDone() { + public boolean isPartComplete() { return lines.hasNext() == false; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { return false; } @Override - public void getContinuation(ActionListener listener) { + public void getNextPart(ActionListener listener) { computingContinuation = true; executor.execute(ActionRunnable.supply(listener, () -> { computingContinuation = false; - return getChunkedBody(); + return getResponseBodyPart(); })); } @@ -628,7 +629,7 @@ public void accept(RestChannel channel) { client.execute(TYPE, new Request(), new RestActionListener<>(channel) { @Override protected void processResponse(Response response) { - channel.sendResponse(RestResponse.chunked(RestStatus.OK, response.getChunkedBody(), () -> { + channel.sendResponse(RestResponse.chunked(RestStatus.OK, response.getResponseBodyPart(), () -> { // cancellation notification only happens while processing a continuation, not while computing // the next one; prompt cancellation requires use of something like RestCancellableNodeClient assertFalse(response.computingContinuation); diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedEncodingIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedEncodingIT.java index b2a54e2027308..e3f60ea7a48e0 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedEncodingIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedEncodingIT.java @@ -37,7 +37,7 @@ import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.BaseRestHandler; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; @@ -245,19 +245,19 @@ public BytesReference next() { private static void sendChunksResponse(RestChannel channel, Iterator chunkIterator) { final var localRefs = refs; // single volatile read if (localRefs != null && localRefs.tryIncRef()) { - channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBody() { + channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBodyPart() { @Override - public boolean isDone() { + public boolean isPartComplete() { return chunkIterator.hasNext() == false; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { return true; } @Override - public void getContinuation(ActionListener listener) { + public void getNextPart(ActionListener listener) { assert false : "no continuations"; } diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4PipeliningIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4PipeliningIT.java index ce8da0c08af54..89a76dd26e285 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4PipeliningIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4PipeliningIT.java @@ -34,7 +34,7 @@ import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.BaseRestHandler; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestRequest; @@ -243,21 +243,21 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli throw new IllegalArgumentException("[" + FAIL_AFTER_BYTES_PARAM + "] must be present and non-negative"); } return channel -> randomExecutor(client.threadPool()).execute( - () -> channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBody() { + () -> channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBodyPart() { int bytesRemaining = failAfterBytes; @Override - public boolean isDone() { + public boolean isPartComplete() { return false; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { return true; } @Override - public void getContinuation(ActionListener listener) { + public void getNextPart(ActionListener listener) { fail("no continuations here"); } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpContinuation.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpContinuation.java index 156f1c27aa67c..cde0249216981 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpContinuation.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpContinuation.java @@ -10,16 +10,16 @@ import io.netty.util.concurrent.PromiseCombiner; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; final class Netty4ChunkedHttpContinuation implements Netty4HttpResponse { private final int sequence; - private final ChunkedRestResponseBody body; + private final ChunkedRestResponseBodyPart bodyPart; private final PromiseCombiner combiner; - Netty4ChunkedHttpContinuation(int sequence, ChunkedRestResponseBody body, PromiseCombiner combiner) { + Netty4ChunkedHttpContinuation(int sequence, ChunkedRestResponseBodyPart bodyPart, PromiseCombiner combiner) { this.sequence = sequence; - this.body = body; + this.bodyPart = bodyPart; this.combiner = combiner; } @@ -28,8 +28,8 @@ public int getSequence() { return sequence; } - public ChunkedRestResponseBody body() { - return body; + public ChunkedRestResponseBodyPart bodyPart() { + return bodyPart; } public PromiseCombiner combiner() { diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpResponse.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpResponse.java index 783c02da0bbcc..3abab9fa2526f 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpResponse.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpResponse.java @@ -13,7 +13,7 @@ import io.netty.handler.codec.http.HttpVersion; import org.elasticsearch.http.HttpResponse; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestStatus; /** @@ -23,16 +23,16 @@ final class Netty4ChunkedHttpResponse extends DefaultHttpResponse implements Net private final int sequence; - private final ChunkedRestResponseBody body; + private final ChunkedRestResponseBodyPart firstBodyPart; - Netty4ChunkedHttpResponse(int sequence, HttpVersion version, RestStatus status, ChunkedRestResponseBody body) { + Netty4ChunkedHttpResponse(int sequence, HttpVersion version, RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) { super(version, HttpResponseStatus.valueOf(status.getStatus())); this.sequence = sequence; - this.body = body; + this.firstBodyPart = firstBodyPart; } - public ChunkedRestResponseBody body() { - return body; + public ChunkedRestResponseBodyPart firstBodyPart() { + return firstBodyPart; } @Override diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java index 8280c438613a2..9cf210c2a8aab 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java @@ -34,7 +34,7 @@ import org.elasticsearch.core.Booleans; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.transport.Transports; import org.elasticsearch.transport.netty4.Netty4Utils; import org.elasticsearch.transport.netty4.Netty4WriteThrottlingHandler; @@ -58,7 +58,7 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler { private final int maxEventsHeld; private final PriorityQueue> outboundHoldingQueue; - private record ChunkedWrite(PromiseCombiner combiner, ChannelPromise onDone, ChunkedRestResponseBody responseBody) {} + private record ChunkedWrite(PromiseCombiner combiner, ChannelPromise onDone, ChunkedRestResponseBodyPart responseBodyPart) {} /** * The current {@link ChunkedWrite} if a chunked write is executed at the moment. @@ -214,9 +214,9 @@ private void doWriteChunkedResponse(ChannelHandlerContext ctx, Netty4ChunkedHttp final PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); final ChannelPromise first = ctx.newPromise(); combiner.add((Future) first); - final var responseBody = readyResponse.body(); + final var firstBodyPart = readyResponse.firstBodyPart(); assert currentChunkedWrite == null; - currentChunkedWrite = new ChunkedWrite(combiner, promise, responseBody); + currentChunkedWrite = new ChunkedWrite(combiner, promise, firstBodyPart); if (enqueueWrite(ctx, readyResponse, first)) { // We were able to write out the first chunk directly, try writing out subsequent chunks until the channel becomes unwritable. // NB "writable" means there's space in the downstream ChannelOutboundBuffer, we aren't trying to saturate the physical channel. @@ -232,9 +232,10 @@ private void doWriteChunkedResponse(ChannelHandlerContext ctx, Netty4ChunkedHttp private void doWriteChunkedContinuation(ChannelHandlerContext ctx, Netty4ChunkedHttpContinuation continuation, ChannelPromise promise) { final PromiseCombiner combiner = continuation.combiner(); assert currentChunkedWrite == null; - final var responseBody = continuation.body(); - assert responseBody.isDone() == false : "response with continuations must have at least one (possibly-empty) chunk in each part"; - currentChunkedWrite = new ChunkedWrite(combiner, promise, responseBody); + final var bodyPart = continuation.bodyPart(); + assert bodyPart.isPartComplete() == false + : "response with continuations must have at least one (possibly-empty) chunk in each part"; + currentChunkedWrite = new ChunkedWrite(combiner, promise, bodyPart); // NB "writable" means there's space in the downstream ChannelOutboundBuffer, we aren't trying to saturate the physical channel. while (ctx.channel().isWritable()) { if (writeChunk(ctx, currentChunkedWrite)) { @@ -251,9 +252,9 @@ private void finishChunkedWrite() { } final var finishingWrite = currentChunkedWrite; currentChunkedWrite = null; - final var finishingWriteBody = finishingWrite.responseBody(); - assert finishingWriteBody.isDone(); - final var endOfResponse = finishingWriteBody.isEndOfResponse(); + final var finishingWriteBodyPart = finishingWrite.responseBodyPart(); + assert finishingWriteBodyPart.isPartComplete(); + final var endOfResponse = finishingWriteBodyPart.isLastPart(); if (endOfResponse) { writeSequence++; finishingWrite.combiner().finish(finishingWrite.onDone()); @@ -261,7 +262,7 @@ private void finishChunkedWrite() { final var channel = finishingWrite.onDone().channel(); ActionListener.run(ActionListener.assertOnce(new ActionListener<>() { @Override - public void onResponse(ChunkedRestResponseBody continuation) { + public void onResponse(ChunkedRestResponseBodyPart continuation) { channel.writeAndFlush( new Netty4ChunkedHttpContinuation(writeSequence, continuation, finishingWrite.combiner()), finishingWrite.onDone() // pass the terminal listener/promise along the line @@ -296,7 +297,7 @@ private void checkShutdown() { } } - }), finishingWriteBody::getContinuation); + }), finishingWriteBodyPart::getNextPart); } } @@ -374,22 +375,22 @@ private boolean doFlush(ChannelHandlerContext ctx) throws IOException { } private boolean writeChunk(ChannelHandlerContext ctx, ChunkedWrite chunkedWrite) { - final var body = chunkedWrite.responseBody(); + final var bodyPart = chunkedWrite.responseBodyPart(); final var combiner = chunkedWrite.combiner(); - assert body.isDone() == false : "should not continue to try and serialize once done"; + assert bodyPart.isPartComplete() == false : "should not continue to try and serialize once done"; final ReleasableBytesReference bytes; try { - bytes = body.encodeChunk(Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE, serverTransport.recycler()); + bytes = bodyPart.encodeChunk(Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE, serverTransport.recycler()); } catch (Exception e) { return handleChunkingFailure(ctx, chunkedWrite, e); } final ByteBuf content = Netty4Utils.toByteBuf(bytes); - final boolean done = body.isDone(); - final boolean lastChunk = done && body.isEndOfResponse(); - final ChannelFuture f = ctx.write(lastChunk ? new DefaultLastHttpContent(content) : new DefaultHttpContent(content)); + final boolean isPartComplete = bodyPart.isPartComplete(); + final boolean isBodyComplete = isPartComplete && bodyPart.isLastPart(); + final ChannelFuture f = ctx.write(isBodyComplete ? new DefaultLastHttpContent(content) : new DefaultHttpContent(content)); f.addListener(ignored -> bytes.close()); combiner.add(f); - return done; + return isPartComplete; } private boolean handleChunkingFailure(ChannelHandlerContext ctx, ChunkedWrite chunkedWrite, Exception e) { diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java index 0e1bb527fed9d..1e35f084c87ec 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java @@ -22,7 +22,7 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.http.HttpRequest; import org.elasticsearch.http.HttpResponse; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.transport.netty4.Netty4Utils; @@ -176,8 +176,8 @@ public Netty4FullHttpResponse createResponse(RestStatus status, BytesReference c } @Override - public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content) { - return new Netty4ChunkedHttpResponse(sequence, request.protocolVersion(), status, content); + public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) { + return new Netty4ChunkedHttpResponse(sequence, request.protocolVersion(), status, firstBodyPart); } @Override diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java index bb4a0939c98f0..4dca3d17bf072 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java @@ -36,7 +36,7 @@ import org.elasticsearch.common.bytes.ZeroBytesReference; import org.elasticsearch.common.recycler.Recycler; import org.elasticsearch.http.HttpResponse; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.transport.netty4.Netty4Utils; @@ -502,23 +502,23 @@ protected void handlePipelinedRequest(ChannelHandlerContext ctx, Netty4HttpReque }; } - private static ChunkedRestResponseBody getRepeatedChunkResponseBody(int chunkCount, BytesReference chunk) { - return new ChunkedRestResponseBody() { + private static ChunkedRestResponseBodyPart getRepeatedChunkResponseBody(int chunkCount, BytesReference chunk) { + return new ChunkedRestResponseBodyPart() { private int remaining = chunkCount; @Override - public boolean isDone() { + public boolean isPartComplete() { return remaining == 0; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { return true; } @Override - public void getContinuation(ActionListener listener) { + public void getNextPart(ActionListener listener) { fail("no continuations here"); } diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java index d2be4212cf41e..bc6e5fef834e8 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java @@ -71,7 +71,7 @@ import org.elasticsearch.http.NullDispatcher; import org.elasticsearch.http.netty4.internal.HttpHeadersAuthenticatorUtils; import org.elasticsearch.http.netty4.internal.HttpValidator; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestResponse; @@ -692,7 +692,7 @@ public void testHeadRequestToChunkedApi() throws InterruptedException { public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { try { channel.sendResponse( - RestResponse.chunked(OK, ChunkedRestResponseBody.fromXContent(ignored -> Iterators.single((builder, params) -> { + RestResponse.chunked(OK, ChunkedRestResponseBodyPart.fromXContent(ignored -> Iterators.single((builder, params) -> { throw new AssertionError("should not be called for HEAD REQUEST"); }), ToXContent.EMPTY_PARAMS, channel), null) ); @@ -1048,7 +1048,7 @@ public void dispatchRequest(final RestRequest request, final RestChannel channel assertEquals(request.uri(), url); final var response = RestResponse.chunked( OK, - ChunkedRestResponseBody.fromTextChunks(RestResponse.TEXT_CONTENT_TYPE, Collections.emptyIterator()), + ChunkedRestResponseBodyPart.fromTextChunks(RestResponse.TEXT_CONTENT_TYPE, Collections.emptyIterator()), responseReleasedLatch::countDown ); transportClosedFuture.addListener(ActionListener.running(() -> channel.sendResponse(response))); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/rest/RestControllerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/rest/RestControllerIT.java index 809ecbc858706..b76bec0652732 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/rest/RestControllerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/rest/RestControllerIT.java @@ -82,7 +82,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli return channel -> { final var response = RestResponse.chunked( RestStatus.OK, - ChunkedRestResponseBody.fromXContent( + ChunkedRestResponseBodyPart.fromXContent( params -> Iterators.single((b, p) -> b.startObject().endObject()), request, channel diff --git a/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java b/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java index 9719716c57ce4..f04b8f13bfe7e 100644 --- a/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java +++ b/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java @@ -21,8 +21,8 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.rest.AbstractRestChannel; -import org.elasticsearch.rest.ChunkedRestResponseBody; -import org.elasticsearch.rest.LoggingChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; +import org.elasticsearch.rest.LoggingChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.RestStatus; @@ -113,7 +113,7 @@ public void sendResponse(RestResponse restResponse) { try { final HttpResponse httpResponse; if (isHeadRequest == false && restResponse.isChunked()) { - ChunkedRestResponseBody chunkedContent = restResponse.chunkedContent(); + ChunkedRestResponseBodyPart chunkedContent = restResponse.chunkedContent(); if (httpLogger != null && httpLogger.isBodyTracerEnabled()) { final var loggerStream = httpLogger.openResponseBodyLoggingStream(request.getRequestId()); toClose.add(() -> { @@ -123,7 +123,7 @@ public void sendResponse(RestResponse restResponse) { assert false : e; // nothing much to go wrong here } }); - chunkedContent = new LoggingChunkedRestResponseBody(chunkedContent, loggerStream); + chunkedContent = new LoggingChunkedRestResponseBodyPart(chunkedContent, loggerStream); } httpResponse = httpRequest.createResponse(restResponse.status(), chunkedContent); diff --git a/server/src/main/java/org/elasticsearch/http/HttpRequest.java b/server/src/main/java/org/elasticsearch/http/HttpRequest.java index b82947e42308b..2757fa15ce477 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpRequest.java +++ b/server/src/main/java/org/elasticsearch/http/HttpRequest.java @@ -10,7 +10,7 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.core.Nullable; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestStatus; import java.util.List; @@ -40,7 +40,7 @@ enum HttpVersion { */ HttpResponse createResponse(RestStatus status, BytesReference content); - HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content); + HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart); @Nullable Exception getInboundException(); diff --git a/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBody.java b/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java similarity index 79% rename from server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBody.java rename to server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java index 2f7fc458ca020..4888b59f19561 100644 --- a/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBody.java +++ b/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java @@ -39,32 +39,32 @@ * materializing the full response into on-heap buffers up front, instead serializing only as much of the response as can be flushed to the * network right away.

* - *

Each {@link ChunkedRestResponseBody} represents a sequence of chunks that are ready for immediate transmission: if {@link #isDone} - * returns {@code false} then {@link #encodeChunk} can be called at any time and must synchronously return the next chunk to be sent. - * Many HTTP responses will be a single such sequence. However, if an implementation's {@link #isEndOfResponse} returns {@code false} at the - * end of the sequence then the transmission is paused and {@link #getContinuation} is called to compute the next sequence of chunks + *

Each {@link ChunkedRestResponseBodyPart} represents a sequence of chunks that are ready for immediate transmission: if + * {@link #isPartComplete} returns {@code false} then {@link #encodeChunk} can be called at any time and must synchronously return the next + * chunk to be sent. Many HTTP responses will be a single part, but if an implementation's {@link #isLastPart} returns {@code false} at the + * end of the part then the transmission is paused and {@link #getNextPart} is called to compute the next sequence of chunks * asynchronously.

*/ -public interface ChunkedRestResponseBody { +public interface ChunkedRestResponseBodyPart { - Logger logger = LogManager.getLogger(ChunkedRestResponseBody.class); + Logger logger = LogManager.getLogger(ChunkedRestResponseBodyPart.class); /** - * @return {@code true} if this body contains no more chunks and the REST layer should check for a possible continuation by calling - * {@link #isEndOfResponse}, or {@code false} if the REST layer should request another chunk from this body using {@link #encodeChunk}. + * @return {@code true} if this body part contains no more chunks and the REST layer should check for a possible continuation by calling + * {@link #isLastPart}, or {@code false} if the REST layer should request another chunk from this body using {@link #encodeChunk}. */ - boolean isDone(); + boolean isPartComplete(); /** - * @return {@code true} if this is the last chunked body in the response, or {@code false} if the REST layer should request further - * chunked bodies by calling {@link #getContinuation}. + * @return {@code true} if this is the last chunked body part in the response, or {@code false} if the REST layer should request further + * chunked bodies by calling {@link #getNextPart}. */ - boolean isEndOfResponse(); + boolean isLastPart(); /** - *

Asynchronously retrieves the next part of the body. Called if {@link #isEndOfResponse} returns {@code false}.

+ *

Asynchronously retrieves the next part of the response body. Called if {@link #isLastPart} returns {@code false}.

* - *

Note that this is called on a transport thread, so implementations must take care to dispatch any nontrivial work elsewhere.

+ *

Note that this is called on a transport thread: implementations must take care to dispatch any nontrivial work elsewhere.

*

Note that the {@link Task} corresponding to any invocation of {@link Client#execute} completes as soon as the client action * returns its response, so it no longer exists when this method is called and cannot be used to receive cancellation notifications. @@ -78,7 +78,7 @@ public interface ChunkedRestResponseBody { * the body of the response, so there's no good ways to handle an exception here. Completing the listener exceptionally * will log an error, abort sending the response, and close the HTTP connection. */ - void getContinuation(ActionListener listener); + void getNextPart(ActionListener listener); /** * Serializes approximately as many bytes of the response as request by {@code sizeHint} to a {@link ReleasableBytesReference} that @@ -97,17 +97,17 @@ public interface ChunkedRestResponseBody { String getResponseContentTypeString(); /** - * Create a chunked response body to be written to a specific {@link RestChannel} from a {@link ChunkedToXContent}. + * Create a one-part chunked response body to be written to a specific {@link RestChannel} from a {@link ChunkedToXContent}. * * @param chunkedToXContent chunked x-content instance to serialize * @param params parameters to use for serialization * @param channel channel the response will be written to * @return chunked rest response body */ - static ChunkedRestResponseBody fromXContent(ChunkedToXContent chunkedToXContent, ToXContent.Params params, RestChannel channel) + static ChunkedRestResponseBodyPart fromXContent(ChunkedToXContent chunkedToXContent, ToXContent.Params params, RestChannel channel) throws IOException { - return new ChunkedRestResponseBody() { + return new ChunkedRestResponseBodyPart() { private final OutputStream out = new OutputStream() { @Override @@ -135,17 +135,17 @@ public void write(byte[] b, int off, int len) throws IOException { private BytesStream target; @Override - public boolean isDone() { + public boolean isPartComplete() { return serialization.hasNext() == false; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { return true; } @Override - public void getContinuation(ActionListener listener) { + public void getNextPart(ActionListener listener) { assert false : "no continuations"; listener.onFailure(new IllegalStateException("no continuations available")); } @@ -191,11 +191,11 @@ public String getResponseContentTypeString() { } /** - * Create a chunked response body to be written to a specific {@link RestChannel} from a stream of text chunks, each represented as a - * consumer of a {@link Writer}. + * Create a one-part chunked response body to be written to a specific {@link RestChannel} from a stream of UTF-8-encoded text chunks, + * each represented as a consumer of a {@link Writer}. */ - static ChunkedRestResponseBody fromTextChunks(String contentType, Iterator> chunkIterator) { - return new ChunkedRestResponseBody() { + static ChunkedRestResponseBodyPart fromTextChunks(String contentType, Iterator> chunkIterator) { + return new ChunkedRestResponseBodyPart() { private RecyclerBytesStreamOutput currentOutput; private final Writer writer = new OutputStreamWriter(new OutputStream() { @Override @@ -224,17 +224,17 @@ public void close() { }, StandardCharsets.UTF_8); @Override - public boolean isDone() { + public boolean isPartComplete() { return chunkIterator.hasNext() == false; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { return true; } @Override - public void getContinuation(ActionListener listener) { + public void getNextPart(ActionListener listener) { assert false : "no continuations"; listener.onFailure(new IllegalStateException("no continuations available")); } diff --git a/server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBody.java b/server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBodyPart.java similarity index 68% rename from server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBody.java rename to server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBodyPart.java index 865f433e25aa4..f7a018eaacf7e 100644 --- a/server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBody.java +++ b/server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBodyPart.java @@ -16,29 +16,29 @@ import java.io.IOException; import java.io.OutputStream; -public class LoggingChunkedRestResponseBody implements ChunkedRestResponseBody { +public class LoggingChunkedRestResponseBodyPart implements ChunkedRestResponseBodyPart { - private final ChunkedRestResponseBody inner; + private final ChunkedRestResponseBodyPart inner; private final OutputStream loggerStream; - public LoggingChunkedRestResponseBody(ChunkedRestResponseBody inner, OutputStream loggerStream) { + public LoggingChunkedRestResponseBodyPart(ChunkedRestResponseBodyPart inner, OutputStream loggerStream) { this.inner = inner; this.loggerStream = loggerStream; } @Override - public boolean isDone() { - return inner.isDone(); + public boolean isPartComplete() { + return inner.isPartComplete(); } @Override - public boolean isEndOfResponse() { - return inner.isEndOfResponse(); + public boolean isLastPart() { + return inner.isLastPart(); } @Override - public void getContinuation(ActionListener listener) { - inner.getContinuation(listener.map(continuation -> new LoggingChunkedRestResponseBody(continuation, loggerStream))); + public void getNextPart(ActionListener listener) { + inner.getNextPart(listener.map(continuation -> new LoggingChunkedRestResponseBodyPart(continuation, loggerStream))); } @Override diff --git a/server/src/main/java/org/elasticsearch/rest/RestController.java b/server/src/main/java/org/elasticsearch/rest/RestController.java index 0c08520a5dd0b..b08f6ed81017a 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestController.java +++ b/server/src/main/java/org/elasticsearch/rest/RestController.java @@ -857,7 +857,7 @@ public void sendResponse(RestResponse response) { final var headers = response.getHeaders(); response = RestResponse.chunked( response.status(), - new EncodedLengthTrackingChunkedRestResponseBody(response.chunkedContent(), responseLengthRecorder), + new EncodedLengthTrackingChunkedRestResponseBodyPart(response.chunkedContent(), responseLengthRecorder), Releasables.wrap(responseLengthRecorder, response) ); for (final var header : headers.entrySet()) { @@ -916,13 +916,13 @@ void addChunkLength(long chunkLength) { } } - private static class EncodedLengthTrackingChunkedRestResponseBody implements ChunkedRestResponseBody { + private static class EncodedLengthTrackingChunkedRestResponseBodyPart implements ChunkedRestResponseBodyPart { - private final ChunkedRestResponseBody delegate; + private final ChunkedRestResponseBodyPart delegate; private final ResponseLengthRecorder responseLengthRecorder; - private EncodedLengthTrackingChunkedRestResponseBody( - ChunkedRestResponseBody delegate, + private EncodedLengthTrackingChunkedRestResponseBodyPart( + ChunkedRestResponseBodyPart delegate, ResponseLengthRecorder responseLengthRecorder ) { this.delegate = delegate; @@ -930,19 +930,19 @@ private EncodedLengthTrackingChunkedRestResponseBody( } @Override - public boolean isDone() { - return delegate.isDone(); + public boolean isPartComplete() { + return delegate.isPartComplete(); } @Override - public boolean isEndOfResponse() { - return delegate.isEndOfResponse(); + public boolean isLastPart() { + return delegate.isLastPart(); } @Override - public void getContinuation(ActionListener listener) { - delegate.getContinuation( - listener.map(continuation -> new EncodedLengthTrackingChunkedRestResponseBody(continuation, responseLengthRecorder)) + public void getNextPart(ActionListener listener) { + delegate.getNextPart( + listener.map(continuation -> new EncodedLengthTrackingChunkedRestResponseBodyPart(continuation, responseLengthRecorder)) ); } @@ -950,7 +950,7 @@ public void getContinuation(ActionListener listener) { public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) throws IOException { final ReleasableBytesReference bytesReference = delegate.encodeChunk(sizeHint, recycler); responseLengthRecorder.addChunkLength(bytesReference.length()); - if (isDone() && isEndOfResponse()) { + if (isPartComplete() && isLastPart()) { responseLengthRecorder.close(); } return bytesReference; diff --git a/server/src/main/java/org/elasticsearch/rest/RestResponse.java b/server/src/main/java/org/elasticsearch/rest/RestResponse.java index 9862ab31bd53f..8cc0e35a64802 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestResponse.java +++ b/server/src/main/java/org/elasticsearch/rest/RestResponse.java @@ -48,7 +48,7 @@ public final class RestResponse implements Releasable { private final BytesReference content; @Nullable - private final ChunkedRestResponseBody chunkedResponseBody; + private final ChunkedRestResponseBodyPart chunkedResponseBody; private final String responseMediaType; private Map> customHeaders; @@ -84,9 +84,9 @@ private RestResponse(RestStatus status, String responseMediaType, BytesReference this(status, responseMediaType, content, null, releasable); } - public static RestResponse chunked(RestStatus restStatus, ChunkedRestResponseBody content, @Nullable Releasable releasable) { - if (content.isDone()) { - assert content.isEndOfResponse() : "response with continuations must have at least one (possibly-empty) chunk in each part"; + public static RestResponse chunked(RestStatus restStatus, ChunkedRestResponseBodyPart content, @Nullable Releasable releasable) { + if (content.isPartComplete()) { + assert content.isLastPart() : "response with continuations must have at least one (possibly-empty) chunk in each part"; return new RestResponse(restStatus, content.getResponseContentTypeString(), BytesArray.EMPTY, releasable); } else { return new RestResponse(restStatus, content.getResponseContentTypeString(), null, content, releasable); @@ -100,7 +100,7 @@ private RestResponse( RestStatus status, String responseMediaType, @Nullable BytesReference content, - @Nullable ChunkedRestResponseBody chunkedResponseBody, + @Nullable ChunkedRestResponseBodyPart chunkedResponseBody, @Nullable Releasable releasable ) { this.status = status; @@ -162,7 +162,7 @@ public BytesReference content() { } @Nullable - public ChunkedRestResponseBody chunkedContent() { + public ChunkedRestResponseBodyPart chunkedContent() { return chunkedResponseBody; } diff --git a/server/src/main/java/org/elasticsearch/rest/action/RestChunkedToXContentListener.java b/server/src/main/java/org/elasticsearch/rest/action/RestChunkedToXContentListener.java index 3798f2b6b6fb1..ef2aa8418eef3 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/RestChunkedToXContentListener.java +++ b/server/src/main/java/org/elasticsearch/rest/action/RestChunkedToXContentListener.java @@ -10,7 +10,7 @@ import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.core.Releasable; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.RestStatus; @@ -40,7 +40,7 @@ protected void processResponse(Response response) throws IOException { channel.sendResponse( RestResponse.chunked( getRestStatus(response), - ChunkedRestResponseBody.fromXContent(response, params, channel), + ChunkedRestResponseBodyPart.fromXContent(response, params, channel), releasableFromResponse(response) ) ); diff --git a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestNodesHotThreadsAction.java b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestNodesHotThreadsAction.java index bcf0d99325594..9cf2d6a2ed395 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestNodesHotThreadsAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestNodesHotThreadsAction.java @@ -27,7 +27,7 @@ import java.util.List; import java.util.Locale; -import static org.elasticsearch.rest.ChunkedRestResponseBody.fromTextChunks; +import static org.elasticsearch.rest.ChunkedRestResponseBodyPart.fromTextChunks; import static org.elasticsearch.rest.RestRequest.Method.GET; import static org.elasticsearch.rest.RestResponse.TEXT_CONTENT_TYPE; import static org.elasticsearch.rest.RestUtils.getTimeout; diff --git a/server/src/main/java/org/elasticsearch/rest/action/cat/RestTable.java b/server/src/main/java/org/elasticsearch/rest/action/cat/RestTable.java index 5999d1b81da47..2f94e3ab90cbf 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/cat/RestTable.java +++ b/server/src/main/java/org/elasticsearch/rest/action/cat/RestTable.java @@ -17,7 +17,7 @@ import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.Booleans; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestResponse; @@ -63,7 +63,7 @@ public static RestResponse buildXContentBuilder(Table table, RestChannel channel return RestResponse.chunked( RestStatus.OK, - ChunkedRestResponseBody.fromXContent( + ChunkedRestResponseBodyPart.fromXContent( ignored -> Iterators.concat( Iterators.single((builder, params) -> builder.startArray()), Iterators.map(rowOrder.iterator(), row -> (builder, params) -> { @@ -94,7 +94,7 @@ public static RestResponse buildTextPlainResponse(Table table, RestChannel chann return RestResponse.chunked( RestStatus.OK, - ChunkedRestResponseBody.fromTextChunks( + ChunkedRestResponseBodyPart.fromTextChunks( RestResponse.TEXT_CONTENT_TYPE, Iterators.concat( // optional header diff --git a/server/src/main/java/org/elasticsearch/rest/action/info/RestClusterInfoAction.java b/server/src/main/java/org/elasticsearch/rest/action/info/RestClusterInfoAction.java index 8be023bb4a182..0a38d59d29729 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/info/RestClusterInfoAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/info/RestClusterInfoAction.java @@ -19,7 +19,7 @@ import org.elasticsearch.http.HttpStats; import org.elasticsearch.ingest.IngestStats; import org.elasticsearch.rest.BaseRestHandler; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.RestStatus; @@ -122,7 +122,7 @@ public RestResponse buildResponse(NodesStatsResponse response) throws Exception return RestResponse.chunked( RestStatus.OK, - ChunkedRestResponseBody.fromXContent( + ChunkedRestResponseBodyPart.fromXContent( outerParams -> Iterators.concat( ChunkedToXContentHelper.startObject(), Iterators.single((builder, params) -> builder.field("cluster_name", response.getClusterName().value())), diff --git a/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java index f12d8ea5c631a..d49347a0dd3fc 100644 --- a/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java +++ b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java @@ -31,7 +31,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Releasable; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestResponse; @@ -530,20 +530,20 @@ public void testHandleHeadRequest() { { // chunked response final var isClosed = new AtomicBoolean(); - channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBody() { + channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBodyPart() { @Override - public boolean isDone() { + public boolean isPartComplete() { return false; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { throw new AssertionError("should not check for end-of-response for HEAD request"); } @Override - public void getContinuation(ActionListener listener) { + public void getNextPart(ActionListener listener) { throw new AssertionError("should not get any continuations for HEAD request"); } @@ -688,25 +688,25 @@ public void testResponseBodyTracing() { HttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/") { @Override - public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content) { + public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) { try (var bso = new BytesStreamOutput()) { - writeContent(bso, content); + writeContent(bso, firstBodyPart); return new TestHttpResponse(status, bso.bytes()); } catch (IOException e) { return fail(e); } } - private static void writeContent(OutputStream bso, ChunkedRestResponseBody content) throws IOException { - while (content.isDone() == false) { + private static void writeContent(OutputStream bso, ChunkedRestResponseBodyPart content) throws IOException { + while (content.isPartComplete() == false) { try (var bytes = content.encodeChunk(1 << 14, BytesRefRecycler.NON_RECYCLING_INSTANCE)) { bytes.writeTo(bso); } } - if (content.isEndOfResponse()) { + if (content.isLastPart()) { return; } - writeContent(bso, PlainActionFuture.get(content::getContinuation)); + writeContent(bso, PlainActionFuture.get(content::getNextPart)); } }; @@ -735,14 +735,14 @@ private static void writeContent(OutputStream bso, ChunkedRestResponseBody conte ) ); - final var parts = new ArrayList(); - class TestBody implements ChunkedRestResponseBody { + final var parts = new ArrayList(); + class TestBodyPart implements ChunkedRestResponseBodyPart { boolean isDone; final BytesReference thisChunk; final BytesReference remainingChunks; final int remainingContinuations; - TestBody(BytesReference content, int remainingContinuations) { + TestBodyPart(BytesReference content, int remainingContinuations) { if (remainingContinuations == 0) { thisChunk = content; remainingChunks = BytesArray.EMPTY; @@ -755,18 +755,18 @@ class TestBody implements ChunkedRestResponseBody { } @Override - public boolean isDone() { + public boolean isPartComplete() { return isDone; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { return remainingContinuations == 0; } @Override - public void getContinuation(ActionListener listener) { - final var continuation = new TestBody(remainingChunks, remainingContinuations - 1); + public void getNextPart(ActionListener listener) { + final var continuation = new TestBodyPart(remainingChunks, remainingContinuations - 1); parts.add(continuation); listener.onResponse(continuation); } @@ -785,7 +785,7 @@ public String getResponseContentTypeString() { } final var isClosed = new AtomicBoolean(); - final var firstPart = new TestBody(responseBody, between(0, 3)); + final var firstPart = new TestBodyPart(responseBody, between(0, 3)); parts.add(firstPart); assertEquals( responseBody, @@ -797,8 +797,8 @@ public String getResponseContentTypeString() { () -> channel.sendResponse(RestResponse.chunked(RestStatus.OK, firstPart, () -> { assertTrue(isClosed.compareAndSet(false, true)); for (int i = 0; i < parts.size(); i++) { - assertTrue("isDone " + i, parts.get(i).isDone()); - assertEquals("isEndOfResponse " + i, i == parts.size() - 1, parts.get(i).isEndOfResponse()); + assertTrue("isPartComplete " + i, parts.get(i).isPartComplete()); + assertEquals("isLastPart " + i, i == parts.size() - 1, parts.get(i).isLastPart()); } })) ) diff --git a/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java b/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java index 4e30dde5e5e7e..e7b0232afa245 100644 --- a/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java +++ b/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java @@ -10,7 +10,7 @@ import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestStatus; @@ -78,7 +78,7 @@ public HttpResponse createResponse(RestStatus status, BytesReference content) { } @Override - public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content) { + public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) { throw new UnsupportedOperationException("chunked responses not supported"); } diff --git a/server/src/test/java/org/elasticsearch/rest/ChunkedRestResponseBodyTests.java b/server/src/test/java/org/elasticsearch/rest/ChunkedRestResponseBodyPartTests.java similarity index 81% rename from server/src/test/java/org/elasticsearch/rest/ChunkedRestResponseBodyTests.java rename to server/src/test/java/org/elasticsearch/rest/ChunkedRestResponseBodyPartTests.java index cce2a8db25c8e..9c703d83e7d0a 100644 --- a/server/src/test/java/org/elasticsearch/rest/ChunkedRestResponseBodyTests.java +++ b/server/src/test/java/org/elasticsearch/rest/ChunkedRestResponseBodyPartTests.java @@ -30,7 +30,7 @@ import java.util.List; import java.util.Map; -public class ChunkedRestResponseBodyTests extends ESTestCase { +public class ChunkedRestResponseBodyPartTests extends ESTestCase { public void testEncodesChunkedXContentCorrectly() throws IOException { final ChunkedToXContent chunkedToXContent = (ToXContent.Params outerParams) -> Iterators.forArray( @@ -50,7 +50,7 @@ public void testEncodesChunkedXContentCorrectly() throws IOException { } final var bytesDirect = BytesReference.bytes(builderDirect); - var chunkedResponse = ChunkedRestResponseBody.fromXContent( + var firstBodyPart = ChunkedRestResponseBodyPart.fromXContent( chunkedToXContent, ToXContent.EMPTY_PARAMS, new FakeRestChannel( @@ -61,20 +61,25 @@ public void testEncodesChunkedXContentCorrectly() throws IOException { ); final List refsGenerated = new ArrayList<>(); - while (chunkedResponse.isDone() == false) { - refsGenerated.add(chunkedResponse.encodeChunk(randomIntBetween(2, 10), BytesRefRecycler.NON_RECYCLING_INSTANCE)); + while (firstBodyPart.isPartComplete() == false) { + refsGenerated.add(firstBodyPart.encodeChunk(randomIntBetween(2, 10), BytesRefRecycler.NON_RECYCLING_INSTANCE)); } + assertTrue(firstBodyPart.isLastPart()); assertEquals(bytesDirect, CompositeBytesReference.of(refsGenerated.toArray(new BytesReference[0]))); } public void testFromTextChunks() throws IOException { final var chunks = randomList(1000, () -> randomUnicodeOfLengthBetween(1, 100)); - var body = ChunkedRestResponseBody.fromTextChunks("text/plain", Iterators.map(chunks.iterator(), s -> w -> w.write(s))); + var firstBodyPart = ChunkedRestResponseBodyPart.fromTextChunks( + "text/plain", + Iterators.map(chunks.iterator(), s -> w -> w.write(s)) + ); final List refsGenerated = new ArrayList<>(); - while (body.isDone() == false) { - refsGenerated.add(body.encodeChunk(randomIntBetween(2, 10), BytesRefRecycler.NON_RECYCLING_INSTANCE)); + while (firstBodyPart.isPartComplete() == false) { + refsGenerated.add(firstBodyPart.encodeChunk(randomIntBetween(2, 10), BytesRefRecycler.NON_RECYCLING_INSTANCE)); } + assertTrue(firstBodyPart.isLastPart()); final BytesReference chunkedBytes = CompositeBytesReference.of(refsGenerated.toArray(new BytesReference[0])); try (var outputStream = new ByteArrayOutputStream(); var writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8)) { diff --git a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java index 37300f1c19b1c..10ea83e59c0ad 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java @@ -733,7 +733,7 @@ public HttpResponse createResponse(RestStatus status, BytesReference content) { } @Override - public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content) { + public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) { throw new AssertionError("should not be called"); } diff --git a/server/src/test/java/org/elasticsearch/rest/RestResponseTests.java b/server/src/test/java/org/elasticsearch/rest/RestResponseTests.java index 41a54ac580a55..eaef60e15822d 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestResponseTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestResponseTests.java @@ -97,7 +97,7 @@ public void testWithHeaders() throws Exception { public void testEmptyChunkedBody() { RestResponse response = RestResponse.chunked( RestStatus.OK, - ChunkedRestResponseBody.fromTextChunks(RestResponse.TEXT_CONTENT_TYPE, Collections.emptyIterator()), + ChunkedRestResponseBodyPart.fromTextChunks(RestResponse.TEXT_CONTENT_TYPE, Collections.emptyIterator()), null ); assertFalse(response.isChunked()); diff --git a/server/src/test/java/org/elasticsearch/rest/action/cat/RestTableTests.java b/server/src/test/java/org/elasticsearch/rest/action/cat/RestTableTests.java index dff6b52e470df..cb98eaddb77cd 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/cat/RestTableTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/cat/RestTableTests.java @@ -432,14 +432,15 @@ public int pageSize() { }; final var bodyChunks = new ArrayList(); - final var chunkedRestResponseBody = response.chunkedContent(); + final var firstBodyPart = response.chunkedContent(); - while (chunkedRestResponseBody.isDone() == false) { - try (var chunk = chunkedRestResponseBody.encodeChunk(pageSize, recycler)) { + while (firstBodyPart.isPartComplete() == false) { + try (var chunk = firstBodyPart.encodeChunk(pageSize, recycler)) { assertThat(chunk.length(), greaterThan(0)); bodyChunks.add(chunk.utf8ToString()); } } + assertTrue(firstBodyPart.isLastPart()); assertEquals(0, openPages.get()); return bodyChunks; } diff --git a/test/framework/src/main/java/org/elasticsearch/rest/RestResponseUtils.java b/test/framework/src/main/java/org/elasticsearch/rest/RestResponseUtils.java index 1b1331fe25bbf..fe2df39b21591 100644 --- a/test/framework/src/main/java/org/elasticsearch/rest/RestResponseUtils.java +++ b/test/framework/src/main/java/org/elasticsearch/rest/RestResponseUtils.java @@ -28,8 +28,8 @@ public static BytesReference getBodyContent(RestResponse restResponse) { return restResponse.content(); } - final var chunkedRestResponseBody = restResponse.chunkedContent(); - assert chunkedRestResponseBody.isDone() == false; + final var firstResponseBodyPart = restResponse.chunkedContent(); + assert firstResponseBodyPart.isPartComplete() == false; final int pageSize; try (var page = NON_RECYCLING_INSTANCE.obtain()) { @@ -37,12 +37,12 @@ public static BytesReference getBodyContent(RestResponse restResponse) { } try (var out = new BytesStreamOutput()) { - while (chunkedRestResponseBody.isDone() == false) { - try (var chunk = chunkedRestResponseBody.encodeChunk(pageSize, NON_RECYCLING_INSTANCE)) { + while (firstResponseBodyPart.isPartComplete() == false) { + try (var chunk = firstResponseBodyPart.encodeChunk(pageSize, NON_RECYCLING_INSTANCE)) { chunk.writeTo(out); } } - assert chunkedRestResponseBody.isEndOfResponse() : "RestResponseUtils#getBodyContent does not support continuations (yet)"; + assert firstResponseBodyPart.isLastPart() : "RestResponseUtils#getBodyContent does not support continuations (yet)"; out.flush(); return out.bytes(); diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java b/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java index 726d2ec0d963d..3a9c4b371c9da 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java @@ -16,7 +16,7 @@ import org.elasticsearch.http.HttpChannel; import org.elasticsearch.http.HttpRequest; import org.elasticsearch.http.HttpResponse; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -129,7 +129,7 @@ public boolean containsHeader(String name) { } @Override - public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content) { + public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) { return createResponse(status, BytesArray.EMPTY); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlResponseListener.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlResponseListener.java index 34f2906d003ae..0ed77b624f5b0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlResponseListener.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlResponseListener.java @@ -12,7 +12,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestResponse; @@ -132,13 +132,13 @@ private RestResponse buildResponse(EsqlQueryResponse esqlResponse) throws IOExce if (mediaType instanceof TextFormat format) { restResponse = RestResponse.chunked( RestStatus.OK, - ChunkedRestResponseBody.fromTextChunks(format.contentType(restRequest), format.format(restRequest, esqlResponse)), + ChunkedRestResponseBodyPart.fromTextChunks(format.contentType(restRequest), format.format(restRequest, esqlResponse)), releasable ); } else { restResponse = RestResponse.chunked( RestStatus.OK, - ChunkedRestResponseBody.fromXContent(esqlResponse, channel.request(), channel), + ChunkedRestResponseBodyPart.fromXContent(esqlResponse, channel.request(), channel), releasable ); } From a97f2abf672eafdc7efd074b6d53a642e12c7cd5 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Mon, 27 May 2024 21:43:38 +1000 Subject: [PATCH 15/29] [Test] Fix EsExecutorsTests.testParseExecutorName (#109053) Use the overloading method that takes settings as parameter to test null nodeName. Resolves: #109014 --- .../common/util/concurrent/EsExecutorsTests.java | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/EsExecutorsTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/EsExecutorsTests.java index 037df07d1e078..df7b02c2309a3 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/EsExecutorsTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/EsExecutorsTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.Processors; +import org.elasticsearch.node.Node; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.hamcrest.Matcher; @@ -22,6 +23,7 @@ import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; +import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -628,7 +630,19 @@ public void testFixedUnboundedRejectOnShutdown() { public void testParseExecutorName() throws InterruptedException { final var executorName = randomAlphaOfLength(10); - final var threadFactory = EsExecutors.daemonThreadFactory(rarely() ? null : randomAlphaOfLength(10), executorName); + final String nodeName = rarely() ? null : randomIdentifier(); + final ThreadFactory threadFactory; + if (nodeName == null) { + threadFactory = EsExecutors.daemonThreadFactory(Settings.EMPTY, executorName); + } else if (randomBoolean()) { + threadFactory = EsExecutors.daemonThreadFactory( + Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), nodeName).build(), + executorName + ); + } else { + threadFactory = EsExecutors.daemonThreadFactory(nodeName, executorName); + } + final var thread = threadFactory.newThread(() -> {}); try { assertThat(EsExecutors.executorName(thread.getName()), equalTo(executorName)); From 415b68f1590d48abab443d697b20ce21b3d407ef Mon Sep 17 00:00:00 2001 From: Patrick Doyle <810052+prdoyle@users.noreply.github.com> Date: Mon, 27 May 2024 08:57:42 -0400 Subject: [PATCH 16/29] Add missing ThreadPoolTypes (#109019) --- .../src/main/java/org/elasticsearch/threadpool/ThreadPool.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java b/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java index a887a2be558e0..88c507404e76b 100644 --- a/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java +++ b/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java @@ -152,12 +152,14 @@ public static ThreadPoolType fromType(String type) { public static final Map THREAD_POOL_TYPES = Map.ofEntries( entry(Names.GENERIC, ThreadPoolType.SCALING), + entry(Names.CLUSTER_COORDINATION, ThreadPoolType.FIXED), entry(Names.GET, ThreadPoolType.FIXED), entry(Names.ANALYZE, ThreadPoolType.FIXED), entry(Names.WRITE, ThreadPoolType.FIXED), entry(Names.SEARCH, ThreadPoolType.FIXED), entry(Names.SEARCH_WORKER, ThreadPoolType.FIXED), entry(Names.SEARCH_COORDINATION, ThreadPoolType.FIXED), + entry(Names.AUTO_COMPLETE, ThreadPoolType.FIXED), entry(Names.MANAGEMENT, ThreadPoolType.SCALING), entry(Names.FLUSH, ThreadPoolType.SCALING), entry(Names.REFRESH, ThreadPoolType.SCALING), From 2b0d2c9c2370c7e294c34984e2c2dd64fae54048 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Istv=C3=A1n=20Zolt=C3=A1n=20Szab=C3=B3?= Date: Mon, 27 May 2024 16:34:01 +0200 Subject: [PATCH 17/29] [DOCS] updates transforms at scale doc with date rounding. (#109073) --- .../transform/transforms-at-scale.asciidoc | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/docs/reference/transform/transforms-at-scale.asciidoc b/docs/reference/transform/transforms-at-scale.asciidoc index f1d47c9943242..f052b2e8a5284 100644 --- a/docs/reference/transform/transforms-at-scale.asciidoc +++ b/docs/reference/transform/transforms-at-scale.asciidoc @@ -15,7 +15,7 @@ relevant considerations in this guide to improve performance. It also helps to understand how {transforms} work as different considerations apply depending on whether or not your transform is running in continuous mode or in batch. -In this guide, you’ll learn how to: +In this guide, you'll learn how to: * Understand the impact of configuration options on the performance of {transforms}. @@ -111,10 +111,17 @@ group of IPs, in order to calculate the total `bytes_sent`. If this second search matches many shards, then this could be resource intensive. Consider limiting the scope that the source index pattern and query will match. -Use an absolute time value as a date range filter in your source query (for -example, greater than `2020-01-01T00:00:00`) to limit which historical indices -are accessed. If you use a relative time value (for example, `now-30d`) then -this date range is re-evaluated at the point of each checkpoint execution. +To limit which historical indices are accessed, exclude certain tiers (for +example `"must_not": { "terms": { "_tier": [ "data_frozen", "data_cold" ] } }` +and/or use an absolute time value as a date range filter in your source query +(for example, greater than 2024-01-01T00:00:00). If you use a relative time +value (for example, gte now-30d/d) then ensure date rounding is applied to take +advantage of query caching and ensure that the relative time is much larger than +the largest of `frequency` or `time.sync.delay` or the date histogram bucket, +otherwise data may be missed. Do not use date filters which are less than a date +value (for example, `lt`: less than or `lte`: less than or equal to) as this +conflicts with the logic applied at each checkpoint execution and data may be +missed. Consider using <> in your index names to reduce the number of indices to resolve in your queries. Add a date pattern From 2704d3a8d1e32efb04d3ad88d17bf9791506f948 Mon Sep 17 00:00:00 2001 From: Rene Groeschke Date: Mon, 27 May 2024 17:59:14 +0200 Subject: [PATCH 18/29] Remove cross project support in TestFixturesPlugin (#109077) - One step closer to configuration cache support - Crossproject support has been replaced by using testcontainer based fixtures --- .../testfixtures/TestFixtureExtension.java | 112 ------------ .../testfixtures/TestFixturesPlugin.java | 168 ++++++++---------- distribution/docker/build.gradle | 2 - qa/apm/build.gradle | 2 - qa/remote-clusters/build.gradle | 2 - 5 files changed, 75 insertions(+), 211 deletions(-) delete mode 100644 build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/testfixtures/TestFixtureExtension.java diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/testfixtures/TestFixtureExtension.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/testfixtures/TestFixtureExtension.java deleted file mode 100644 index 2bcfb7c76d5cd..0000000000000 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/testfixtures/TestFixtureExtension.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ -package org.elasticsearch.gradle.internal.testfixtures; - -import org.gradle.api.GradleException; -import org.gradle.api.NamedDomainObjectContainer; -import org.gradle.api.Project; - -import java.util.HashMap; -import java.util.Locale; -import java.util.Map; -import java.util.Optional; - -public class TestFixtureExtension { - - private final Project project; - final NamedDomainObjectContainer fixtures; - final Map serviceToProjectUseMap = new HashMap<>(); - - public TestFixtureExtension(Project project) { - this.project = project; - this.fixtures = project.container(Project.class); - } - - public void useFixture() { - useFixture(this.project.getPath()); - } - - public void useFixture(String path) { - addFixtureProject(path); - serviceToProjectUseMap.put(path, this.project.getPath()); - } - - public void useFixture(String path, String serviceName) { - addFixtureProject(path); - String key = getServiceNameKey(path, serviceName); - serviceToProjectUseMap.put(key, this.project.getPath()); - - Optional otherProject = this.findOtherProjectUsingService(key); - if (otherProject.isPresent()) { - throw new GradleException( - String.format( - Locale.ROOT, - "Projects %s and %s both claim the %s service defined in the docker-compose.yml of " - + "%sThis is not supported because it breaks running in parallel. Configure dedicated " - + "services for each project and use those instead.", - otherProject.get(), - this.project.getPath(), - serviceName, - path - ) - ); - } - } - - private String getServiceNameKey(String fixtureProjectPath, String serviceName) { - return fixtureProjectPath + "::" + serviceName; - } - - private Optional findOtherProjectUsingService(String serviceName) { - return this.project.getRootProject() - .getAllprojects() - .stream() - .filter(p -> p.equals(this.project) == false) - .filter(p -> p.getExtensions().findByType(TestFixtureExtension.class) != null) - .map(project -> project.getExtensions().getByType(TestFixtureExtension.class)) - .flatMap(ext -> ext.serviceToProjectUseMap.entrySet().stream()) - .filter(entry -> entry.getKey().equals(serviceName)) - .map(Map.Entry::getValue) - .findAny(); - } - - private void addFixtureProject(String path) { - Project fixtureProject = this.project.findProject(path); - if (fixtureProject == null) { - throw new IllegalArgumentException("Could not find test fixture " + fixtureProject); - } - if (fixtureProject.file(TestFixturesPlugin.DOCKER_COMPOSE_YML).exists() == false) { - throw new IllegalArgumentException( - "Project " + path + " is not a valid test fixture: missing " + TestFixturesPlugin.DOCKER_COMPOSE_YML - ); - } - fixtures.add(fixtureProject); - // Check for exclusive access - Optional otherProject = this.findOtherProjectUsingService(path); - if (otherProject.isPresent()) { - throw new GradleException( - String.format( - Locale.ROOT, - "Projects %s and %s both claim all services from %s. This is not supported because it" - + " breaks running in parallel. Configure specific services in docker-compose.yml " - + "for each and add the service name to `useFixture`", - otherProject.get(), - this.project.getPath(), - path - ) - ); - } - } - - boolean isServiceRequired(String serviceName, String fixtureProject) { - if (serviceToProjectUseMap.containsKey(fixtureProject)) { - return true; - } - return serviceToProjectUseMap.containsKey(getServiceNameKey(fixtureProject, serviceName)); - } -} diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/testfixtures/TestFixturesPlugin.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/testfixtures/TestFixturesPlugin.java index c50ff97498c31..4c5f2abb9515c 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/testfixtures/TestFixturesPlugin.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/testfixtures/TestFixturesPlugin.java @@ -70,7 +70,6 @@ public void apply(Project project) { project.getRootProject().getPluginManager().apply(DockerSupportPlugin.class); TaskContainer tasks = project.getTasks(); - TestFixtureExtension extension = project.getExtensions().create("testFixtures", TestFixtureExtension.class, project); Provider dockerComposeThrottle = project.getGradle() .getSharedServices() .registerIfAbsent(DOCKER_COMPOSE_THROTTLE, DockerComposeThrottle.class, spec -> spec.getMaxParallelUsages().set(1)); @@ -84,73 +83,63 @@ public void apply(Project project) { File testFixturesDir = project.file("testfixtures_shared"); ext.set("testFixturesDir", testFixturesDir); - if (project.file(DOCKER_COMPOSE_YML).exists()) { - project.getPluginManager().apply(BasePlugin.class); - project.getPluginManager().apply(DockerComposePlugin.class); - TaskProvider preProcessFixture = project.getTasks().register("preProcessFixture", TestFixtureTask.class, t -> { - t.getFixturesDir().set(testFixturesDir); - t.doFirst(task -> { - try { - Files.createDirectories(testFixturesDir.toPath()); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - }); - }); - TaskProvider buildFixture = project.getTasks() - .register("buildFixture", t -> t.dependsOn(preProcessFixture, tasks.named("composeUp"))); - - TaskProvider postProcessFixture = project.getTasks() - .register("postProcessFixture", TestFixtureTask.class, task -> { - task.getFixturesDir().set(testFixturesDir); - task.dependsOn(buildFixture); - configureServiceInfoForTask( - task, - project, - false, - (name, port) -> task.getExtensions().getByType(ExtraPropertiesExtension.class).set(name, port) - ); - }); - - maybeSkipTask(dockerSupport, preProcessFixture); - maybeSkipTask(dockerSupport, postProcessFixture); - maybeSkipTask(dockerSupport, buildFixture); - - ComposeExtension composeExtension = project.getExtensions().getByType(ComposeExtension.class); - composeExtension.setProjectName(project.getName()); - composeExtension.getUseComposeFiles().addAll(Collections.singletonList(DOCKER_COMPOSE_YML)); - composeExtension.getRemoveContainers().set(true); - composeExtension.getCaptureContainersOutput() - .set(EnumSet.of(LogLevel.INFO, LogLevel.DEBUG).contains(project.getGradle().getStartParameter().getLogLevel())); - composeExtension.getUseDockerComposeV2().set(false); - composeExtension.getExecutable().set(this.providerFactory.provider(() -> { - String composePath = dockerSupport.get().getDockerAvailability().dockerComposePath(); - LOGGER.debug("Docker Compose path: {}", composePath); - return composePath != null ? composePath : "/usr/bin/docker-compose"; - })); - - tasks.named("composeUp").configure(t -> { - // Avoid running docker-compose tasks in parallel in CI due to some issues on certain Linux distributions - if (BuildParams.isCi()) { - t.usesService(dockerComposeThrottle); + if (project.file(DOCKER_COMPOSE_YML).exists() == false) { + // if only one fixture is used, that's this one, but without a compose file that's not a valid configuration + throw new IllegalStateException("No " + DOCKER_COMPOSE_YML + " found for " + project.getPath() + "."); + } + project.getPluginManager().apply(BasePlugin.class); + project.getPluginManager().apply(DockerComposePlugin.class); + TaskProvider preProcessFixture = project.getTasks().register("preProcessFixture", TestFixtureTask.class, t -> { + t.getFixturesDir().set(testFixturesDir); + t.doFirst(task -> { + try { + Files.createDirectories(testFixturesDir.toPath()); + } catch (IOException e) { + throw new UncheckedIOException(e); } - t.mustRunAfter(preProcessFixture); }); - tasks.named("composePull").configure(t -> t.mustRunAfter(preProcessFixture)); - tasks.named("composeDown").configure(t -> t.doLast(t2 -> getFileSystemOperations().delete(d -> d.delete(testFixturesDir)))); - } else { - project.afterEvaluate(spec -> { - if (extension.fixtures.isEmpty()) { - // if only one fixture is used, that's this one, but without a compose file that's not a valid configuration - throw new IllegalStateException( - "No " + DOCKER_COMPOSE_YML + " found for " + project.getPath() + " nor does it use other fixtures." - ); - } + }); + TaskProvider buildFixture = project.getTasks() + .register("buildFixture", t -> t.dependsOn(preProcessFixture, tasks.named("composeUp"))); + + TaskProvider postProcessFixture = project.getTasks() + .register("postProcessFixture", TestFixtureTask.class, task -> { + task.getFixturesDir().set(testFixturesDir); + task.dependsOn(buildFixture); + configureServiceInfoForTask( + task, + project, + false, + (name, port) -> task.getExtensions().getByType(ExtraPropertiesExtension.class).set(name, port) + ); }); - } - extension.fixtures.matching(fixtureProject -> fixtureProject.equals(project) == false) - .all(fixtureProject -> project.evaluationDependsOn(fixtureProject.getPath())); + maybeSkipTask(dockerSupport, preProcessFixture); + maybeSkipTask(dockerSupport, postProcessFixture); + maybeSkipTask(dockerSupport, buildFixture); + + ComposeExtension composeExtension = project.getExtensions().getByType(ComposeExtension.class); + composeExtension.setProjectName(project.getName()); + composeExtension.getUseComposeFiles().addAll(Collections.singletonList(DOCKER_COMPOSE_YML)); + composeExtension.getRemoveContainers().set(true); + composeExtension.getCaptureContainersOutput() + .set(EnumSet.of(LogLevel.INFO, LogLevel.DEBUG).contains(project.getGradle().getStartParameter().getLogLevel())); + composeExtension.getUseDockerComposeV2().set(false); + composeExtension.getExecutable().set(this.providerFactory.provider(() -> { + String composePath = dockerSupport.get().getDockerAvailability().dockerComposePath(); + LOGGER.debug("Docker Compose path: {}", composePath); + return composePath != null ? composePath : "/usr/bin/docker-compose"; + })); + + tasks.named("composeUp").configure(t -> { + // Avoid running docker-compose tasks in parallel in CI due to some issues on certain Linux distributions + if (BuildParams.isCi()) { + t.usesService(dockerComposeThrottle); + } + t.mustRunAfter(preProcessFixture); + }); + tasks.named("composePull").configure(t -> t.mustRunAfter(preProcessFixture)); + tasks.named("composeDown").configure(t -> t.doLast(t2 -> getFileSystemOperations().delete(d -> d.delete(testFixturesDir)))); // Skip docker compose tasks if it is unavailable maybeSkipTasks(tasks, dockerSupport, Test.class); @@ -161,17 +150,18 @@ public void apply(Project project) { maybeSkipTasks(tasks, dockerSupport, ComposePull.class); maybeSkipTasks(tasks, dockerSupport, ComposeDown.class); - tasks.withType(Test.class).configureEach(task -> extension.fixtures.all(fixtureProject -> { - task.dependsOn(fixtureProject.getTasks().named("postProcessFixture")); - task.finalizedBy(fixtureProject.getTasks().named("composeDown")); + tasks.withType(Test.class).configureEach(testTask -> { + testTask.dependsOn(postProcessFixture); + testTask.finalizedBy(tasks.named("composeDown")); configureServiceInfoForTask( - task, - fixtureProject, + testTask, + project, true, - (name, host) -> task.getExtensions().getByType(SystemPropertyCommandLineArgumentProvider.class).systemProperty(name, host) + (name, host) -> testTask.getExtensions() + .getByType(SystemPropertyCommandLineArgumentProvider.class) + .systemProperty(name, host) ); - })); - + }); } private void maybeSkipTasks(TaskContainer tasks, Provider dockerSupport, Class taskClass) { @@ -203,28 +193,20 @@ private void configureServiceInfoForTask( task.doFirst(new Action() { @Override public void execute(Task theTask) { - TestFixtureExtension extension = theTask.getProject().getExtensions().getByType(TestFixtureExtension.class); - - fixtureProject.getExtensions() - .getByType(ComposeExtension.class) - .getServicesInfos() - .entrySet() - .stream() - .filter(entry -> enableFilter == false || extension.isServiceRequired(entry.getKey(), fixtureProject.getPath())) - .forEach(entry -> { - String service = entry.getKey(); - ServiceInfo infos = entry.getValue(); - infos.getTcpPorts().forEach((container, host) -> { - String name = "test.fixtures." + service + ".tcp." + container; - theTask.getLogger().info("port mapping property: {}={}", name, host); - consumer.accept(name, host); - }); - infos.getUdpPorts().forEach((container, host) -> { - String name = "test.fixtures." + service + ".udp." + container; - theTask.getLogger().info("port mapping property: {}={}", name, host); - consumer.accept(name, host); - }); + fixtureProject.getExtensions().getByType(ComposeExtension.class).getServicesInfos().entrySet().stream().forEach(entry -> { + String service = entry.getKey(); + ServiceInfo infos = entry.getValue(); + infos.getTcpPorts().forEach((container, host) -> { + String name = "test.fixtures." + service + ".tcp." + container; + theTask.getLogger().info("port mapping property: {}={}", name, host); + consumer.accept(name, host); }); + infos.getUdpPorts().forEach((container, host) -> { + String name = "test.fixtures." + service + ".udp." + container; + theTask.getLogger().info("port mapping property: {}={}", name, host); + consumer.accept(name, host); + }); + }); } }); } diff --git a/distribution/docker/build.gradle b/distribution/docker/build.gradle index a3bb202780c7a..68ff2028b92a3 100644 --- a/distribution/docker/build.gradle +++ b/distribution/docker/build.gradle @@ -72,8 +72,6 @@ if (useDra == false) { } } -testFixtures.useFixture() - configurations { aarch64DockerSource { attributes { diff --git a/qa/apm/build.gradle b/qa/apm/build.gradle index b26efdf1f9a69..ff22334462fdc 100644 --- a/qa/apm/build.gradle +++ b/qa/apm/build.gradle @@ -16,8 +16,6 @@ apply plugin: 'elasticsearch.standalone-rest-test' apply plugin: 'elasticsearch.test.fixtures' apply plugin: 'elasticsearch.internal-distribution-download' -testFixtures.useFixture() - dockerCompose { environment.put 'STACK_VERSION', BuildParams.snapshotBuild ? VersionProperties.elasticsearch : VersionProperties.elasticsearch + "-SNAPSHOT" } diff --git a/qa/remote-clusters/build.gradle b/qa/remote-clusters/build.gradle index 0475b7e0eeb80..67f62c0fee04d 100644 --- a/qa/remote-clusters/build.gradle +++ b/qa/remote-clusters/build.gradle @@ -15,8 +15,6 @@ apply plugin: 'elasticsearch.standalone-rest-test' apply plugin: 'elasticsearch.test.fixtures' apply plugin: 'elasticsearch.internal-distribution-download' -testFixtures.useFixture() - tasks.register("copyNodeKeyMaterial", Sync) { from project(':x-pack:plugin:core') .files( From 42f4294a8602acc1f9ca2d68ec611007b3fb130e Mon Sep 17 00:00:00 2001 From: Oleksandr Kolomiiets Date: Mon, 27 May 2024 10:22:59 -0700 Subject: [PATCH 19/29] Enable fallback synthetic source for token_count (#109044) --- docs/changelog/109044.yaml | 5 ++ .../mapping/fields/synthetic-source.asciidoc | 1 + .../mapping/types/token-count.asciidoc | 20 +++++- .../mapper/extras/TokenCountFieldMapper.java | 5 ++ .../extras/TokenCountFieldMapperTests.java | 65 ++++++++++++++++++- .../test/token_count/10_basic.yml | 65 +++++++++++++++++++ .../test/search/330_fetch_fields.yml | 35 ---------- 7 files changed, 157 insertions(+), 39 deletions(-) create mode 100644 docs/changelog/109044.yaml create mode 100644 modules/mapper-extras/src/yamlRestTest/resources/rest-api-spec/test/token_count/10_basic.yml diff --git a/docs/changelog/109044.yaml b/docs/changelog/109044.yaml new file mode 100644 index 0000000000000..9e50c377606a0 --- /dev/null +++ b/docs/changelog/109044.yaml @@ -0,0 +1,5 @@ +pr: 109044 +summary: Enable fallback synthetic source for `token_count` +area: Mapping +type: feature +issues: [] diff --git a/docs/reference/mapping/fields/synthetic-source.asciidoc b/docs/reference/mapping/fields/synthetic-source.asciidoc index 1eba9dfba8b50..a0e7aed177a9c 100644 --- a/docs/reference/mapping/fields/synthetic-source.asciidoc +++ b/docs/reference/mapping/fields/synthetic-source.asciidoc @@ -64,6 +64,7 @@ types: ** <> ** <> ** <> +** <> ** <> ** <> diff --git a/docs/reference/mapping/types/token-count.asciidoc b/docs/reference/mapping/types/token-count.asciidoc index 23bbc775243af..7d9dffcc82082 100644 --- a/docs/reference/mapping/types/token-count.asciidoc +++ b/docs/reference/mapping/types/token-count.asciidoc @@ -64,10 +64,10 @@ The following parameters are accepted by `token_count` fields: value. Required. For best performance, use an analyzer without token filters. -`enable_position_increments`:: +`enable_position_increments`:: -Indicates if position increments should be counted. -Set to `false` if you don't want to count tokens removed by analyzer filters (like <>). +Indicates if position increments should be counted. +Set to `false` if you don't want to count tokens removed by analyzer filters (like <>). Defaults to `true`. <>:: @@ -91,3 +91,17 @@ Defaults to `true`. Whether the field value should be stored and retrievable separately from the <> field. Accepts `true` or `false` (default). + +[[token-count-synthetic-source]] +===== Synthetic `_source` + +IMPORTANT: Synthetic `_source` is Generally Available only for TSDB indices +(indices that have `index.mode` set to `time_series`). For other indices +synthetic `_source` is in technical preview. Features in technical preview may +be changed or removed in a future release. Elastic will work to fix +any issues, but features in technical preview are not subject to the support SLA +of official GA features. + +`token_count` fields support <> in their +default configuration. Synthetic `_source` cannot be used together with +<>. diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/TokenCountFieldMapper.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/TokenCountFieldMapper.java index 831306a8e8594..c538c7641a015 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/TokenCountFieldMapper.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/TokenCountFieldMapper.java @@ -215,4 +215,9 @@ protected String contentType() { public FieldMapper.Builder getMergeBuilder() { return new Builder(simpleName()).init(this); } + + @Override + protected SyntheticSourceMode syntheticSourceMode() { + return SyntheticSourceMode.FALLBACK; + } } diff --git a/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/extras/TokenCountFieldMapperTests.java b/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/extras/TokenCountFieldMapperTests.java index 1636def53536b..d34d9c3178c78 100644 --- a/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/extras/TokenCountFieldMapperTests.java +++ b/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/extras/TokenCountFieldMapperTests.java @@ -33,7 +33,11 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Collectors; import static org.hamcrest.Matchers.equalTo; @@ -196,7 +200,66 @@ protected boolean supportsIgnoreMalformed() { @Override protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) { - throw new AssumptionViolatedException("not supported"); + assertFalse(ignoreMalformed); + + var nullValue = usually() ? null : randomNonNegativeInt(); + return new SyntheticSourceSupport() { + @Override + public boolean preservesExactSource() { + return true; + } + + public SyntheticSourceExample example(int maxValues) { + if (randomBoolean()) { + var value = generateValue(); + return new SyntheticSourceExample(value.text, value.text, value.tokenCount, this::mapping); + } + + var values = randomList(1, 5, this::generateValue); + + var textArray = values.stream().map(Value::text).toList(); + + var blockExpectedList = values.stream().map(Value::tokenCount).filter(Objects::nonNull).toList(); + var blockExpected = blockExpectedList.size() == 1 ? blockExpectedList.get(0) : blockExpectedList; + + return new SyntheticSourceExample(textArray, textArray, blockExpected, this::mapping); + } + + private record Value(String text, Integer tokenCount) {} + + private Value generateValue() { + if (rarely()) { + return new Value(null, null); + } + + var text = randomList(0, 10, () -> randomAlphaOfLengthBetween(0, 10)).stream().collect(Collectors.joining(" ")); + // with keyword analyzer token count is always 1 + return new Value(text, 1); + } + + private void mapping(XContentBuilder b) throws IOException { + b.field("type", "token_count").field("analyzer", "keyword"); + if (rarely()) { + b.field("index", false); + } + if (rarely()) { + b.field("store", true); + } + if (nullValue != null) { + b.field("null_value", nullValue); + } + } + + @Override + public List invalidExample() throws IOException { + return List.of(); + } + }; + } + + protected Function loadBlockExpected() { + // we can get either a number from doc values or null + return v -> v != null ? (Number) v : null; } @Override diff --git a/modules/mapper-extras/src/yamlRestTest/resources/rest-api-spec/test/token_count/10_basic.yml b/modules/mapper-extras/src/yamlRestTest/resources/rest-api-spec/test/token_count/10_basic.yml new file mode 100644 index 0000000000000..03b72a2623497 --- /dev/null +++ b/modules/mapper-extras/src/yamlRestTest/resources/rest-api-spec/test/token_count/10_basic.yml @@ -0,0 +1,65 @@ +"Test token count": + - requires: + cluster_features: ["gte_v7.10.0"] + reason: "support for token_count was instroduced in 7.10" + - do: + indices.create: + index: test + body: + mappings: + properties: + count: + type: token_count + analyzer: standard + count_without_dv: + type: token_count + analyzer: standard + doc_values: false + + - do: + index: + index: test + id: "1" + refresh: true + body: + count: "some text" + - do: + search: + index: test + body: + fields: [count, count_without_dv] + + - is_true: hits.hits.0._id + - match: { hits.hits.0.fields.count: [2] } + - is_false: hits.hits.0.fields.count_without_dv + +--- +"Synthetic source": + - requires: + cluster_features: ["mapper.track_ignored_source"] + reason: requires tracking ignored source + - do: + indices.create: + index: test + body: + mappings: + _source: + mode: synthetic + properties: + count: + type: token_count + analyzer: standard + + - do: + index: + index: test + id: "1" + refresh: true + body: + count: "quick brown fox jumps over a lazy dog" + - do: + get: + index: test + id: "1" + + - match: { _source.count: "quick brown fox jumps over a lazy dog" } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/330_fetch_fields.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/330_fetch_fields.yml index 52b55098ec4db..703f2a0352fbd 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/330_fetch_fields.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/330_fetch_fields.yml @@ -262,41 +262,6 @@ - match: { hits.hits.0.fields.date.0: "1990/12/29" } --- -"Test token count": - - requires: - cluster_features: ["gte_v7.10.0"] - reason: "support for token_count was instroduced in 7.10" - - do: - indices.create: - index: test - body: - mappings: - properties: - count: - type: token_count - analyzer: standard - count_without_dv: - type: token_count - analyzer: standard - doc_values: false - - - do: - index: - index: test - id: "1" - refresh: true - body: - count: "some text" - - do: - search: - index: test - body: - fields: [count, count_without_dv] - - - is_true: hits.hits.0._id - - match: { hits.hits.0.fields.count: [2] } - - is_false: hits.hits.0.fields.count_without_dv ---- Test unmapped field: - requires: cluster_features: "gte_v7.11.0" From 19b1218882161498742fb73bddbedafe8e96b2e1 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Tue, 28 May 2024 08:18:41 +0200 Subject: [PATCH 20/29] [Inference API] Add Google AI Studio completion (#109029) --- .../org/elasticsearch/TransportVersions.java | 1 + .../InferenceNamedWriteablesProvider.java | 13 + .../xpack/inference/InferencePlugin.java | 2 + .../GoogleAiStudioActionCreator.java | 34 + .../GoogleAiStudioActionVisitor.java | 19 + .../GoogleAiStudioCompletionAction.java | 73 ++ .../GoogleAiStudioResponseHandler.java | 75 +++ .../http/retry/BaseResponseHandler.java | 1 + ...oogleAiStudioCompletionRequestManager.java | 56 ++ .../sender/GoogleAiStudioRequestManager.java | 27 + .../GoogleAiStudioCompletionRequest.java | 72 ++ ...GoogleAiStudioCompletionRequestEntity.java | 79 +++ .../googleaistudio/GoogleAiStudioRequest.java | 38 ++ .../googleaistudio/GoogleAiStudioUtils.java | 22 + ...oogleAiStudioCompletionResponseEntity.java | 109 +++ .../GoogleAiStudioErrorResponseEntity.java | 78 +++ .../googleaistudio/GoogleAiStudioModel.java | 39 ++ ...oogleAiStudioRateLimitServiceSettings.java | 18 + .../GoogleAiStudioSecretSettings.java | 106 +++ .../googleaistudio/GoogleAiStudioService.java | 218 ++++++ .../GoogleAiStudioCompletionModel.java | 124 ++++ ...ogleAiStudioCompletionServiceSettings.java | 126 ++++ .../xpack/inference/MatchersUtils.java | 83 +++ .../xpack/inference/MatchersUtilsTests.java | 186 ++++++ .../GoogleAiStudioCompletionActionTests.java | 274 ++++++++ .../GoogleAiStudioResponseHandlerTests.java | 133 ++++ .../GoogleAiStudioRequestTests.java | 55 ++ ...eAiStudioCompletionRequestEntityTests.java | 49 ++ .../GoogleAiStudioCompletionRequestTests.java | 73 ++ ...AiStudioCompletionResponseEntityTests.java | 189 ++++++ ...oogleAiStudioErrorResponseEntityTests.java | 68 ++ .../GoogleAiStudioSecretSettingsTests.java | 71 ++ .../GoogleAiStudioServiceTests.java | 630 ++++++++++++++++++ .../GoogleAiStudioCompletionModelTests.java | 66 ++ ...iStudioCompletionServiceSettingsTests.java | 76 +++ 35 files changed, 3283 insertions(+) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionCreator.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionVisitor.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionAction.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioUtils.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioCompletionResponseEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioErrorResponseEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioRateLimitServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettings.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/MatchersUtils.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/MatchersUtilsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandlerTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequestTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioCompletionResponseEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioErrorResponseEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettingsTests.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index d372f4ee023bd..22460775300f3 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -178,6 +178,7 @@ static TransportVersion def(int id) { public static final TransportVersion GET_SHUTDOWN_STATUS_TIMEOUT = def(8_669_00_0); public static final TransportVersion FAILURE_STORE_TELEMETRY = def(8_670_00_0); public static final TransportVersion ADD_METADATA_FLATTENED_TO_ROLES = def(8_671_00_0); + public static final TransportVersion ML_INFERENCE_GOOGLE_AI_STUDIO_COMPLETION_ADDED = def(8_672_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 4931b4da6f724..edea0104ded16 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -44,6 +44,7 @@ import org.elasticsearch.xpack.inference.services.elasticsearch.MultilingualE5SmallInternalServiceSettings; import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserSecretSettings; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings; @@ -106,6 +107,7 @@ public static List getNamedWriteables() { addCohereNamedWriteables(namedWriteables); addAzureOpenAiNamedWriteables(namedWriteables); addAzureAiStudioNamedWriteables(namedWriteables); + addGoogleAiStudioNamedWritables(namedWriteables); return namedWriteables; } @@ -254,6 +256,16 @@ private static void addHuggingFaceNamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + GoogleAiStudioCompletionServiceSettings.NAME, + GoogleAiStudioCompletionServiceSettings::new + ) + ); + } + private static void addInternalElserNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry(ServiceSettings.class, ElserInternalServiceSettings.NAME, ElserInternalServiceSettings::new) @@ -318,4 +330,5 @@ private static void addInferenceResultsNamedWriteables(List getInferenceServiceFactories() { context -> new CohereService(httpFactory.get(), serviceComponents.get()), context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()), context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()), + context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionCreator.java new file mode 100644 index 0000000000000..51a8cc7a0bd56 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionCreator.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.googleaistudio; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; + +import java.util.Map; +import java.util.Objects; + +public class GoogleAiStudioActionCreator implements GoogleAiStudioActionVisitor { + + private final Sender sender; + + private final ServiceComponents serviceComponents; + + public GoogleAiStudioActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(GoogleAiStudioCompletionModel model, Map taskSettings) { + // no overridden model as task settings are always empty for Google AI Studio completion model + return new GoogleAiStudioCompletionAction(sender, model, serviceComponents.threadPool()); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionVisitor.java new file mode 100644 index 0000000000000..090d3f9a69710 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionVisitor.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.googleaistudio; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; + +import java.util.Map; + +public interface GoogleAiStudioActionVisitor { + + ExecutableAction create(GoogleAiStudioCompletionModel model, Map taskSettings); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionAction.java new file mode 100644 index 0000000000000..7f918ae9a7db7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionAction.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.googleaistudio; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.GoogleAiStudioCompletionRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; + +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class GoogleAiStudioCompletionAction implements ExecutableAction { + + private final String failedToSendRequestErrorMessage; + + private final GoogleAiStudioCompletionRequestManager requestManager; + + private final Sender sender; + + public GoogleAiStudioCompletionAction(Sender sender, GoogleAiStudioCompletionModel model, ThreadPool threadPool) { + Objects.requireNonNull(threadPool); + Objects.requireNonNull(model); + this.sender = Objects.requireNonNull(sender); + this.requestManager = new GoogleAiStudioCompletionRequestManager(model, threadPool); + this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "Google AI Studio completion"); + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + if (inferenceInputs instanceof DocumentsOnlyInput == false) { + listener.onFailure(new ElasticsearchStatusException("Invalid inference input type", RestStatus.INTERNAL_SERVER_ERROR)); + return; + } + + var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs; + if (docsOnlyInput.getInputs().size() > 1) { + listener.onFailure( + new ElasticsearchStatusException("Google AI Studio completion only accepts 1 input", RestStatus.BAD_REQUEST) + ); + return; + } + + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException( + failedToSendRequestErrorMessage, + listener + ); + sender.send(requestManager, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java new file mode 100644 index 0000000000000..1138cfcb7cdc6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.googleaistudio; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.googleaistudio.GoogleAiStudioErrorResponseEntity; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody; + +public class GoogleAiStudioResponseHandler extends BaseResponseHandler { + + static final String GOOGLE_AI_STUDIO_UNAVAILABLE = "The Google AI Studio service may be temporarily overloaded or down"; + + public GoogleAiStudioResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, GoogleAiStudioErrorResponseEntity::fromResponse); + } + + @Override + public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) + throws RetryException { + checkForFailureStatusCode(request, result); + checkForEmptyBody(throttlerManager, logger, request, result); + } + + /** + * Validates the status code and throws a RetryException if not in the range [200, 300). + * + * The Google AI Studio error codes are documented here. + * @param request The originating request + * @param result The http response and body + * @throws RetryException Throws if status code is {@code >= 300 or < 200 } + */ + void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { + int statusCode = result.response().getStatusLine().getStatusCode(); + if (statusCode >= 200 && statusCode < 300) { + return; + } + + // handle error codes + if (statusCode == 500) { + throw new RetryException(true, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 503) { + throw new RetryException(true, buildError(GOOGLE_AI_STUDIO_UNAVAILABLE, request, result)); + } else if (statusCode > 500) { + throw new RetryException(false, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 429) { + throw new RetryException(true, buildError(RATE_LIMIT, request, result)); + } else if (statusCode == 404) { + throw new RetryException(false, buildError(resourceNotFoundError(request), request, result)); + } else if (statusCode == 403) { + throw new RetryException(false, buildError(PERMISSION_DENIED, request, result)); + } else if (statusCode >= 300 && statusCode < 400) { + throw new RetryException(false, buildError(REDIRECTION, request, result)); + } else { + throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); + } + } + + private static String resourceNotFoundError(Request request) { + return format("Resource not found at [%s]", request.getURI()); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index b703cf2f14b75..f793cb3586924 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -23,6 +23,7 @@ public abstract class BaseResponseHandler implements ResponseHandler { public static final String SERVER_ERROR = "Received a server error status code"; public static final String RATE_LIMIT = "Received a rate limit status code"; public static final String AUTHENTICATION = "Received an authentication error status code"; + public static final String PERMISSION_DENIED = "Received a permission denied error status code"; public static final String REDIRECTION = "Unhandled redirection"; public static final String CONTENT_TOO_LARGE = "Received a content too large status code"; public static final String UNSUCCESSFUL = "Received an unsuccessful status code"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java new file mode 100644 index 0000000000000..eb9baa680446a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.googleaistudio.GoogleAiStudioResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.googleaistudio.GoogleAiStudioCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +public class GoogleAiStudioCompletionRequestManager extends GoogleAiStudioRequestManager { + + private static final Logger logger = LogManager.getLogger(GoogleAiStudioCompletionRequestManager.class); + + private static final ResponseHandler HANDLER = createCompletionHandler(); + + private final GoogleAiStudioCompletionModel model; + + private static ResponseHandler createCompletionHandler() { + return new GoogleAiStudioResponseHandler("google ai studio completion", GoogleAiStudioCompletionResponseEntity::fromResponse); + } + + public GoogleAiStudioCompletionRequestManager(GoogleAiStudioCompletionModel model, ThreadPool threadPool) { + super(threadPool, model); + this.model = Objects.requireNonNull(model); + } + + @Override + public Runnable create( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + HttpClientContext context, + ActionListener listener + ) { + GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(input, model); + return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioRequestManager.java new file mode 100644 index 0000000000000..670c00f9a2808 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioRequestManager.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel; + +import java.util.Objects; + +public abstract class GoogleAiStudioRequestManager extends BaseRequestManager { + GoogleAiStudioRequestManager(ThreadPool threadPool, GoogleAiStudioModel model) { + super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + } + + record RateLimitGrouping(int modelIdHash) { + public static RateLimitGrouping of(GoogleAiStudioModel model) { + Objects.requireNonNull(model); + + return new RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode()); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java new file mode 100644 index 0000000000000..f52fe623e7918 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +public class GoogleAiStudioCompletionRequest implements GoogleAiStudioRequest { + + private final List input; + + private final URI uri; + + private final GoogleAiStudioCompletionModel model; + + public GoogleAiStudioCompletionRequest(List input, GoogleAiStudioCompletionModel model) { + this.input = input; + this.model = Objects.requireNonNull(model); + this.uri = model.uri(); + } + + @Override + public HttpRequest createHttpRequest() { + var httpPost = new HttpPost(uri); + var requestEntity = Strings.toString(new GoogleAiStudioCompletionRequestEntity(input)); + + ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + GoogleAiStudioRequest.decorateWithApiKeyParameter(httpPost, model.getSecretSettings()); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public Request truncate() { + // No truncation for Google AI Studio completion + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation for Google AI Studio completion + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequestEntity.java new file mode 100644 index 0000000000000..85e4d616c16e5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequestEntity.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record GoogleAiStudioCompletionRequestEntity(List input) implements ToXContentObject { + + private static final String CONTENTS_FIELD = "contents"; + + private static final String PARTS_FIELD = "parts"; + + private static final String TEXT_FIELD = "text"; + + private static final String GENERATION_CONFIG_FIELD = "generationConfig"; + + private static final String CANDIDATE_COUNT_FIELD = "candidateCount"; + + private static final String ROLE_FIELD = "role"; + + private static final String ROLE_USER = "user"; + + public GoogleAiStudioCompletionRequestEntity { + Objects.requireNonNull(input); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startArray(CONTENTS_FIELD); + + { + for (String content : input) { + builder.startObject(); + + { + builder.startArray(PARTS_FIELD); + builder.startObject(); + + { + builder.field(TEXT_FIELD, content); + } + + builder.endObject(); + builder.endArray(); + } + + builder.field(ROLE_FIELD, ROLE_USER); + + builder.endObject(); + } + } + + builder.endArray(); + + builder.startObject(GENERATION_CONFIG_FIELD); + + { + // default is already 1, but we want to guard ourselves against API changes so setting it explicitly + builder.field(CANDIDATE_COUNT_FIELD, 1); + } + + builder.endObject(); + + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequest.java new file mode 100644 index 0000000000000..ede9c6193aa21 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequest.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio; + +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioSecretSettings; + +public interface GoogleAiStudioRequest extends Request { + + String API_KEY_PARAMETER = "key"; + + static void decorateWithApiKeyParameter(HttpPost httpPost, GoogleAiStudioSecretSettings secretSettings) { + try { + var uri = httpPost.getURI(); + var uriWithApiKey = new URIBuilder().setScheme(uri.getScheme()) + .setHost(uri.getHost()) + .setPort(uri.getPort()) + .setPath(uri.getPath()) + .addParameter(API_KEY_PARAMETER, secretSettings.apiKey().toString()) + .build(); + + httpPost.setURI(uriWithApiKey); + } catch (Exception e) { + ValidationException validationException = new ValidationException(e); + validationException.addValidationError(e.getMessage()); + throw validationException; + } + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioUtils.java new file mode 100644 index 0000000000000..d63a0bbe2af91 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioUtils.java @@ -0,0 +1,22 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio; + +public class GoogleAiStudioUtils { + + public static final String HOST_SUFFIX = "generativelanguage.googleapis.com"; + + public static final String V1 = "v1"; + + public static final String MODELS = "models"; + + public static final String GENERATE_CONTENT_ACTION = "generateContent"; + + private GoogleAiStudioUtils() {} + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioCompletionResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioCompletionResponseEntity.java new file mode 100644 index 0000000000000..852f25705d6ff --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioCompletionResponseEntity.java @@ -0,0 +1,109 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.googleaistudio; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class GoogleAiStudioCompletionResponseEntity { + + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = + "Failed to find required field [%s] in Google AI Studio completion response"; + + /** + * Parses the Google AI Studio completion response. + * + * For a request like: + * + *

+     *     
+     *         {
+     *           "contents": [
+     *                          {
+     *                              "parts": [{
+     *                                  "text": "input"
+     *                              }]
+     *                          }
+     *                      ]
+     *          }
+     *     
+     * 
+ * + * The response would look like: + * + *
+     *     
+     *         {
+     *     "candidates": [
+     *         {
+     *             "content": {
+     *                 "parts": [
+     *                     {
+     *                         "text": "response"
+     *                     }
+     *                 ],
+     *                 "role": "model"
+     *             },
+     *             "finishReason": "STOP",
+     *             "index": 0,
+     *             "safetyRatings": [...]
+     *         }
+     *     ],
+     *     "usageMetadata": { ... }
+     * }
+     *     
+     * 
+ * + */ + + public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "candidates", FAILED_TO_FIND_FIELD_TEMPLATE); + + jsonParser.nextToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, jsonParser.currentToken(), jsonParser); + + positionParserAtTokenAfterField(jsonParser, "content", FAILED_TO_FIND_FIELD_TEMPLATE); + + token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "parts", FAILED_TO_FIND_FIELD_TEMPLATE); + + jsonParser.nextToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "text", FAILED_TO_FIND_FIELD_TEMPLATE); + + XContentParser.Token contentToken = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, contentToken, jsonParser); + String content = jsonParser.text(); + + return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(content))); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioErrorResponseEntity.java new file mode 100644 index 0000000000000..f57f672e10b16 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioErrorResponseEntity.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.googleaistudio; + +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorMessage; + +import java.util.Map; + +public class GoogleAiStudioErrorResponseEntity implements ErrorMessage { + + private final String errorMessage; + + private GoogleAiStudioErrorResponseEntity(String errorMessage) { + this.errorMessage = errorMessage; + } + + @Override + public String getErrorMessage() { + return errorMessage; + } + + /** + * An example error response for invalid auth would look like + * + * { + * "error": { + * "code": 400, + * "message": "API key not valid. Please pass a valid API key.", + * "status": "INVALID_ARGUMENT", + * "details": [ + * { + * "@type": "type.googleapis.com/google.rpc.ErrorInfo", + * "reason": "API_KEY_INVALID", + * "domain": "googleapis.com", + * "metadata": { + * "service": "generativelanguage.googleapis.com" + * } + * } + * ] + * } + * } + * + * @param response The error response + * @return An error entity if the response is JSON with the above structure + * or null if the response does not contain the `error.message` field + */ + + @SuppressWarnings("unchecked") + public static GoogleAiStudioErrorResponseEntity fromResponse(HttpResult response) { + try ( + XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response.body()) + ) { + var responseMap = jsonParser.map(); + var error = (Map) responseMap.get("error"); + if (error != null) { + var message = (String) error.get("message"); + if (message != null) { + return new GoogleAiStudioErrorResponseEntity(message); + } + } + } catch (Exception e) { + // swallow the error + } + + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioModel.java new file mode 100644 index 0000000000000..4ddffd0bae615 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioModel.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor; + +import java.util.Map; +import java.util.Objects; + +public abstract class GoogleAiStudioModel extends Model { + + private final GoogleAiStudioRateLimitServiceSettings rateLimitServiceSettings; + + public GoogleAiStudioModel( + ModelConfigurations configurations, + ModelSecrets secrets, + GoogleAiStudioRateLimitServiceSettings rateLimitServiceSettings + ) { + super(configurations, secrets); + + this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings); + } + + public abstract ExecutableAction accept(GoogleAiStudioActionVisitor creator, Map taskSettings, InputType inputType); + + public GoogleAiStudioRateLimitServiceSettings rateLimitServiceSettings() { + return rateLimitServiceSettings; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioRateLimitServiceSettings.java new file mode 100644 index 0000000000000..2e443263c7f54 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioRateLimitServiceSettings.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio; + +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +public interface GoogleAiStudioRateLimitServiceSettings { + + String modelId(); + + RateLimitSettings rateLimitSettings(); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettings.java new file mode 100644 index 0000000000000..bf702d010e2a8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettings.java @@ -0,0 +1,106 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalSecureString; + +public class GoogleAiStudioSecretSettings implements SecretSettings { + + public static final String NAME = "google_ai_studio_secret_settings"; + public static final String API_KEY = "api_key"; + + private final SecureString apiKey; + + public static GoogleAiStudioSecretSettings fromMap(@Nullable Map map) { + if (map == null) { + return null; + } + + ValidationException validationException = new ValidationException(); + SecureString secureApiKey = extractOptionalSecureString(map, API_KEY, ModelSecrets.SECRET_SETTINGS, validationException); + + if (secureApiKey == null) { + validationException.addValidationError(format("[secret_settings] must have [%s] set", API_KEY)); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new GoogleAiStudioSecretSettings(secureApiKey); + } + + public GoogleAiStudioSecretSettings(SecureString apiKey) { + Objects.requireNonNull(apiKey); + this.apiKey = apiKey; + } + + public GoogleAiStudioSecretSettings(StreamInput in) throws IOException { + this(in.readOptionalSecureString()); + } + + public SecureString apiKey() { + return apiKey; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (apiKey != null) { + builder.field(API_KEY, apiKey.toString()); + } + + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_GOOGLE_AI_STUDIO_COMPLETION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalSecureString(apiKey); + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + GoogleAiStudioSecretSettings that = (GoogleAiStudioSecretSettings) object; + return Objects.equals(apiKey, that.apiKey); + } + + @Override + public int hashCode() { + return Objects.hash(apiKey); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java new file mode 100644 index 0000000000000..f990923cee922 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -0,0 +1,218 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionCreator; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +public class GoogleAiStudioService extends SenderService { + + public static final String NAME = "googleaistudio"; + + public GoogleAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + Set platfromArchitectures, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + GoogleAiStudioModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + + } + + private static GoogleAiStudioModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + return switch (taskType) { + case COMPLETION -> new GoogleAiStudioCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings + ); + default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + }; + } + + @Override + public GoogleAiStudioModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + private static GoogleAiStudioModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + @Override + public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_GOOGLE_AI_STUDIO_COMPLETION_ADDED; + } + + @Override + protected void doInfer( + Model model, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof GoogleAiStudioModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model; + var actionCreator = new GoogleAiStudioActionCreator(getSender(), getServiceComponents()); + + var action = googleAiStudioModel.accept(actionCreator, taskSettings, inputType); + action.execute(new DocumentsOnlyInput(input), timeout, listener); + } + + @Override + protected void doInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + throw new UnsupportedOperationException("Query input not supported for Google AI Studio"); + } + + @Override + protected void doChunkedInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + ChunkingOptions chunkingOptions, + TimeValue timeout, + ActionListener> listener + ) { + throw new UnsupportedOperationException("Chunked inference not supported yet for Google AI Studio"); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java new file mode 100644 index 0000000000000..6a11f678158b6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio.completion; + +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor; +import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioSecretSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; + +import static org.elasticsearch.core.Strings.format; + +public class GoogleAiStudioCompletionModel extends GoogleAiStudioModel { + + private URI uri; + + public GoogleAiStudioCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + Map secrets + ) { + this( + inferenceEntityId, + taskType, + service, + GoogleAiStudioCompletionServiceSettings.fromMap(serviceSettings), + EmptyTaskSettings.INSTANCE, + GoogleAiStudioSecretSettings.fromMap(secrets) + ); + } + + // Should only be used directly for testing + GoogleAiStudioCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + GoogleAiStudioCompletionServiceSettings serviceSettings, + TaskSettings taskSettings, + @Nullable GoogleAiStudioSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secrets), + serviceSettings + ); + try { + this.uri = buildUri(serviceSettings.modelId()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + // Should only be used directly for testing + GoogleAiStudioCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + String url, + GoogleAiStudioCompletionServiceSettings serviceSettings, + TaskSettings taskSettings, + @Nullable GoogleAiStudioSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secrets), + serviceSettings + ); + try { + this.uri = new URI(url); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public URI uri() { + return uri; + } + + @Override + public GoogleAiStudioCompletionServiceSettings getServiceSettings() { + return (GoogleAiStudioCompletionServiceSettings) super.getServiceSettings(); + } + + @Override + public GoogleAiStudioSecretSettings getSecretSettings() { + return (GoogleAiStudioSecretSettings) super.getSecretSettings(); + } + + public static URI buildUri(String model) throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(GoogleAiStudioUtils.HOST_SUFFIX) + .setPathSegments( + GoogleAiStudioUtils.V1, + GoogleAiStudioUtils.MODELS, + format("%s:%s", model, GoogleAiStudioUtils.GENERATE_CONTENT_ACTION) + ) + .build(); + } + + @Override + public ExecutableAction accept(GoogleAiStudioActionVisitor visitor, Map taskSettings, InputType inputType) { + return visitor.create(this, taskSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettings.java new file mode 100644 index 0000000000000..f8f343be8eb4c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettings.java @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; + +public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + GoogleAiStudioRateLimitServiceSettings { + + public static final String NAME = "google_ai_studio_completion_service_settings"; + + /** + * Rate limits are defined at Google Gemini API Pricing. + * For pay-as-you-go you've 360 requests per minute. + */ + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360); + + public static GoogleAiStudioCompletionServiceSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new GoogleAiStudioCompletionServiceSettings(model, rateLimitSettings); + } + + private final String modelId; + + private final RateLimitSettings rateLimitSettings; + + public GoogleAiStudioCompletionServiceSettings(String modelId, @Nullable RateLimitSettings rateLimitSettings) { + this.modelId = modelId; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public GoogleAiStudioCompletionServiceSettings(StreamInput in) throws IOException { + modelId = in.readString(); + rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragmentOfExposedFields(builder, params); + rateLimitSettings.toXContent(builder, params); + + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_GOOGLE_AI_STUDIO_COMPLETION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + rateLimitSettings.writeTo(out); + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + + return builder; + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + GoogleAiStudioCompletionServiceSettings that = (GoogleAiStudioCompletionServiceSettings) object; + return Objects.equals(modelId, that.modelId) && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/MatchersUtils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/MatchersUtils.java new file mode 100644 index 0000000000000..6397e83fc246e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/MatchersUtils.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; + +import java.util.regex.Pattern; + +/** + * Utility class containing custom hamcrest {@link Matcher} implementations or other utility functionality related to hamcrest. + */ +public class MatchersUtils { + + /** + * Custom matcher implementing a matcher operating on json strings ignoring whitespaces, which are not inside a key or a value. + * + * Example: + * { + * "key": "value" + * } + * + * will match + * + * {"key":"value"} + * + * as both json strings are equal ignoring the whitespace, which does not reside in a key or a value. + * + */ + protected static class IsEqualIgnoreWhitespaceInJsonString extends TypeSafeMatcher { + + protected static final Pattern WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN = createPattern(); + + private static Pattern createPattern() { + String regex = "(?<=[:,\\[{])\\s+|\\s+(?=[\\]}:,])|^\\s+|\\s+$"; + return Pattern.compile(regex); + } + + private final String string; + + IsEqualIgnoreWhitespaceInJsonString(String string) { + if (string == null) { + throw new IllegalArgumentException("Non-null value required"); + } + this.string = string; + } + + @Override + protected boolean matchesSafely(String item) { + java.util.regex.Matcher itemMatcher = WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN.matcher(item); + java.util.regex.Matcher stringMatcher = WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN.matcher(string); + + String itemReplacedWhitespaces = itemMatcher.replaceAll(""); + String stringReplacedWhitespaces = stringMatcher.replaceAll(""); + + return itemReplacedWhitespaces.equals(stringReplacedWhitespaces); + } + + @Override + public void describeTo(Description description) { + java.util.regex.Matcher stringMatcher = WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN.matcher(string); + String stringReplacedWhitespaces = stringMatcher.replaceAll(""); + + description.appendText("a string equal to (when all whitespaces are ignored expect in keys and values): ") + .appendValue(stringReplacedWhitespaces); + } + + public static Matcher equalToIgnoringWhitespaceInJsonString(String expectedString) { + return new IsEqualIgnoreWhitespaceInJsonString(expectedString); + } + } + + public static Matcher equalToIgnoringWhitespaceInJsonString(String expectedString) { + return IsEqualIgnoreWhitespaceInJsonString.equalToIgnoringWhitespaceInJsonString(expectedString); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/MatchersUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/MatchersUtilsTests.java new file mode 100644 index 0000000000000..6f30d23a45ae5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/MatchersUtilsTests.java @@ -0,0 +1,186 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Description; +import org.hamcrest.SelfDescribing; + +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.hamcrest.Matchers.is; + +public class MatchersUtilsTests extends ESTestCase { + + public void testIsEqualIgnoreWhitespaceInJsonString_Pattern() { + var json = """ + + { + "key": "value" + } + + """; + + Pattern pattern = MatchersUtils.IsEqualIgnoreWhitespaceInJsonString.WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN; + Matcher matcher = pattern.matcher(json); + String jsonWithRemovedWhitespaces = matcher.replaceAll(""); + + assertThat(jsonWithRemovedWhitespaces, is(""" + {"key":"value"}""")); + } + + public void testIsEqualIgnoreWhitespaceInJsonString_Pattern_DoesNotRemoveWhitespaceInKeysAndValues() { + var json = """ + + { + "key 1": "value 1" + } + + """; + + Pattern pattern = MatchersUtils.IsEqualIgnoreWhitespaceInJsonString.WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN; + Matcher matcher = pattern.matcher(json); + String jsonWithRemovedWhitespaces = matcher.replaceAll(""); + + assertThat(jsonWithRemovedWhitespaces, is(""" + {"key 1":"value 1"}""")); + } + + public void testIsEqualIgnoreWhitespaceInJsonString_MatchesSafely_DoesMatch() { + var json = """ + + { + "key 1": "value 1", + "key 2: { + "key 3: "value 3" + }, + "key 4": [ + "value 4", "value 5" + ] + } + + """; + + var jsonWithDifferentSpacing = """ + {"key 1": "value 1", + "key 2: { + "key 3: "value 3" + }, + "key 4": [ + "value 4", "value 5" + ] + } + + """; + + var typeSafeMatcher = new MatchersUtils.IsEqualIgnoreWhitespaceInJsonString(json); + boolean matches = typeSafeMatcher.matchesSafely(jsonWithDifferentSpacing); + + assertTrue(matches); + } + + public void testIsEqualIgnoreWhitespaceInJsonString_MatchesSafely_DoesNotMatch() { + var json = """ + + { + "key 1": "value 1", + "key 2: { + "key 3: "value 3" + }, + "key 4": [ + "value 4", "value 5" + ] + } + + """; + + // one value missing in array + var jsonWithDifferentSpacing = """ + {"key 1": "value 1", + "key 2: { + "key 3: "value 3" + }, + "key 4": [ + "value 4" + ] + } + + """; + + var typeSafeMatcher = new MatchersUtils.IsEqualIgnoreWhitespaceInJsonString(json); + boolean matches = typeSafeMatcher.matchesSafely(jsonWithDifferentSpacing); + + assertFalse(matches); + } + + public void testIsEqualIgnoreWhitespaceInJsonString_DescribeTo() { + var jsonOne = """ + { + "key": "value" + } + """; + + var typeSafeMatcher = new MatchersUtils.IsEqualIgnoreWhitespaceInJsonString(jsonOne); + var description = new TestDescription(""); + + typeSafeMatcher.describeTo(description); + + assertThat(description.toString(), is(""" + a string equal to (when all whitespaces are ignored expect in keys and values): {"key":"value"}""")); + } + + private static class TestDescription implements Description { + + private String descriptionContent; + + TestDescription(String descriptionContent) { + Objects.requireNonNull(descriptionContent); + this.descriptionContent = descriptionContent; + } + + @Override + public Description appendText(String text) { + descriptionContent += text; + return this; + } + + @Override + public Description appendDescriptionOf(SelfDescribing value) { + throw new UnsupportedOperationException(); + } + + @Override + public Description appendValue(Object value) { + descriptionContent += value; + return this; + } + + @SafeVarargs + @Override + public final Description appendValueList(String start, String separator, String end, T... values) { + throw new UnsupportedOperationException(); + } + + @Override + public Description appendValueList(String start, String separator, String end, Iterable values) { + throw new UnsupportedOperationException(); + } + + @Override + public Description appendList(String start, String separator, String end, Iterable values) { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() { + return descriptionContent; + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java new file mode 100644 index 0000000000000..09ef5351eb1fc --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java @@ -0,0 +1,274 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.googleaistudio; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModelTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioServiceTests.buildExpectationCompletions; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class GoogleAiStudioCompletionActionTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { + sender.start(); + + String responseJson = """ + { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "result" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 566, + "totalTokenCount": 570 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction(getUrl(webServer), "secret", "model", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("input")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletions(List.of("result")))); + assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests().get(0).getUri().getQuery(), is("key=secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat( + requestMap, + is( + Map.of( + "contents", + List.of(Map.of("role", "user", "parts", List.of(Map.of("text", "input")))), + "generationConfig", + Map.of("candidateCount", 1) + ) + ) + ); + } + } + + public void testExecute_ThrowsElasticsearchException() { + var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "secret", "model", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("failed")); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "secret", "model", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + thrownException.getMessage(), + is(format("Failed to send Google AI Studio completion request to [%s]", getUrl(webServer))) + ); + } + + public void testExecute_ThrowsException() { + var sender = mock(Sender.class); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "secret", "model", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + thrownException.getMessage(), + is(format("Failed to send Google AI Studio completion request to [%s]", getUrl(webServer))) + ); + } + + public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + String responseJson = """ + { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "result" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 566, + "totalTokenCount": 570 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction(getUrl(webServer), "secret", "model", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("Google AI Studio completion only accepts 1 input")); + assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST)); + } + } + + private GoogleAiStudioCompletionAction createAction(String url, String apiKey, String modelName, Sender sender) { + var model = GoogleAiStudioCompletionModelTests.createModel(modelName, url, apiKey); + + return new GoogleAiStudioCompletionAction(sender, model, threadPool); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandlerTests.java new file mode 100644 index 0000000000000..ba20799978d45 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandlerTests.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.googleaistudio; + +import org.apache.http.Header; +import org.apache.http.HeaderElement; +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class GoogleAiStudioResponseHandlerTests extends ESTestCase { + + public void testCheckForFailureStatusCode_DoesNotThrowFor200() { + callCheckForFailureStatusCode(200, "id"); + } + + public void testCheckForFailureStatusCode_ThrowsFor500_ShouldRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(500, "id")); + assertTrue(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [500]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor503_ShouldRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(503, "id")); + assertTrue(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString( + "The Google AI Studio service may be temporarily overloaded or down for request from inference entity id [id] status [503]" + ) + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor505_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(505, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [505]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor429_ShouldRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(429, "id")); + assertTrue(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received a rate limit status code for request from inference entity id [id] status [429]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.TOO_MANY_REQUESTS)); + } + + public void testCheckForFailureStatusCode_ThrowsFor404_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(404, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Resource not found at [null] for request from inference entity id [id] status [404]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.NOT_FOUND)); + } + + public void testCheckForFailureStatusCode_ThrowsFor403_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(403, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received a permission denied error status code for request from inference entity id [id] status [403]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.FORBIDDEN)); + } + + public void testCheckForFailureStatusCode_ThrowsFor300_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(300, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Unhandled redirection for request from inference entity id [id] status [300]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.MULTIPLE_CHOICES)); + } + + public void testCheckForFailureStatusCode_ThrowsFor425_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(425, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received an unsuccessful status code for request from inference entity id [id] status [425]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + private static void callCheckForFailureStatusCode(int statusCode, String modelId) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var httpResponse = mock(HttpResponse.class); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + var header = mock(Header.class); + when(header.getElements()).thenReturn(new HeaderElement[] {}); + when(httpResponse.getFirstHeader(anyString())).thenReturn(header); + + var mockRequest = mock(Request.class); + when(mockRequest.getInferenceEntityId()).thenReturn(modelId); + var httpResult = new HttpResult(httpResponse, new byte[] {}); + var handler = new GoogleAiStudioResponseHandler("", (request, result) -> null); + + handler.checkForFailureStatusCode(mockRequest, httpResult); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequestTests.java new file mode 100644 index 0000000000000..d77c88dacd06f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequestTests.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio; + +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioSecretSettings; + +import java.net.URI; +import java.net.URISyntaxException; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class GoogleAiStudioRequestTests extends ESTestCase { + + public void testDecorateWithApiKeyParameter() throws URISyntaxException { + var uriString = "https://localhost:3000"; + var secureApiKey = new SecureString("api_key".toCharArray()); + var httpPost = new HttpPost(uriString); + var secretSettings = new GoogleAiStudioSecretSettings(secureApiKey); + + GoogleAiStudioRequest.decorateWithApiKeyParameter(httpPost, secretSettings); + + assertThat(httpPost.getURI(), is(new URI(Strings.format("%s?key=%s", uriString, secureApiKey)))); + } + + public void testDecorateWithApiKeyParameter_ThrowsValidationException_WhenAnyExceptionIsThrown() { + var errorMessage = "something went wrong"; + var cause = new RuntimeException(errorMessage); + var httpPost = mock(HttpPost.class); + when(httpPost.getURI()).thenThrow(cause); + + ValidationException validationException = expectThrows( + ValidationException.class, + () -> GoogleAiStudioRequest.decorateWithApiKeyParameter( + httpPost, + new GoogleAiStudioSecretSettings(new SecureString("abc".toCharArray())) + ) + ); + assertThat(validationException.getCause(), is(cause)); + assertThat(validationException.getMessage(), containsString(errorMessage)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..0b8ded1a4f118 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestEntityTests.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioCompletionRequestEntity; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class GoogleAiStudioCompletionRequestEntityTests extends ESTestCase { + + public void testToXContent_WritesSingleMessage() throws IOException { + var entity = new GoogleAiStudioCompletionRequestEntity(List.of("input")); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "contents": [ + { + "parts": [ + { + "text":"input" + } + ], + "role": "user" + } + ], + "generationConfig": { + "candidateCount": 1 + } + }""")); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java new file mode 100644 index 0000000000000..7d7ee1dcba6c2 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio.completion; + +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioCompletionRequest; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModelTests; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.endsWith; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class GoogleAiStudioCompletionRequestTests extends ESTestCase { + + public void testCreateRequest() throws IOException { + var apiKey = "api_key"; + var input = "input"; + + var request = new GoogleAiStudioCompletionRequest(List.of(input), GoogleAiStudioCompletionModelTests.createModel("model", apiKey)); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), endsWith(Strings.format("%s=%s", "key", apiKey))); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat( + requestMap, + is( + Map.of( + "contents", + List.of(Map.of("role", "user", "parts", List.of(Map.of("text", input)))), + "generationConfig", + Map.of("candidateCount", 1) + ) + ) + ); + } + + public void testTruncate_ReturnsSameInstance() { + var request = new GoogleAiStudioCompletionRequest( + List.of("input"), + GoogleAiStudioCompletionModelTests.createModel("model", "api key") + ); + var truncatedRequest = request.truncate(); + + assertThat(truncatedRequest, sameInstance(request)); + } + + public void testTruncationInfo_ReturnsNull() { + var request = new GoogleAiStudioCompletionRequest( + List.of("input"), + GoogleAiStudioCompletionModelTests.createModel("model", "api key") + ); + + assertNull(request.getTruncationInfo()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioCompletionResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioCompletionResponseEntityTests.java new file mode 100644 index 0000000000000..ea4dd6ce47e22 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioCompletionResponseEntityTests.java @@ -0,0 +1,189 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.googleaistudio; + +import org.apache.http.HttpResponse; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class GoogleAiStudioCompletionResponseEntityTests extends ESTestCase { + + public void testFromResponse_CreatesResultsForASingleItem() throws IOException { + String responseJson = """ + { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "result" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 312, + "totalTokenCount": 316 + } + } + """; + + ChatCompletionResults chatCompletionResults = GoogleAiStudioCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(chatCompletionResults.getResults().size(), is(1)); + assertThat(chatCompletionResults.getResults().get(0).content(), is("result")); + } + + public void testFromResponse_FailsWhenCandidatesFieldIsNotPresent() { + String responseJson = """ + { + "not_candidates": [ + { + "content": { + "parts": [ + { + "text": "result" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 312, + "totalTokenCount": 316 + } + } + """; + + var thrownException = expectThrows( + IllegalStateException.class, + () -> GoogleAiStudioCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), is("Failed to find required field [candidates] in Google AI Studio completion response")); + } + + public void testFromResponse_FailsWhenTextFieldIsNotAString() { + String responseJson = """ + { + "candidates": [ + { + "content": { + "parts": [ + { + "text": { + "key": "value" + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 312, + "totalTokenCount": 316 + } + } + """; + + var thrownException = expectThrows( + ParsingException.class, + () -> GoogleAiStudioCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [VALUE_STRING] but found [START_OBJECT]") + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioErrorResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioErrorResponseEntityTests.java new file mode 100644 index 0000000000000..61448f2e35bdf --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioErrorResponseEntityTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.googleaistudio; + +import org.apache.http.HttpResponse; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class GoogleAiStudioErrorResponseEntityTests extends ESTestCase { + + private static HttpResult getMockResult(String jsonString) { + var response = mock(HttpResponse.class); + return new HttpResult(response, Strings.toUTF8Bytes(jsonString)); + } + + public void testErrorResponse_ExtractsError() { + var result = getMockResult(""" + { + "error": { + "code": 400, + "message": "error message", + "status": "INVALID_ARGUMENT", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.BadRequest", + "fieldViolations": [ + { + "description": "Invalid JSON payload received. Unknown name \\"abc\\": Cannot find field." + } + ] + } + ] + } + } + """); + + var error = GoogleAiStudioErrorResponseEntity.fromResponse(result); + assertNotNull(error); + assertThat(error.getErrorMessage(), is("error message")); + } + + public void testErrorResponse_ReturnsNullIfNoError() { + var result = getMockResult(""" + { + "foo": "bar" + } + """); + + var error = GoogleAiStudioErrorResponseEntity.fromResponse(result); + assertNull(error); + } + + public void testErrorResponse_ReturnsNullIfNotJson() { + var result = getMockResult("error message"); + + var error = GoogleAiStudioErrorResponseEntity.fromResponse(result); + assertNull(error); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettingsTests.java new file mode 100644 index 0000000000000..a0339934783d8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettingsTests.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class GoogleAiStudioSecretSettingsTests extends AbstractWireSerializingTestCase { + + public static GoogleAiStudioSecretSettings createRandom() { + return new GoogleAiStudioSecretSettings(randomSecureStringOfLength(15)); + } + + public void testFromMap() { + var apiKey = "abc"; + var secretSettings = GoogleAiStudioSecretSettings.fromMap(new HashMap<>(Map.of(GoogleAiStudioSecretSettings.API_KEY, apiKey))); + + assertThat(new GoogleAiStudioSecretSettings(new SecureString(apiKey.toCharArray())), is(secretSettings)); + } + + public void testFromMap_ReturnsNull_WhenMapIsNull() { + assertNull(GoogleAiStudioSecretSettings.fromMap(null)); + } + + public void testFromMap_ThrowsError_WhenApiKeyIsNull() { + var throwException = expectThrows(ValidationException.class, () -> GoogleAiStudioSecretSettings.fromMap(new HashMap<>())); + + assertThat(throwException.getMessage(), containsString("[secret_settings] must have [api_key] set")); + } + + public void testFromMap_ThrowsError_WhenApiKeyIsEmpty() { + var thrownException = expectThrows( + ValidationException.class, + () -> GoogleAiStudioSecretSettings.fromMap(new HashMap<>(Map.of(GoogleAiStudioSecretSettings.API_KEY, ""))) + ); + + assertThat( + thrownException.getMessage(), + containsString("[secret_settings] Invalid value empty string. [api_key] must be a non-empty string") + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return GoogleAiStudioSecretSettings::new; + } + + @Override + protected GoogleAiStudioSecretSettings createTestInstance() { + return createRandom(); + } + + @Override + protected GoogleAiStudioSecretSettings mutateInstance(GoogleAiStudioSecretSettings instance) throws IOException { + return randomValueOtherThan(instance, GoogleAiStudioSecretSettingsTests::createRandom); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java new file mode 100644 index 0000000000000..f157622ea7291 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -0,0 +1,630 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; +import static org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModelTests.createModel; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasSize; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class GoogleAiStudioServiceTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesAGoogleAiStudioCompletionModel() throws IOException { + var apiKey = "apiKey"; + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getSecretSettings().apiKey().toString(), is(apiKey)); + }, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.COMPLETION, + getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + new HashMap<>(Map.of()), + getSecretSettingsMap(apiKey) + ), + Set.of(), + modelListener + ); + } + } + + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { + try (var service = createGoogleAiStudioService()) { + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "The [googleaistudio] service does not support task type [sparse_embedding]" + ); + + service.parseRequestConfig( + "id", + TaskType.SPARSE_EMBEDDING, + getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model")), + new HashMap<>(Map.of()), + getSecretSettingsMap("secret") + ), + Set.of(), + failureListener + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createGoogleAiStudioService()) { + var config = getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model")), + getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ); + config.put("extra_key", "value"); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" + ); + service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + try (var service = createGoogleAiStudioService()) { + Map serviceSettings = new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model")); + serviceSettings.put("extra_key", "value"); + + var config = getRequestConfigMap(serviceSettings, getTaskSettingsMapEmpty(), getSecretSettingsMap("api_key")); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" + ); + service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + try (var service = createGoogleAiStudioService()) { + Map taskSettingsMap = new HashMap<>(); + taskSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model")), + taskSettingsMap, + getSecretSettingsMap("secret") + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" + ); + service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + try (var service = createGoogleAiStudioService()) { + Map secretSettings = getSecretSettingsMap("secret"); + secretSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model")), + getTaskSettingsMapEmpty(), + secretSettings + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" + ); + service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioCompletionModel() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + getSecretSettingsMap(apiKey) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(completionModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + getSecretSettingsMap(apiKey) + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(completionModel.getSecretSettings().apiKey(), is(apiKey)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + var secretSettingsMap = getSecretSettingsMap(apiKey); + secretSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + secretSettingsMap + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(completionModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + Map serviceSettingsMap = new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMapEmpty(), getSecretSettingsMap(apiKey)); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(completionModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + Map taskSettings = getTaskSettingsMapEmpty(); + taskSettings.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + taskSettings, + getSecretSettingsMap(apiKey) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(completionModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + + public void testParsePersistedConfig_CreatesAGoogleAiStudioCompletionModel() throws IOException { + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), getTaskSettingsMapEmpty()); + + var model = service.parsePersistedConfig("id", TaskType.COMPLETION, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertNull(completionModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), getTaskSettingsMapEmpty()); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfig("id", TaskType.COMPLETION, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertNull(completionModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + Map serviceSettingsMap = new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMapEmpty()); + + var model = service.parsePersistedConfig("id", TaskType.COMPLETION, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertNull(completionModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + Map taskSettings = getTaskSettingsMapEmpty(); + taskSettings.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), taskSettings); + + var model = service.parsePersistedConfig("id", TaskType.COMPLETION, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertNull(completionModel.getSecretSettings()); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotGoogleAiStudioModel() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender(anyString())).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name"); + + try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + List.of(""), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(anyString()); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testInfer_SendsRequest() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + String responseJson = """ + { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "result" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 215, + "totalTokenCount": 219 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel("model", getUrl(webServer), "secret"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("input"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletions(List.of("result")))); + assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests().get(0).getUri().getQuery(), is("key=secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat( + requestMap, + is( + Map.of( + "contents", + List.of(Map.of("role", "user", "parts", List.of(Map.of("text", "input")))), + "generationConfig", + Map.of("candidateCount", 1) + ) + ) + ); + } + } + + public void testInfer_ResourceNotFound() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "error": { + "message": "error" + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + + var model = createModel("model", getUrl(webServer), "secret"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(error.getMessage(), containsString("Resource not found at ")); + assertThat(error.getMessage(), containsString("Error message: [error]")); + assertThat(webServer.requests(), hasSize(1)); + } + } + + public static Map buildExpectationCompletions(List completions) { + return Map.of( + ChatCompletionResults.COMPLETION, + completions.stream().map(completion -> Map.of(ChatCompletionResults.Result.RESULT, completion)).collect(Collectors.toList()) + ); + } + + private static ActionListener getModelListenerForException(Class exceptionClass, String expectedMessage) { + return ActionListener.wrap((model) -> fail("Model parsing should have failed"), e -> { + assertThat(e, Matchers.instanceOf(exceptionClass)); + assertThat(e.getMessage(), is(expectedMessage)); + }); + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>( + Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) + ); + } + + private GoogleAiStudioService createGoogleAiStudioService() { + return new GoogleAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + } + + private PersistedConfig getPersistedConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + + return new PersistedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + new HashMap<>(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings)) + ); + } + + private PersistedConfig getPersistedConfigMap(Map serviceSettings, Map taskSettings) { + return new PersistedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + null + ); + } + + private record PersistedConfig(Map config, Map secrets) {} + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java new file mode 100644 index 0000000000000..1f8233f7eb103 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio.completion; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioSecretSettings; + +import java.net.URISyntaxException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class GoogleAiStudioCompletionModelTests extends ESTestCase { + + public void testCreateModel_AlwaysWithEmptyTaskSettings() { + var model = new GoogleAiStudioCompletionModel( + "inference entity id", + TaskType.COMPLETION, + "service", + new HashMap<>(Map.of("model_id", "model")), + new HashMap<>(Map.of()), + null + ); + + assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + } + + public void testBuildUri() throws URISyntaxException { + assertThat( + GoogleAiStudioCompletionModel.buildUri("model").toString(), + is("https://generativelanguage.googleapis.com/v1/models/model:generateContent") + ); + } + + public static GoogleAiStudioCompletionModel createModel(String model, String apiKey) { + return new GoogleAiStudioCompletionModel( + "id", + TaskType.COMPLETION, + "service", + new GoogleAiStudioCompletionServiceSettings(model, null), + EmptyTaskSettings.INSTANCE, + new GoogleAiStudioSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static GoogleAiStudioCompletionModel createModel(String model, String url, String apiKey) { + return new GoogleAiStudioCompletionModel( + "id", + TaskType.COMPLETION, + "service", + url, + new GoogleAiStudioCompletionServiceSettings(model, null), + EmptyTaskSettings.INSTANCE, + new GoogleAiStudioSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..46e6e60af493c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettingsTests.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class GoogleAiStudioCompletionServiceSettingsTests extends AbstractWireSerializingTestCase { + + public static GoogleAiStudioCompletionServiceSettings createRandom() { + return new GoogleAiStudioCompletionServiceSettings(randomAlphaOfLength(8), randomFrom(RateLimitSettingsTests.createRandom(), null)); + } + + public void testFromMap_Request_CreatesSettingsCorrectly() { + var model = "some model"; + + var serviceSettings = GoogleAiStudioCompletionServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, model))); + + assertThat(serviceSettings, is(new GoogleAiStudioCompletionServiceSettings(model, null))); + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new GoogleAiStudioCompletionServiceSettings("model", null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"model_id":"model","rate_limit":{"requests_per_minute":360}}""")); + } + + public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException { + var entity = new GoogleAiStudioCompletionServiceSettings("model", null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + var filteredXContent = entity.getFilteredXContentObject(); + filteredXContent.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"model_id":"model"}""")); + } + + @Override + protected Writeable.Reader instanceReader() { + return GoogleAiStudioCompletionServiceSettings::new; + } + + @Override + protected GoogleAiStudioCompletionServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected GoogleAiStudioCompletionServiceSettings mutateInstance(GoogleAiStudioCompletionServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, GoogleAiStudioCompletionServiceSettingsTests::createRandom); + } +} From a5ac6a25c758bb2316b8b92a1b4d80d391e6f193 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Tue, 28 May 2024 08:19:47 +0200 Subject: [PATCH 21/29] [Inference API] Add CustomElandRerankTaskSettingsTests (#109069) --- .../CustomElandRerankTaskSettings.java | 2 +- .../CustomElandRerankTaskSettingsTests.java | 115 ++++++++++++++++++ 2 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettingsTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettings.java index a82ffbba3d688..0b586af5005fb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettings.java @@ -66,7 +66,7 @@ public static CustomElandRerankTaskSettings fromMap(Map map) { } /** - * Return either the request or orignal settings by preferring non-null fields + * Return either the request or original settings by preferring non-null fields * from the request settings over the original settings. * * @param originalSettings the settings stored as part of the inference entity configuration diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettingsTests.java new file mode 100644 index 0000000000000..05515bf9e3865 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettingsTests.java @@ -0,0 +1,115 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.HashMap; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class CustomElandRerankTaskSettingsTests extends AbstractWireSerializingTestCase { + + public void testDefaultsFromMap_MapIsNull_ReturnsDefaultSettings() { + var customElandRerankTaskSettings = CustomElandRerankTaskSettings.defaultsFromMap(null); + + assertThat(customElandRerankTaskSettings, sameInstance(CustomElandRerankTaskSettings.DEFAULT_SETTINGS)); + } + + public void testDefaultsFromMap_MapIsEmpty_ReturnsDefaultSettings() { + var customElandRerankTaskSettings = CustomElandRerankTaskSettings.defaultsFromMap(new HashMap<>()); + + assertThat(customElandRerankTaskSettings, sameInstance(CustomElandRerankTaskSettings.DEFAULT_SETTINGS)); + } + + public void testDefaultsFromMap_ExtractedReturnDocumentsNull_SetsReturnDocumentToTrue() { + var customElandRerankTaskSettings = CustomElandRerankTaskSettings.defaultsFromMap(new HashMap<>()); + + assertThat(customElandRerankTaskSettings.returnDocuments(), is(Boolean.TRUE)); + } + + public void testFromMap_MapIsNull_ReturnsDefaultSettings() { + var customElandRerankTaskSettings = CustomElandRerankTaskSettings.fromMap(null); + + assertThat(customElandRerankTaskSettings, sameInstance(CustomElandRerankTaskSettings.DEFAULT_SETTINGS)); + } + + public void testFromMap_MapIsEmpty_ReturnsDefaultSettings() { + var customElandRerankTaskSettings = CustomElandRerankTaskSettings.fromMap(new HashMap<>()); + + assertThat(customElandRerankTaskSettings, sameInstance(CustomElandRerankTaskSettings.DEFAULT_SETTINGS)); + } + + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = new CustomElandRerankTaskSettings(Boolean.TRUE); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"return_documents":true}""")); + } + + public void testToXContent_DoesNotWriteReturnDocuments_IfNull() throws IOException { + Boolean bool = null; + var serviceSettings = new CustomElandRerankTaskSettings(bool); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {}""")); + } + + public void testOf_PrefersNonNullRequestTaskSettings() { + var originalSettings = new CustomElandRerankTaskSettings(Boolean.FALSE); + var requestTaskSettings = new CustomElandRerankTaskSettings(Boolean.TRUE); + + var taskSettings = CustomElandRerankTaskSettings.of(originalSettings, requestTaskSettings); + + assertThat(taskSettings, sameInstance(requestTaskSettings)); + } + + public void testOf_UseOriginalSettings_IfRequestSettingsValuesAreNull() { + Boolean bool = null; + var originalSettings = new CustomElandRerankTaskSettings(Boolean.TRUE); + var requestTaskSettings = new CustomElandRerankTaskSettings(bool); + + var taskSettings = CustomElandRerankTaskSettings.of(originalSettings, requestTaskSettings); + + assertThat(taskSettings, sameInstance(originalSettings)); + } + + private static CustomElandRerankTaskSettings createRandom() { + return new CustomElandRerankTaskSettings(randomOptionalBoolean()); + } + + @Override + protected Writeable.Reader instanceReader() { + return CustomElandRerankTaskSettings::new; + } + + @Override + protected CustomElandRerankTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected CustomElandRerankTaskSettings mutateInstance(CustomElandRerankTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, CustomElandRerankTaskSettingsTests::createRandom); + } +} From aab02782e0f47b49c249c53510461abdddb838dd Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Tue, 28 May 2024 09:38:17 +0200 Subject: [PATCH 22/29] [Inference API] Improve URI validation in ServiceUtils#convertToUri (#109060) --- .../xpack/inference/services/ServiceUtils.java | 11 +++++------ .../cohere/CohereCompletionActionTests.java | 3 ++- .../cohere/CohereEmbeddingsActionTests.java | 3 ++- .../openai/OpenAiChatCompletionActionTests.java | 3 ++- .../openai/OpenAiEmbeddingsActionTests.java | 3 ++- .../inference/services/ServiceUtilsTests.java | 16 ++++++++++++++-- .../cohere/CohereServiceSettingsTests.java | 4 +++- .../HuggingFaceServiceSettingsTests.java | 4 +++- .../elser/HuggingFaceElserModelTests.java | 4 ++-- .../HuggingFaceElserServiceSettingsTests.java | 4 ++-- .../HuggingFaceEmbeddingsModelTests.java | 4 ++-- ...OpenAiChatCompletionServiceSettingsTests.java | 4 +++- .../OpenAiEmbeddingsServiceSettingsTests.java | 4 +++- 13 files changed, 45 insertions(+), 22 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 25e8afbe1d16c..4b5ec48f99b74 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -196,8 +196,8 @@ public static String invalidTypesErrorMsg(String settingName, Object foundObject ); } - public static String invalidUrlErrorMsg(String url, String settingName, String settingScope) { - return Strings.format("[%s] Invalid url [%s] received for field [%s]", settingScope, url, settingName); + public static String invalidUrlErrorMsg(String url, String settingName, String settingScope, String error) { + return Strings.format("[%s] Invalid url [%s] received for field [%s]. Error: %s", settingScope, url, settingName, error); } public static String mustBeNonEmptyString(String settingName, String scope) { @@ -231,7 +231,6 @@ public static String invalidSettingError(String settingName, String scope) { return Strings.format("[%s] does not allow the setting [%s]", scope, settingName); } - // TODO improve URI validation logic public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) { try { if (url == null) { @@ -239,8 +238,8 @@ public static URI convertToUri(@Nullable String url, String settingName, String } return createUri(url); - } catch (IllegalArgumentException ignored) { - validationException.addValidationError(ServiceUtils.invalidUrlErrorMsg(url, settingName, settingScope)); + } catch (IllegalArgumentException cause) { + validationException.addValidationError(ServiceUtils.invalidUrlErrorMsg(url, settingName, settingScope, cause.getMessage())); return null; } } @@ -251,7 +250,7 @@ public static URI createUri(String url) throws IllegalArgumentException { try { return new URI(url); } catch (URISyntaxException e) { - throw new IllegalArgumentException(format("unable to parse url [%s]", url), e); + throw new IllegalArgumentException(format("unable to parse url [%s]. Reason: %s", url, e.getReason()), e); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java index 195f2bab1d6b5..9ac4674ef0b1e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java @@ -44,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -198,7 +199,7 @@ public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOException { try (var sender = mock(Sender.class)) { var thrownException = expectThrows(IllegalArgumentException.class, () -> createAction("a^b", "api key", "model", sender)); - assertThat(thrownException.getMessage(), is("unable to parse url [a^b]")); + assertThat(thrownException.getMessage(), containsString("unable to parse url [a^b]")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java index 6ca4cb305ab32..dbc97fa2e13d8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java @@ -49,6 +49,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationByte; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -245,7 +246,7 @@ public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOExcept IllegalArgumentException.class, () -> createAction("^^", "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender) ); - MatcherAssert.assertThat(thrownException.getMessage(), is("unable to parse url [^^]")); + MatcherAssert.assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java index e28c3e817b351..35d1ee8fc5a5a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java @@ -47,6 +47,7 @@ import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -142,7 +143,7 @@ public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOExcept IllegalArgumentException.class, () -> createAction("^^", "org", "secret", "model", "user", sender) ); - assertThat(thrownException.getMessage(), is("unable to parse url [^^]")); + assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java index 260e352fd26c8..15b7417912ef5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java @@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -131,7 +132,7 @@ public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOExcept IllegalArgumentException.class, () -> createAction("^^", "org", "secret", "model", "user", sender) ); - assertThat(thrownException.getMessage(), is("unable to parse url [^^]")); + assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index 0a34de7b342ee..edd9637d92dd8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -307,7 +307,19 @@ public void testConvertToUri_AddsValidationError_WhenUrlIsInvalid() { assertNull(uri); assertThat(validation.validationErrors().size(), is(1)); - assertThat(validation.validationErrors().get(0), is("[scope] Invalid url [^^] received for field [name]")); + assertThat(validation.validationErrors().get(0), containsString("[scope] Invalid url [^^] received for field [name]")); + } + + public void testConvertToUri_AddsValidationError_WhenUrlIsInvalid_PreservesReason() { + var validation = new ValidationException(); + var uri = convertToUri("^^", "name", "scope", validation); + + assertNull(uri); + assertThat(validation.validationErrors().size(), is(1)); + assertThat( + validation.validationErrors().get(0), + is("[scope] Invalid url [^^] received for field [name]. Error: unable to parse url [^^]. Reason: Illegal character in path") + ); } public void testCreateUri_CreatesUri() { @@ -320,7 +332,7 @@ public void testCreateUri_CreatesUri() { public void testCreateUri_ThrowsException_WithInvalidUrl() { var exception = expectThrows(IllegalArgumentException.class, () -> createUri("^^")); - assertThat(exception.getMessage(), is("unable to parse url [^^]")); + assertThat(exception.getMessage(), containsString("unable to parse url [^^]")); } public void testCreateUri_ThrowsException_WithNullUrl() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java index 303ed1cab2c50..f4dad7546c8a2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java @@ -229,7 +229,9 @@ public void testFromMap_InvalidUrl_ThrowsError() { MatcherAssert.assertThat( thrownException.getMessage(), - is(Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s];", url, ServiceFields.URL)) + containsString( + Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]", url, ServiceFields.URL) + ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java index 9d92f756dd31c..91b91593adee7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java @@ -141,7 +141,9 @@ public void testFromMap_InvalidUrl_ThrowsError() { assertThat( thrownException.getMessage(), - is(Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s];", url, ServiceFields.URL)) + containsString( + Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]", url, ServiceFields.URL) + ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModelTests.java index 33dbee2a32b9f..2ad2c12b4a97c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModelTests.java @@ -11,13 +11,13 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; -import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.containsString; public class HuggingFaceElserModelTests extends ESTestCase { public void testThrowsURISyntaxException_ForInvalidUrl() { var thrownException = expectThrows(IllegalArgumentException.class, () -> createModel("^^", "secret")); - assertThat(thrownException.getMessage(), is("unable to parse url [^^]")); + assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); } public static HuggingFaceElserModel createModel(String url, String apiKey) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettingsTests.java index bd6a5007b72ee..57f9c59b65e12 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettingsTests.java @@ -77,9 +77,9 @@ public void testFromMap_InvalidUrl_ThrowsError() { assertThat( thrownException.getMessage(), - is( + containsString( Strings.format( - "Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s];", + "Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]", url, HuggingFaceElserServiceSettings.URL ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java index d579da2d9fbc5..baf5467d8fe06 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java @@ -16,13 +16,13 @@ import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; -import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.containsString; public class HuggingFaceEmbeddingsModelTests extends ESTestCase { public void testThrowsURISyntaxException_ForInvalidUrl() { var thrownException = expectThrows(IllegalArgumentException.class, () -> createModel("^^", "secret")); - assertThat(thrownException.getMessage(), is("unable to parse url [^^]")); + assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); } public static HuggingFaceEmbeddingsModel createModel(String url, String apiKey) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettingsTests.java index 75ea63eba8a34..186ca89426418 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettingsTests.java @@ -170,7 +170,9 @@ public void testFromMap_InvalidUrl_ThrowsError() { assertThat( thrownException.getMessage(), - is(Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s];", url, ServiceFields.URL)) + containsString( + Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]", url, ServiceFields.URL) + ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java index 1be70ee586835..438f895fe48ad 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java @@ -335,7 +335,9 @@ public void testFromMap_InvalidUrl_ThrowsError() { assertThat( thrownException.getMessage(), - is(Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s];", url, ServiceFields.URL)) + containsString( + Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]", url, ServiceFields.URL) + ) ); } From 1cff3b4366d8cf5f6e6616820f14585d69073872 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Tue, 28 May 2024 09:44:40 +0200 Subject: [PATCH 23/29] [Inference API] Make method names for building expectations more explicit (#109066) --- .../HuggingFaceActionCreatorTests.java | 2 +- .../HuggingFaceElserResponseEntityTests.java | 24 ++++++++++++------- .../results/SparseEmbeddingResultsTests.java | 6 ++--- .../TextEmbeddingByteResultsTests.java | 2 +- .../huggingface/HuggingFaceServiceTests.java | 2 +- 5 files changed, 22 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java index 099ac166dda72..fceea8810f6c2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java @@ -99,7 +99,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce assertThat( result.asMap(), is( - SparseEmbeddingResultsTests.buildExpectation( + SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings( List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of(".", 0.13315596f), false)) ) ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java index bdb8e38fa8228..c3c416d8fe65e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java @@ -22,7 +22,7 @@ import java.util.List; import java.util.Map; -import static org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests.buildExpectation; +import static org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -46,7 +46,7 @@ public void testFromResponse_CreatesTextExpansionResults() throws IOException { assertThat( parsedResults.asMap(), is( - buildExpectation( + buildExpectationSparseEmbeddings( List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of(".", 0.13315596f, "the", 0.67472112f), false)) ) ) @@ -73,7 +73,7 @@ public void testFromResponse_CreatesTextExpansionResults_ThatAreTruncated() thro assertThat( parsedResults.asMap(), is( - buildExpectation( + buildExpectationSparseEmbeddings( List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of(".", 0.13315596f, "the", 0.67472112f), true)) ) ) @@ -101,7 +101,7 @@ public void testFromResponse_CreatesTextExpansionResultsForMultipleItems_Truncat assertThat( parsedResults.asMap(), is( - buildExpectation( + buildExpectationSparseEmbeddings( List.of( new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of(".", 0.13315596f, "the", 0.67472112f), false), new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("hi", 0.13315596f, "super", 0.67472112f), false) @@ -135,7 +135,7 @@ public void testFromResponse_CreatesTextExpansionResults_WithTruncation() throws assertThat( parsedResults.asMap(), is( - buildExpectation( + buildExpectationSparseEmbeddings( List.of( new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of(".", 0.13315596f, "the", 0.67472112f), true), new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("hi", 0.13315596f, "super", 0.67472112f), false) @@ -169,7 +169,7 @@ public void testFromResponse_CreatesTextExpansionResults_WithTruncationLessArray assertThat( parsedResults.asMap(), is( - buildExpectation( + buildExpectationSparseEmbeddings( List.of( new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of(".", 0.13315596f, "the", 0.67472112f), false), new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("hi", 0.13315596f, "super", 0.67472112f), false) @@ -239,7 +239,11 @@ public void testFromResponse_CreatesResultsWithValueInt() throws IOException { assertThat( parsedResults.asMap(), - is(buildExpectation(List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("field", 1.0f), false)))) + is( + buildExpectationSparseEmbeddings( + List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("field", 1.0f), false)) + ) + ) ); } @@ -259,7 +263,11 @@ public void testFromResponse_CreatesResultsWithValueLong() throws IOException { assertThat( parsedResults.asMap(), - is(buildExpectation(List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("field", 4.0294965E10F), false)))) + is( + buildExpectationSparseEmbeddings( + List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("field", 4.0294965E10F), false)) + ) + ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java index 727df98d27bbb..acc0ef6eed269 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java @@ -87,7 +87,7 @@ protected SparseEmbeddingResults mutateInstance(SparseEmbeddingResults instance) public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException { var entity = createSparseResult(List.of(createEmbedding(List.of(new SparseEmbedding.WeightedToken("token", 0.1F)), false))); - assertThat(entity.asMap(), is(buildExpectation(List.of(new EmbeddingExpectation(Map.of("token", 0.1F), false))))); + assertThat(entity.asMap(), is(buildExpectationSparseEmbeddings(List.of(new EmbeddingExpectation(Map.of("token", 0.1F), false))))); String xContentResult = Strings.toString(entity, true, true); assertThat(xContentResult, is(""" { @@ -118,7 +118,7 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I assertThat( entity.asMap(), is( - buildExpectation( + buildExpectationSparseEmbeddings( List.of( new EmbeddingExpectation(Map.of("token", 0.1F, "token2", 0.2F), false), new EmbeddingExpectation(Map.of("token3", 0.3F, "token4", 0.4F), false) @@ -170,7 +170,7 @@ public void testTransformToCoordinationFormat() { public record EmbeddingExpectation(Map tokens, boolean isTruncated) {} - public static Map buildExpectation(List embeddings) { + public static Map buildExpectationSparseEmbeddings(List embeddings) { return Map.of( SparseEmbeddingResults.SPARSE_EMBEDDING, embeddings.stream() diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java index 48784b9bd8652..a07f75ec2c536 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java @@ -131,7 +131,7 @@ protected TextEmbeddingByteResults mutateInstance(TextEmbeddingByteResults insta } } - public static Map buildExpectation(List> embeddings) { + public static Map buildExpectationByte(List> embeddings) { return Map.of( TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, embeddings.stream().map(embedding -> Map.of(ByteEmbedding.EMBEDDING, embedding)).toList() diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 54eef58fb2f7b..914775bf9fa61 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -494,7 +494,7 @@ public void testInfer_SendsElserRequest() throws IOException { assertThat( result.asMap(), Matchers.is( - SparseEmbeddingResultsTests.buildExpectation( + SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings( List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of(".", 0.13315596f), false)) ) ) From bcb82796af1bbc86ea9ab3f0e68d4e5ec501bd00 Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 28 May 2024 09:11:18 +0100 Subject: [PATCH 24/29] Fix trappy timeouts in ILM rest actions (#109056) Relates #107984 --- .../ilm/action/RestMoveToStepAction.java | 19 +++++++--- .../xpack/ilm/action/RestRetryAction.java | 6 +-- .../ilm/action/TransportMoveToStepAction.java | 38 +++++++++++-------- .../ilm/action/TransportRetryAction.java | 12 ++---- .../ilm/action/MoveToStepRequestTests.java | 21 ++++++++-- .../xpack/ilm/action/RetryRequestTests.java | 13 ++++--- 6 files changed, 68 insertions(+), 41 deletions(-) diff --git a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/RestMoveToStepAction.java b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/RestMoveToStepAction.java index 9256a61addd8f..64ce857a0198b 100644 --- a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/RestMoveToStepAction.java +++ b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/RestMoveToStepAction.java @@ -36,13 +36,22 @@ public String getName() { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - String index = restRequest.param("name"); - TransportMoveToStepAction.Request request; + final var masterNodeTimeout = getMasterNodeTimeout(restRequest); + final var ackTimeout = getAckTimeout(restRequest); + final var index = restRequest.param("name"); + final TransportMoveToStepAction.Request request; try (XContentParser parser = restRequest.contentParser()) { - request = TransportMoveToStepAction.Request.parseRequest(index, parser); + request = TransportMoveToStepAction.Request.parseRequest( + (currentStepKey, nextStepKey) -> new TransportMoveToStepAction.Request( + masterNodeTimeout, + ackTimeout, + index, + currentStepKey, + nextStepKey + ), + parser + ); } - request.ackTimeout(getAckTimeout(restRequest)); - request.masterNodeTimeout(getMasterNodeTimeout(restRequest)); return channel -> client.execute(ILMActions.MOVE_TO_STEP, request, new RestToXContentListener<>(channel)); } } diff --git a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/RestRetryAction.java b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/RestRetryAction.java index 10a3fa38df672..1000bd1e68249 100644 --- a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/RestRetryAction.java +++ b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/RestRetryAction.java @@ -36,10 +36,8 @@ public String getName() { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - String[] indices = Strings.splitStringByCommaToArray(restRequest.param("index")); - TransportRetryAction.Request request = new TransportRetryAction.Request(indices); - request.ackTimeout(getAckTimeout(restRequest)); - request.masterNodeTimeout(getMasterNodeTimeout(restRequest)); + final var indices = Strings.splitStringByCommaToArray(restRequest.param("index")); + final var request = new TransportRetryAction.Request(getMasterNodeTimeout(restRequest), getAckTimeout(restRequest), indices); request.indices(indices); request.indicesOptions(IndicesOptions.fromRequest(restRequest, IndicesOptions.strictExpandOpen())); return channel -> client.execute(ILMActions.RETRY, request, new RestToXContentListener<>(channel)); diff --git a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/TransportMoveToStepAction.java b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/TransportMoveToStepAction.java index 87c93a9198215..ec905c0e9eb48 100644 --- a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/TransportMoveToStepAction.java +++ b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/TransportMoveToStepAction.java @@ -32,6 +32,7 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -188,15 +189,20 @@ protected ClusterBlockException checkBlock(Request request, ClusterState state) } public static class Request extends AcknowledgedRequest implements ToXContentObject { + + public interface Factory { + Request create(Step.StepKey currentStepKey, PartialStepKey nextStepKey); + } + static final ParseField CURRENT_KEY_FIELD = new ParseField("current_step"); static final ParseField NEXT_KEY_FIELD = new ParseField("next_step"); - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "move_to_step_request", false, - (a, index) -> { + (a, factory) -> { Step.StepKey currentStepKey = (Step.StepKey) a[0]; PartialStepKey nextStepKey = (PartialStepKey) a[1]; - return new Request(index, currentStepKey, nextStepKey); + return factory.create(currentStepKey, nextStepKey); } ); @@ -207,12 +213,18 @@ public static class Request extends AcknowledgedRequest implements ToXC PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, name) -> PartialStepKey.parse(p), NEXT_KEY_FIELD); } - private String index; - private Step.StepKey currentStepKey; - private PartialStepKey nextStepKey; - - public Request(String index, Step.StepKey currentStepKey, PartialStepKey nextStepKey) { - super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT); + private final String index; + private final Step.StepKey currentStepKey; + private final PartialStepKey nextStepKey; + + public Request( + TimeValue masterNodeTimeout, + TimeValue ackTimeout, + String index, + Step.StepKey currentStepKey, + PartialStepKey nextStepKey + ) { + super(masterNodeTimeout, ackTimeout); this.index = index; this.currentStepKey = currentStepKey; this.nextStepKey = nextStepKey; @@ -225,10 +237,6 @@ public Request(StreamInput in) throws IOException { this.nextStepKey = new PartialStepKey(in); } - public Request() { - super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT); - } - public String getIndex() { return index; } @@ -246,8 +254,8 @@ public ActionRequestValidationException validate() { return null; } - public static Request parseRequest(String name, XContentParser parser) { - return PARSER.apply(parser, name); + public static Request parseRequest(Factory factory, XContentParser parser) { + return PARSER.apply(parser, factory); } @Override diff --git a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/TransportRetryAction.java b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/TransportRetryAction.java index 95358adb832c7..ee96fa73838df 100644 --- a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/TransportRetryAction.java +++ b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/TransportRetryAction.java @@ -26,12 +26,12 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.LifecycleExecutionState; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -114,11 +114,11 @@ protected ClusterBlockException checkBlock(Request request, ClusterState state) } public static class Request extends AcknowledgedRequest implements IndicesRequest.Replaceable { - private String[] indices = Strings.EMPTY_ARRAY; + private String[] indices; private IndicesOptions indicesOptions = IndicesOptions.strictExpandOpen(); - public Request(String... indices) { - super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT); + public Request(TimeValue masterNodeTimeout, TimeValue ackTimeout, String... indices) { + super(masterNodeTimeout, ackTimeout); this.indices = indices; } @@ -128,10 +128,6 @@ public Request(StreamInput in) throws IOException { this.indicesOptions = IndicesOptions.readIndicesOptions(in); } - public Request() { - super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT); - } - @Override public Request indices(String... indices) { this.indices = indices; diff --git a/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/MoveToStepRequestTests.java b/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/MoveToStepRequestTests.java index 441e61708e3cc..16d6f5fdd8579 100644 --- a/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/MoveToStepRequestTests.java +++ b/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/MoveToStepRequestTests.java @@ -26,7 +26,13 @@ public void setup() { @Override protected TransportMoveToStepAction.Request createTestInstance() { - return new TransportMoveToStepAction.Request(index, stepKeyTests.createTestInstance(), randomStepSpecification()); + return new TransportMoveToStepAction.Request( + TEST_REQUEST_TIMEOUT, + TEST_REQUEST_TIMEOUT, + index, + stepKeyTests.createTestInstance(), + randomStepSpecification() + ); } @Override @@ -36,7 +42,16 @@ protected Writeable.Reader instanceReader() { @Override protected TransportMoveToStepAction.Request doParseInstance(XContentParser parser) { - return TransportMoveToStepAction.Request.parseRequest(index, parser); + return TransportMoveToStepAction.Request.parseRequest( + (currentStepKey, nextStepKey) -> new TransportMoveToStepAction.Request( + TEST_REQUEST_TIMEOUT, + TEST_REQUEST_TIMEOUT, + index, + currentStepKey, + nextStepKey + ), + parser + ); } @Override @@ -52,7 +67,7 @@ protected TransportMoveToStepAction.Request mutateInstance(TransportMoveToStepAc default -> throw new AssertionError("Illegal randomisation branch"); } - return new TransportMoveToStepAction.Request(indexName, currentStepKey, nextStepKey); + return new TransportMoveToStepAction.Request(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, indexName, currentStepKey, nextStepKey); } private static TransportMoveToStepAction.Request.PartialStepKey randomStepSpecification() { diff --git a/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/RetryRequestTests.java b/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/RetryRequestTests.java index e4f3c58fe6e66..4f053ddc2caa4 100644 --- a/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/RetryRequestTests.java +++ b/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/RetryRequestTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ilm.action; import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; @@ -17,10 +18,11 @@ public class RetryRequestTests extends AbstractWireSerializingTestCase throw new AssertionError("Illegal randomisation branch"); } - TransportRetryAction.Request newRequest = new TransportRetryAction.Request(); - newRequest.indices(indices); + final var newRequest = new TransportRetryAction.Request(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, indices); newRequest.indicesOptions(indicesOptions); return newRequest; } From a4bd25687cb93a8ed1d8d1b276b7c4731eee4ca7 Mon Sep 17 00:00:00 2001 From: Henning Andersen <33268011+henningandersen@users.noreply.github.com> Date: Tue, 28 May 2024 10:12:51 +0200 Subject: [PATCH 25/29] Assert not same executor when completing future (#108934) A common deadlock pattern is waiting and completing a future on the same executor. This only works until the executor is fully depleted of threads. Now assert that waiting for a future to be completed and the completion happens on different executors. Introduced UnsafePlainActionFuture, used in all offending places, allowing those to be tackled independently. --- .../discovery/ClusterDisruptionIT.java | 4 +- .../CorruptedBlobStoreRepositoryIT.java | 3 +- .../action/support/PlainActionFuture.java | 28 ++++++++++ .../support/UnsafePlainActionFuture.java | 52 +++++++++++++++++++ .../internal/support/AbstractClient.java | 9 +++- .../common/util/concurrent/EsExecutors.java | 4 ++ .../index/engine/CompletionStatsCache.java | 4 +- .../elasticsearch/index/engine/Engine.java | 4 +- .../index/shard/IndexShardTestCase.java | 3 +- .../AbstractSimpleTransportTestCase.java | 7 +-- .../shared/SharedBlobCacheService.java | 6 ++- .../xpack/ccr/repository/CcrRepository.java | 7 ++- .../ShardFollowTaskReplicationTests.java | 3 +- .../dataframe/inference/InferenceRunner.java | 6 ++- .../TrainedModelAssignmentNodeService.java | 5 +- .../cache/full/CacheService.java | 3 +- .../xpack/security/Security.java | 3 +- 17 files changed, 135 insertions(+), 16 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/action/support/UnsafePlainActionFuture.java diff --git a/server/src/internalClusterTest/java/org/elasticsearch/discovery/ClusterDisruptionIT.java b/server/src/internalClusterTest/java/org/elasticsearch/discovery/ClusterDisruptionIT.java index 7f94809e64fa6..cd9adea500dbd 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/discovery/ClusterDisruptionIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/discovery/ClusterDisruptionIT.java @@ -16,6 +16,7 @@ import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.action.shard.ShardStateAction; @@ -26,6 +27,7 @@ import org.elasticsearch.cluster.routing.Murmur3HashFunction; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.routing.ShardRoutingState; +import org.elasticsearch.cluster.service.ClusterApplierService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; @@ -542,7 +544,7 @@ public void testRejoinWhileBeingRemoved() { }); final ClusterService dataClusterService = internalCluster().getInstance(ClusterService.class, dataNode); - final PlainActionFuture failedLeader = new PlainActionFuture<>() { + final PlainActionFuture failedLeader = new UnsafePlainActionFuture<>(ClusterApplierService.CLUSTER_UPDATE_THREAD_NAME) { @Override protected boolean blockingAllowed() { // we're deliberately blocking the cluster applier on the master until the data node starts to rejoin diff --git a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/CorruptedBlobStoreRepositoryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/CorruptedBlobStoreRepositoryIT.java index f507e27c6073e..9eb9041aa51f1 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/CorruptedBlobStoreRepositoryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/CorruptedBlobStoreRepositoryIT.java @@ -299,7 +299,8 @@ public void testHandlingMissingRootLevelSnapshotMetadata() throws Exception { final ThreadPool threadPool = internalCluster().getCurrentMasterNodeInstance(ThreadPool.class); assertThat( PlainActionFuture.get( - f -> threadPool.generic() + // any other executor than generic and management + f -> threadPool.executor(ThreadPool.Names.SNAPSHOT) .execute( ActionRunnable.supply( f, diff --git a/server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java b/server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java index e2b8fcbf2825c..938fe4c84480b 100644 --- a/server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java +++ b/server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java @@ -9,10 +9,12 @@ package org.elasticsearch.action.support; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterApplierService; import org.elasticsearch.cluster.service.MasterService; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.FutureUtils; import org.elasticsearch.common.util.concurrent.UncategorizedExecutionException; import org.elasticsearch.core.CheckedConsumer; @@ -37,6 +39,7 @@ public void onResponse(@Nullable T result) { @Override public void onFailure(Exception e) { + assert assertCompleteAllowed(); if (sync.setException(Objects.requireNonNull(e))) { done(false); } @@ -113,6 +116,7 @@ public boolean isCancelled() { @Override public boolean cancel(boolean mayInterruptIfRunning) { + assert assertCompleteAllowed(); if (sync.cancel() == false) { return false; } @@ -130,6 +134,7 @@ public boolean cancel(boolean mayInterruptIfRunning) { * @return true if the state was successfully changed. */ protected final boolean set(@Nullable T value) { + assert assertCompleteAllowed(); boolean result = sync.set(value); if (result) { done(true); @@ -399,4 +404,27 @@ public static T get(CheckedConsumer extends PlainActionFuture { + + private final String unsafeExecutor; + private final String unsafeExecutor2; + + public UnsafePlainActionFuture(String unsafeExecutor) { + this(unsafeExecutor, null); + } + + public UnsafePlainActionFuture(String unsafeExecutor, String unsafeExecutor2) { + Objects.requireNonNull(unsafeExecutor); + this.unsafeExecutor = unsafeExecutor; + this.unsafeExecutor2 = unsafeExecutor2; + } + + @Override + boolean allowedExecutors(Thread thread1, Thread thread2) { + return super.allowedExecutors(thread1, thread2) + || unsafeExecutor.equals(EsExecutors.executorName(thread1)) + || unsafeExecutor2 == null + || unsafeExecutor2.equals(EsExecutors.executorName(thread1)); + } + + public static T get(CheckedConsumer, E> e, String allowedExecutor) throws E { + PlainActionFuture fut = new UnsafePlainActionFuture<>(allowedExecutor); + e.accept(fut); + return fut.actionGet(); + } +} diff --git a/server/src/main/java/org/elasticsearch/client/internal/support/AbstractClient.java b/server/src/main/java/org/elasticsearch/client/internal/support/AbstractClient.java index 966299408a678..f4e86c8a4eca6 100644 --- a/server/src/main/java/org/elasticsearch/client/internal/support/AbstractClient.java +++ b/server/src/main/java/org/elasticsearch/client/internal/support/AbstractClient.java @@ -59,6 +59,7 @@ import org.elasticsearch.action.search.TransportSearchAction; import org.elasticsearch.action.search.TransportSearchScrollAction; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.action.termvectors.MultiTermVectorsAction; import org.elasticsearch.action.termvectors.MultiTermVectorsRequest; import org.elasticsearch.action.termvectors.MultiTermVectorsRequestBuilder; @@ -410,7 +411,13 @@ protected void * on the result before it goes out of scope. * @param reference counted result type */ - private static class RefCountedFuture extends PlainActionFuture { + // todo: the use of UnsafePlainActionFuture here is quite broad, we should find a better way to be more specific + // (unless making all usages safe is easy). + private static class RefCountedFuture extends UnsafePlainActionFuture { + + private RefCountedFuture() { + super(ThreadPool.Names.GENERIC); + } @Override public final void onResponse(R result) { diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java index 015d3899ab90d..9bf381e6f4719 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java @@ -276,6 +276,10 @@ public static String executorName(String threadName) { return threadName.substring(executorNameStart + 1, executorNameEnd); } + public static String executorName(Thread thread) { + return executorName(thread.getName()); + } + public static ThreadFactory daemonThreadFactory(Settings settings, String namePrefix) { return daemonThreadFactory(threadName(settings, namePrefix)); } diff --git a/server/src/main/java/org/elasticsearch/index/engine/CompletionStatsCache.java b/server/src/main/java/org/elasticsearch/index/engine/CompletionStatsCache.java index f66b856471894..91eea9f6b1b12 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/CompletionStatsCache.java +++ b/server/src/main/java/org/elasticsearch/index/engine/CompletionStatsCache.java @@ -15,10 +15,12 @@ import org.apache.lucene.search.suggest.document.CompletionTerms; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.common.FieldMemoryStats; import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.search.suggest.completion.CompletionStats; +import org.elasticsearch.threadpool.ThreadPool; import java.util.HashMap; import java.util.Map; @@ -42,7 +44,7 @@ public CompletionStatsCache(Supplier searcherSupplier) { } public CompletionStats get(String... fieldNamePatterns) { - final PlainActionFuture newFuture = new PlainActionFuture<>(); + final PlainActionFuture newFuture = new UnsafePlainActionFuture<>(ThreadPool.Names.MANAGEMENT); final PlainActionFuture oldFuture = completionStatsFutureRef.compareAndExchange(null, newFuture); if (oldFuture != null) { diff --git a/server/src/main/java/org/elasticsearch/index/engine/Engine.java b/server/src/main/java/org/elasticsearch/index/engine/Engine.java index 65f47dd3994af..c219e16659c99 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/Engine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/Engine.java @@ -36,6 +36,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.cluster.service.ClusterApplierService; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.logging.Loggers; @@ -75,6 +76,7 @@ import org.elasticsearch.index.translog.Translog; import org.elasticsearch.index.translog.TranslogStats; import org.elasticsearch.search.suggest.completion.CompletionStats; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.Transports; import java.io.Closeable; @@ -1956,7 +1958,7 @@ private boolean drainForClose() { logger.debug("drainForClose(): draining ops"); releaseEnsureOpenRef.close(); - final var future = new PlainActionFuture() { + final var future = new UnsafePlainActionFuture(ThreadPool.Names.GENERIC) { @Override protected boolean blockingAllowed() { // TODO remove this blocking, or at least do it elsewhere, see https://github.com/elastic/elasticsearch/issues/89821 diff --git a/test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java index a389020cdcde8..442a8c3b82dc6 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.admin.indices.flush.FlushRequest; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.action.support.replication.TransportReplicationAction; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.MappingMetadata; @@ -869,7 +870,7 @@ protected final void recoverUnstartedReplica( routingTable ); try { - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new UnsafePlainActionFuture<>(ThreadPool.Names.GENERIC); recovery.recoverToTarget(future); future.actionGet(); recoveryTarget.markAsDone(); diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 3dc7201535e0a..d966a21a56b5f 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -20,6 +20,7 @@ import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.support.ChannelActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.cluster.node.VersionInformation; @@ -960,7 +961,7 @@ public void onFailure(Exception e) { protected void doRun() throws Exception { safeAwait(go); for (int iter = 0; iter < 10; iter++) { - PlainActionFuture listener = new PlainActionFuture<>(); + PlainActionFuture listener = new UnsafePlainActionFuture<>(ThreadPool.Names.GENERIC); final String info = sender + "_B_" + iter; serviceB.sendRequest( nodeA, @@ -996,7 +997,7 @@ public void onFailure(Exception e) { protected void doRun() throws Exception { go.await(); for (int iter = 0; iter < 10; iter++) { - PlainActionFuture listener = new PlainActionFuture<>(); + PlainActionFuture listener = new UnsafePlainActionFuture<>(ThreadPool.Names.GENERIC); final String info = sender + "_" + iter; final DiscoveryNode node = nodeB; // capture now try { @@ -3464,7 +3465,7 @@ public static void connectToNode(TransportService service, DiscoveryNode node) t * @param connectionProfile the connection profile to use when connecting to this node */ public static void connectToNode(TransportService service, DiscoveryNode node, ConnectionProfile connectionProfile) { - PlainActionFuture.get(fut -> service.connectToNode(node, connectionProfile, fut.map(x -> null))); + UnsafePlainActionFuture.get(fut -> service.connectToNode(node, connectionProfile, fut.map(x -> null)), ThreadPool.Names.GENERIC); } /** diff --git a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java index 481f39d673410..c5ef1d7c2bf1d 100644 --- a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java +++ b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.RefCountingListener; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.blobcache.BlobCacheMetrics; import org.elasticsearch.blobcache.BlobCacheUtils; import org.elasticsearch.blobcache.common.ByteRange; @@ -36,6 +37,7 @@ import org.elasticsearch.env.NodeEnvironment; import org.elasticsearch.monitor.fs.FsProbe; import org.elasticsearch.node.NodeRoleSettings; +import org.elasticsearch.repositories.blobstore.BlobStoreRepository; import org.elasticsearch.threadpool.ThreadPool; import java.io.IOException; @@ -1136,7 +1138,9 @@ private int readMultiRegions( int startRegion, int endRegion ) throws InterruptedException, ExecutionException { - final PlainActionFuture readsComplete = new PlainActionFuture<>(); + final PlainActionFuture readsComplete = new UnsafePlainActionFuture<>( + BlobStoreRepository.STATELESS_SHARD_PREWARMING_THREAD_NAME + ); final AtomicInteger bytesRead = new AtomicInteger(); try (var listeners = new RefCountingListener(1, readsComplete)) { for (int region = startRegion; region <= endRegion; region++) { diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java index baf1509c73883..67c4c769d21d1 100644 --- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java +++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java @@ -26,6 +26,7 @@ import org.elasticsearch.action.support.ListenerTimeouts; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.ThreadedActionListener; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.RemoteClusterClient; import org.elasticsearch.cluster.ClusterName; @@ -599,7 +600,11 @@ private void updateMappings( Client followerClient, Index followerIndex ) { - final PlainActionFuture indexMetadataFuture = new PlainActionFuture<>(); + // todo: this could manifest in production and seems we could make this async easily. + final PlainActionFuture indexMetadataFuture = new UnsafePlainActionFuture<>( + Ccr.CCR_THREAD_POOL_NAME, + ThreadPool.Names.GENERIC + ); final long startTimeInNanos = System.nanoTime(); final Supplier timeout = () -> { final long elapsedInNanos = System.nanoTime() - startTimeInNanos; diff --git a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTaskReplicationTests.java b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTaskReplicationTests.java index 3a16f368d322a..04a97ad9e7f95 100644 --- a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTaskReplicationTests.java +++ b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTaskReplicationTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.delete.DeleteRequest; import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.action.support.replication.PostWriteRefresh; import org.elasticsearch.action.support.replication.ReplicationResponse; import org.elasticsearch.action.support.replication.TransportWriteAction; @@ -802,7 +803,7 @@ class CcrAction extends ReplicationAction listener) { - final PlainActionFuture permitFuture = new PlainActionFuture<>(); + final PlainActionFuture permitFuture = new UnsafePlainActionFuture<>(ThreadPool.Names.GENERIC); primary.acquirePrimaryOperationPermit(permitFuture, EsExecutors.DIRECT_EXECUTOR_SERVICE); final TransportWriteAction.WritePrimaryResult ccrResult; final var threadpool = mock(ThreadPool.class); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java index 637b37853363f..06075363997c7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java @@ -17,6 +17,7 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.common.settings.Settings; @@ -31,6 +32,7 @@ import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.dataframe.DestinationIndex; import org.elasticsearch.xpack.ml.dataframe.stats.DataCountsTracker; import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; @@ -100,7 +102,9 @@ public void run(String modelId) { LOGGER.info("[{}] Started inference on test data against model [{}]", config.getId(), modelId); try { - PlainActionFuture localModelPlainActionFuture = new PlainActionFuture<>(); + PlainActionFuture localModelPlainActionFuture = new UnsafePlainActionFuture<>( + MachineLearning.UTILITY_THREAD_POOL_NAME + ); modelLoadingService.getModelForInternalInference(modelId, localModelPlainActionFuture); InferenceState inferenceState = restoreInferenceState(); dataCountsTracker.setTestDocsCount(inferenceState.processedTestDocsCount); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java index e181e1fc86684..7052e6f147b36 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.SearchPhaseExecutionException; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterState; @@ -205,7 +206,9 @@ void loadQueuedModels() { if (stopped) { return; } - final PlainActionFuture listener = new PlainActionFuture<>(); + final PlainActionFuture listener = new UnsafePlainActionFuture<>( + MachineLearning.UTILITY_THREAD_POOL_NAME + ); try { deploymentManager.startDeployment(loadingTask, listener); // This needs to be synchronous here in the utility thread to keep queueing order diff --git a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/cache/full/CacheService.java b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/cache/full/CacheService.java index 6e480a21d507a..636d138c8a3e2 100644 --- a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/cache/full/CacheService.java +++ b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/cache/full/CacheService.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.store.AlreadyClosedException; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.blobcache.common.ByteRange; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.UUIDs; @@ -347,7 +348,7 @@ public void markShardAsEvictedInCache(String snapshotUUID, String snapshotIndexN if (allowShardsEvictions) { final ShardEviction shardEviction = new ShardEviction(snapshotUUID, snapshotIndexName, shardId); pendingShardsEvictions.computeIfAbsent(shardEviction, shard -> { - final PlainActionFuture future = new PlainActionFuture<>(); + final PlainActionFuture future = new UnsafePlainActionFuture<>(ThreadPool.Names.GENERIC); threadPool.generic().execute(new AbstractRunnable() { @Override protected void doRun() { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java index ecfce2f858428..84fa92bb7d2d4 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java @@ -25,6 +25,7 @@ import org.elasticsearch.action.support.ActionFilter; import org.elasticsearch.action.support.DestructiveOperations; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.bootstrap.BootstrapCheck; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; @@ -2130,7 +2131,7 @@ private void reloadRemoteClusterCredentials(Settings settingsWithKeystore) { return; } - final PlainActionFuture future = new PlainActionFuture<>(); + final PlainActionFuture future = new UnsafePlainActionFuture<>(ThreadPool.Names.GENERIC); getClient().execute( ActionTypes.RELOAD_REMOTE_CLUSTER_CREDENTIALS_ACTION, new TransportReloadRemoteClusterCredentialsAction.Request(settingsWithKeystore), From eff8d0ff9532bc2837797835234d5cd916d8f104 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Tue, 28 May 2024 18:29:58 +1000 Subject: [PATCH 26/29] [Test] Re-enable debug logging for AWS request (#109068) With logging restriction (#105020), the networkTrace flag needs to be set for AWS request debug logging. Relates: #101608 --- modules/repository-s3/build.gradle | 2 ++ .../repositories/s3/S3BlobStoreRepositoryTests.java | 2 ++ 2 files changed, 4 insertions(+) diff --git a/modules/repository-s3/build.gradle b/modules/repository-s3/build.gradle index 8b1f30a1bba61..1732fd39794b9 100644 --- a/modules/repository-s3/build.gradle +++ b/modules/repository-s3/build.gradle @@ -164,6 +164,8 @@ tasks.named("processYamlRestTestResources").configure { tasks.named("internalClusterTest").configure { // this is tested explicitly in a separate test task exclude '**/S3RepositoryThirdPartyTests.class' + // TODO: remove once https://github.com/elastic/elasticsearch/issues/101608 is fixed + systemProperty 'es.insecure_network_trace_enabled', 'true' } tasks.named("yamlRestTest").configure { diff --git a/modules/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java b/modules/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java index ebc60e8027d81..030f791feee16 100644 --- a/modules/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java +++ b/modules/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java @@ -627,6 +627,8 @@ public void maybeTrack(final String rawRequest, Headers requestHeaders) { trackRequest("HeadObject"); metricsCount.computeIfAbsent(new S3BlobStore.StatsKey(S3BlobStore.Operation.HEAD_OBJECT, purpose), k -> new AtomicLong()) .incrementAndGet(); + } else { + logger.info("--> rawRequest not tracked [{}] with parsed purpose [{}]", request, purpose.getKey()); } } From 8525751dca198cab4dade9714e1a47d214543f85 Mon Sep 17 00:00:00 2001 From: Tim Grein Date: Tue, 28 May 2024 10:42:25 +0200 Subject: [PATCH 27/29] [Inference API] Move completion expectation to ChatCompletionResultsTests and remove code duplication (#109065) --- .../AzureAiStudioActionAndCreatorTests.java | 7 ++----- .../action/cohere/CohereActionCreatorTests.java | 6 +++--- .../action/cohere/CohereCompletionActionTests.java | 13 +++---------- .../action/openai/OpenAiActionCreatorTests.java | 8 ++++---- .../openai/OpenAiChatCompletionActionTests.java | 11 ++--------- .../results/ChatCompletionResultsTests.java | 7 +++++++ 6 files changed, 21 insertions(+), 31 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java index c4878c495a94e..88d408d309a7b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java @@ -20,7 +20,6 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.common.TruncatorTests; -import org.elasticsearch.xpack.inference.external.action.openai.OpenAiChatCompletionActionTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -46,6 +45,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -163,10 +163,7 @@ public void testChatCompletionRequestAction() throws IOException { var result = listener.actionGet(TIMEOUT); - assertThat( - result.asMap(), - is(OpenAiChatCompletionActionTests.buildExpectedChatCompletionResultMap(List.of("test input string"))) - ); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test input string")))); assertThat(webServer.requests(), hasSize(1)); MockRequest request = webServer.requests().get(0); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java index 6ca6985c9e8f7..9b0371ad51f8c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java @@ -40,9 +40,9 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.external.action.cohere.CohereCompletionActionTests.buildExpectedChatCompletionResultMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.equalTo; @@ -200,7 +200,7 @@ public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOExcep var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result")))); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result")))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType())); @@ -260,7 +260,7 @@ public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOEx var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result")))); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result")))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java index 9ac4674ef0b1e..12c3d132d1244 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java @@ -23,7 +23,6 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -44,6 +43,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -120,7 +120,7 @@ public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IO var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result")))); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result")))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); @@ -181,7 +181,7 @@ public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result")))); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result")))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); @@ -339,13 +339,6 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc } } - public static Map buildExpectedChatCompletionResultMap(List results) { - return Map.of( - ChatCompletionResults.COMPLETION, - results.stream().map(result -> Map.of(ChatCompletionResults.Result.RESULT, result)).toList() - ); - } - private CohereCompletionAction createAction(String url, String apiKey, @Nullable String modelName, Sender sender) { var model = CohereCompletionModelTests.createModel(url, apiKey, modelName); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java index b62e8fc9865e4..496238eaad0e4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java @@ -35,11 +35,11 @@ import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.external.action.openai.OpenAiChatCompletionActionTests.buildExpectedChatCompletionResultMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; @@ -333,7 +333,7 @@ public void testCreate_OpenAiChatCompletionModel() throws IOException { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("Hello there, how may I assist you today?")))); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); assertThat(webServer.requests(), hasSize(1)); var request = webServer.requests().get(0); @@ -396,7 +396,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOExceptio var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("Hello there, how may I assist you today?")))); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); assertThat(webServer.requests(), hasSize(1)); var request = webServer.requests().get(0); @@ -458,7 +458,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IO var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("Hello there, how may I assist you today?")))); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); assertThat(webServer.requests(), hasSize(1)); var request = webServer.requests().get(0); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java index 35d1ee8fc5a5a..914ff12db259a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java @@ -24,7 +24,6 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -45,6 +44,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; import static org.hamcrest.Matchers.containsString; @@ -118,7 +118,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result content")))); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result content")))); assertThat(webServer.requests(), hasSize(1)); MockRequest request = webServer.requests().get(0); @@ -277,13 +277,6 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc } } - public static Map buildExpectedChatCompletionResultMap(List results) { - return Map.of( - ChatCompletionResults.COMPLETION, - results.stream().map(result -> Map.of(ChatCompletionResults.Result.RESULT, result)).toList() - ); - } - private OpenAiChatCompletionAction createAction( String url, String org, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChatCompletionResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChatCompletionResultsTests.java index 6bbe6eea5394f..1b9b2db660bf3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChatCompletionResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChatCompletionResultsTests.java @@ -125,6 +125,13 @@ public static ChatCompletionResults createRandomResults() { return new ChatCompletionResults(chatCompletionResults); } + public static Map buildExpectationCompletion(List results) { + return Map.of( + ChatCompletionResults.COMPLETION, + results.stream().map(result -> Map.of(ChatCompletionResults.Result.RESULT, result)).toList() + ); + } + private static ChatCompletionResults.Result createRandomChatCompletionResult() { return new ChatCompletionResults.Result(randomAlphaOfLengthBetween(10, 300)); } From 42e5293c049539c5e2416961998a41a740ba9ae3 Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 28 May 2024 09:54:04 +0100 Subject: [PATCH 28/29] Capture GC logs alongside heap dumps (#109087) GC logs can be important to understand a heap dump, especially if there's lots of unreachable objects and the GC is struggling to keep up. --- .../common-issues/high-jvm-memory-pressure.asciidoc | 3 ++- docs/reference/troubleshooting/network-timeouts.asciidoc | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/reference/troubleshooting/common-issues/high-jvm-memory-pressure.asciidoc b/docs/reference/troubleshooting/common-issues/high-jvm-memory-pressure.asciidoc index e88927f159f21..267d6594b8025 100644 --- a/docs/reference/troubleshooting/common-issues/high-jvm-memory-pressure.asciidoc +++ b/docs/reference/troubleshooting/common-issues/high-jvm-memory-pressure.asciidoc @@ -30,7 +30,8 @@ collection. **Capture a JVM heap dump** To determine the exact reason for the high JVM memory pressure, capture a heap -dump of the JVM while its memory usage is high. +dump of the JVM while its memory usage is high, and also capture the +<> covering the same time period. [discrete] [[reduce-jvm-memory-pressure]] diff --git a/docs/reference/troubleshooting/network-timeouts.asciidoc b/docs/reference/troubleshooting/network-timeouts.asciidoc index 1920dafe62210..ef942ac1d268d 100644 --- a/docs/reference/troubleshooting/network-timeouts.asciidoc +++ b/docs/reference/troubleshooting/network-timeouts.asciidoc @@ -4,8 +4,8 @@ usually by the `JvmMonitorService` in the main node logs. Use these logs to confirm whether or not the node is experiencing high heap usage with long GC pauses. If so, <> has some suggestions for further investigation but typically you -will need to capture a heap dump during a time of high heap usage to fully -understand the problem. +will need to capture a heap dump and the <> +during a time of high heap usage to fully understand the problem. * VM pauses also affect other processes on the same host. A VM pause also typically causes a discontinuity in the system clock, which {es} will report in From f2c2fc0d994f4cc28496aeac4e96821dd3ee66f9 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Tue, 28 May 2024 10:07:57 +0100 Subject: [PATCH 29/29] Add error indexes to all datetime formatter tests (#108944) And sometimes test the java.time implementations too --- .../common/time/DateFormatters.java | 3 + .../common/time/DateFormattersTests.java | 202 +++++++++--------- .../org/elasticsearch/test/ESTestCase.java | 9 + 3 files changed, 118 insertions(+), 96 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/common/time/DateFormatters.java b/server/src/main/java/org/elasticsearch/common/time/DateFormatters.java index 1133eac3f8f7b..55c421b87196d 100644 --- a/server/src/main/java/org/elasticsearch/common/time/DateFormatters.java +++ b/server/src/main/java/org/elasticsearch/common/time/DateFormatters.java @@ -53,6 +53,9 @@ public class DateFormatters { * If a string cannot be parsed by the ISO parser, it then tries the java.time one. * If there's lots of these strings, trying the ISO parser, then the java.time parser, might cause a performance drop. * So provide a JVM option so that users can just use the java.time parsers, if they really need to. + *

+ * Note that this property is sometimes set by {@code ESTestCase.setTestSysProps} to flip between implementations in tests, + * to ensure both are fully tested */ @UpdateForV9 // evaluate if we need to deprecate/remove this private static final boolean JAVA_TIME_PARSERS_ONLY = Booleans.parseBoolean(System.getProperty("es.datetime.java_time_parsers"), false); diff --git a/server/src/test/java/org/elasticsearch/common/time/DateFormattersTests.java b/server/src/test/java/org/elasticsearch/common/time/DateFormattersTests.java index fa333ddf6b0c7..3b0935e8f7b5c 100644 --- a/server/src/test/java/org/elasticsearch/common/time/DateFormattersTests.java +++ b/server/src/test/java/org/elasticsearch/common/time/DateFormattersTests.java @@ -12,8 +12,10 @@ import org.elasticsearch.common.util.LocaleUtils; import org.elasticsearch.index.mapper.DateFieldMapper; import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matcher; import java.time.Clock; +import java.time.DateTimeException; import java.time.Instant; import java.time.LocalDateTime; import java.time.ZoneId; @@ -39,12 +41,25 @@ public class DateFormattersTests extends ESTestCase { - private IllegalArgumentException assertParseException(String input, String format) { + private void assertParseException(String input, String format) { DateFormatter javaTimeFormatter = DateFormatter.forPattern(format); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> javaTimeFormatter.parse(input)); assertThat(e.getMessage(), containsString(input)); assertThat(e.getMessage(), containsString(format)); - return e; + assertThat(e.getCause(), instanceOf(DateTimeException.class)); + } + + private void assertParseException(String input, String format, int errorIndex) { + assertParseException(input, format, equalTo(errorIndex)); + } + + private void assertParseException(String input, String format, Matcher indexMatcher) { + DateFormatter javaTimeFormatter = DateFormatter.forPattern(format); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> javaTimeFormatter.parse(input)); + assertThat(e.getMessage(), containsString(input)); + assertThat(e.getMessage(), containsString(format)); + assertThat(e.getCause(), instanceOf(DateTimeParseException.class)); + assertThat(((DateTimeParseException) e.getCause()).getErrorIndex(), indexMatcher); } private void assertParses(String input, String format) { @@ -698,7 +713,7 @@ public void testPartialTimeParsing() { ES java.time implementation does not suffer from this, but we intentionally not allow parsing timezone without a time part as it is not allowed in iso8601 */ - assertParseException("2016-11-30T+01", "strict_date_optional_time"); + assertParseException("2016-11-30T+01", "strict_date_optional_time", 11); assertParses("2016-11-30T12+01", "strict_date_optional_time"); assertParses("2016-11-30T12:00+01", "strict_date_optional_time"); @@ -792,8 +807,8 @@ public void testDecimalPointParsing() { assertParses("2001-01-01T00:00:00.123Z", javaFormatter); assertParses("2001-01-01T00:00:00,123Z", javaFormatter); - assertParseException("2001-01-01T00:00:00.123,456Z", "strict_date_optional_time"); - assertParseException("2001-01-01T00:00:00.123,456Z", "date_optional_time"); + assertParseException("2001-01-01T00:00:00.123,456Z", "strict_date_optional_time", 23); + assertParseException("2001-01-01T00:00:00.123,456Z", "date_optional_time", 23); // This should fail, but java is ok with this because the field has the same value // assertJavaTimeParseException("2001-01-01T00:00:00.123,123Z", "strict_date_optional_time_nanos"); } @@ -911,7 +926,7 @@ public void testFormatsValidParsing() { assertParses("2018-12-31T12:12:12.123456789", "date_hour_minute_second_fraction"); assertParses("2018-12-31T12:12:12.1", "date_hour_minute_second_millis"); assertParses("2018-12-31T12:12:12.123", "date_hour_minute_second_millis"); - assertParseException("2018-12-31T12:12:12.123456789", "date_hour_minute_second_millis"); + assertParseException("2018-12-31T12:12:12.123456789", "date_hour_minute_second_millis", 23); assertParses("2018-12-31T12:12:12.1", "date_hour_minute_second_millis"); assertParses("2018-12-31T12:12:12.1", "date_hour_minute_second_fraction"); @@ -981,11 +996,11 @@ public void testFormatsValidParsing() { assertParses("12:12:12.123", "hour_minute_second_fraction"); assertParses("12:12:12.123456789", "hour_minute_second_fraction"); assertParses("12:12:12.1", "hour_minute_second_fraction"); - assertParseException("12:12:12", "hour_minute_second_fraction"); + assertParseException("12:12:12", "hour_minute_second_fraction", 8); assertParses("12:12:12.123", "hour_minute_second_millis"); - assertParseException("12:12:12.123456789", "hour_minute_second_millis"); + assertParseException("12:12:12.123456789", "hour_minute_second_millis", 12); assertParses("12:12:12.1", "hour_minute_second_millis"); - assertParseException("12:12:12", "hour_minute_second_millis"); + assertParseException("12:12:12", "hour_minute_second_millis", 8); assertParses("2018-128", "ordinal_date"); assertParses("2018-1", "ordinal_date"); @@ -1025,8 +1040,8 @@ public void testFormatsValidParsing() { assertParses("10:15:3.123Z", "time"); assertParses("10:15:3.123+0100", "time"); assertParses("10:15:3.123+01:00", "time"); - assertParseException("10:15:3.1", "time"); - assertParseException("10:15:3Z", "time"); + assertParseException("10:15:3.1", "time", 9); + assertParseException("10:15:3Z", "time", 7); assertParses("10:15:30Z", "time_no_millis"); assertParses("10:15:30+0100", "time_no_millis"); @@ -1043,7 +1058,7 @@ public void testFormatsValidParsing() { assertParses("10:15:3Z", "time_no_millis"); assertParses("10:15:3+0100", "time_no_millis"); assertParses("10:15:3+01:00", "time_no_millis"); - assertParseException("10:15:3", "time_no_millis"); + assertParseException("10:15:3", "time_no_millis", 7); assertParses("T10:15:30.1Z", "t_time"); assertParses("T10:15:30.123Z", "t_time"); @@ -1061,8 +1076,8 @@ public void testFormatsValidParsing() { assertParses("T10:15:3.123Z", "t_time"); assertParses("T10:15:3.123+0100", "t_time"); assertParses("T10:15:3.123+01:00", "t_time"); - assertParseException("T10:15:3.1", "t_time"); - assertParseException("T10:15:3Z", "t_time"); + assertParseException("T10:15:3.1", "t_time", 10); + assertParseException("T10:15:3Z", "t_time", 8); assertParses("T10:15:30Z", "t_time_no_millis"); assertParses("T10:15:30+0100", "t_time_no_millis"); @@ -1076,12 +1091,12 @@ public void testFormatsValidParsing() { assertParses("T10:15:3Z", "t_time_no_millis"); assertParses("T10:15:3+0100", "t_time_no_millis"); assertParses("T10:15:3+01:00", "t_time_no_millis"); - assertParseException("T10:15:3", "t_time_no_millis"); + assertParseException("T10:15:3", "t_time_no_millis", 8); assertParses("2012-W48-6", "week_date"); assertParses("2012-W01-6", "week_date"); assertParses("2012-W1-6", "week_date"); - assertParseException("2012-W1-8", "week_date"); + assertParseException("2012-W1-8", "week_date", 0); assertParses("2012-W48-6T10:15:30.1Z", "week_date_time"); assertParses("2012-W48-6T10:15:30.123Z", "week_date_time"); @@ -1135,17 +1150,12 @@ public void testCompositeParsing() { } public void testExceptionWhenCompositeParsingFails() { - assertParseException("2014-06-06T12:01:02.123", "yyyy-MM-dd'T'HH:mm:ss||yyyy-MM-dd'T'HH:mm:ss.SS"); - } - - public void testExceptionErrorIndex() { - Exception e = assertParseException("2024-01-01j", "iso8601||strict_date_optional_time"); - assertThat(((DateTimeParseException) e.getCause()).getErrorIndex(), equalTo(10)); + assertParseException("2014-06-06T12:01:02.123", "yyyy-MM-dd'T'HH:mm:ss||yyyy-MM-dd'T'HH:mm:ss.SS", 19); } public void testStrictParsing() { assertParses("2018W313", "strict_basic_week_date"); - assertParseException("18W313", "strict_basic_week_date"); + assertParseException("18W313", "strict_basic_week_date", 0); assertParses("2018W313T121212.1Z", "strict_basic_week_date_time"); assertParses("2018W313T121212.123Z", "strict_basic_week_date_time"); assertParses("2018W313T121212.123456789Z", "strict_basic_week_date_time"); @@ -1153,52 +1163,52 @@ public void testStrictParsing() { assertParses("2018W313T121212.123+0100", "strict_basic_week_date_time"); assertParses("2018W313T121212.1+01:00", "strict_basic_week_date_time"); assertParses("2018W313T121212.123+01:00", "strict_basic_week_date_time"); - assertParseException("2018W313T12128.123Z", "strict_basic_week_date_time"); - assertParseException("2018W313T12128.123456789Z", "strict_basic_week_date_time"); - assertParseException("2018W313T81212.123Z", "strict_basic_week_date_time"); - assertParseException("2018W313T12812.123Z", "strict_basic_week_date_time"); - assertParseException("2018W313T12812.1Z", "strict_basic_week_date_time"); + assertParseException("2018W313T12128.123Z", "strict_basic_week_date_time", 13); + assertParseException("2018W313T12128.123456789Z", "strict_basic_week_date_time", 13); + assertParseException("2018W313T81212.123Z", "strict_basic_week_date_time", 13); + assertParseException("2018W313T12812.123Z", "strict_basic_week_date_time", 13); + assertParseException("2018W313T12812.1Z", "strict_basic_week_date_time", 13); assertParses("2018W313T121212Z", "strict_basic_week_date_time_no_millis"); assertParses("2018W313T121212+0100", "strict_basic_week_date_time_no_millis"); assertParses("2018W313T121212+01:00", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T12128Z", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T12128+0100", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T12128+01:00", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T81212Z", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T81212+0100", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T81212+01:00", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T12812Z", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T12812+0100", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T12812+01:00", "strict_basic_week_date_time_no_millis"); + assertParseException("2018W313T12128Z", "strict_basic_week_date_time_no_millis", 13); + assertParseException("2018W313T12128+0100", "strict_basic_week_date_time_no_millis", 13); + assertParseException("2018W313T12128+01:00", "strict_basic_week_date_time_no_millis", 13); + assertParseException("2018W313T81212Z", "strict_basic_week_date_time_no_millis", 13); + assertParseException("2018W313T81212+0100", "strict_basic_week_date_time_no_millis", 13); + assertParseException("2018W313T81212+01:00", "strict_basic_week_date_time_no_millis", 13); + assertParseException("2018W313T12812Z", "strict_basic_week_date_time_no_millis", 13); + assertParseException("2018W313T12812+0100", "strict_basic_week_date_time_no_millis", 13); + assertParseException("2018W313T12812+01:00", "strict_basic_week_date_time_no_millis", 13); assertParses("2018-12-31", "strict_date"); - assertParseException("10000-12-31", "strict_date"); - assertParseException("2018-8-31", "strict_date"); + assertParseException("10000-12-31", "strict_date", 0); + assertParseException("2018-8-31", "strict_date", 5); assertParses("2018-12-31T12", "strict_date_hour"); - assertParseException("2018-12-31T8", "strict_date_hour"); + assertParseException("2018-12-31T8", "strict_date_hour", 11); assertParses("2018-12-31T12:12", "strict_date_hour_minute"); - assertParseException("2018-12-31T8:3", "strict_date_hour_minute"); + assertParseException("2018-12-31T8:3", "strict_date_hour_minute", 11); assertParses("2018-12-31T12:12:12", "strict_date_hour_minute_second"); - assertParseException("2018-12-31T12:12:1", "strict_date_hour_minute_second"); + assertParseException("2018-12-31T12:12:1", "strict_date_hour_minute_second", 17); assertParses("2018-12-31T12:12:12.1", "strict_date_hour_minute_second_fraction"); assertParses("2018-12-31T12:12:12.123", "strict_date_hour_minute_second_fraction"); assertParses("2018-12-31T12:12:12.123456789", "strict_date_hour_minute_second_fraction"); assertParses("2018-12-31T12:12:12.123", "strict_date_hour_minute_second_millis"); assertParses("2018-12-31T12:12:12.1", "strict_date_hour_minute_second_millis"); assertParses("2018-12-31T12:12:12.1", "strict_date_hour_minute_second_fraction"); - assertParseException("2018-12-31T12:12:12", "strict_date_hour_minute_second_millis"); - assertParseException("2018-12-31T12:12:12", "strict_date_hour_minute_second_fraction"); + assertParseException("2018-12-31T12:12:12", "strict_date_hour_minute_second_millis", 19); + assertParseException("2018-12-31T12:12:12", "strict_date_hour_minute_second_fraction", 19); assertParses("2018-12-31", "strict_date_optional_time"); - assertParseException("2018-12-1", "strict_date_optional_time"); - assertParseException("2018-1-31", "strict_date_optional_time"); - assertParseException("10000-01-31", "strict_date_optional_time"); + assertParseException("2018-12-1", "strict_date_optional_time", 7); + assertParseException("2018-1-31", "strict_date_optional_time", 4); + assertParseException("10000-01-31", "strict_date_optional_time", 4); assertParses("2010-01-05T02:00", "strict_date_optional_time"); assertParses("2018-12-31T10:15:30", "strict_date_optional_time"); assertParses("2018-12-31T10:15:30Z", "strict_date_optional_time"); assertParses("2018-12-31T10:15:30+0100", "strict_date_optional_time"); assertParses("2018-12-31T10:15:30+01:00", "strict_date_optional_time"); - assertParseException("2018-12-31T10:15:3", "strict_date_optional_time"); - assertParseException("2018-12-31T10:5:30", "strict_date_optional_time"); - assertParseException("2018-12-31T9:15:30", "strict_date_optional_time"); + assertParseException("2018-12-31T10:15:3", "strict_date_optional_time", 16); + assertParseException("2018-12-31T10:5:30", "strict_date_optional_time", 13); + assertParseException("2018-12-31T9:15:30", "strict_date_optional_time", 11); assertParses("2015-01-04T00:00Z", "strict_date_optional_time"); assertParses("2018-12-31T10:15:30.1Z", "strict_date_time"); assertParses("2018-12-31T10:15:30.123Z", "strict_date_time"); @@ -1210,33 +1220,33 @@ public void testStrictParsing() { assertParses("2018-12-31T10:15:30.11Z", "strict_date_time"); assertParses("2018-12-31T10:15:30.11+0100", "strict_date_time"); assertParses("2018-12-31T10:15:30.11+01:00", "strict_date_time"); - assertParseException("2018-12-31T10:15:3.123Z", "strict_date_time"); - assertParseException("2018-12-31T10:5:30.123Z", "strict_date_time"); - assertParseException("2018-12-31T1:15:30.123Z", "strict_date_time"); + assertParseException("2018-12-31T10:15:3.123Z", "strict_date_time", 17); + assertParseException("2018-12-31T10:5:30.123Z", "strict_date_time", 14); + assertParseException("2018-12-31T1:15:30.123Z", "strict_date_time", 11); assertParses("2018-12-31T10:15:30Z", "strict_date_time_no_millis"); assertParses("2018-12-31T10:15:30+0100", "strict_date_time_no_millis"); assertParses("2018-12-31T10:15:30+01:00", "strict_date_time_no_millis"); - assertParseException("2018-12-31T10:5:30Z", "strict_date_time_no_millis"); - assertParseException("2018-12-31T10:15:3Z", "strict_date_time_no_millis"); - assertParseException("2018-12-31T1:15:30Z", "strict_date_time_no_millis"); + assertParseException("2018-12-31T10:5:30Z", "strict_date_time_no_millis", 14); + assertParseException("2018-12-31T10:15:3Z", "strict_date_time_no_millis", 17); + assertParseException("2018-12-31T1:15:30Z", "strict_date_time_no_millis", 11); assertParses("12", "strict_hour"); assertParses("01", "strict_hour"); - assertParseException("1", "strict_hour"); + assertParseException("1", "strict_hour", 0); assertParses("12:12", "strict_hour_minute"); assertParses("12:01", "strict_hour_minute"); - assertParseException("12:1", "strict_hour_minute"); + assertParseException("12:1", "strict_hour_minute", 3); assertParses("12:12:12", "strict_hour_minute_second"); assertParses("12:12:01", "strict_hour_minute_second"); - assertParseException("12:12:1", "strict_hour_minute_second"); + assertParseException("12:12:1", "strict_hour_minute_second", 6); assertParses("12:12:12.123", "strict_hour_minute_second_fraction"); assertParses("12:12:12.123456789", "strict_hour_minute_second_fraction"); assertParses("12:12:12.1", "strict_hour_minute_second_fraction"); - assertParseException("12:12:12", "strict_hour_minute_second_fraction"); + assertParseException("12:12:12", "strict_hour_minute_second_fraction", 8); assertParses("12:12:12.123", "strict_hour_minute_second_millis"); assertParses("12:12:12.1", "strict_hour_minute_second_millis"); - assertParseException("12:12:12", "strict_hour_minute_second_millis"); + assertParseException("12:12:12", "strict_hour_minute_second_millis", 8); assertParses("2018-128", "strict_ordinal_date"); - assertParseException("2018-1", "strict_ordinal_date"); + assertParseException("2018-1", "strict_ordinal_date", 5); assertParses("2018-128T10:15:30.1Z", "strict_ordinal_date_time"); assertParses("2018-128T10:15:30.123Z", "strict_ordinal_date_time"); @@ -1245,23 +1255,23 @@ public void testStrictParsing() { assertParses("2018-128T10:15:30.123+0100", "strict_ordinal_date_time"); assertParses("2018-128T10:15:30.1+01:00", "strict_ordinal_date_time"); assertParses("2018-128T10:15:30.123+01:00", "strict_ordinal_date_time"); - assertParseException("2018-1T10:15:30.123Z", "strict_ordinal_date_time"); + assertParseException("2018-1T10:15:30.123Z", "strict_ordinal_date_time", 5); assertParses("2018-128T10:15:30Z", "strict_ordinal_date_time_no_millis"); assertParses("2018-128T10:15:30+0100", "strict_ordinal_date_time_no_millis"); assertParses("2018-128T10:15:30+01:00", "strict_ordinal_date_time_no_millis"); - assertParseException("2018-1T10:15:30Z", "strict_ordinal_date_time_no_millis"); + assertParseException("2018-1T10:15:30Z", "strict_ordinal_date_time_no_millis", 5); assertParses("10:15:30.1Z", "strict_time"); assertParses("10:15:30.123Z", "strict_time"); assertParses("10:15:30.123456789Z", "strict_time"); assertParses("10:15:30.123+0100", "strict_time"); assertParses("10:15:30.123+01:00", "strict_time"); - assertParseException("1:15:30.123Z", "strict_time"); - assertParseException("10:1:30.123Z", "strict_time"); - assertParseException("10:15:3.123Z", "strict_time"); - assertParseException("10:15:3.1", "strict_time"); - assertParseException("10:15:3Z", "strict_time"); + assertParseException("1:15:30.123Z", "strict_time", 0); + assertParseException("10:1:30.123Z", "strict_time", 3); + assertParseException("10:15:3.123Z", "strict_time", 6); + assertParseException("10:15:3.1", "strict_time", 6); + assertParseException("10:15:3Z", "strict_time", 6); assertParses("10:15:30Z", "strict_time_no_millis"); assertParses("10:15:30+0100", "strict_time_no_millis"); @@ -1269,10 +1279,10 @@ public void testStrictParsing() { assertParses("01:15:30Z", "strict_time_no_millis"); assertParses("01:15:30+0100", "strict_time_no_millis"); assertParses("01:15:30+01:00", "strict_time_no_millis"); - assertParseException("1:15:30Z", "strict_time_no_millis"); - assertParseException("10:5:30Z", "strict_time_no_millis"); - assertParseException("10:15:3Z", "strict_time_no_millis"); - assertParseException("10:15:3", "strict_time_no_millis"); + assertParseException("1:15:30Z", "strict_time_no_millis", 0); + assertParseException("10:5:30Z", "strict_time_no_millis", 3); + assertParseException("10:15:3Z", "strict_time_no_millis", 6); + assertParseException("10:15:3", "strict_time_no_millis", 6); assertParses("T10:15:30.1Z", "strict_t_time"); assertParses("T10:15:30.123Z", "strict_t_time"); @@ -1281,28 +1291,28 @@ public void testStrictParsing() { assertParses("T10:15:30.123+0100", "strict_t_time"); assertParses("T10:15:30.1+01:00", "strict_t_time"); assertParses("T10:15:30.123+01:00", "strict_t_time"); - assertParseException("T1:15:30.123Z", "strict_t_time"); - assertParseException("T10:1:30.123Z", "strict_t_time"); - assertParseException("T10:15:3.123Z", "strict_t_time"); - assertParseException("T10:15:3.1", "strict_t_time"); - assertParseException("T10:15:3Z", "strict_t_time"); + assertParseException("T1:15:30.123Z", "strict_t_time", 1); + assertParseException("T10:1:30.123Z", "strict_t_time", 4); + assertParseException("T10:15:3.123Z", "strict_t_time", 7); + assertParseException("T10:15:3.1", "strict_t_time", 7); + assertParseException("T10:15:3Z", "strict_t_time", 7); assertParses("T10:15:30Z", "strict_t_time_no_millis"); assertParses("T10:15:30+0100", "strict_t_time_no_millis"); assertParses("T10:15:30+01:00", "strict_t_time_no_millis"); - assertParseException("T1:15:30Z", "strict_t_time_no_millis"); - assertParseException("T10:1:30Z", "strict_t_time_no_millis"); - assertParseException("T10:15:3Z", "strict_t_time_no_millis"); - assertParseException("T10:15:3", "strict_t_time_no_millis"); + assertParseException("T1:15:30Z", "strict_t_time_no_millis", 1); + assertParseException("T10:1:30Z", "strict_t_time_no_millis", 4); + assertParseException("T10:15:3Z", "strict_t_time_no_millis", 7); + assertParseException("T10:15:3", "strict_t_time_no_millis", 7); assertParses("2012-W48-6", "strict_week_date"); assertParses("2012-W01-6", "strict_week_date"); - assertParseException("2012-W1-6", "strict_week_date"); - assertParseException("2012-W1-8", "strict_week_date"); + assertParseException("2012-W1-6", "strict_week_date", 6); + assertParseException("2012-W1-8", "strict_week_date", 6); assertParses("2012-W48-6", "strict_week_date"); assertParses("2012-W01-6", "strict_week_date"); - assertParseException("2012-W1-6", "strict_week_date"); + assertParseException("2012-W1-6", "strict_week_date", 6); assertParseException("2012-W01-8", "strict_week_date"); assertParses("2012-W48-6T10:15:30.1Z", "strict_week_date_time"); @@ -1312,38 +1322,38 @@ public void testStrictParsing() { assertParses("2012-W48-6T10:15:30.123+0100", "strict_week_date_time"); assertParses("2012-W48-6T10:15:30.1+01:00", "strict_week_date_time"); assertParses("2012-W48-6T10:15:30.123+01:00", "strict_week_date_time"); - assertParseException("2012-W1-6T10:15:30.123Z", "strict_week_date_time"); + assertParseException("2012-W1-6T10:15:30.123Z", "strict_week_date_time", 6); assertParses("2012-W48-6T10:15:30Z", "strict_week_date_time_no_millis"); assertParses("2012-W48-6T10:15:30+0100", "strict_week_date_time_no_millis"); assertParses("2012-W48-6T10:15:30+01:00", "strict_week_date_time_no_millis"); - assertParseException("2012-W1-6T10:15:30Z", "strict_week_date_time_no_millis"); + assertParseException("2012-W1-6T10:15:30Z", "strict_week_date_time_no_millis", 6); assertParses("2012", "strict_year"); - assertParseException("1", "strict_year"); + assertParseException("1", "strict_year", 0); assertParses("-2000", "strict_year"); assertParses("2012-12", "strict_year_month"); - assertParseException("1-1", "strict_year_month"); + assertParseException("1-1", "strict_year_month", 0); assertParses("2012-12-31", "strict_year_month_day"); - assertParseException("1-12-31", "strict_year_month_day"); - assertParseException("2012-1-31", "strict_year_month_day"); - assertParseException("2012-12-1", "strict_year_month_day"); + assertParseException("1-12-31", "strict_year_month_day", 0); + assertParseException("2012-1-31", "strict_year_month_day", 4); + assertParseException("2012-12-1", "strict_year_month_day", 7); assertParses("2018", "strict_weekyear"); - assertParseException("1", "strict_weekyear"); + assertParseException("1", "strict_weekyear", 0); assertParses("2018", "strict_weekyear"); assertParses("2017", "strict_weekyear"); - assertParseException("1", "strict_weekyear"); + assertParseException("1", "strict_weekyear", 0); assertParses("2018-W29", "strict_weekyear_week"); assertParses("2018-W01", "strict_weekyear_week"); - assertParseException("2018-W1", "strict_weekyear_week"); + assertParseException("2018-W1", "strict_weekyear_week", 6); assertParses("2012-W31-5", "strict_weekyear_week_day"); - assertParseException("2012-W1-1", "strict_weekyear_week_day"); + assertParseException("2012-W1-1", "strict_weekyear_week_day", 6); } public void testDateFormatterWithLocale() { diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index eeb94beff04d5..14269a8835f57 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -374,6 +374,15 @@ private static void setTestSysProps() { // We have to disable setting the number of available processors as tests in the same JVM randomize processors and will step on each // other if we allow them to set the number of available processors as it's set-once in Netty. System.setProperty("es.set.netty.runtime.available.processors", "false"); + + // sometimes use the java.time date formatters + // we can't use randomBoolean here, the random context isn't set properly + // so read it directly from the test seed in an unfortunately hacky way + String testSeed = System.getProperty("tests.seed", "0"); + boolean firstBit = (Integer.parseInt(testSeed.substring(testSeed.length() - 1), 16) & 1) == 1; + if (firstBit) { + System.setProperty("es.datetime.java_time_parsers", "true"); + } } protected final Logger logger = LogManager.getLogger(getClass());