Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] adding new defer_definition_decompression parameter to put trained model API #77189

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion docs/reference/ml/df-analytics/apis/put-trained-models.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ WARNING: Models created in version 7.8.0 are not backwards compatible
[[ml-put-trained-models-prereq]]
== {api-prereq-title}

Requires the `manage_ml` cluster privilege. This privilege is included in the
Requires the `manage_ml` cluster privilege. This privilege is included in the
`machine_learning_admin` built-in role.


Expand All @@ -42,6 +42,17 @@ created by {dfanalytics}.
(Required, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]

[[ml-put-trained-models-query-params]]
== {api-query-parms-title}

`defer_definition_decompression`::
(Optional, Boolean)
benwtrent marked this conversation as resolved.
Show resolved Hide resolved
benwtrent marked this conversation as resolved.
Show resolved Hide resolved
Should the request defer definition decompression and skip relevant
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Should the request defer definition decompression and skip relevant
If set to `true` and a `compressed_definition` is provided, the request defers definition decompression and skips relevant

validations when a `compressed_definition` is provided.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my first suggestion is accepted, this is the second half:

Suggested change
validations when a `compressed_definition` is provided.
validations.

This would be useful for systems or users that know a good JVM heap size estimate for their
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This would be useful for systems or users that know a good JVM heap size estimate for their
This deferral is useful for systems or users that know a good JVM heap size estimate for their

model and that their model is valid and likely won't fail during inference.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model and that their model is valid and likely won't fail during inference.
model and know that their model is valid and likely won't fail during inference.



