diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java index 1fda9ababfabd..5793e1f32808f 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java @@ -696,6 +696,10 @@ private static void failIfMatchesRoutingPath(DocumentParserContext context, Stri */ private static void parseCopyFields(DocumentParserContext context, List copyToFields) throws IOException { for (String field : copyToFields) { + if (context.mappingLookup().inferenceFields().get(field) != null) { + // ignore copy_to that targets inference fields, values are already extracted in the coordinating node to perform inference. + continue; + } // In case of a hierarchy of nested documents, we need to figure out // which document the field should go to LuceneDocument targetDoc = null; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java index 21a0c4d393a23..181e6a5ed8ce0 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldMapper.java @@ -1199,7 +1199,7 @@ public static final class Conflicts { private final String mapperName; private final List conflicts = new ArrayList<>(); - Conflicts(String mapperName) { + public Conflicts(String mapperName) { this.mapperName = mapperName; } @@ -1211,7 +1211,7 @@ void addConflict(String parameter, String existing, String toMerge) { conflicts.add("Cannot update parameter [" + parameter + "] from [" + existing + "] to [" + toMerge + "]"); } - void check() { + public void check() { if (conflicts.isEmpty()) { return; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java index 1e3f69baf86dd..48e04a938d2b2 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperMergeContext.java @@ -55,7 +55,7 @@ public static MapperMergeContext from(MapperBuilderContext mapperBuilderContext, * @param name the name of the child context * @return a new {@link MapperMergeContext} with this context as its parent */ - MapperMergeContext createChildContext(String name, ObjectMapper.Dynamic dynamic) { + public MapperMergeContext createChildContext(String name, ObjectMapper.Dynamic dynamic) { return createChildContext(mapperBuilderContext.createChildContext(name, dynamic)); } @@ -69,7 +69,7 @@ MapperMergeContext createChildContext(MapperBuilderContext childContext) { return new MapperMergeContext(childContext, newFieldsBudget); } - MapperBuilderContext getMapperBuilderContext() { + public MapperBuilderContext getMapperBuilderContext() { return mapperBuilderContext; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index b2ec1612ea3da..5159a76206ef6 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -230,6 +230,16 @@ protected Parameter[] getParameters() { return new Parameter[] { elementType, dims, indexed, similarity, indexOptions, meta }; } + public Builder similarity(VectorSimilarity vectorSimilarity) { + similarity.setValue(vectorSimilarity); + return this; + } + + public Builder dimensions(int dimensions) { + this.dims.setValue(dimensions); + return this; + } + @Override public DenseVectorFieldMapper build(MapperBuilderContext context) { return new DenseVectorFieldMapper( @@ -754,7 +764,7 @@ public static ElementType fromString(String name) { ElementType.FLOAT ); - enum VectorSimilarity { + public enum VectorSimilarity { L2_NORM { @Override float score(float similarity, ElementType elementType, int dim) { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/CopyToMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/CopyToMapperTests.java index 5eacfe6f2e3ab..33341e6b36987 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/CopyToMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/CopyToMapperTests.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; import static org.hamcrest.Matchers.equalTo; @@ -106,6 +107,12 @@ public void testCopyToFieldsParsing() throws Exception { fieldMapper = mapperService.documentMapper().mappers().getMapper("new_field"); assertThat(fieldMapper.typeName(), equalTo("long")); + + MappingLookup mappingLookup = mapperService.mappingLookup(); + assertThat(mappingLookup.sourcePaths("another_field"), equalTo(Set.of("copy_test", "int_to_str_test", "another_field"))); + assertThat(mappingLookup.sourcePaths("new_field"), equalTo(Set.of("new_field", "int_to_str_test"))); + assertThat(mappingLookup.sourcePaths("copy_test"), equalTo(Set.of("copy_test", "cyclic_test"))); + assertThat(mappingLookup.sourcePaths("cyclic_test"), equalTo(Set.of("cyclic_test", "copy_test"))); } public void testCopyToFieldsInnerObjectParsing() throws Exception { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupInferenceFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupInferenceFieldMapperTests.java new file mode 100644 index 0000000000000..b13f6b79b0de5 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupInferenceFieldMapperTests.java @@ -0,0 +1,160 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.mapper; + +import org.apache.lucene.search.Query; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.plugins.MapperPlugin; +import org.elasticsearch.plugins.Plugin; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.hamcrest.Matchers.arrayContaining; +import static org.hamcrest.Matchers.arrayContainingInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class MappingLookupInferenceFieldMapperTests extends MapperServiceTestCase { + + @Override + protected Collection getPlugins() { + return List.of(new TestInferenceFieldMapperPlugin()); + } + + public void testInferenceFieldMapper() throws Exception { + MapperService mapperService = createMapperService(mapping(b -> { + b.startObject("non_inference_field").field("type", "text").endObject(); + b.startObject("another_non_inference_field").field("type", "text").endObject(); + b.startObject("inference_field").field("type", TestInferenceFieldMapper.CONTENT_TYPE).endObject(); + b.startObject("another_inference_field").field("type", TestInferenceFieldMapper.CONTENT_TYPE).endObject(); + })); + + Map inferenceFieldMetadataMap = mapperService.mappingLookup().inferenceFields(); + assertThat(inferenceFieldMetadataMap.keySet(), hasSize(2)); + + InferenceFieldMetadata inferenceFieldMetadata = inferenceFieldMetadataMap.get("inference_field"); + assertThat(inferenceFieldMetadata.getInferenceId(), equalTo(TestInferenceFieldMapper.INFERENCE_ID)); + assertThat(inferenceFieldMetadata.getSourceFields(), arrayContaining("inference_field")); + + inferenceFieldMetadata = inferenceFieldMetadataMap.get("another_inference_field"); + assertThat(inferenceFieldMetadata.getInferenceId(), equalTo(TestInferenceFieldMapper.INFERENCE_ID)); + assertThat(inferenceFieldMetadata.getSourceFields(), arrayContaining("another_inference_field")); + } + + public void testInferenceFieldMapperWithCopyTo() throws Exception { + MapperService mapperService = createMapperService(mapping(b -> { + b.startObject("non_inference_field"); + { + b.field("type", "text"); + b.array("copy_to", "inference_field"); + } + b.endObject(); + b.startObject("another_non_inference_field"); + { + b.field("type", "text"); + b.array("copy_to", "inference_field"); + } + b.endObject(); + b.startObject("inference_field").field("type", TestInferenceFieldMapper.CONTENT_TYPE).endObject(); + b.startObject("independent_field").field("type", "text").endObject(); + })); + + Map inferenceFieldMetadataMap = mapperService.mappingLookup().inferenceFields(); + assertThat(inferenceFieldMetadataMap.keySet(), hasSize(1)); + + InferenceFieldMetadata inferenceFieldMetadata = inferenceFieldMetadataMap.get("inference_field"); + assertThat(inferenceFieldMetadata.getInferenceId(), equalTo(TestInferenceFieldMapper.INFERENCE_ID)); + assertThat( + inferenceFieldMetadata.getSourceFields(), + arrayContainingInAnyOrder("another_non_inference_field", "inference_field", "non_inference_field") + ); + } + + private static class TestInferenceFieldMapperPlugin extends Plugin implements MapperPlugin { + + @Override + public Map getMappers() { + return Map.of(TestInferenceFieldMapper.CONTENT_TYPE, TestInferenceFieldMapper.PARSER); + } + } + + private static class TestInferenceFieldMapper extends FieldMapper implements InferenceFieldMapper { + + public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n)); + public static final String INFERENCE_ID = "test_inference_id"; + public static final String CONTENT_TYPE = "test_inference_field"; + + TestInferenceFieldMapper(String simpleName) { + super(simpleName, new TestInferenceFieldMapperFieldType(simpleName), MultiFields.empty(), CopyTo.empty()); + } + + @Override + public InferenceFieldMetadata getMetadata(Set sourcePaths) { + return new InferenceFieldMetadata(name(), INFERENCE_ID, sourcePaths.toArray(new String[0])); + } + + @Override + protected void parseCreateField(DocumentParserContext context) throws IOException {} + + @Override + public Builder getMergeBuilder() { + return new Builder(simpleName()); + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + public static class Builder extends FieldMapper.Builder { + + @Override + protected Parameter[] getParameters() { + return new Parameter[0]; + } + + Builder(String name) { + super(name); + } + + @Override + public FieldMapper build(MapperBuilderContext context) { + return new TestInferenceFieldMapper(name()); + } + } + + private static class TestInferenceFieldMapperFieldType extends MappedFieldType { + + TestInferenceFieldMapperFieldType(String name) { + super(name, false, false, false, TextSearchInfo.NONE, Map.of()); + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + return null; + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + return null; + } + } + } + +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldTests.java index d7df41131414e..6446033c07c5b 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MultiFieldTests.java @@ -224,6 +224,9 @@ public void testSourcePathFields() throws IOException { final Set fieldsUsingSourcePath = new HashSet<>(); ((FieldMapper) mapper).sourcePathUsedBy().forEachRemaining(mapper1 -> fieldsUsingSourcePath.add(mapper1.name())); assertThat(fieldsUsingSourcePath, equalTo(Set.of("field.subfield1", "field.subfield2"))); + + assertThat(mapperService.mappingLookup().sourcePaths("field.subfield1"), equalTo(Set.of("field"))); + assertThat(mapperService.mappingLookup().sourcePaths("field.subfield2"), equalTo(Set.of("field"))); } public void testUnknownLegacyFieldsUnderKnownRootField() throws Exception { diff --git a/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java b/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java index e6252e46a12a3..bbbafef514e30 100644 --- a/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java +++ b/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java @@ -641,7 +641,7 @@ public static MetadataRolloverService getMetadataRolloverService( AllocationService allocationService = mock(AllocationService.class); when(allocationService.reroute(any(ClusterState.class), any(String.class), any())).then(i -> i.getArguments()[0]); when(allocationService.getShardRoutingRoleStrategy()).thenReturn(TestShardRoutingRoleStrategies.DEFAULT_ROLE_ONLY); - MappingLookup mappingLookup = null; + MappingLookup mappingLookup = MappingLookup.EMPTY; if (dataStream != null) { RootObjectMapper.Builder root = new RootObjectMapper.Builder("_doc", ObjectMapper.Defaults.SUBOBJECTS); root.add( @@ -721,6 +721,7 @@ public static IndicesService mockIndicesServices(MappingLookup mappingLookup) th when(documentMapper.mapping()).thenReturn(mapping); when(documentMapper.mappers()).thenReturn(MappingLookup.EMPTY); when(documentMapper.mappingSource()).thenReturn(mapping.toCompressedXContent()); + when(documentMapper.mappers()).thenReturn(mappingLookup); RoutingFieldMapper routingFieldMapper = mock(RoutingFieldMapper.class); when(routingFieldMapper.required()).thenReturn(false); when(documentMapper.routingFieldMapper()).thenReturn(routingFieldMapper); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java index 097c23b96bb76..5f60e0eedbf03 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java @@ -1030,7 +1030,7 @@ public final void testMinimalIsInvalidInRoutingPath() throws IOException { } } - private String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { + protected String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { if (mapper instanceof FieldMapper fieldMapper && fieldMapper.fieldType().isDimension() == false) { return "All fields that match routing_path must be configured with [time_series_dimension: true] " + "or flattened fields with a list of dimensions in [time_series_dimensions] and " diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 2c473517e5aab..0aef8601ffcc6 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -6,6 +6,13 @@ */ apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' +apply plugin: 'elasticsearch.internal-yaml-rest-test' + +restResources { + restApi { + include '_common', 'indices', 'inference', 'index' + } +} esplugin { name 'x-pack-inference' @@ -24,6 +31,11 @@ dependencies { compileOnly project(path: xpackModule('core')) testImplementation(testArtifact(project(xpackModule('core')))) testImplementation project(':modules:reindex') + clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') api "com.ibm.icu:icu4j:${versions.icu4j}" } + +tasks.named('yamlRestTest') { + usesDefaultDistribution() +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index f41f9a97cec18..1afe3c891db80 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -21,11 +21,13 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.TimeValue; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; +import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.rest.RestController; @@ -50,6 +52,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; @@ -67,12 +70,13 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; -public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin { +public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin, MapperPlugin { /** * When this setting is true the verification check that @@ -260,4 +264,12 @@ public void close() { IOUtils.closeWhileHandlingException(inferenceServiceRegistry.get(), throttlerToClose); } + + @Override + public Map getMappers() { + if (SemanticTextFeature.isEnabled()) { + return Map.of(SemanticTextFieldMapper.CONTENT_TYPE, SemanticTextFieldMapper.PARSER); + } + return Map.of(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/SemanticTextFeature.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/SemanticTextFeature.java new file mode 100644 index 0000000000000..4f2c5c564bcb8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/SemanticTextFeature.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.common.util.FeatureFlag; + +/** + * semantic_text feature flag. When the feature is complete, this flag will be removed. + */ +public class SemanticTextFeature { + + private SemanticTextFeature() {} + + private static final FeatureFlag FEATURE_FLAG = new FeatureFlag("semantic_text"); + + public static boolean isEnabled() { + return FEATURE_FLAG.isEnabled(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java new file mode 100644 index 0000000000000..b3b0c39f5a1d1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -0,0 +1,351 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.DeprecationHandler; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.support.MapXContentParser; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * A {@link ToXContentObject} that is used to represent the transformation of the semantic text field's inputs. + * The resulting object preserves the original input under the {@link SemanticTextField#TEXT_FIELD} and exposes + * the inference results under the {@link SemanticTextField#INFERENCE_FIELD}. + * + * @param fieldName The original field name. + * @param originalValues The original values associated with the field name. + * @param inference The inference result. + * @param contentType The {@link XContentType} used to store the embeddings chunks. + */ +public record SemanticTextField(String fieldName, List originalValues, InferenceResult inference, XContentType contentType) + implements + ToXContentObject { + + static final String TEXT_FIELD = "text"; + static final String INFERENCE_FIELD = "inference"; + static final String INFERENCE_ID_FIELD = "inference_id"; + static final String CHUNKS_FIELD = "chunks"; + static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings"; + static final String CHUNKED_TEXT_FIELD = "text"; + static final String MODEL_SETTINGS_FIELD = "model_settings"; + static final String TASK_TYPE_FIELD = "task_type"; + static final String DIMENSIONS_FIELD = "dimensions"; + static final String SIMILARITY_FIELD = "similarity"; + + public record InferenceResult(String inferenceId, ModelSettings modelSettings, List chunks) {} + + public record Chunk(String text, BytesReference rawEmbeddings) {} + + public record ModelSettings(TaskType taskType, Integer dimensions, SimilarityMeasure similarity) implements ToXContentObject { + public ModelSettings(Model model) { + this(model.getTaskType(), model.getServiceSettings().dimensions(), model.getServiceSettings().similarity()); + } + + public ModelSettings(TaskType taskType, Integer dimensions, SimilarityMeasure similarity) { + this.taskType = Objects.requireNonNull(taskType, "task type must not be null"); + this.dimensions = dimensions; + this.similarity = similarity; + validate(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TASK_TYPE_FIELD, taskType.toString()); + if (dimensions != null) { + builder.field(DIMENSIONS_FIELD, dimensions); + } + if (similarity != null) { + builder.field(SIMILARITY_FIELD, similarity); + } + return builder.endObject(); + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(); + sb.append("task_type=").append(taskType); + if (dimensions != null) { + sb.append(", dimensions=").append(dimensions); + } + if (similarity != null) { + sb.append(", similarity=").append(similarity); + } + return sb.toString(); + } + + private void validate() { + switch (taskType) { + case TEXT_EMBEDDING: + validateFieldPresent(DIMENSIONS_FIELD, dimensions); + validateFieldPresent(SIMILARITY_FIELD, similarity); + break; + case SPARSE_EMBEDDING: + validateFieldNotPresent(DIMENSIONS_FIELD, dimensions); + validateFieldNotPresent(SIMILARITY_FIELD, similarity); + break; + + default: + throw new IllegalArgumentException( + "Wrong [" + + TASK_TYPE_FIELD + + "], expected " + + TEXT_EMBEDDING + + " or " + + SPARSE_EMBEDDING + + ", got " + + taskType.name() + ); + } + } + + private void validateFieldPresent(String field, Object fieldValue) { + if (fieldValue == null) { + throw new IllegalArgumentException("required [" + field + "] field is missing for task_type [" + taskType.name() + "]"); + } + } + + private void validateFieldNotPresent(String field, Object fieldValue) { + if (fieldValue != null) { + throw new IllegalArgumentException("[" + field + "] is not allowed for task_type [" + taskType.name() + "]"); + } + } + } + + public static String getOriginalTextFieldName(String fieldName) { + return fieldName + "." + TEXT_FIELD; + } + + public static String getInferenceFieldName(String fieldName) { + return fieldName + "." + INFERENCE_FIELD; + } + + public static String getChunksFieldName(String fieldName) { + return getInferenceFieldName(fieldName) + "." + CHUNKS_FIELD; + } + + public static String getEmbeddingsFieldName(String fieldName) { + return getChunksFieldName(fieldName) + "." + CHUNKED_EMBEDDINGS_FIELD; + } + + static SemanticTextField parse(XContentParser parser, Tuple context) throws IOException { + return SEMANTIC_TEXT_FIELD_PARSER.parse(parser, context); + } + + static ModelSettings parseModelSettings(XContentParser parser) throws IOException { + return MODEL_SETTINGS_PARSER.parse(parser, null); + } + + static ModelSettings parseModelSettingsFromMap(Object node) { + if (node == null) { + return null; + } + try { + Map map = XContentMapValues.nodeMapValue(node, MODEL_SETTINGS_FIELD); + XContentParser parser = new MapXContentParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + map, + XContentType.JSON + ); + return parseModelSettings(parser); + } catch (Exception exc) { + throw new ElasticsearchException(exc); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (originalValues.isEmpty() == false) { + builder.field(TEXT_FIELD, originalValues.size() == 1 ? originalValues.get(0) : originalValues); + } + builder.startObject(INFERENCE_FIELD); + builder.field(INFERENCE_ID_FIELD, inference.inferenceId); + builder.field(MODEL_SETTINGS_FIELD, inference.modelSettings); + builder.startArray(CHUNKS_FIELD); + for (var chunk : inference.chunks) { + builder.startObject(); + builder.field(CHUNKED_TEXT_FIELD, chunk.text); + XContentParser parser = XContentHelper.createParserNotCompressed( + XContentParserConfiguration.EMPTY, + chunk.rawEmbeddings, + contentType + ); + builder.field(CHUNKED_EMBEDDINGS_FIELD).copyCurrentStructure(parser); + builder.endObject(); + } + builder.endArray(); + builder.endObject(); + builder.endObject(); + return builder; + } + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser> SEMANTIC_TEXT_FIELD_PARSER = + new ConstructingObjectParser<>( + SemanticTextFieldMapper.CONTENT_TYPE, + true, + (args, context) -> new SemanticTextField( + context.v1(), + (List) (args[0] == null ? List.of() : args[0]), + (InferenceResult) args[1], + context.v2() + ) + ); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser INFERENCE_RESULT_PARSER = new ConstructingObjectParser<>( + INFERENCE_FIELD, + true, + args -> new InferenceResult((String) args[0], (ModelSettings) args[1], (List) args[2]) + ); + + private static final ConstructingObjectParser CHUNKS_PARSER = new ConstructingObjectParser<>( + CHUNKS_FIELD, + true, + args -> new Chunk((String) args[0], (BytesReference) args[1]) + ); + + private static final ConstructingObjectParser MODEL_SETTINGS_PARSER = new ConstructingObjectParser<>( + MODEL_SETTINGS_FIELD, + true, + args -> { + TaskType taskType = TaskType.fromString((String) args[0]); + Integer dimensions = (Integer) args[1]; + SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]); + return new ModelSettings(taskType, dimensions, similarity); + } + ); + + static { + SEMANTIC_TEXT_FIELD_PARSER.declareStringArray(optionalConstructorArg(), new ParseField(TEXT_FIELD)); + SEMANTIC_TEXT_FIELD_PARSER.declareObject( + constructorArg(), + (p, c) -> INFERENCE_RESULT_PARSER.parse(p, null), + new ParseField(INFERENCE_FIELD) + ); + + INFERENCE_RESULT_PARSER.declareString(constructorArg(), new ParseField(INFERENCE_ID_FIELD)); + INFERENCE_RESULT_PARSER.declareObject(constructorArg(), MODEL_SETTINGS_PARSER, new ParseField(MODEL_SETTINGS_FIELD)); + INFERENCE_RESULT_PARSER.declareObjectArray(constructorArg(), CHUNKS_PARSER, new ParseField(CHUNKS_FIELD)); + + CHUNKS_PARSER.declareString(constructorArg(), new ParseField(CHUNKED_TEXT_FIELD)); + CHUNKS_PARSER.declareField(constructorArg(), (p, c) -> { + XContentBuilder b = XContentBuilder.builder(p.contentType().xContent()); + b.copyCurrentStructure(p); + return BytesReference.bytes(b); + }, new ParseField(CHUNKED_EMBEDDINGS_FIELD), ObjectParser.ValueType.OBJECT_ARRAY); + + MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(TASK_TYPE_FIELD)); + MODEL_SETTINGS_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(DIMENSIONS_FIELD)); + MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(SIMILARITY_FIELD)); + } + + /** + * Converts the provided {@link ChunkedInferenceServiceResults} into a list of {@link Chunk}. + */ + public static List toSemanticTextFieldChunks( + String field, + String inferenceId, + List results, + XContentType contentType + ) { + List chunks = new ArrayList<>(); + for (var result : results) { + if (result instanceof ChunkedSparseEmbeddingResults textExpansionResults) { + for (var chunk : textExpansionResults.getChunkedResults()) { + chunks.add(new Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.weightedTokens()))); + } + } else if (result instanceof ChunkedTextEmbeddingResults textEmbeddingResults) { + for (var chunk : textEmbeddingResults.getChunks()) { + chunks.add(new Chunk(chunk.matchedText(), toBytesReference(contentType.xContent(), chunk.embedding()))); + } + } else { + throw new ElasticsearchStatusException( + "Invalid inference results format for field [{}] with inference id [{}], got {}", + RestStatus.BAD_REQUEST, + field, + inferenceId, + result.getWriteableName() + ); + } + } + return chunks; + } + + /** + * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, double[] value) { + try { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startArray(); + for (double v : value) { + b.value(v); + } + b.endArray(); + return BytesReference.bytes(b); + } catch (IOException exc) { + throw new RuntimeException(exc); + } + } + + /** + * Serialises the {@link TextExpansionResults.WeightedToken} list, according to the provided {@link XContent}, + * into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, List tokens) { + try { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startObject(); + for (var weightedToken : tokens) { + weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS); + } + b.endObject(); + return BytesReference.bytes(b); + } catch (IOException exc) { + throw new RuntimeException(exc); + } + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java new file mode 100644 index 0000000000000..c4293d16ce6a4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -0,0 +1,396 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.apache.lucene.search.Query; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.Explicit; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.fielddata.FieldDataContext; +import org.elasticsearch.index.fielddata.IndexFieldData; +import org.elasticsearch.index.mapper.DocumentParserContext; +import org.elasticsearch.index.mapper.DocumentParsingException; +import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.InferenceFieldMapper; +import org.elasticsearch.index.mapper.KeywordFieldMapper; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.MapperMergeContext; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.index.mapper.ObjectMapper; +import org.elasticsearch.index.mapper.SimpleMappedFieldType; +import org.elasticsearch.index.mapper.SourceValueFetcher; +import org.elasticsearch.index.mapper.TextSearchInfo; +import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentLocation; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; + +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_TEXT_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_ID_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName; + +/** + * A {@link FieldMapper} for semantic text fields. + */ +public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper { + public static final String CONTENT_TYPE = "semantic_text"; + + public static final TypeParser PARSER = new TypeParser( + (n, c) -> new Builder(n, c.indexVersionCreated()), + notInMultiFields(CONTENT_TYPE) + ); + + public static class Builder extends FieldMapper.Builder { + private final IndexVersion indexVersionCreated; + + private final Parameter inferenceId = Parameter.stringParam( + "inference_id", + false, + mapper -> ((SemanticTextFieldType) mapper.fieldType()).inferenceId, + null + ).addValidator(v -> { + if (Strings.isEmpty(v)) { + throw new IllegalArgumentException("field [inference_id] must be specified"); + } + }); + + private final Parameter modelSettings = new Parameter<>( + "model_settings", + true, + () -> null, + (n, c, o) -> SemanticTextField.parseModelSettingsFromMap(o), + mapper -> ((SemanticTextFieldType) mapper.fieldType()).modelSettings, + XContentBuilder::field, + Objects::toString + ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings); + + private final Parameter> meta = Parameter.metaParam(); + + private Function inferenceFieldBuilder; + + public Builder(String name, IndexVersion indexVersionCreated) { + super(name); + this.indexVersionCreated = indexVersionCreated; + this.inferenceFieldBuilder = c -> createInferenceField(c, indexVersionCreated, modelSettings.get()); + } + + public Builder setInferenceId(String id) { + this.inferenceId.setValue(id); + return this; + } + + public Builder setModelSettings(SemanticTextField.ModelSettings value) { + this.modelSettings.setValue(value); + return this; + } + + @Override + protected Parameter[] getParameters() { + return new Parameter[] { inferenceId, modelSettings, meta }; + } + + @Override + protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeContext mapperMergeContext) { + super.merge(mergeWith, conflicts, mapperMergeContext); + conflicts.check(); + var semanticMergeWith = (SemanticTextFieldMapper) mergeWith; + var context = mapperMergeContext.createChildContext(mergeWith.simpleName(), ObjectMapper.Dynamic.FALSE); + var inferenceField = inferenceFieldBuilder.apply(context.getMapperBuilderContext()); + var mergedInferenceField = inferenceField.merge(semanticMergeWith.fieldType().getInferenceField(), context); + inferenceFieldBuilder = c -> mergedInferenceField; + } + + @Override + public SemanticTextFieldMapper build(MapperBuilderContext context) { + if (copyTo.copyToFields().isEmpty() == false) { + throw new IllegalArgumentException(CONTENT_TYPE + " field [" + name() + "] does not support [copy_to]"); + } + if (multiFieldsBuilder.hasMultiFields()) { + throw new IllegalArgumentException(CONTENT_TYPE + " field [" + name() + "] does not support multi-fields"); + } + final String fullName = context.buildFullName(name()); + var childContext = context.createChildContext(name(), ObjectMapper.Dynamic.FALSE); + final ObjectMapper inferenceField = inferenceFieldBuilder.apply(childContext); + return new SemanticTextFieldMapper( + name(), + new SemanticTextFieldType( + fullName, + inferenceId.getValue(), + modelSettings.getValue(), + inferenceField, + indexVersionCreated, + meta.getValue() + ), + copyTo + ); + } + } + + private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, CopyTo copyTo) { + super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); + } + + @Override + public Iterator iterator() { + List subIterators = new ArrayList<>(); + subIterators.add(fieldType().getInferenceField()); + return subIterators.iterator(); + } + + @Override + public FieldMapper.Builder getMergeBuilder() { + return new Builder(simpleName(), fieldType().indexVersionCreated).init(this); + } + + @Override + protected void parseCreateField(DocumentParserContext context) throws IOException { + XContentParser parser = context.parser(); + if (parser.currentToken() == XContentParser.Token.VALUE_NULL) { + return; + } + XContentLocation xContentLocation = parser.getTokenLocation(); + final SemanticTextField field; + boolean isWithinLeaf = context.path().isWithinLeafObject(); + try { + context.path().setWithinLeafObject(true); + field = SemanticTextField.parse(parser, new Tuple<>(name(), context.parser().contentType())); + } finally { + context.path().setWithinLeafObject(isWithinLeaf); + } + final String fullFieldName = fieldType().name(); + if (field.inference().inferenceId().equals(fieldType().getInferenceId()) == false) { + throw new DocumentParsingException( + xContentLocation, + Strings.format( + "The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.", + INFERENCE_ID_FIELD, + field.inference().inferenceId(), + fullFieldName, + INFERENCE_ID_FIELD, + fieldType().getInferenceId() + ) + ); + } + final SemanticTextFieldMapper mapper; + if (fieldType().getModelSettings() == null) { + context.path().remove(); + Builder builder = (Builder) new Builder(simpleName(), fieldType().indexVersionCreated).init(this); + try { + mapper = builder.setModelSettings(field.inference().modelSettings()) + .setInferenceId(field.inference().inferenceId()) + .build(context.createDynamicMapperBuilderContext()); + context.addDynamicMapper(mapper); + } finally { + context.path().add(simpleName()); + } + } else { + Conflicts conflicts = new Conflicts(fullFieldName); + canMergeModelSettings(field.inference().modelSettings(), fieldType().getModelSettings(), conflicts); + try { + conflicts.check(); + } catch (Exception exc) { + throw new DocumentParsingException( + xContentLocation, + "Incompatible model settings for field [" + + name() + + "]. Check that the " + + INFERENCE_ID_FIELD + + " is not using different model settings", + exc + ); + } + mapper = this; + } + var chunksField = mapper.fieldType().getChunksField(); + var embeddingsField = mapper.fieldType().getEmbeddingsField(); + for (var chunk : field.inference().chunks()) { + try ( + XContentParser subParser = XContentHelper.createParserNotCompressed( + XContentParserConfiguration.EMPTY, + chunk.rawEmbeddings(), + context.parser().contentType() + ) + ) { + DocumentParserContext subContext = context.createNestedContext(chunksField).switchParser(subParser); + subParser.nextToken(); + embeddingsField.parse(subContext); + } + } + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + public SemanticTextFieldType fieldType() { + return (SemanticTextFieldType) super.fieldType(); + } + + @Override + public InferenceFieldMetadata getMetadata(Set sourcePaths) { + String[] copyFields = sourcePaths.toArray(String[]::new); + // ensure consistent order + Arrays.sort(copyFields); + return new InferenceFieldMetadata(name(), fieldType().inferenceId, copyFields); + } + + public static class SemanticTextFieldType extends SimpleMappedFieldType { + private final String inferenceId; + private final SemanticTextField.ModelSettings modelSettings; + private final ObjectMapper inferenceField; + private final IndexVersion indexVersionCreated; + + public SemanticTextFieldType( + String name, + String modelId, + SemanticTextField.ModelSettings modelSettings, + ObjectMapper inferenceField, + IndexVersion indexVersionCreated, + Map meta + ) { + super(name, false, false, false, TextSearchInfo.NONE, meta); + this.inferenceId = modelId; + this.modelSettings = modelSettings; + this.inferenceField = inferenceField; + this.indexVersionCreated = indexVersionCreated; + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + public String getInferenceId() { + return inferenceId; + } + + public SemanticTextField.ModelSettings getModelSettings() { + return modelSettings; + } + + public ObjectMapper getInferenceField() { + return inferenceField; + } + + public NestedObjectMapper getChunksField() { + return (NestedObjectMapper) inferenceField.getMapper(CHUNKS_FIELD); + } + + public FieldMapper getEmbeddingsField() { + return (FieldMapper) getChunksField().getMapper(CHUNKED_EMBEDDINGS_FIELD); + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + throw new IllegalArgumentException(CONTENT_TYPE + " fields do not support term query"); + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + // Redirect the fetcher to load the original values of the field + return SourceValueFetcher.toString(getOriginalTextFieldName(name()), context, format); + } + + @Override + public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext) { + throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating"); + } + } + + private static ObjectMapper createInferenceField( + MapperBuilderContext context, + IndexVersion indexVersionCreated, + @Nullable SemanticTextField.ModelSettings modelSettings + ) { + return new ObjectMapper.Builder(INFERENCE_FIELD, Explicit.EXPLICIT_TRUE).dynamic(ObjectMapper.Dynamic.FALSE) + .add(createChunksField(indexVersionCreated, modelSettings)) + .build(context); + } + + private static NestedObjectMapper.Builder createChunksField( + IndexVersion indexVersionCreated, + SemanticTextField.ModelSettings modelSettings + ) { + NestedObjectMapper.Builder chunksField = new NestedObjectMapper.Builder(CHUNKS_FIELD, indexVersionCreated); + chunksField.dynamic(ObjectMapper.Dynamic.FALSE); + KeywordFieldMapper.Builder chunkTextField = new KeywordFieldMapper.Builder(CHUNKED_TEXT_FIELD, indexVersionCreated).indexed(false) + .docValues(false); + if (modelSettings != null) { + chunksField.add(createEmbeddingsField(indexVersionCreated, modelSettings)); + } + chunksField.add(chunkTextField); + return chunksField; + } + + private static Mapper.Builder createEmbeddingsField(IndexVersion indexVersionCreated, SemanticTextField.ModelSettings modelSettings) { + return switch (modelSettings.taskType()) { + case SPARSE_EMBEDDING -> new SparseVectorFieldMapper.Builder(CHUNKED_EMBEDDINGS_FIELD); + case TEXT_EMBEDDING -> { + DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( + CHUNKED_EMBEDDINGS_FIELD, + indexVersionCreated + ); + SimilarityMeasure similarity = modelSettings.similarity(); + if (similarity != null) { + switch (similarity) { + case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); + case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); + case L2_NORM -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.L2_NORM); + default -> throw new IllegalArgumentException( + "Unknown similarity measure in model_settings [" + similarity.name() + "]" + ); + } + } + denseVectorMapperBuilder.dimensions(modelSettings.dimensions()); + yield denseVectorMapperBuilder; + } + default -> throw new IllegalArgumentException("Invalid task_type in model_settings [" + modelSettings.taskType().name() + "]"); + }; + } + + private static boolean canMergeModelSettings( + SemanticTextField.ModelSettings previous, + SemanticTextField.ModelSettings current, + Conflicts conflicts + ) { + if (Objects.equals(previous, current)) { + return true; + } + if (previous == null) { + return true; + } + conflicts.addConflict("model_settings", ""); + return false; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java new file mode 100644 index 0000000000000..1c4a2f561ad4a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.cluster.metadata; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.indices.mapping.put.PutMappingClusterStateUpdateRequest; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.cluster.service.ClusterStateTaskExecutorUtils; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.IndexService; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.xpack.inference.InferencePlugin; +import org.hamcrest.Matchers; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.CoreMatchers.equalTo; + +public class SemanticTextClusterMetadataTests extends ESSingleNodeTestCase { + + @Override + protected Collection> getPlugins() { + return List.of(InferencePlugin.class); + } + + public void testCreateIndexWithSemanticTextField() { + final IndexService indexService = createIndex( + "test", + client().admin().indices().prepareCreate("test").setMapping("field", "type=semantic_text,inference_id=test_model") + ); + assertEquals(indexService.getMetadata().getInferenceFields().get("field").getInferenceId(), "test_model"); + } + + public void testSingleSourceSemanticTextField() throws Exception { + final IndexService indexService = createIndex("test", client().admin().indices().prepareCreate("test")); + final MetadataMappingService mappingService = getInstanceFromNode(MetadataMappingService.class); + final MetadataMappingService.PutMappingExecutor putMappingExecutor = mappingService.new PutMappingExecutor(); + final ClusterService clusterService = getInstanceFromNode(ClusterService.class); + + final PutMappingClusterStateUpdateRequest request = new PutMappingClusterStateUpdateRequest(""" + { "properties": { "field": { "type": "semantic_text", "inference_id": "test_model" }}}"""); + request.indices(new Index[] { indexService.index() }); + final var resultingState = ClusterStateTaskExecutorUtils.executeAndAssertSuccessful( + clusterService.state(), + putMappingExecutor, + singleTask(request) + ); + assertEquals(resultingState.metadata().index("test").getInferenceFields().get("field").getInferenceId(), "test_model"); + } + + public void testCopyToSemanticTextField() throws Exception { + final IndexService indexService = createIndex("test", client().admin().indices().prepareCreate("test")); + final MetadataMappingService mappingService = getInstanceFromNode(MetadataMappingService.class); + final MetadataMappingService.PutMappingExecutor putMappingExecutor = mappingService.new PutMappingExecutor(); + final ClusterService clusterService = getInstanceFromNode(ClusterService.class); + + final PutMappingClusterStateUpdateRequest request = new PutMappingClusterStateUpdateRequest(""" + { + "properties": { + "semantic": { + "type": "semantic_text", + "inference_id": "test_model" + }, + "copy_origin_1": { + "type": "text", + "copy_to": "semantic" + }, + "copy_origin_2": { + "type": "text", + "copy_to": "semantic" + } + } + } + """); + request.indices(new Index[] { indexService.index() }); + final var resultingState = ClusterStateTaskExecutorUtils.executeAndAssertSuccessful( + clusterService.state(), + putMappingExecutor, + singleTask(request) + ); + IndexMetadata indexMetadata = resultingState.metadata().index("test"); + InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get("semantic"); + assertThat(inferenceFieldMetadata.getInferenceId(), equalTo("test_model")); + assertThat( + Arrays.asList(inferenceFieldMetadata.getSourceFields()), + Matchers.containsInAnyOrder("semantic", "copy_origin_1", "copy_origin_2") + ); + } + + private static List singleTask(PutMappingClusterStateUpdateRequest request) { + return Collections.singletonList(new MetadataMappingService.PutMappingClusterStateUpdateTask(request, ActionListener.running(() -> { + throw new AssertionError("task should not complete publication"); + }))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java new file mode 100644 index 0000000000000..19776628a8d00 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -0,0 +1,560 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.apache.lucene.document.FeatureField; +import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.QueryBitSetProducer; +import org.apache.lucene.search.join.ScoreMode; +import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.compress.CompressedXContent; +import org.elasticsearch.common.lucene.search.Queries; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.DocumentMapper; +import org.elasticsearch.index.mapper.DocumentParsingException; +import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.KeywordFieldMapper; +import org.elasticsearch.index.mapper.LuceneDocument; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.MapperParsingException; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.MapperTestCase; +import org.elasticsearch.index.mapper.NestedLookup; +import org.elasticsearch.index.mapper.NestedObjectMapper; +import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.index.mapper.SourceToParse; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; +import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.LeafNestedDocuments; +import org.elasticsearch.search.NestedDocuments; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.model.TestModel; +import org.junit.AssumptionViolatedException; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static java.util.Collections.singletonList; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_TEXT_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_ID_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.MODEL_SETTINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getEmbeddingsFieldName; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +public class SemanticTextFieldMapperTests extends MapperTestCase { + @Override + protected Collection getPlugins() { + return singletonList(new InferencePlugin(Settings.EMPTY)); + } + + @Override + protected void minimalMapping(XContentBuilder b) throws IOException { + b.field("type", "semantic_text").field("inference_id", "test_model"); + } + + @Override + protected String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { + return "cannot have nested fields when index is in [index.mode=time_series]"; + } + + @Override + protected Object getSampleValueForDocument() { + return null; + } + + @Override + protected boolean supportsIgnoreMalformed() { + return false; + } + + @Override + protected boolean supportsStoredFields() { + return false; + } + + @Override + protected void registerParameters(ParameterChecker checker) throws IOException {} + + @Override + protected Object generateRandomInputValue(MappedFieldType ft) { + assumeFalse("doc_values are not supported in semantic_text", true); + return null; + } + + @Override + protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) { + throw new AssumptionViolatedException("not supported"); + } + + @Override + protected IngestScriptSupport ingestScriptSupport() { + throw new AssumptionViolatedException("not supported"); + } + + public void testDefaults() throws Exception { + DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); + assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString()); + + ParsedDocument doc1 = mapper.parse(source(this::writeField)); + List fields = doc1.rootDoc().getFields("field"); + + // No indexable fields + assertTrue(fields.isEmpty()); + } + + public void testInferenceIdNotPresent() { + Exception e = expectThrows( + MapperParsingException.class, + () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text"))) + ); + assertThat(e.getMessage(), containsString("field [inference_id] must be specified")); + } + + public void testCannotBeUsedInMultiFields() { + Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { + b.field("type", "text"); + b.startObject("fields"); + b.startObject("semantic"); + b.field("type", "semantic_text"); + b.field("inference_id", "my_inference_id"); + b.endObject(); + b.endObject(); + }))); + assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); + } + + public void testUpdatesToInferenceIdNotSupported() throws IOException { + String fieldName = randomAlphaOfLengthBetween(5, 15); + MapperService mapperService = createMapperService( + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) + ); + assertSemanticTextField(mapperService, fieldName, false); + Exception e = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "another_model").endObject()) + ) + ); + assertThat(e.getMessage(), containsString("Cannot update parameter [inference_id] from [test_model] to [another_model]")); + } + + public void testDynamicUpdate() throws IOException { + MapperService mapperService = createMapperService(mapping(b -> {})); + mapperService.merge( + "_doc", + new CompressedXContent( + Strings.toString(PutMappingRequest.simpleMapping("semantic", "type=semantic_text,inference_id=test_service")) + ), + MapperService.MergeReason.MAPPING_UPDATE + ); + String source = """ + { + "semantic": { + "inference": { + "inference_id": "test_service", + "model_settings": { + "task_type": "SPARSE_EMBEDDING" + }, + "chunks": [ + { + "embeddings": { + "feature_0": 1 + }, + "text": "feature_0" + } + ] + } + } + } + """; + SourceToParse sourceToParse = new SourceToParse("test", new BytesArray(source), XContentType.JSON); + ParsedDocument parsedDocument = mapperService.documentMapper().parse(sourceToParse); + mapperService.merge( + "_doc", + parsedDocument.dynamicMappingsUpdate().toCompressedXContent(), + MapperService.MergeReason.MAPPING_UPDATE + ); + assertSemanticTextField(mapperService, "semantic", true); + } + + public void testUpdateModelSettings() throws IOException { + for (int depth = 1; depth < 5; depth++) { + String fieldName = randomFieldName(depth); + MapperService mapperService = createMapperService( + mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) + ); + assertSemanticTextField(mapperService, fieldName, false); + { + Exception exc = expectThrows( + MapperParsingException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("inference_id", "test_model") + .endObject() + .endObject() + ) + ) + ); + assertThat(exc.getMessage(), containsString("Required [task_type]")); + } + { + merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("task_type", "sparse_embedding") + .endObject() + .endObject() + ) + ); + assertSemanticTextField(mapperService, fieldName, true); + } + { + Exception exc = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject() + ) + ) + ); + assertThat( + exc.getMessage(), + containsString("Cannot update parameter [model_settings] " + "from [task_type=sparse_embedding] to [null]") + ); + } + { + Exception exc = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + mapping( + b -> b.startObject(fieldName) + .field("type", "semantic_text") + .field("inference_id", "test_model") + .startObject("model_settings") + .field("task_type", "text_embedding") + .field("dimensions", 10) + .field("similarity", "cosine") + .endObject() + .endObject() + ) + ) + ); + assertThat( + exc.getMessage(), + containsString( + "Cannot update parameter [model_settings] " + + "from [task_type=sparse_embedding] " + + "to [task_type=text_embedding, dimensions=10, similarity=cosine]" + ) + ); + } + } + } + + static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { + Mapper mapper = mapperService.mappingLookup().getMapper(fieldName); + assertNotNull(mapper); + assertThat(mapper, instanceOf(SemanticTextFieldMapper.class)); + SemanticTextFieldMapper semanticFieldMapper = (SemanticTextFieldMapper) mapper; + + var fieldType = mapperService.fieldType(fieldName); + assertNotNull(fieldType); + assertThat(fieldType, instanceOf(SemanticTextFieldMapper.SemanticTextFieldType.class)); + SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldType; + assertTrue(semanticFieldMapper.fieldType() == semanticTextFieldType); + + NestedObjectMapper chunksMapper = mapperService.mappingLookup() + .nestedLookup() + .getNestedMappers() + .get(getChunksFieldName(fieldName)); + assertThat(chunksMapper, equalTo(semanticFieldMapper.fieldType().getChunksField())); + assertThat(chunksMapper.name(), equalTo(getChunksFieldName(fieldName))); + Mapper textMapper = chunksMapper.getMapper(CHUNKED_TEXT_FIELD); + assertNotNull(textMapper); + assertThat(textMapper, instanceOf(KeywordFieldMapper.class)); + KeywordFieldMapper textFieldMapper = (KeywordFieldMapper) textMapper; + assertFalse(textFieldMapper.fieldType().isIndexed()); + assertFalse(textFieldMapper.fieldType().hasDocValues()); + if (expectedModelSettings) { + assertNotNull(semanticFieldMapper.fieldType().getModelSettings()); + Mapper embeddingsMapper = chunksMapper.getMapper(CHUNKED_EMBEDDINGS_FIELD); + assertNotNull(embeddingsMapper); + assertThat(embeddingsMapper, instanceOf(FieldMapper.class)); + FieldMapper embeddingsFieldMapper = (FieldMapper) embeddingsMapper; + assertTrue(embeddingsFieldMapper.fieldType() == mapperService.mappingLookup().getFieldType(getEmbeddingsFieldName(fieldName))); + assertThat(embeddingsMapper.name(), equalTo(getEmbeddingsFieldName(fieldName))); + switch (semanticFieldMapper.fieldType().getModelSettings().taskType()) { + case SPARSE_EMBEDDING -> assertThat(embeddingsMapper, instanceOf(SparseVectorFieldMapper.class)); + case TEXT_EMBEDDING -> assertThat(embeddingsMapper, instanceOf(DenseVectorFieldMapper.class)); + default -> throw new AssertionError("Invalid task type"); + } + } else { + assertNull(semanticFieldMapper.fieldType().getModelSettings()); + } + } + + public void testSuccessfulParse() throws IOException { + for (int depth = 1; depth < 4; depth++) { + final String fieldName1 = randomFieldName(depth); + final String fieldName2 = randomFieldName(depth + 1); + + Model model1 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + Model model2 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + XContentBuilder mapping = mapping(b -> { + addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId()); + addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId()); + }); + + MapperService mapperService = createMapperService(mapping); + SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName1, false); + SemanticTextFieldMapperTests.assertSemanticTextField(mapperService, fieldName2, false); + DocumentMapper documentMapper = mapperService.documentMapper(); + ParsedDocument doc = documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + b, + List.of( + randomSemanticText(fieldName1, model1, List.of("a b", "c"), XContentType.JSON), + randomSemanticText(fieldName2, model2, List.of("d e f"), XContentType.JSON) + ) + ) + ) + ); + + List luceneDocs = doc.docs(); + assertEquals(4, luceneDocs.size()); + for (int i = 0; i < 3; i++) { + assertEquals(doc.rootDoc(), luceneDocs.get(i).getParent()); + } + // nested docs are in reversed order + assertSparseFeatures(luceneDocs.get(0), getEmbeddingsFieldName(fieldName1), 2); + assertSparseFeatures(luceneDocs.get(1), getEmbeddingsFieldName(fieldName1), 1); + assertSparseFeatures(luceneDocs.get(2), getEmbeddingsFieldName(fieldName2), 3); + assertEquals(doc.rootDoc(), luceneDocs.get(3)); + assertNull(luceneDocs.get(3).getParent()); + + withLuceneIndex(mapperService, iw -> iw.addDocuments(doc.docs()), reader -> { + NestedDocuments nested = new NestedDocuments( + mapperService.mappingLookup(), + QueryBitSetProducer::new, + IndexVersion.current() + ); + LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); + + Set visitedNestedIdentities = new HashSet<>(); + Set expectedVisitedNestedIdentities = Set.of( + new SearchHit.NestedIdentity(getChunksFieldName(fieldName1), 0, null), + new SearchHit.NestedIdentity(getChunksFieldName(fieldName1), 1, null), + new SearchHit.NestedIdentity(getChunksFieldName(fieldName2), 0, null) + ); + + assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); + assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); + + assertNull(leaf.advance(3)); + assertEquals(3, leaf.doc()); + assertEquals(3, leaf.rootDoc()); + assertNull(leaf.nestedIdentity()); + + IndexSearcher searcher = newSearcher(reader); + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a")), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a", "b")), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName2, List.of("d")), + 10 + ); + assertEquals(1, topDocs.totalHits.value); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName2, List.of("z")), + 10 + ); + assertEquals(0, topDocs.totalHits.value); + } + }); + } + } + + public void testMissingInferenceId() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + IllegalArgumentException ex = expectThrows( + DocumentParsingException.class, + IllegalArgumentException.class, + () -> documentMapper.parse( + source( + b -> b.startObject("field") + .startObject(INFERENCE_FIELD) + .field(MODEL_SETTINGS_FIELD, new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null)) + .field(CHUNKS_FIELD, List.of()) + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getCause().getMessage(), containsString("Required [inference_id]")); + } + + public void testMissingModelSettings() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + IllegalArgumentException ex = expectThrows( + DocumentParsingException.class, + IllegalArgumentException.class, + () -> documentMapper.parse( + source(b -> b.startObject("field").startObject(INFERENCE_FIELD).field(INFERENCE_ID_FIELD, "my_id").endObject().endObject()) + ) + ); + assertThat(ex.getCause().getMessage(), containsString("Required [model_settings, chunks]")); + } + + public void testMissingTaskType() throws IOException { + DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id"))); + IllegalArgumentException ex = expectThrows( + DocumentParsingException.class, + IllegalArgumentException.class, + () -> documentMapper.parse( + source( + b -> b.startObject("field") + .startObject(INFERENCE_FIELD) + .field(INFERENCE_ID_FIELD, "my_id") + .startObject(MODEL_SETTINGS_FIELD) + .endObject() + .endObject() + .endObject() + ) + ) + ); + assertThat(ex.getCause().getMessage(), containsString("failed to parse field [model_settings]")); + } + + private static void addSemanticTextMapping(XContentBuilder mappingBuilder, String fieldName, String modelId) throws IOException { + mappingBuilder.startObject(fieldName); + mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); + mappingBuilder.field("inference_id", modelId); + mappingBuilder.endObject(); + } + + private static void addSemanticTextInferenceResults(XContentBuilder sourceBuilder, List semanticTextInferenceResults) + throws IOException { + for (var field : semanticTextInferenceResults) { + sourceBuilder.field(field.fieldName()); + sourceBuilder.value(field); + } + } + + static String randomFieldName(int numLevel) { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < numLevel; i++) { + if (i > 0) { + builder.append('.'); + } + builder.append(randomAlphaOfLengthBetween(5, 15)); + } + return builder.toString(); + } + + private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLookup, String fieldName, List tokens) { + NestedObjectMapper mapper = nestedLookup.getNestedMappers().get(getChunksFieldName(fieldName)); + assertNotNull(mapper); + + BitSetProducer parentFilter = new QueryBitSetProducer(Queries.newNonNestedFilter(IndexVersion.current())); + BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder(); + for (String token : tokens) { + queryBuilder.add( + new BooleanClause(new TermQuery(new Term(getEmbeddingsFieldName(fieldName), token)), BooleanClause.Occur.MUST) + ); + } + queryBuilder.add(new BooleanClause(mapper.nestedTypeFilter(), BooleanClause.Occur.FILTER)); + + return new ESToParentBlockJoinQuery(queryBuilder.build(), parentFilter, ScoreMode.Total, null); + } + + private static void assertChildLeafNestedDocument( + LeafNestedDocuments leaf, + int advanceToDoc, + int expectedRootDoc, + Set visitedNestedIdentities + ) throws IOException { + + assertNotNull(leaf.advance(advanceToDoc)); + assertEquals(advanceToDoc, leaf.doc()); + assertEquals(expectedRootDoc, leaf.rootDoc()); + assertNotNull(leaf.nestedIdentity()); + visitedNestedIdentities.add(leaf.nestedIdentity()); + } + + private static void assertSparseFeatures(LuceneDocument doc, String fieldName, int expectedCount) { + int count = 0; + for (IndexableField field : doc.getFields()) { + if (field instanceof FeatureField featureField) { + assertThat(featureField.name(), equalTo(fieldName)); + ++count; + } + } + assertThat(count, equalTo(expectedCount)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java new file mode 100644 index 0000000000000..3885563720484 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -0,0 +1,237 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.inference.model.TestModel; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; + +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; + +public class SemanticTextFieldTests extends AbstractXContentTestCase { + private static final String NAME = "field"; + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return n -> n.endsWith(CHUNKED_EMBEDDINGS_FIELD); + } + + @Override + protected void assertEqualInstances(SemanticTextField expectedInstance, SemanticTextField newInstance) { + assertThat(newInstance.fieldName(), equalTo(expectedInstance.fieldName())); + assertThat(newInstance.originalValues(), equalTo(expectedInstance.originalValues())); + assertThat(newInstance.inference().modelSettings(), equalTo(expectedInstance.inference().modelSettings())); + assertThat(newInstance.inference().chunks().size(), equalTo(expectedInstance.inference().chunks().size())); + SemanticTextField.ModelSettings modelSettings = newInstance.inference().modelSettings(); + for (int i = 0; i < newInstance.inference().chunks().size(); i++) { + assertThat(newInstance.inference().chunks().get(i).text(), equalTo(expectedInstance.inference().chunks().get(i).text())); + switch (modelSettings.taskType()) { + case TEXT_EMBEDDING -> { + double[] expectedVector = parseDenseVector( + expectedInstance.inference().chunks().get(i).rawEmbeddings(), + modelSettings.dimensions(), + expectedInstance.contentType() + ); + double[] newVector = parseDenseVector( + newInstance.inference().chunks().get(i).rawEmbeddings(), + modelSettings.dimensions(), + newInstance.contentType() + ); + assertArrayEquals(expectedVector, newVector, 0f); + } + case SPARSE_EMBEDDING -> { + List expectedTokens = parseWeightedTokens( + expectedInstance.inference().chunks().get(i).rawEmbeddings(), + expectedInstance.contentType() + ); + List newTokens = parseWeightedTokens( + newInstance.inference().chunks().get(i).rawEmbeddings(), + newInstance.contentType() + ); + assertThat(newTokens, equalTo(expectedTokens)); + } + default -> throw new AssertionError("Invalid task type " + modelSettings.taskType()); + } + } + } + + @Override + protected SemanticTextField createTestInstance() { + List rawValues = randomList(1, 5, () -> randomAlphaOfLengthBetween(10, 20)); + return randomSemanticText(NAME, TestModel.createRandomInstance(), rawValues, randomFrom(XContentType.values())); + } + + @Override + protected SemanticTextField doParseInstance(XContentParser parser) throws IOException { + return SemanticTextField.parse(parser, new Tuple<>(NAME, parser.contentType())); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + public void testModelSettingsValidation() { + NullPointerException npe = expectThrows(NullPointerException.class, () -> { + new SemanticTextField.ModelSettings(null, 10, SimilarityMeasure.COSINE); + }); + assertThat(npe.getMessage(), equalTo("task type must not be null")); + + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { + new SemanticTextField.ModelSettings(TaskType.COMPLETION, 10, SimilarityMeasure.COSINE); + }); + assertThat(ex.getMessage(), containsString("Wrong [task_type]")); + + ex = expectThrows( + IllegalArgumentException.class, + () -> { new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, 10, null); } + ); + assertThat(ex.getMessage(), containsString("[dimensions] is not allowed")); + + ex = expectThrows(IllegalArgumentException.class, () -> { + new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, SimilarityMeasure.COSINE); + }); + assertThat(ex.getMessage(), containsString("[similarity] is not allowed")); + + ex = expectThrows(IllegalArgumentException.class, () -> { + new SemanticTextField.ModelSettings(TaskType.TEXT_EMBEDDING, null, SimilarityMeasure.COSINE); + }); + assertThat(ex.getMessage(), containsString("required [dimensions] field is missing")); + + ex = expectThrows( + IllegalArgumentException.class, + () -> { new SemanticTextField.ModelSettings(TaskType.TEXT_EMBEDDING, 10, null); } + ); + assertThat(ex.getMessage(), containsString("required [similarity] field is missing")); + } + + public static ChunkedTextEmbeddingResults randomTextEmbeddings(Model model, List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + double[] values = new double[model.getServiceSettings().dimensions()]; + for (int j = 0; j < values.length; j++) { + values[j] = randomDouble(); + } + chunks.add(new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk(input, values)); + } + return new ChunkedTextEmbeddingResults(chunks); + } + + public static ChunkedSparseEmbeddingResults randomSparseEmbeddings(List inputs) { + List chunks = new ArrayList<>(); + for (String input : inputs) { + var tokens = new ArrayList(); + for (var token : input.split("\\s+")) { + tokens.add(new TextExpansionResults.WeightedToken(token, randomFloat())); + } + chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input, tokens)); + } + return new ChunkedSparseEmbeddingResults(chunks); + } + + public static SemanticTextField randomSemanticText(String fieldName, Model model, List inputs, XContentType contentType) { + ChunkedInferenceServiceResults results = switch (model.getTaskType()) { + case TEXT_EMBEDDING -> randomTextEmbeddings(model, inputs); + case SPARSE_EMBEDDING -> randomSparseEmbeddings(inputs); + default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); + }; + return new SemanticTextField( + fieldName, + inputs, + new SemanticTextField.InferenceResult( + model.getInferenceEntityId(), + new SemanticTextField.ModelSettings(model), + toSemanticTextFieldChunks(fieldName, model.getInferenceEntityId(), List.of(results), contentType) + ), + contentType + ); + } + + public static ChunkedInferenceServiceResults toChunkedResult(SemanticTextField field) { + switch (field.inference().modelSettings().taskType()) { + case SPARSE_EMBEDDING -> { + List chunks = new ArrayList<>(); + for (var chunk : field.inference().chunks()) { + var tokens = parseWeightedTokens(chunk.rawEmbeddings(), field.contentType()); + chunks.add(new ChunkedTextExpansionResults.ChunkedResult(chunk.text(), tokens)); + } + return new ChunkedSparseEmbeddingResults(chunks); + } + case TEXT_EMBEDDING -> { + List chunks = + new ArrayList<>(); + for (var chunk : field.inference().chunks()) { + double[] values = parseDenseVector( + chunk.rawEmbeddings(), + field.inference().modelSettings().dimensions(), + field.contentType() + ); + chunks.add( + new org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk( + chunk.text(), + values + ) + ); + } + return new ChunkedTextEmbeddingResults(chunks); + } + default -> throw new AssertionError("Invalid task_type: " + field.inference().modelSettings().taskType().name()); + } + } + + private static double[] parseDenseVector(BytesReference value, int numDims, XContentType contentType) { + try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, value, contentType)) { + parser.nextToken(); + assertThat(parser.currentToken(), equalTo(XContentParser.Token.START_ARRAY)); + double[] values = new double[numDims]; + for (int i = 0; i < numDims; i++) { + assertThat(parser.nextToken(), equalTo(XContentParser.Token.VALUE_NUMBER)); + values[i] = parser.doubleValue(); + } + assertThat(parser.nextToken(), equalTo(XContentParser.Token.END_ARRAY)); + return values; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static List parseWeightedTokens(BytesReference value, XContentType contentType) { + try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, value, contentType)) { + Map map = parser.map(); + List weightedTokens = new ArrayList<>(); + for (var entry : map.entrySet()) { + weightedTokens.add(new TextExpansionResults.WeightedToken(entry.getKey(), ((Number) entry.getValue()).floatValue())); + } + return weightedTokens; + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java index 75e7ca12c1d56..ced6e3ff43e2c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/model/TestModel.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ToXContentObject; @@ -26,16 +27,23 @@ import java.util.Map; import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength; +import static org.elasticsearch.test.ESTestCase.randomFrom; import static org.elasticsearch.test.ESTestCase.randomInt; public class TestModel extends Model { public static TestModel createRandomInstance() { + return createRandomInstance(randomFrom(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)); + } + + public static TestModel createRandomInstance(TaskType taskType) { + var dimensions = taskType == TaskType.TEXT_EMBEDDING ? randomInt(1024) : null; + var similarity = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(SimilarityMeasure.values()) : null; return new TestModel( randomAlphaOfLength(4), - TaskType.TEXT_EMBEDDING, + taskType, randomAlphaOfLength(10), - new TestModel.TestServiceSettings(randomAlphaOfLength(4)), + new TestModel.TestServiceSettings(randomAlphaOfLength(4), dimensions, similarity), new TestModel.TestTaskSettings(randomInt(3)), new TestModel.TestSecretSettings(randomAlphaOfLength(4)) ); @@ -70,7 +78,7 @@ public TestSecretSettings getSecretSettings() { return (TestSecretSettings) super.getSecretSettings(); } - public record TestServiceSettings(String model) implements ServiceSettings { + public record TestServiceSettings(String model, Integer dimensions, SimilarityMeasure similarity) implements ServiceSettings { private static final String NAME = "test_service_settings"; @@ -87,17 +95,23 @@ public static TestServiceSettings fromMap(Map map) { throw validationException; } - return new TestServiceSettings(model); + return new TestServiceSettings(model, null, null); } public TestServiceSettings(StreamInput in) throws IOException { - this(in.readString()); + this(in.readString(), in.readOptionalVInt(), in.readOptionalEnum(SimilarityMeasure.class)); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field("model", model); + if (dimensions != null) { + builder.field("dimensions", dimensions()); + } + if (similarity != null) { + builder.field("similarity", similarity); + } builder.endObject(); return builder; } @@ -115,12 +129,24 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(model); + out.writeOptionalVInt(dimensions); + out.writeOptionalEnum(similarity); } @Override public ToXContentObject getFilteredXContentObject() { return this; } + + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public Integer dimensions() { + return dimensions; + } } public record TestTaskSettings(Integer temperature) implements TaskSettings { diff --git a/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java new file mode 100644 index 0000000000000..a594c577dcdd2 --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/java/org/elasticsearch/xpack/inference/InferenceRestIT.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.local.distribution.DistributionType; +import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; +import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; +import org.junit.ClassRule; + +public class InferenceRestIT extends ESClientYamlSuiteTestCase { + + @ClassRule + public static ElasticsearchCluster cluster = ElasticsearchCluster.local() + .setting("xpack.security.enabled", "false") + .setting("xpack.security.http.ssl.enabled", "false") + .plugin("x-pack-inference") + .plugin("inference-service-test") + .distribution(DistributionType.INTEG_TEST) + .build(); + + public InferenceRestIT(final ClientYamlTestCandidate testCandidate) { + super(testCandidate); + } + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } + + @ParametersFactory + public static Iterable parameters() throws Exception { + return ESClientYamlSuiteTestCase.createParameters(); + } +} diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_field_mapping.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_field_mapping.yml new file mode 100644 index 0000000000000..532f25a556f6d --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/10_semantic_text_field_mapping.yml @@ -0,0 +1,175 @@ +setup: + - skip: + version: " - 8.14.99" + reason: semantic_text introduced in 8.15.0 + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + sparse_field: + type: semantic_text + inference_id: sparse-inference-id + dense_field: + type: semantic_text + inference_id: dense-inference-id + +--- +"Indexes sparse vector document": + + # Checks mapping is not updated until first doc arrives + - do: + indices.get_mapping: + index: test-index + + - match: { "test-index.mappings.properties.sparse_field.type": semantic_text } + - match: { "test-index.mappings.properties.sparse_field.inference_id": sparse-inference-id } + - length: { "test-index.mappings.properties.sparse_field": 2 } + + - do: + index: + index: test-index + id: doc_1 + body: + sparse_field: + text: "these are not the droids you're looking for. He's free to go around" + inference: + inference_id: sparse-inference-id + model_settings: + task_type: sparse_embedding + chunks: + - text: "these are not the droids you're looking for" + embeddings: + feature_0: 1.0 + feature_1: 2.0 + feature_2: 3.0 + feature_3: 4.0 + - text: "He's free to go around" + embeddings: + feature_4: 0.1 + feature_5: 0.2 + feature_6: 0.3 + feature_7: 0.4 + + # Checks mapping is updated when first doc arrives + - do: + indices.get_mapping: + index: test-index + + - match: { "test-index.mappings.properties.sparse_field.type": semantic_text } + - match: { "test-index.mappings.properties.sparse_field.inference_id": sparse-inference-id } + - match: { "test-index.mappings.properties.sparse_field.model_settings.task_type": sparse_embedding } + - length: { "test-index.mappings.properties.sparse_field": 3 } + +--- +"Indexes dense vector document": + + # Checks mapping is not updated until first doc arrives + - do: + indices.get_mapping: + index: test-index + + - match: { "test-index.mappings.properties.dense_field.type": semantic_text } + - match: { "test-index.mappings.properties.dense_field.inference_id": dense-inference-id } + - length: { "test-index.mappings.properties.dense_field": 2 } + + - do: + index: + index: test-index + id: doc_2 + body: + dense_field: + text: "these are not the droids you're looking for. He's free to go around" + inference: + inference_id: dense-inference-id + model_settings: + task_type: text_embedding + dimensions: 4 + similarity: cosine + chunks: + - text: "these are not the droids you're looking for" + embeddings: [0.04673296958208084, -0.03237321600317955, -0.02543032355606556, 0.056035321205854416] + - text: "He's free to go around" + embeddings: [0.00641461368650198, -0.0016253676731139421, -0.05126338079571724, 0.053438711911439896] + + # Checks mapping is updated when first doc arrives + - do: + indices.get_mapping: + index: test-index + + - match: { "test-index.mappings.properties.dense_field.type": semantic_text } + - match: { "test-index.mappings.properties.dense_field.inference_id": dense-inference-id } + - match: { "test-index.mappings.properties.dense_field.model_settings.task_type": text_embedding } + - length: { "test-index.mappings.properties.dense_field": 3 } + +--- +"Can't be used as a multifield": + + - do: + catch: /Field \[semantic\] of type \[semantic_text\] can't be used in multifields/ + indices.create: + index: test-multi-index + body: + mappings: + properties: + text_field: + type: text + fields: + semantic: + type: semantic_text + inference_id: sparse-inference-id + +--- +"Can't have multifields": + + - do: + catch: /semantic_text field \[semantic\] does not support multi-fields/ + indices.create: + index: test-multi-index + body: + mappings: + properties: + semantic: + type: semantic_text + inference_id: sparse-inference-id + fields: + keyword_field: + type: keyword + +--- +"Can't configure copy_to in semantic_text": + + - do: + catch: /semantic_text field \[semantic\] does not support \[copy_to\]/ + indices.create: + index: test-copy_to-index + body: + mappings: + properties: + semantic: + type: semantic_text + inference_id: sparse-inference-id + copy_to: another_field + another_field: + type: keyword + +--- +"Can be used as a nested field": + + - do: + indices.create: + index: test-copy_to-index + body: + mappings: + properties: + nested: + type: nested + properties: + semantic: + type: semantic_text + inference_id: sparse-inference-id + another_field: + type: keyword + diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapping_incompatible_field_mapping.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapping_incompatible_field_mapping.yml new file mode 100644 index 0000000000000..3dc9081d121ab --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/20_semantic_text_field_mapping_incompatible_field_mapping.yml @@ -0,0 +1,223 @@ +setup: + - skip: + version: " - 8.14.99" + reason: semantic_text introduced in 8.15.0 + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + sparse_field: + type: semantic_text + inference_id: sparse-inference-id + dense_field: + type: semantic_text + inference_id: dense-inference-id + + # Indexes a doc with inference results to update mappings + - do: + index: + index: test-index + id: doc_1 + body: + sparse_field: + text: "these are not the droids you're looking for. He's free to go around" + inference: + inference_id: sparse-inference-id + model_settings: + task_type: sparse_embedding + chunks: + - text: "these are not the droids you're looking for" + embeddings: + feature_0: 1.0 + feature_1: 2.0 + feature_2: 3.0 + feature_3: 4.0 + - text: "He's free to go around" + embeddings: + feature_4: 0.1 + feature_5: 0.2 + feature_6: 0.3 + feature_7: 0.4 + dense_field: + text: "these are not the droids you're looking for. He's free to go around" + inference: + inference_id: dense-inference-id + model_settings: + task_type: text_embedding + dimensions: 4 + similarity: cosine + chunks: + - text: "these are not the droids you're looking for" + embeddings: [0.04673296958208084, -0.03237321600317955, -0.02543032355606556, 0.056035321205854416] + - text: "He's free to go around" + embeddings: [0.00641461368650198, -0.0016253676731139421, -0.05126338079571724, 0.053438711911439896] + + +--- +"Fails for non-compatible dimensions": + + - do: + catch: /Incompatible model settings for field \[dense_field\].+/ + index: + index: test-index + id: doc_2 + body: + dense_field: + text: "other text" + inference: + inference_id: dense-inference-id + model_settings: + task_type: text_embedding + dimensions: 5 + similarity: cosine + chunks: + - text: "other text" + embeddings: [0.04673296958208084, -0.03237321600317955, -0.02543032355606556, 0.056035321205854416, 0.053438711911439896] + +--- +"Fails for non-compatible inference id": + + - do: + catch: /The configured inference_id \[a-different-inference-id\] for field \[dense_field\] doesn't match the inference_id \[dense-inference-id\].+/ + index: + index: test-index + id: doc_2 + body: + dense_field: + text: "other text" + inference: + inference_id: a-different-inference-id + model_settings: + task_type: text_embedding + dimensions: 4 + similarity: cosine + chunks: + - text: "other text" + embeddings: [0.04673296958208084, -0.03237321600317955, -0.02543032355606556, 0.056035321205854416] + +--- +"Fails for non-compatible similarity": + + - do: + catch: /Incompatible model settings for field \[dense_field\].+/ + index: + index: test-index + id: doc_2 + body: + dense_field: + text: "other text" + inference: + inference_id: dense-inference-id + model_settings: + task_type: text_embedding + dimensions: 4 + similarity: dot_product + chunks: + - text: "other text" + embeddings: [0.04673296958208084, -0.03237321600317955, -0.02543032355606556, 0.056035321205854416] + +--- +"Fails for non-compatible task type for dense vectors": + + - do: + catch: /Incompatible model settings for field \[dense_field\].+/ + index: + index: test-index + id: doc_2 + body: + dense_field: + text: "other text" + inference: + inference_id: dense-inference-id + model_settings: + task_type: sparse_embedding + chunks: + - text: "these are not the droids you're looking for" + embeddings: + feature_0: 1.0 + feature_1: 2.0 + feature_2: 3.0 + feature_3: 4.0 + +--- +"Fails for non-compatible task type for sparse vectors": + + - do: + catch: /Incompatible model settings for field \[sparse_field\].+/ + index: + index: test-index + id: doc_2 + body: + sparse_field: + text: "these are not the droids you're looking for. He's free to go around" + inference: + inference_id: sparse-inference-id + model_settings: + task_type: text_embedding + dimensions: 4 + similarity: cosine + chunks: + - text: "these are not the droids you're looking for" + embeddings: [0.04673296958208084, -0.03237321600317955, -0.02543032355606556, 0.056035321205854416] + +--- +"Fails for missing dense vector inference results in chunks": + + - do: + catch: /failed to parse field \[dense_field\] of type \[semantic_text\]/ + index: + index: test-index + id: doc_2 + body: + dense_field: + text: "these are not the droids you're looking for. He's free to go around" + inference: + inference_id: dense-inference-id + model_settings: + task_type: text_embedding + dimensions: 4 + similarity: cosine + chunks: + - text: "these are not the droids you're looking for" + +--- +"Fails for missing sparse vector inference results in chunks": + + - do: + catch: /failed to parse field \[sparse_field\] of type \[semantic_text\]/ + index: + index: test-index + id: doc_2 + body: + sparse_field: + text: "these are not the droids you're looking for. He's free to go around" + inference: + inference_id: sparse-inference-id + model_settings: + task_type: sparse_embedding + chunks: + - text: "these are not the droids you're looking for" + +--- +"Fails for missing text in chunks": + + - do: + catch: /failed to parse field \[dense_field\] of type \[semantic_text\]/ + index: + index: test-index + id: doc_2 + body: + dense_field: + text: "these are not the droids you're looking for. He's free to go around" + inference: + inference_id: dense-inference-id + model_settings: + task_type: text_embedding + dimensions: 4 + similarity: cosine + chunks: + - embeddings: [ 0.04673296958208084, -0.03237321600317955, -0.02543032355606556, 0.056035321205854416 ] +