Skip to content

Commit

Permalink
Return 400 on failed training request (opensearch-project#168)
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 authored and martin-gaievski committed Mar 7, 2022
1 parent 4bdeaa5 commit ce17f72
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr
parser.nextToken();

if (TRAIN_INDEX_PARAMETER.equals(fieldName) && ensureNotSet(fieldName, trainingIndex)) {
trainingIndex = parser.text();
trainingIndex = parser.textOrNull();
} else if (TRAIN_FIELD_PARAMETER.equals(fieldName) && ensureNotSet(fieldName, trainingField)) {
trainingField = parser.text();
trainingField = parser.textOrNull();
} else if (KNN_METHOD.equals(fieldName) && ensureNotSet(fieldName, knnMethodContext)) {
knnMethodContext = KNNMethodContext.parse(parser.map());
} else if (DIMENSION.equals(fieldName) && ensureNotSet(fieldName, dimension)) {
Expand All @@ -112,7 +112,7 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr
} else if (SEARCH_SIZE_PARAMETER.equals(fieldName) && ensureNotSet(fieldName, searchSize)) {
searchSize = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
} else if (MODEL_DESCRIPTION.equals(fieldName) && ensureNotSet(fieldName, description)) {
description = parser.text();
description = parser.textOrNull();
} else {
throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is not a valid " +
"parameter.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Strings;
import org.opensearch.common.ValidationException;
import org.opensearch.common.collect.ImmutableOpenMap;
import org.opensearch.common.inject.Inject;
import org.opensearch.search.builder.SearchSourceBuilder;
Expand Down Expand Up @@ -70,7 +71,9 @@ protected void routeRequest(TrainingModelRequest request, ActionListener<Trainin
DiscoveryNode node = selectNode(request.getPreferredNodeId(), response);

if (node == null) {
listener.onFailure(new RejectedExecutionException("Cluster does not have capacity to train"));
ValidationException exception = new ValidationException();
exception.addValidationError("Cluster does not have capacity to train");
listener.onFailure(exception);
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
Expand Down Expand Up @@ -82,7 +83,9 @@ public void execute(TrainingJob trainingJob, ActionListener<IndexResponse> liste
// the number of training jobs that enter this function. Although the training threadpool size will also prevent
// this, we want to prevent this before we perform any serialization.
if (!semaphore.tryAcquire()) {
throw new RejectedExecutionException("Unable to run training job: No training capacity on node.");
ValidationException exception = new ValidationException();
exception.addValidationError("Unable to run training job: No training capacity on node.");
throw exception;
}

jobCount.incrementAndGet();
Expand Down

0 comments on commit ce17f72

Please sign in to comment.