Skip to content

Commit

Permalink
Enhance syntax for nested mapping in destination fields (#841)
Browse files Browse the repository at this point in the history
* Enhance syntax for nested mapping in destination fields

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored Jul 22, 2024
1 parent d96f7d1 commit 770a8ca
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 49 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- InferenceProcessor inherits from AbstractBatchingProcessor to support sub batching in processor [#820](https://github.com/opensearch-project/neural-search/pull/820)
- Adds dynamic knn query parameters efsearch and nprobes [#814](https://github.com/opensearch-project/neural-search/pull/814/)
- Enable '.' for nested field in text embedding processor ([#811](https://github.com/opensearch-project/neural-search/pull/811))
- Enhance syntax for nested mapping in destination fields([#841](https://github.com/opensearch-project/neural-search/pull/841))
### Bug Fixes
- Fix function names and comments in the gradle file for BWC tests ([#795](https://github.com/opensearch-project/neural-search/pull/795/files))
- Fix for missing HybridQuery results when concurrent segment search is enabled ([#800](https://github.com/opensearch-project/neural-search/pull/800))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand All @@ -29,6 +30,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.IndexFieldMapper;

import org.opensearch.ingest.AbstractBatchingProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.IngestDocumentWrapper;
Expand All @@ -50,6 +52,17 @@ public abstract class InferenceProcessor extends AbstractBatchingProcessor {

public static final String MODEL_ID_FIELD = "model_id";
public static final String FIELD_MAP_FIELD = "field_map";
private static final BiFunction<Object, Object, Object> REMAPPING_FUNCTION = (v1, v2) -> {
if (v1 instanceof Collection && v2 instanceof Collection) {
((Collection) v1).addAll((Collection) v2);
return v1;
} else if (v1 instanceof Map && v2 instanceof Map) {
((Map) v1).putAll((Map) v2);
return v1;
} else {
return v2;
}
};

private final String type;

Expand Down Expand Up @@ -325,17 +338,7 @@ void buildNestedMap(String parentKey, Object processorKey, Map<String, Object> s
buildNestedMap(nestedFieldMapEntry.getKey(), nestedFieldMapEntry.getValue(), map, next);
}
}
treeRes.merge(parentKey, next, (v1, v2) -> {
if (v1 instanceof Collection && v2 instanceof Collection) {
((Collection) v1).addAll((Collection) v2);
return v1;
} else if (v1 instanceof Map && v2 instanceof Map) {
((Map) v1).putAll((Map) v2);
return v1;
} else {
return v2;
}
});
treeRes.merge(parentKey, next, REMAPPING_FUNCTION);
} else {
String key = String.valueOf(processorKey);
treeRes.put(key, sourceAndMetadataMap.get(parentKey));
Expand Down Expand Up @@ -389,8 +392,9 @@ Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> res
IndexWrapper indexWrapper = new IndexWrapper(0);
Map<String, Object> result = new LinkedHashMap<>();
for (Map.Entry<String, Object> knnMapEntry : processorMap.entrySet()) {
String knnKey = knnMapEntry.getKey();
Object sourceValue = knnMapEntry.getValue();
Pair<String, Object> processedNestedKey = processNestedKey(knnMapEntry);
String knnKey = processedNestedKey.getKey();
Object sourceValue = processedNestedKey.getValue();
if (sourceValue instanceof String) {
result.put(knnKey, results.get(indexWrapper.index++));
} else if (sourceValue instanceof List) {
Expand Down Expand Up @@ -419,19 +423,31 @@ private void putNLPResultToSourceMapForMapType(
nestedElement.put(inputNestedMapEntry.getKey(), results.get(indexWrapper.index++));
}
} else {
Pair<String, Object> processedNestedKey = processNestedKey(inputNestedMapEntry);
Map<String, Object> sourceMap;
if (sourceAndMetadataMap.get(processorKey) == null) {
sourceMap = new HashMap<>();
sourceAndMetadataMap.put(processorKey, sourceMap);
} else {
sourceMap = (Map<String, Object>) sourceAndMetadataMap.get(processorKey);
}
putNLPResultToSourceMapForMapType(
inputNestedMapEntry.getKey(),
inputNestedMapEntry.getValue(),
processedNestedKey.getKey(),
processedNestedKey.getValue(),
results,
indexWrapper,
(Map<String, Object>) sourceAndMetadataMap.get(processorKey)
sourceMap
);
}
}
} else if (sourceValue instanceof String) {
sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++));
sourceAndMetadataMap.merge(processorKey, results.get(indexWrapper.index++), REMAPPING_FUNCTION);
} else if (sourceValue instanceof List) {
sourceAndMetadataMap.put(processorKey, buildNLPResultForListType((List<String>) sourceValue, results, indexWrapper));
sourceAndMetadataMap.merge(
processorKey,
buildNLPResultForListType((List<String>) sourceValue, results, indexWrapper),
REMAPPING_FUNCTION
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

import org.apache.commons.lang3.StringUtils;
Expand Down Expand Up @@ -41,10 +42,15 @@ public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT {
protected static final String LEVEL_1_FIELD = "nested_passages";
protected static final String LEVEL_2_FIELD = "level_2";
protected static final String LEVEL_3_FIELD_TEXT = "level_3_text";
protected static final String LEVEL_3_FIELD_CONTAINER = "level_3_container";
protected static final String LEVEL_3_FIELD_EMBEDDING = "level_3_embedding";
protected static final String TEXT_FIELD_VALUE_1 = "hello";
protected static final String TEXT_FIELD_VALUE_2 = "clown";
protected static final String TEXT_FIELD_VALUE_3 = "abc";
private final String INGEST_DOC1 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc1.json").toURI()));
private final String INGEST_DOC2 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc2.json").toURI()));
private final String INGEST_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc3.json").toURI()));
private final String INGEST_DOC4 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc4.json").toURI()));
private final String BULK_ITEM_TEMPLATE = Files.readString(
Path.of(classLoader.getResource("processor/bulk_item_template.json").toURI())
);
Expand Down Expand Up @@ -99,23 +105,17 @@ public void testNestedFieldMapping_whenDocumentsIngested_thenSuccessful() throws
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING);
createTextEmbeddingIndex();
ingestDocument(INGEST_DOC3, "3");
ingestDocument(INGEST_DOC4, "4");

Map<String, Object> sourceMap = (Map<String, Object>) getDocById(INDEX_NAME, "3").get("_source");
assertNotNull(sourceMap);
assertTrue(sourceMap.containsKey(LEVEL_1_FIELD));
Map<String, Object> nestedPassages = (Map<String, Object>) sourceMap.get(LEVEL_1_FIELD);
assertTrue(nestedPassages.containsKey(LEVEL_2_FIELD));
Map<String, Object> level2 = (Map<String, Object>) nestedPassages.get(LEVEL_2_FIELD);
assertEquals(QUERY_TEXT, level2.get(LEVEL_3_FIELD_TEXT));
assertTrue(level2.containsKey(LEVEL_3_FIELD_EMBEDDING));
List<Double> embeddings = (List<Double>) level2.get(LEVEL_3_FIELD_EMBEDDING);
assertEquals(768, embeddings.size());
for (Double embedding : embeddings) {
assertTrue(embedding >= 0.0 && embedding <= 1.0);
}
assertDoc(
(Map<String, Object>) getDocById(INDEX_NAME, "3").get("_source"),
TEXT_FIELD_VALUE_1,
Optional.of(TEXT_FIELD_VALUE_3)
);
assertDoc((Map<String, Object>) getDocById(INDEX_NAME, "4").get("_source"), TEXT_FIELD_VALUE_2, Optional.empty());

NeuralQueryBuilder neuralQueryBuilderQuery = new NeuralQueryBuilder(
LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_EMBEDDING,
LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_CONTAINER + "." + LEVEL_3_FIELD_EMBEDDING,
QUERY_TEXT,
"",
modelId,
Expand All @@ -133,7 +133,7 @@ public void testNestedFieldMapping_whenDocumentsIngested_thenSuccessful() throws
);
QueryBuilder queryNestedHighLevel = QueryBuilders.nestedQuery(LEVEL_1_FIELD, queryNestedLowerLevel, ScoreMode.Total);

Map<String, Object> searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 1);
Map<String, Object> searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 2);
assertNotNull(searchResponseAsMap);

Map<String, Object> hits = (Map<String, Object>) searchResponseAsMap.get("hits");
Expand All @@ -142,15 +142,38 @@ public void testNestedFieldMapping_whenDocumentsIngested_thenSuccessful() throws
assertEquals(1.0, hits.get("max_score"));
List<Map<String, Object>> listOfHits = (List<Map<String, Object>>) hits.get("hits");
assertNotNull(listOfHits);
assertEquals(1, listOfHits.size());
Map<String, Object> hitsInner = listOfHits.get(0);
assertEquals("3", hitsInner.get("_id"));
assertEquals(1.0, hitsInner.get("_score"));
assertEquals(2, listOfHits.size());

Map<String, Object> innerHitDetails = listOfHits.get(0);
assertEquals("3", innerHitDetails.get("_id"));
assertEquals(1.0, innerHitDetails.get("_score"));

innerHitDetails = listOfHits.get(1);
assertEquals("4", innerHitDetails.get("_id"));
assertTrue((double) innerHitDetails.get("_score") <= 1.0);
} finally {
wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null);
}
}

private void assertDoc(Map<String, Object> sourceMap, String textFieldValue, Optional<String> level3ExpectedValue) {
assertNotNull(sourceMap);
assertTrue(sourceMap.containsKey(LEVEL_1_FIELD));
Map<String, Object> nestedPassages = (Map<String, Object>) sourceMap.get(LEVEL_1_FIELD);
assertTrue(nestedPassages.containsKey(LEVEL_2_FIELD));
Map<String, Object> level2 = (Map<String, Object>) nestedPassages.get(LEVEL_2_FIELD);
assertEquals(textFieldValue, level2.get(LEVEL_3_FIELD_TEXT));
Map<String, Object> level3 = (Map<String, Object>) level2.get(LEVEL_3_FIELD_CONTAINER);
List<Double> embeddings = (List<Double>) level3.get(LEVEL_3_FIELD_EMBEDDING);
assertEquals(768, embeddings.size());
for (Double embedding : embeddings) {
assertTrue(embedding >= 0.0 && embedding <= 1.0);
}
if (level3ExpectedValue.isPresent()) {
assertEquals(level3ExpectedValue.get(), level3.get("level_4_text_field"));
}
}

public void testTextEmbeddingProcessor_withBatchSizeInProcessor() throws Exception {
String modelId = null;
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ public class TextEmbeddingProcessorTests extends InferenceProcessorTestCase {
protected static final String CHILD_FIELD_LEVEL_2 = "child_level2";
protected static final String CHILD_LEVEL_2_TEXT_FIELD_VALUE = "text_field_value";
protected static final String CHILD_LEVEL_2_KNN_FIELD = "test3_knn";
protected static final String CHILD_1_TEXT_FIELD = "child_1_text_field";
protected static final String TEXT_VALUE_1 = "text_value";
protected static final String TEXT_FIELD_2 = "abc";
@Mock
private MLCommonsClientAccessor mlCommonsClientAccessor;

Expand Down Expand Up @@ -363,6 +366,126 @@ public void testNestedFieldInMapping_withMapTypeInput_successful() {
}
}

@SneakyThrows
public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentHasTheDestinationStructure_theSuccessful() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
/*
modeling following document:
parent:
child_level_1:
child_level_1_text_field: "text"
child_level_2:
child_level_2_text_field: "abc"
*/
Map<String, String> childLevel2NestedField = new HashMap<>();
childLevel2NestedField.put(CHILD_LEVEL_2_TEXT_FIELD_VALUE, TEXT_FIELD_2);
Map<String, Object> childLevel2 = new HashMap<>();
childLevel2.put(CHILD_FIELD_LEVEL_2, childLevel2NestedField);
childLevel2.put(CHILD_1_TEXT_FIELD, TEXT_VALUE_1);
Map<String, Object> childLevel1 = new HashMap<>();
childLevel1.put(CHILD_FIELD_LEVEL_1, childLevel2);
sourceAndMetadata.put(PARENT_FIELD, childLevel1);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());

Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
config.put(
TextEmbeddingProcessor.FIELD_MAP_FIELD,
ImmutableMap.of(
String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1, CHILD_1_TEXT_FIELD)),
CHILD_FIELD_LEVEL_2 + "." + CHILD_LEVEL_2_KNN_FIELD
)
);
TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(
registry,
PROCESSOR_TAG,
DESCRIPTION,
config
);

List<List<Float>> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f);
doAnswer(invocation -> {
ActionListener<List<List<Float>>> listener = invocation.getArgument(2);
listener.onResponse(modelTensorList);
return null;
}).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class));

processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {});
assertNotNull(ingestDocument);
assertNotNull(ingestDocument.getSourceAndMetadata().get(PARENT_FIELD));
Map<String, Object> parent1AfterProcessor = (Map<String, Object>) ingestDocument.getSourceAndMetadata().get(PARENT_FIELD);
Map<String, Object> childLevel1Actual = (Map<String, Object>) parent1AfterProcessor.get(CHILD_FIELD_LEVEL_1);
assertEquals(2, childLevel1Actual.size());
assertEquals(TEXT_VALUE_1, childLevel1Actual.get(CHILD_1_TEXT_FIELD));
Map<String, Object> child2Actual = (Map<String, Object>) childLevel1Actual.get(CHILD_FIELD_LEVEL_2);
assertEquals(2, child2Actual.size());
assertEquals(TEXT_FIELD_2, child2Actual.get(CHILD_LEVEL_2_TEXT_FIELD_VALUE));
List<Float> vectors = (List<Float>) child2Actual.get(CHILD_LEVEL_2_KNN_FIELD);
assertEquals(100, vectors.size());
for (Float vector : vectors) {
assertTrue(vector >= 0.0f && vector <= 1.0f);
}
}

@SneakyThrows
public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentWithoutDestinationStructure_theSuccessful() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
/*
modeling following document:
parent:
child_level_1:
child_level_1_text_field: "text"
*/
Map<String, Object> childLevel2 = new HashMap<>();
childLevel2.put(CHILD_1_TEXT_FIELD, TEXT_VALUE_1);
Map<String, Object> childLevel1 = new HashMap<>();
childLevel1.put(CHILD_FIELD_LEVEL_1, childLevel2);
sourceAndMetadata.put(PARENT_FIELD, childLevel1);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());

Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
config.put(
TextEmbeddingProcessor.FIELD_MAP_FIELD,
ImmutableMap.of(
String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1, CHILD_1_TEXT_FIELD)),
CHILD_FIELD_LEVEL_2 + "." + CHILD_LEVEL_2_KNN_FIELD
)
);
TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(
registry,
PROCESSOR_TAG,
DESCRIPTION,
config
);

