Skip to content

Commit

Permalink
Fixed bug when mapped field name retrieved incorrectly for getting in…
Browse files Browse the repository at this point in the history
…ferences

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Oct 3, 2023
1 parent 86eb654 commit 9286598
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,18 @@ Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Float> mo
private void validateEmbeddingFieldsValue(final IngestDocument ingestDocument) {
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
for (Map.Entry<String, String> embeddingFieldsEntry : fieldMap.entrySet()) {
Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey());
String mappedSourceKey = embeddingFieldsEntry.getValue();
Object sourceValue = sourceAndMetadataMap.get(mappedSourceKey);
if (Objects.isNull(sourceValue)) {
continue;
}
String sourceKey = embeddingFieldsEntry.getKey();
Class<?> sourceValueClass = sourceValue.getClass();
if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) {
validateNestedTypeValue(sourceKey, sourceValue, () -> 1);
validateNestedTypeValue(mappedSourceKey, sourceValue, () -> 1);
} else if (!String.class.isAssignableFrom(sourceValueClass)) {
throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it");
throw new IllegalArgumentException("field [" + mappedSourceKey + "] is neither string nor nested type, can not process it");

Check warning on line 181 in src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java#L181

Added line #L181 was not covered by tests
} else if (StringUtils.isBlank(sourceValue.toString())) {
throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it");
throw new IllegalArgumentException("field [" + mappedSourceKey + "] has empty string value, can not process it");

Check warning on line 183 in src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java#L183

Added line #L183 was not covered by tests
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,54 @@ public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalA
}
}

@SneakyThrows
public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_throwIllegalArgumentException() {
boolean ignoreFailure = false;
String modelId = "mockModelId";
String embeddingField = "my_embedding_field";

// create with null type mapping
IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> new TextImageEmbeddingProcessor(PROCESSOR_TAG, DESCRIPTION, modelId, embeddingField, null, mlCommonsClientAccessor, env)
);
assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage());

// type mapping has empty key
exception = expectThrows(
IllegalArgumentException.class,
() -> new TextImageEmbeddingProcessor(
PROCESSOR_TAG,
DESCRIPTION,
modelId,
embeddingField,
Map.of("", "my_field"),
mlCommonsClientAccessor,
env
)
);
assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage());

// type mapping has empty value
// use vanila java syntax because it allows null values
Map<String, String> typeMapping = new HashMap<>();
typeMapping.put("my_field", null);

exception = expectThrows(
IllegalArgumentException.class,
() -> new TextImageEmbeddingProcessor(
PROCESSOR_TAG,
DESCRIPTION,
modelId,
embeddingField,
typeMapping,
mlCommonsClientAccessor,
env
)
);
assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage());
}

@SneakyThrows
public void testTextEmbeddingProcessConstructor_whenEmptyModelId_throwIllegalArgumentException() {
Map<String, Processor.Factory> registry = new HashMap<>();
Expand All @@ -111,6 +159,7 @@ public void testExecute_successful() {
sourceAndMetadata.put("key1", "value1");
sourceAndMetadata.put("my_text_field", "value2");
sourceAndMetadata.put("key3", "value3");
sourceAndMetadata.put("image_field", "base64_of_image_1234567890");
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextImageEmbeddingProcessor processor = createInstance();

Expand Down Expand Up @@ -151,8 +200,6 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep
}

public void testExecute_withListTypeInput_successful() {
List<String> list1 = ImmutableList.of("test1", "test2", "test3");
List<String> list2 = ImmutableList.of("test4", "test5", "test6");
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("my_text_field", "value1");
sourceAndMetadata.put("another_text_field", "value2");
Expand Down Expand Up @@ -238,6 +285,25 @@ public void testExecute_hybridTypeInput_successful() throws Exception {
assert document.getSourceAndMetadata().containsKey("key2");
}

public void testExecute_whenInferencesAreEmpty_thenSuccessful() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("my_field", "value1");
sourceAndMetadata.put("another_text_field", "value2");
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextImageEmbeddingProcessor processor = createInstance();

List<List<Float>> modelTensorList = createMockVectorResult();
doAnswer(invocation -> {
ActionListener<List<List<Float>>> listener = invocation.getArgument(2);
listener.onResponse(modelTensorList);
return null;
}).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class));

BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(any(IngestDocument.class), isNull());
}

private List<List<Float>> createMockVectorResult() {
List<List<Float>> modelTensorList = new ArrayList<>();
List<Float> number1 = ImmutableList.of(1.234f, 2.354f);
Expand Down

0 comments on commit 9286598

Please sign in to comment.