diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index fdbe0bb206..9b4bc4f53a 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -35,6 +35,9 @@ public class KNNConstants { public static final String MODEL_TIMESTAMP = "timestamp"; public static final String MODEL_DESCRIPTION = "description"; public static final String MODEL_ERROR = "error"; + public static final String PARAM_SIZE = "size"; + public static final Integer SEARCH_MODEL_MIN_SIZE = 0; + public static final Integer SEARCH_MODEL_MAX_SIZE = 1000; public static final String KNN_THREAD_POOL_PREFIX = "knn"; public static final String TRAIN_THREAD_POOL = "training"; diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestSearchModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestSearchModelHandler.java index eeb510b9d3..b85cd9bc95 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestSearchModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestSearchModelHandler.java @@ -30,6 +30,9 @@ import java.util.function.IntConsumer; import static org.opensearch.knn.common.KNNConstants.MODELS; +import static org.opensearch.knn.common.KNNConstants.PARAM_SIZE; +import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MAX_SIZE; +import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MIN_SIZE; /** * Rest Handler for search model api endpoint. @@ -69,9 +72,30 @@ private void checkUnSupportedParamsExists(RestRequest request) { throw new IllegalArgumentException(errorMessage); } + private void validateSizeParameter(RestRequest request) { + if (!request.hasParam(PARAM_SIZE)) { + return; + } + if (isSearchSizeValueValid(request.paramAsInt(PARAM_SIZE, 1))) { + return; + } + throw new IllegalArgumentException( + String.format("%s must be between %d and %d inclusive", PARAM_SIZE, SEARCH_MODEL_MIN_SIZE, SEARCH_MODEL_MAX_SIZE) + ); + } + + private boolean isSearchSizeValueValid(int searchSize) { + return searchSize > SEARCH_MODEL_MIN_SIZE && searchSize <= SEARCH_MODEL_MAX_SIZE; + } + + private void validateRequest(RestRequest request) { + checkUnSupportedParamsExists(request); + validateSizeParameter(request); + } + @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - checkUnSupportedParamsExists(request); + validateRequest(request); SearchRequest searchRequest = new SearchRequest(); IntConsumer setSize = size -> searchRequest.source().size(size); request.withContentOrSourceParamParserOrNull( diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java index 8137f50a10..dab14fbaac 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java @@ -36,6 +36,9 @@ import java.util.Map; import static org.opensearch.knn.common.KNNConstants.MODELS; +import static org.opensearch.knn.common.KNNConstants.PARAM_SIZE; +import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MAX_SIZE; +import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MIN_SIZE; /** * Integration tests to check the correctness of {@link org.opensearch.knn.plugin.rest.RestSearchModelHandler} @@ -76,6 +79,18 @@ public void testNoModelExists() throws IOException { } + public void testSizeValidationFailsInvalidSize() throws IOException { + createModelSystemIndex(); + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search?size=2000"); + Request request = new Request("GET", restURI); + + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); + assertTrue( + ex.getMessage() + .contains(String.format("%s must be between %d and %d inclusive", PARAM_SIZE, SEARCH_MODEL_MIN_SIZE, SEARCH_MODEL_MAX_SIZE)) + ); + } + public void testSearchModelExists() throws IOException { createModelSystemIndex(); createIndex("irrelevant-index", Settings.EMPTY);