Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add size validation for Search Model API #352

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ public class KNNConstants {
public static final String MODEL_TIMESTAMP = "timestamp";
public static final String MODEL_DESCRIPTION = "description";
public static final String MODEL_ERROR = "error";
public static final String PARAM_SIZE = "size";
public static final Integer SEARCH_MODEL_MIN_SIZE = 1;
public static final Integer SEARCH_MODEL_MAX_SIZE = 1000;

public static final String KNN_THREAD_POOL_PREFIX = "knn";
public static final String TRAIN_THREAD_POOL = "training";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
import java.util.function.IntConsumer;

import static org.opensearch.knn.common.KNNConstants.MODELS;
import static org.opensearch.knn.common.KNNConstants.PARAM_SIZE;
import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MAX_SIZE;
import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MIN_SIZE;

/**
* Rest Handler for search model api endpoint.
Expand Down Expand Up @@ -69,9 +72,30 @@ private void checkUnSupportedParamsExists(RestRequest request) {
throw new IllegalArgumentException(errorMessage);
}

private void validateSizeParameter(RestRequest request) {
if (!request.hasParam(PARAM_SIZE)) {
return;
}
if (isSearchSizeValueValid(request.paramAsInt(PARAM_SIZE, 1))) {
return;
}
throw new IllegalArgumentException(
String.format("%s must be between %d and %d inclusive", PARAM_SIZE, SEARCH_MODEL_MIN_SIZE, SEARCH_MODEL_MAX_SIZE)
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
);
}

private boolean isSearchSizeValueValid(int searchSize) {
return (searchSize >= SEARCH_MODEL_MIN_SIZE) && (searchSize <= SEARCH_MODEL_MAX_SIZE);
}

private void validateRequest(RestRequest request) {
checkUnSupportedParamsExists(request);
validateSizeParameter(request);
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
checkUnSupportedParamsExists(request);
validateRequest(request);
SearchRequest searchRequest = new SearchRequest();
IntConsumer setSize = size -> searchRequest.source().size(size);
request.withContentOrSourceParamParserOrNull(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.MODELS;
import static org.opensearch.knn.common.KNNConstants.PARAM_SIZE;
import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MAX_SIZE;
import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MIN_SIZE;

/**
* Integration tests to check the correctness of {@link org.opensearch.knn.plugin.rest.RestSearchModelHandler}
Expand Down Expand Up @@ -76,6 +79,23 @@ public void testNoModelExists() throws IOException {

}

public void testSizeValidationFailsInvalidSize() throws IOException {
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
createModelSystemIndex();
for (Integer invalidSize : Arrays.asList(SEARCH_MODEL_MIN_SIZE - 1, SEARCH_MODEL_MAX_SIZE + 1)) {
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search?" + PARAM_SIZE + "=" + invalidSize);
Request request = new Request("GET", restURI);

ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request));
assertTrue(
ex.getMessage()
.contains(
String.format("%s must be between %d and %d inclusive", PARAM_SIZE, SEARCH_MODEL_MIN_SIZE, SEARCH_MODEL_MAX_SIZE)
)
);
}

}

public void testSearchModelExists() throws IOException {
createModelSystemIndex();
createIndex("irrelevant-index", Settings.EMPTY);
Expand Down