Skip to content

Commit

Permalink
Add field max depth limit to prevent malicious attack
Browse files Browse the repository at this point in the history
Signed-off-by: Zan Niu <[email protected]>
  • Loading branch information
zane-neo committed Oct 24, 2022
1 parent c083a7c commit d168e95
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
Expand Down Expand Up @@ -43,12 +46,15 @@ public class TextEmbeddingProcessor extends AbstractProcessor {

private final MLCommonsClientAccessor mlCommonsClientAccessor;

private final Environment environment;

public TextEmbeddingProcessor(
String tag,
String description,
String modelId,
Map<String, Object> fieldMap,
MLCommonsClientAccessor clientAccessor
MLCommonsClientAccessor clientAccessor,
Environment environment
) {
super(tag, description);
if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it");
Expand All @@ -57,6 +63,7 @@ public TextEmbeddingProcessor(
this.modelId = modelId;
this.fieldMap = fieldMap;
this.mlCommonsClientAccessor = clientAccessor;
this.environment = environment;
}

private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {
Expand Down Expand Up @@ -236,7 +243,7 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
String sourceKey = embeddingFieldsEntry.getKey();
Class<?> sourceValueClass = sourceValue.getClass();
if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) {
validateNestedTypeValue(sourceKey, sourceValue);
validateNestedTypeValue(sourceKey, sourceValue, () -> 1);
} else if (!String.class.isAssignableFrom(sourceValueClass)) {
throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
Expand All @@ -247,11 +254,17 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private static void validateNestedTypeValue(String sourceKey, Object sourceValue) {
if ((List.class.isAssignableFrom(sourceValue.getClass()))) {
private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier<Integer> maxDepthSupplier) {
int maxDepth = maxDepthSupplier.get();
if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it");
} else if ((List.class.isAssignableFrom(sourceValue.getClass()))) {
validateListTypeValue(sourceKey, sourceValue);
} else if (Map.class.isAssignableFrom(sourceValue.getClass())) {
((Map) sourceValue).values().stream().filter(Objects::nonNull).forEach(x -> validateNestedTypeValue(sourceKey, x));
((Map) sourceValue).values()
.stream()
.filter(Objects::nonNull)
.forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1));
} else if (!String.class.isAssignableFrom(sourceValue.getClass())) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it");
} else if (StringUtils.isBlank(sourceValue.toString())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import java.util.Map;

import org.opensearch.env.Environment;
import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
Expand All @@ -19,8 +20,11 @@ public class TextEmbeddingProcessorFactory implements Processor.Factory {

private final MLCommonsClientAccessor clientAccessor;

public TextEmbeddingProcessorFactory(MLCommonsClientAccessor clientAccessor) {
private final Environment environment;

public TextEmbeddingProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment) {
this.clientAccessor = clientAccessor;
this.environment = environment;
}

@Override
Expand All @@ -32,6 +36,6 @@ public TextEmbeddingProcessor create(
) throws Exception {
String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD);
Map<String, Object> filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD);
return new TextEmbeddingProcessor(processorTag, description, modelId, filedMap, clientAccessor);
return new TextEmbeddingProcessor(processorTag, description, modelId, filedMap, clientAccessor, environment);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;

import org.junit.Before;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchParseException;
import org.opensearch.common.settings.Settings;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
Expand All @@ -28,18 +35,31 @@

public class TextEmbeddingProcessorTests extends OpenSearchTestCase {

private static final MLCommonsClientAccessor ML_COMMONS_CLIENT_ACCESSOR = mock(MLCommonsClientAccessor.class);
private static final TextEmbeddingProcessorFactory FACTORY = new TextEmbeddingProcessorFactory(ML_COMMONS_CLIENT_ACCESSOR);
@Mock
private MLCommonsClientAccessor mlCommonsClientAccessor;

@Mock
private Environment env;

@InjectMocks
private TextEmbeddingProcessorFactory textEmbeddingProcessorFactory;
private static final String PROCESSOR_TAG = "mockTag";
private static final String DESCRIPTION = "mockDescription";

@Before
public void setup() {
MockitoAnnotations.openMocks(this);
Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build();
when(env.settings()).thenReturn(settings);
}

private TextEmbeddingProcessor createInstance(List<List<Float>> vector) throws Exception {
Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped"));
TextEmbeddingProcessor processor = FACTORY.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
when(ML_COMMONS_CLIENT_ACCESSOR.inferenceSentences(anyString(), anyList())).thenReturn(vector);
TextEmbeddingProcessor processor = textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
when(mlCommonsClientAccessor.inferenceSentences(anyString(), anyList())).thenReturn(vector);
return processor;
}

Expand All @@ -52,7 +72,7 @@ public void testTextEmbeddingProcessConstructor_whenConfigMapError_throwIllegalA
fieldMap.put("key2", "key2Mapped");
config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, fieldMap);
try {
FACTORY.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
} catch (IllegalArgumentException e) {
assertEquals("Unable to create the TextEmbedding processor as field_map has invalid key or value", e.getMessage());
}
Expand All @@ -63,7 +83,7 @@ public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalA
Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
try {
FACTORY.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
} catch (OpenSearchParseException e) {
assertEquals("[field_map] required property is missing", e.getMessage());
}
Expand All @@ -86,7 +106,7 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
Map<String, Processor.Factory> registry = new HashMap<>();
MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class);
TextEmbeddingProcessorFactory textEmbeddingProcessorFactory = new TextEmbeddingProcessorFactory(accessor);
TextEmbeddingProcessorFactory textEmbeddingProcessorFactory = new TextEmbeddingProcessorFactory(accessor, env);

Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
Expand Down Expand Up @@ -208,6 +228,34 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() t
}
}

public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() throws Exception {
Map<String, Object> ret = createMaxDepthLimitExceedMap(() -> 1);
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", "hello world");
sourceAndMetadata.put("key2", ret);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
try {
processor.execute(ingestDocument);
} catch (IllegalArgumentException e) {
assertEquals("map type field [key2] reached max depth limit, can not process it", e.getMessage());
return;
}
fail("Shouldn't be here, expected exception!");
}

private Map<String, Object> createMaxDepthLimitExceedMap(Supplier<Integer> maxDepthSupplier) {
int maxDepth = maxDepthSupplier.get();
if (maxDepth > 21) {
return null;
}
Map<String, Object> innerMap = new HashMap<>();
Map<String, Object> ret = createMaxDepthLimitExceedMap(() -> maxDepth + 1);
if (ret == null) return innerMap;
innerMap.put("hello", ret);
return innerMap;
}

public void testExecute_hybridTypeInput_successful() throws Exception {
List<String> list1 = ImmutableList.of("test1", "test2");
Map<String, List<String>> map1 = ImmutableMap.of("test3", list1);
Expand Down Expand Up @@ -327,7 +375,7 @@ private TextEmbeddingProcessor createInstanceWithNestedMapConfiguration(Map<Stri
Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, fieldMap);
return FACTORY.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
return textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
}

private Map<String, Object> createPlainStringConfiguration() {
Expand Down

0 comments on commit d168e95

Please sign in to comment.