diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index fdbe0bb20..46d429fcc 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -35,6 +35,9 @@ public class KNNConstants { public static final String MODEL_TIMESTAMP = "timestamp"; public static final String MODEL_DESCRIPTION = "description"; public static final String MODEL_ERROR = "error"; + public static final String PARAM_SIZE = "size"; + public static final Integer SEARCH_MODEL_MIN_SIZE = 1; + public static final Integer SEARCH_MODEL_MAX_SIZE = 1000; public static final String KNN_THREAD_POOL_PREFIX = "knn"; public static final String TRAIN_THREAD_POOL = "training"; diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestSearchModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestSearchModelHandler.java index eeb510b9d..675e3c1d1 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestSearchModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestSearchModelHandler.java @@ -30,6 +30,9 @@ import java.util.function.IntConsumer; import static org.opensearch.knn.common.KNNConstants.MODELS; +import static org.opensearch.knn.common.KNNConstants.PARAM_SIZE; +import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MAX_SIZE; +import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MIN_SIZE; /** * Rest Handler for search model api endpoint. @@ -69,9 +72,30 @@ private void checkUnSupportedParamsExists(RestRequest request) { throw new IllegalArgumentException(errorMessage); } + private void validateSizeParameter(RestRequest request) { + if (!request.hasParam(PARAM_SIZE)) { + return; + } + if (isSearchSizeValueValid(request.paramAsInt(PARAM_SIZE, 1))) { + return; + } + throw new IllegalArgumentException( + String.format("%s must be between %d and %d inclusive", PARAM_SIZE, SEARCH_MODEL_MIN_SIZE, SEARCH_MODEL_MAX_SIZE) + ); + } + + private boolean isSearchSizeValueValid(int searchSize) { + return (searchSize >= SEARCH_MODEL_MIN_SIZE) && (searchSize <= SEARCH_MODEL_MAX_SIZE); + } + + private void validateRequest(RestRequest request) { + checkUnSupportedParamsExists(request); + validateSizeParameter(request); + } + @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - checkUnSupportedParamsExists(request); + validateRequest(request); SearchRequest searchRequest = new SearchRequest(); IntConsumer setSize = size -> searchRequest.source().size(size); request.withContentOrSourceParamParserOrNull( diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java index 8137f50a1..d65f434f6 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java @@ -36,6 +36,9 @@ import java.util.Map; import static org.opensearch.knn.common.KNNConstants.MODELS; +import static org.opensearch.knn.common.KNNConstants.PARAM_SIZE; +import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MAX_SIZE; +import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MIN_SIZE; /** * Integration tests to check the correctness of {@link org.opensearch.knn.plugin.rest.RestSearchModelHandler} @@ -76,6 +79,23 @@ public void testNoModelExists() throws IOException { } + public void testSizeValidationFailsInvalidSize() throws IOException { + createModelSystemIndex(); + for (Integer invalidSize : Arrays.asList(SEARCH_MODEL_MIN_SIZE - 1, SEARCH_MODEL_MAX_SIZE + 1)) { + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search?" + PARAM_SIZE + "=" + invalidSize); + Request request = new Request("GET", restURI); + + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); + assertTrue( + ex.getMessage() + .contains( + String.format("%s must be between %d and %d inclusive", PARAM_SIZE, SEARCH_MODEL_MIN_SIZE, SEARCH_MODEL_MAX_SIZE) + ) + ); + } + + } + public void testSearchModelExists() throws IOException { createModelSystemIndex(); createIndex("irrelevant-index", Settings.EMPTY);