diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java index f77d72157..70ddc0d60 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -169,18 +169,18 @@ Map buildTextEmbeddingResult(final String knnKey, List mo private void validateEmbeddingFieldsValue(final IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { - Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); + String mappedSourceKey = embeddingFieldsEntry.getValue(); + Object sourceValue = sourceAndMetadataMap.get(mappedSourceKey); if (Objects.isNull(sourceValue)) { continue; } - String sourceKey = embeddingFieldsEntry.getKey(); Class sourceValueClass = sourceValue.getClass(); if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { - validateNestedTypeValue(sourceKey, sourceValue, () -> 1); + validateNestedTypeValue(mappedSourceKey, sourceValue, () -> 1); } else if (!String.class.isAssignableFrom(sourceValueClass)) { - throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); + throw new IllegalArgumentException("field [" + mappedSourceKey + "] is neither string nor nested type, can not process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); + throw new IllegalArgumentException("field [" + mappedSourceKey + "] has empty string value, can not process it"); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java index 97597691d..c0cab4422 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java @@ -89,6 +89,54 @@ public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalA } } + @SneakyThrows + public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_throwIllegalArgumentException() { + boolean ignoreFailure = false; + String modelId = "mockModelId"; + String embeddingField = "my_embedding_field"; + + // create with null type mapping + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new TextImageEmbeddingProcessor(PROCESSOR_TAG, DESCRIPTION, modelId, embeddingField, null, mlCommonsClientAccessor, env) + ); + assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage()); + + // type mapping has empty key + exception = expectThrows( + IllegalArgumentException.class, + () -> new TextImageEmbeddingProcessor( + PROCESSOR_TAG, + DESCRIPTION, + modelId, + embeddingField, + Map.of("", "my_field"), + mlCommonsClientAccessor, + env + ) + ); + assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage()); + + // type mapping has empty value + // use vanila java syntax because it allows null values + Map typeMapping = new HashMap<>(); + typeMapping.put("my_field", null); + + exception = expectThrows( + IllegalArgumentException.class, + () -> new TextImageEmbeddingProcessor( + PROCESSOR_TAG, + DESCRIPTION, + modelId, + embeddingField, + typeMapping, + mlCommonsClientAccessor, + env + ) + ); + assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage()); + } + @SneakyThrows public void testTextEmbeddingProcessConstructor_whenEmptyModelId_throwIllegalArgumentException() { Map registry = new HashMap<>(); @@ -111,6 +159,7 @@ public void testExecute_successful() { sourceAndMetadata.put("key1", "value1"); sourceAndMetadata.put("my_text_field", "value2"); sourceAndMetadata.put("key3", "value3"); + sourceAndMetadata.put("image_field", "base64_of_image_1234567890"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextImageEmbeddingProcessor processor = createInstance(); @@ -151,8 +200,6 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep } public void testExecute_withListTypeInput_successful() { - List list1 = ImmutableList.of("test1", "test2", "test3"); - List list2 = ImmutableList.of("test4", "test5", "test6"); Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("my_text_field", "value1"); sourceAndMetadata.put("another_text_field", "value2"); @@ -238,6 +285,25 @@ public void testExecute_hybridTypeInput_successful() throws Exception { assert document.getSourceAndMetadata().containsKey("key2"); } + public void testExecute_whenInferencesAreEmpty_thenSuccessful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("my_field", "value1"); + sourceAndMetadata.put("another_text_field", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + private List> createMockVectorResult() { List> modelTensorList = new ArrayList<>(); List number1 = ImmutableList.of(1.234f, 2.354f);