From ce17f7276ebf1e3928121bb1c67894b5669f8696 Mon Sep 17 00:00:00 2001 From: Jack Mazanec Date: Mon, 1 Nov 2021 11:10:50 -0700 Subject: [PATCH] Return 400 on failed training request (#168) Signed-off-by: John Mazanec --- .../opensearch/knn/plugin/rest/RestTrainModelHandler.java | 6 +++--- .../plugin/transport/TrainingJobRouterTransportAction.java | 5 ++++- .../java/org/opensearch/knn/training/TrainingJobRunner.java | 5 ++++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java index bc92836213..a5b2f13e46 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java @@ -100,9 +100,9 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr parser.nextToken(); if (TRAIN_INDEX_PARAMETER.equals(fieldName) && ensureNotSet(fieldName, trainingIndex)) { - trainingIndex = parser.text(); + trainingIndex = parser.textOrNull(); } else if (TRAIN_FIELD_PARAMETER.equals(fieldName) && ensureNotSet(fieldName, trainingField)) { - trainingField = parser.text(); + trainingField = parser.textOrNull(); } else if (KNN_METHOD.equals(fieldName) && ensureNotSet(fieldName, knnMethodContext)) { knnMethodContext = KNNMethodContext.parse(parser.map()); } else if (DIMENSION.equals(fieldName) && ensureNotSet(fieldName, dimension)) { @@ -112,7 +112,7 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr } else if (SEARCH_SIZE_PARAMETER.equals(fieldName) && ensureNotSet(fieldName, searchSize)) { searchSize = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false); } else if (MODEL_DESCRIPTION.equals(fieldName) && ensureNotSet(fieldName, description)) { - description = parser.text(); + description = parser.textOrNull(); } else { throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is not a valid " + "parameter."); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java index c939e652a2..0a24c0c478 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java @@ -20,6 +20,7 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Strings; +import org.opensearch.common.ValidationException; import org.opensearch.common.collect.ImmutableOpenMap; import org.opensearch.common.inject.Inject; import org.opensearch.search.builder.SearchSourceBuilder; @@ -70,7 +71,9 @@ protected void routeRequest(TrainingModelRequest request, ActionListener liste // the number of training jobs that enter this function. Although the training threadpool size will also prevent // this, we want to prevent this before we perform any serialization. if (!semaphore.tryAcquire()) { - throw new RejectedExecutionException("Unable to run training job: No training capacity on node."); + ValidationException exception = new ValidationException(); + exception.addValidationError("Unable to run training job: No training capacity on node."); + throw exception; } jobCount.incrementAndGet();