[role="child_attributes"]
[[ml-put-trained-models-request-body]]
== {api-request-body-title}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
}
]
},
"params":{
"defer_definition_decompression": {
"required": false,
"type": "boolean",
"description": "Should the action skip decompressing the definition to validate it and set default values, default value is false"
benwtrent marked this conversation as resolved.
Show resolved Hide resolved
benwtrent marked this conversation as resolved.
Show resolved Hide resolved
}
},
"body":{
"description":"The trained model configuration",
"required":true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
Expand All @@ -25,6 +26,7 @@

public class PutTrainedModelAction extends ActionType<PutTrainedModelAction.Response> {

public static final String DEFER_DEFINITION_DECOMPRESSION = "defer_definition_decompression";
public static final PutTrainedModelAction INSTANCE = new PutTrainedModelAction();
public static final String NAME = "cluster:admin/xpack/ml/inference/put";
private PutTrainedModelAction() {
Expand All @@ -33,7 +35,7 @@ private PutTrainedModelAction() {

public static class Request extends AcknowledgedRequest<Request> {

public static Request parseRequest(String modelId, XContentParser parser) {
public static Request parseRequest(String modelId, boolean deferDefinitionValidation, XContentParser parser) {
TrainedModelConfig.Builder builder = TrainedModelConfig.STRICT_PARSER.apply(parser, null);

if (builder.getModelId() == null) {
Expand All @@ -47,18 +49,25 @@ public static Request parseRequest(String modelId, XContentParser parser) {
}
// Validations are done against the builder so we can build the full config object.
// This allows us to not worry about serializing a builder class between nodes.
return new Request(builder.validate(true).build());
return new Request(builder.validate(true).build(), deferDefinitionValidation);
}

private final TrainedModelConfig config;
private final boolean deferDefinitionDecompression;

public Request(TrainedModelConfig config) {
public Request(TrainedModelConfig config, boolean deferDefinitionDecompression) {
this.config = config;
this.deferDefinitionDecompression = deferDefinitionDecompression;
}

public Request(StreamInput in) throws IOException {
super(in);
this.config = new TrainedModelConfig(in);
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
this.deferDefinitionDecompression = in.readBoolean();
} else {
this.deferDefinitionDecompression = false;
}
}

public TrainedModelConfig getTrainedModelConfig() {
Expand All @@ -67,26 +76,44 @@ public TrainedModelConfig getTrainedModelConfig() {

@Override
public ActionRequestValidationException validate() {
if (deferDefinitionDecompression
&& config.getEstimatedHeapMemory() == 0
&& config.getCompressedDefinitionIfSet() != null) {
ActionRequestValidationException validationException = new ActionRequestValidationException();
validationException.addValidationError(
"when ["
+ DEFER_DEFINITION_DECOMPRESSION
+ "] is true and a compressed definition is provided, estimated_heap_memory_usage_bytes must be set"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
+ "] is true and a compressed definition is provided, estimated_heap_memory_usage_bytes must be set"
+ "] is true and a compressed definition is provided, [" + ESTIMATED_HEAP_MEMORY_USAGE_BYTES + "] must be set"

);
return validationException;
}
return null;
}

public boolean isDeferDefinitionDecompression() {
return deferDefinitionDecompression;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
config.writeTo(out);
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
out.writeBoolean(deferDefinitionDecompression);
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(config, request.config);
return Objects.equals(config, request.config) && deferDefinitionDecompression == request.deferDefinitionDecompression;
}

@Override
public int hashCode() {
return Objects.hash(config);
return Objects.hash(config, deferDefinitionDecompression);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,14 @@ public BytesReference getCompressedDefinition() throws IOException {
return definition.getCompressedDefinition();
}

public BytesReference getCompressedDefinitionIfSet() {
if (definition == null) {
return null;
}
return definition.getCompressedDefinitionIfSet();
}


public void clearCompressed() {
definition.compressedRepresentation = null;
}
Expand Down Expand Up @@ -704,6 +712,7 @@ public Builder validate() {

/**
* Runs validations against the builder.
* @param forCreation indicates if we should validate for model creation or for a model read from storage
* @return The current builder object if validations are successful
* @throws ActionRequestValidationException when there are validation failures.
*/
Expand Down Expand Up @@ -773,12 +782,6 @@ public Builder validate(boolean forCreation) {
validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException);
validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException);
validationException = checkIllegalSetting(createTime, CREATE_TIME.getPreferredName(), validationException);
validationException = checkIllegalSetting(estimatedHeapMemory,
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
validationException);
validationException = checkIllegalSetting(estimatedOperations,
ESTIMATED_OPERATIONS.getPreferredName(),
validationException);
validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
if (metadata != null) {
validationException = checkIllegalSetting(
Expand Down Expand Up @@ -877,6 +880,10 @@ private BytesReference getCompressedDefinition() throws IOException {
return compressedRepresentation;
}

private BytesReference getCompressedDefinitionIfSet() {
return compressedRepresentation;
}

private String getBase64CompressedDefinition() throws IOException {
BytesReference compressedDef = getCompressedDefinition();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ public class PutTrainedModelActionRequestTests extends AbstractWireSerializingTe
@Override
protected Request createTestInstance() {
String modelId = randomAlphaOfLength(10);
return new Request(TrainedModelConfigTests.createTestInstance(modelId)
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.build());
return new Request(
TrainedModelConfigTests.createTestInstance(modelId)
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.build(),
randomBoolean()
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ void createModelDeployment() {
)
.setLocation(new IndexLocation(indexname))
.setModelId(TRAINED_MODEL_ID)
.build()
.build(),
false
)
)
.actionGet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ private void putInferenceModel(String modelId) {
.setInput(new TrainedModelInput(Collections.singletonList("feature1")))
.setInferenceConfig(RegressionConfig.EMPTY_PARAMS)
.build();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config, false)).actionGet();
}

private static OperationMode randomInvalidLicenseType() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public void testFeatureTrackingInferenceModelPipeline() throws Exception {
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding, false)))
.setTrainedModel(buildClassification(true)))
.build();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config, false)).actionGet();

String pipelineId = "pipeline-inference-model-tracked";
putTrainedModelIngestPipeline(pipelineId, modelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ public void testRemoveUnusedStats() throws Exception {
.build())
)
.validate(true)
.build())).actionGet();
.build(),
false)).actionGet();

indexStatDocument(new DataCounts("analytics-with-stats", 1, 1, 1),
DataCounts.documentId("analytics-with-stats"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,17 @@ protected void masterOperation(Task task,
ActionListener<Response> listener) {
TrainedModelConfig config = request.getTrainedModelConfig();
try {
config.ensureParsedDefinition(xContentRegistry);
if (request.isDeferDefinitionDecompression() == false) {
config.ensureParsedDefinition(xContentRegistry);
}
} catch (IOException ex) {
listener.onFailure(ExceptionsHelper.badRequestException("Failed to parse definition for [{}]",
ex,
config.getModelId()));
return;
}

// NOTE: hasModelDefinition is false if we don't parse it. But, if the fully parsed model was already provided, continue
boolean hasModelDefinition = config.getModelDefinition() != null;
if (hasModelDefinition) {
try {
Expand Down Expand Up @@ -138,11 +141,16 @@ protected void masterOperation(Task task,
minCompatibilityVersion.toString()));
return;
}
} else if (state.nodes().getMinNodeVersion().before(state.nodes().getMaxNodeVersion())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move this check at the top of masterOperation?

In addition, do we really need this check? I'm trying to think what happens if a user starts a rolling upgrade to the cluster and installs a fleet package that tries to put a model with defer_definition_compression. Is it worth failing the request? If we allowed it what would break? I assume if the model was loaded in an older node it would fail as the definition would be missing. Is that preferable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dimitris-athanasiou I am being paranoid for sure. My concern is that we have no way of determining if the model definition can be inflated on the current min node version or not. Previously, we were able to validate that. I guess its "buyer beware" and I can remove this check, but they could get an ugly parsing error. Which, I suppose, is the case anyways.

&& request.isDeferDefinitionDecompression() == false) {
listener.onFailure(ExceptionsHelper.badRequestException(
"deferring model definition parsing is not possible in a cluster with mixed node versions;"
+ " max version [{}] min version [{}]",
state.nodes().getMinNodeVersion(),
state.nodes().getMaxNodeVersion()));
return;
}




TrainedModelConfig.Builder trainedModelConfig = new TrainedModelConfig.Builder(config)
.setVersion(Version.CURRENT)
.setCreateTime(Instant.now())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
Expand Down Expand Up @@ -137,16 +136,16 @@ public void storeTrainedModel(TrainedModelConfig trainedModelConfig,
return;
}

BytesReference definition;
try {
trainedModelConfig.ensureParsedDefinition(xContentRegistry);
definition = trainedModelConfig.getCompressedDefinition();
} catch (IOException ex) {
listener.onFailure(ExceptionsHelper.serverError(
"Unexpected serialization error when parsing model definition for model [" + trainedModelConfig.getModelId() + "]",
ex));
"Unexpected IOException while serializing definition for storage for model [{}]",
ex,
trainedModelConfig.getModelId()));
return;
}

TrainedModelDefinition definition = trainedModelConfig.getModelDefinition();
TrainedModelLocation location = trainedModelConfig.getLocation();
if (definition == null && location == null) {
listener.onFailure(ExceptionsHelper.badRequestException("Unable to store [{}]. [{}] or [{}] is required",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ public String getName() {
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
String id = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
XContentParser parser = restRequest.contentParser();
PutTrainedModelAction.Request putRequest = PutTrainedModelAction.Request.parseRequest(id, parser);
boolean deferDefinitionDecompression = restRequest.paramAsBoolean(PutTrainedModelAction.DEFER_DEFINITION_DECOMPRESSION, false);
PutTrainedModelAction.Request putRequest = PutTrainedModelAction.Request.parseRequest(id, deferDefinitionDecompression, parser);
putRequest.timeout(restRequest.paramAsTime("timeout", putRequest.timeout()));

return channel -> client.execute(PutTrainedModelAction.INSTANCE, putRequest, new RestToXContentListener<>(channel));
}
}
Loading