diff --git a/docs/reference/ml/trained-models/apis/start-trained-model-deployment.asciidoc b/docs/reference/ml/trained-models/apis/start-trained-model-deployment.asciidoc index 0d5fb6c29bc72..af9dc8963fb56 100644 --- a/docs/reference/ml/trained-models/apis/start-trained-model-deployment.asciidoc +++ b/docs/reference/ml/trained-models/apis/start-trained-model-deployment.asciidoc @@ -70,7 +70,7 @@ the inference speed. The inference process is a compute-bound process; any numbe greater than the number of available hardware threads on the machine does not increase the inference speed. If this setting is greater than the number of hardware threads it will automatically be changed to a value less than the number of hardware threads. -Defaults to 1. +Defaults to 1. Must be a power of 2. Max allowed value is 32. `timeout`:: (Optional, time) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java index 3c92bcf30c338..da94bdc19c21b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -66,6 +66,9 @@ public static class Request extends MasterNodeRequest implements ToXCon AllocationStatus.State.STARTED, AllocationStatus.State.STARTING, AllocationStatus.State.FULLY_ALLOCATED }; + + private static final int MAX_THREADS_PER_ALLOCATION = 32; + public static final ParseField MODEL_ID = new ParseField("model_id"); public static final ParseField TIMEOUT = new ParseField("timeout"); public static final ParseField WAIT_FOR = new ParseField("wait_for"); @@ -209,12 +212,21 @@ public ActionRequestValidationException validate() { if (threadsPerAllocation < 1) { validationException.addValidationError("[" + THREADS_PER_ALLOCATION + "] must be a positive integer"); } + if (threadsPerAllocation > MAX_THREADS_PER_ALLOCATION || isPowerOf2(threadsPerAllocation) == false) { + validationException.addValidationError( + "[" + THREADS_PER_ALLOCATION + "] must be a power of 2 less than or equal to " + MAX_THREADS_PER_ALLOCATION + ); + } if (queueCapacity < 1) { validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be a positive integer"); } return validationException.validationErrors().isEmpty() ? null : validationException; } + private static boolean isPowerOf2(int value) { + return Integer.bitCount(value) == 1; + } + @Override public int hashCode() { return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java index c18b995fc25e4..6cdd355997d4e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java @@ -16,6 +16,10 @@ import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus; import java.io.IOException; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -80,6 +84,32 @@ public void testValidate_GivenThreadsPerAllocationIsNegative() { assertThat(e.getMessage(), containsString("[threads_per_allocation] must be a positive integer")); } + public void testValidate_GivenThreadsPerAllocationIsNotPowerOf2() { + Set powersOf2 = IntStream.range(0, 10).map(n -> (int) Math.pow(2, n)).boxed().collect(Collectors.toSet()); + List input = IntStream.range(1, 33).filter(n -> powersOf2.contains(n) == false).boxed().toList(); + + for (int n : input) { + Request request = createRandom(); + request.setThreadsPerAllocation(n); + + ActionRequestValidationException e = request.validate(); + + assertThat(e, is(not(nullValue()))); + assertThat(e.getMessage(), containsString("[threads_per_allocation] must be a power of 2 less than or equal to 32")); + } + } + + public void testValidate_GivenThreadsPerAllocationIsValid() { + for (int n : List.of(1, 2, 4, 8, 16, 32)) { + Request request = createRandom(); + request.setThreadsPerAllocation(n); + + ActionRequestValidationException e = request.validate(); + + assertThat(e, is(nullValue())); + } + } + public void testValidate_GivenNumberOfAllocationsIsZero() { Request request = createRandom(); request.setNumberOfAllocations(0);