Skip to content

Commit

Permalink
[ML] Require that threads_per_allocation is a power of 2 (#87697)
Browse files Browse the repository at this point in the history
As the number of cores in CPUs is typically a power of 2,
this commit adds a validation that trained model deployments
start with `threads_per_allocation` set to be a power of 2.
When we look for how we distribute the allocations across the
cluster, this prevents situations where we have a lot of wasted
CPU cores.

In addition, we add a max value limit of `32`.
  • Loading branch information
dimitris-athanasiou authored Jun 17, 2022
1 parent 4fb8550 commit 679351e
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ the inference speed. The inference process is a compute-bound process; any numbe
greater than the number of available hardware threads on the machine does not increase the
inference speed. If this setting is greater than the number of hardware threads
it will automatically be changed to a value less than the number of hardware threads.
Defaults to 1.
Defaults to 1. Must be a power of 2. Max allowed value is 32.

`timeout`::
(Optional, time)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
AllocationStatus.State.STARTED,
AllocationStatus.State.STARTING,
AllocationStatus.State.FULLY_ALLOCATED };

private static final int MAX_THREADS_PER_ALLOCATION = 32;

public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField TIMEOUT = new ParseField("timeout");
public static final ParseField WAIT_FOR = new ParseField("wait_for");
Expand Down Expand Up @@ -209,12 +212,21 @@ public ActionRequestValidationException validate() {
if (threadsPerAllocation < 1) {
validationException.addValidationError("[" + THREADS_PER_ALLOCATION + "] must be a positive integer");
}
if (threadsPerAllocation > MAX_THREADS_PER_ALLOCATION || isPowerOf2(threadsPerAllocation) == false) {
validationException.addValidationError(
"[" + THREADS_PER_ALLOCATION + "] must be a power of 2 less than or equal to " + MAX_THREADS_PER_ALLOCATION
);
}
if (queueCapacity < 1) {
validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be a positive integer");
}
return validationException.validationErrors().isEmpty() ? null : validationException;
}

private static boolean isPowerOf2(int value) {
return Integer.bitCount(value) == 1;
}

@Override
public int hashCode() {
return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;

import java.io.IOException;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -80,6 +84,32 @@ public void testValidate_GivenThreadsPerAllocationIsNegative() {
assertThat(e.getMessage(), containsString("[threads_per_allocation] must be a positive integer"));
}

public void testValidate_GivenThreadsPerAllocationIsNotPowerOf2() {
Set<Integer> powersOf2 = IntStream.range(0, 10).map(n -> (int) Math.pow(2, n)).boxed().collect(Collectors.toSet());
List<Integer> input = IntStream.range(1, 33).filter(n -> powersOf2.contains(n) == false).boxed().toList();

for (int n : input) {
Request request = createRandom();
request.setThreadsPerAllocation(n);

ActionRequestValidationException e = request.validate();

assertThat(e, is(not(nullValue())));
assertThat(e.getMessage(), containsString("[threads_per_allocation] must be a power of 2 less than or equal to 32"));
}
}

public void testValidate_GivenThreadsPerAllocationIsValid() {
for (int n : List.of(1, 2, 4, 8, 16, 32)) {
Request request = createRandom();
request.setThreadsPerAllocation(n);

ActionRequestValidationException e = request.validate();

assertThat(e, is(nullValue()));
}
}

public void testValidate_GivenNumberOfAllocationsIsZero() {
Request request = createRandom();
request.setNumberOfAllocations(0);
Expand Down

0 comments on commit 679351e

Please sign in to comment.