List<List<Float>> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f);
doAnswer(invocation -> {
ActionListener<List<List<Float>>> listener = invocation.getArgument(2);
listener.onResponse(modelTensorList);
return null;
}).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class));

processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {});
assertNotNull(ingestDocument);
assertNotNull(ingestDocument.getSourceAndMetadata().get(PARENT_FIELD));
Map<String, Object> parent1AfterProcessor = (Map<String, Object>) ingestDocument.getSourceAndMetadata().get(PARENT_FIELD);
Map<String, Object> childLevel1Actual = (Map<String, Object>) parent1AfterProcessor.get(CHILD_FIELD_LEVEL_1);
assertEquals(2, childLevel1Actual.size());
assertEquals(TEXT_VALUE_1, childLevel1Actual.get(CHILD_1_TEXT_FIELD));
Map<String, Object> child2Actual = (Map<String, Object>) childLevel1Actual.get(CHILD_FIELD_LEVEL_2);
assertEquals(1, child2Actual.size());
List<Float> vectors = (List<Float>) child2Actual.get(CHILD_LEVEL_2_KNN_FIELD);
assertEquals(100, vectors.size());
for (Float vector : vectors) {
assertTrue(vector >= 0.0f && vector <= 1.0f);
}
}

@SneakyThrows
public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_successful() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand Down
Loading

0 comments on commit 770a8ca

Please sign in to comment.