Skip to content

Commit

Permalink
[ML] Memory based trained model task allocation (elastic#75378)
Browse files Browse the repository at this point in the history
This commit adds memory tracking to trained model tasks.
When a trained model deployment is started, we now
choose a node that has enough free memory to accommodate
the native process that loads the model.

The memory usage is calculated as twice the model size
plus some overhead. The reason we require twice the model size
is that during the loading of the model we store the model once
and then another time for the inflated object that represents
the model. After that, the process does return the memory needed
for storing the model back to the OS. However, if we lowered
the memory usage after the loading phase it would cause flopping
with the autoscaling service. For this reason, and as an initial
implementation, we require twice the model size. In the future,
we can avoid this waste by writing the model to disk and inflating
it from there instead.
  • Loading branch information
dimitris-athanasiou authored and ywangd committed Jul 30, 2021
1 parent 5bb47e1 commit 63894ca
Show file tree
Hide file tree
Showing 12 changed files with 295 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState;
import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState;
import org.elasticsearch.xpack.core.ml.job.config.JobState;
import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeState;
Expand Down Expand Up @@ -223,6 +225,31 @@ public static DataFrameAnalyticsState getDataFrameAnalyticsState(@Nullable Persi
return state;
}

public static TrainedModelDeploymentState getTrainedModelDeploymentState(PersistentTasksCustomMetadata.PersistentTask<?> task) {
if (task == null) {
return TrainedModelDeploymentState.STOPPED;
}
TrainedModelDeploymentTaskState taskState = (TrainedModelDeploymentTaskState) task.getState();
if (taskState == null) {
return TrainedModelDeploymentState.STARTING;
}

TrainedModelDeploymentState state = taskState.getState();
if (taskState.isStatusStale(task)) {
if (state == TrainedModelDeploymentState.STOPPING) {
// previous executor node failed while the job was stopping - it won't
// be restarted on another node, so consider it STOPPED for reassignment purposes
return TrainedModelDeploymentState.STOPPED;
}
if (state != TrainedModelDeploymentState.FAILED) {
// we are relocating at the moment
// TODO Revisit this in the new allocation framework as there won't necessarily be a concept of relocation.
return TrainedModelDeploymentState.STARTING;
}
}
return state;
}

/**
* The job Ids of anomaly detector job tasks.
* All anomaly detector jobs are returned regardless of the status of the
Expand Down Expand Up @@ -345,6 +372,8 @@ public static MemoryTrackedTaskState getMemoryTrackedTaskState(PersistentTasksCu
return taskState == null ? SnapshotUpgradeState.LOADING_OLD_STATE : taskState.getState();
case DATA_FRAME_ANALYTICS_TASK_NAME:
return getDataFrameAnalyticsState(task);
case TRAINED_MODEL_DEPLOYMENT_TASK_NAME:
return getTrainedModelDeploymentState(task);
default:
throw new IllegalStateException("unexpected task type [" + task.getTaskName() + "]");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.MasterNodeRequest;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
Expand All @@ -24,6 +25,7 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;

import java.io.IOException;
import java.util.Objects;
Expand Down Expand Up @@ -116,21 +118,36 @@ public String toString() {
}
}

public static class TaskParams implements PersistentTaskParams {
public static class TaskParams implements PersistentTaskParams, MlTaskParams {

public static final Version VERSION_INTRODUCED = Version.V_8_0_0;

private static final ParseField MODEL_BYTES = new ParseField("model_bytes");

/**
* This has been found to be approximately 300MB on linux by manual testing.
* We also subtract 30MB that we always add as overhead (see MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD).
* TODO Check if it is substantially different in other platforms.
*/
private static final ByteSizeValue MEMORY_OVERHEAD = ByteSizeValue.ofMb(270);

private final String modelId;
private final String index;
private final long modelBytes;

public TaskParams(String modelId, String index) {
public TaskParams(String modelId, String index, long modelBytes) {
this.modelId = Objects.requireNonNull(modelId);
this.index = Objects.requireNonNull(index);
this.modelBytes = modelBytes;
if (modelBytes < 0) {
throw new IllegalArgumentException("modelBytes must be non-negative");
}
}

public TaskParams(StreamInput in) throws IOException {
this.modelId = in.readString();
this.index = in.readString();
this.modelBytes = in.readVLong();
}

public String getModelId() {
Expand All @@ -141,6 +158,11 @@ public String getIndex() {
return index;
}

public long estimateMemoryUsageBytes() {
// While loading the model in the process we need twice the model size.
return MEMORY_OVERHEAD.getBytes() + 2 * modelBytes;
}

@Override
public String getWriteableName() {
return MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME;
Expand All @@ -155,20 +177,22 @@ public Version getMinimalSupportedVersion() {
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
out.writeString(index);
out.writeVLong(modelBytes);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
builder.field(IndexLocation.INDEX.getPreferredName(), index);
builder.field(MODEL_BYTES.getPreferredName(), modelBytes);
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(modelId);
return Objects.hash(modelId, index, modelBytes);
}

@Override
Expand All @@ -177,7 +201,14 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;

TaskParams other = (TaskParams) o;
return Objects.equals(modelId, other.modelId);
return Objects.equals(modelId, other.modelId)
&& Objects.equals(index, other.index)
&& modelBytes == other.modelBytes;
}

@Override
public String getMlId() {
return modelId;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xpack.core.ml.utils.MemoryTrackedTaskState;

import java.io.IOException;
import java.util.Arrays;
import java.util.Locale;

public enum TrainedModelDeploymentState implements Writeable {
public enum TrainedModelDeploymentState implements Writeable, MemoryTrackedTaskState {

STARTING, STARTED, STOPPING, STOPPED, FAILED;

Expand All @@ -35,4 +37,16 @@ public void writeTo(StreamOutput out) throws IOException {
public String toString() {
return name().toLowerCase(Locale.ROOT);
}

/**
* @return {@code true} if state matches none of the given {@code candidates}
*/
public boolean isNoneOf(TrainedModelDeploymentState... candidates) {
return Arrays.stream(candidates).noneMatch(candidate -> this == candidate);
}

@Override
public boolean consumesMemory() {
return isNoneOf(FAILED, STOPPED);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.persistent.PersistentTaskState;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;

Expand Down Expand Up @@ -109,4 +110,8 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(state, allocationId, reason);
}

public boolean isStatusStale(PersistentTasksCustomMetadata.PersistentTask<?> task) {
return allocationId != task.getAllocationId();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState;
import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState;
import org.elasticsearch.xpack.core.ml.job.config.JobState;
import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;

import java.net.InetAddress;
import java.util.Arrays;

import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.empty;
Expand Down Expand Up @@ -243,6 +247,41 @@ public void testGetDataFrameAnalyticsState_GivenStaleTaskWithFailedState() {
assertThat(state, equalTo(DataFrameAnalyticsState.FAILED));
}

public void testGetTrainedModelDeploymentState_GivenNull() {
assertThat(MlTasks.getTrainedModelDeploymentState(null), equalTo(TrainedModelDeploymentState.STOPPED));
}

public void testGetTrainedModelDeploymentState_GivenTaskStateIsNull() {
PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(null, false);
assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(TrainedModelDeploymentState.STARTING));
}

public void testGetTrainedModelDeploymentState_GivenTaskStateIsNotNullAndNotStale() {
TrainedModelDeploymentState state = randomFrom(TrainedModelDeploymentState.values());
PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(state, false);
assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(state));
}

public void testGetTrainedModelDeploymentState_GivenTaskStateIsStaleAndStopping() {
PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(TrainedModelDeploymentState.STOPPING, true);
assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(TrainedModelDeploymentState.STOPPED));
}

public void testGetTrainedModelDeploymentState_GivenTaskStateIsStaleAndFailed() {
PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(TrainedModelDeploymentState.FAILED, true);
assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(TrainedModelDeploymentState.FAILED));
}

public void testGetTrainedModelDeploymentState_GivenTaskStateIsStaleAndNotFailedNorStopping() {
TrainedModelDeploymentState state = randomFrom(
Arrays.stream(TrainedModelDeploymentState.values())
.filter(s -> s != TrainedModelDeploymentState.FAILED && s != TrainedModelDeploymentState.STOPPING)
.toArray(TrainedModelDeploymentState[]::new)
);
PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(state, true);
assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(TrainedModelDeploymentState.STARTING));
}

private static PersistentTasksCustomMetadata.PersistentTask<?> createDataFrameAnalyticsTask(String jobId, String nodeId,
DataFrameAnalyticsState state,
boolean isStale) {
Expand All @@ -257,4 +296,19 @@ private static PersistentTasksCustomMetadata.PersistentTask<?> createDataFrameAn
PersistentTasksCustomMetadata tasks = builder.build();
return tasks.getTask(MlTasks.dataFrameAnalyticsTaskId(jobId));
}

private static PersistentTasksCustomMetadata.PersistentTask<?> createTrainedModelTask(TrainedModelDeploymentState state,
boolean isStale) {
String id = randomAlphaOfLength(10);
PersistentTasksCustomMetadata.Builder builder = PersistentTasksCustomMetadata.builder();
builder.addTask(MlTasks.trainedModelDeploymentTaskId(id), MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME,
new StartTrainedModelDeploymentAction.TaskParams(id, randomAlphaOfLength(10), randomNonNegativeLong()),
new PersistentTasksCustomMetadata.Assignment(randomAlphaOfLength(10), "test assignment"));
if (state != null) {
builder.updateTaskState(MlTasks.trainedModelDeploymentTaskId(id),
new TrainedModelDeploymentTaskState(state, builder.getLastAllocationId() - (isStale ? 1 : 0), null));
}
PersistentTasksCustomMetadata tasks = builder.build();
return tasks.getTask(MlTasks.trainedModelDeploymentTaskId(id));
}
}
Loading

0 comments on commit 63894ca

Please sign in to comment.