Skip to content

Commit

Permalink
Properly designate model state for actively training models when node…
Browse files Browse the repository at this point in the history
…s crash or leave cluster (#1317)

* Initial implementation

Signed-off-by: Ryan Bogan <[email protected]>

* Fix compile errors for tests

Signed-off-by: Ryan Bogan <[email protected]>

* Temporary tests

Signed-off-by: Ryan Bogan <[email protected]>

* Ensure backwards compatibility and add zombie to model state enum

Signed-off-by: Ryan Bogan <[email protected]>

* Update current tests

Signed-off-by: Ryan Bogan <[email protected]>

* Fix current integration tests

Signed-off-by: Ryan Bogan <[email protected]>

* Fix unit tests with new changes

Signed-off-by: Ryan Bogan <[email protected]>

* Add unit tests

Signed-off-by: Ryan Bogan <[email protected]>

* Fix spotless

Signed-off-by: Ryan Bogan <[email protected]>

* Add changelog entry

Signed-off-by: Ryan Bogan <[email protected]>

* Delete temporary test file

Signed-off-by: Ryan Bogan <[email protected]>

* Remove temporary changes to build.gradle

Signed-off-by: Ryan Bogan <[email protected]>

* Add more backwards compatibility

Signed-off-by: Ryan Bogan <[email protected]>

* Attempt to fix bwc tests

Signed-off-by: Ryan Bogan <[email protected]>

* Fix spotless

Signed-off-by: Ryan Bogan <[email protected]>

* Remove star imports

Signed-off-by: Ryan Bogan <[email protected]>

* Add another unit test

Signed-off-by: Ryan Bogan <[email protected]>

* Modify unit test to increase coverage

Signed-off-by: Ryan Bogan <[email protected]>

* Change unit test to increase coverage

Signed-off-by: Ryan Bogan <[email protected]>

* Add method description for clusterChanged

Signed-off-by: Ryan Bogan <[email protected]>

* Address PR feedback

Signed-off-by: Ryan Bogan <[email protected]>

* Refactor into TrainingJobClusterStateListener

Signed-off-by: Ryan Bogan <[email protected]>

* Make node assignment final and added in the constructor of TrainingJob

Signed-off-by: Ryan Bogan <[email protected]>

* Remove clusterService from TrainingJobRunner

Signed-off-by: Ryan Bogan <[email protected]>

* Address PR Feedback

Signed-off-by: Ryan Bogan <[email protected]>

* Add flag when node rejoins and check when serializing model

Signed-off-by: Ryan Bogan <[email protected]>

* Address PR feedback

Signed-off-by: Ryan Bogan <[email protected]>

* Address PR Feedback

Signed-off-by: Ryan Bogan <[email protected]>

* Fix spotless

Signed-off-by: Ryan Bogan <[email protected]>

* Test new version check for StreamInput

Signed-off-by: Ryan Bogan <[email protected]>

* Remove check to test new method

Signed-off-by: Ryan Bogan <[email protected]>

* Add version check for stream input/output logic

Signed-off-by: Ryan Bogan <[email protected]>

* Address PR Feedback

Signed-off-by: Ryan Bogan <[email protected]>

* Address PR Feedback

Signed-off-by: Ryan Bogan <[email protected]>

* Address PR Feedback

Signed-off-by: Ryan Bogan <[email protected]>

* Address PR Feedback

Signed-off-by: Ryan Bogan <[email protected]>

* Address PR Feedback

Signed-off-by: Ryan Bogan <[email protected]>

---------

Signed-off-by: Ryan Bogan <[email protected]>
ryanbogan authored Dec 12, 2023
1 parent 2e3ab95 commit 33da521
Showing 31 changed files with 706 additions and 89 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Bug Fixes
* Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305)
* Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318)
* Properly designate model state for actively training models when nodes crash or leave cluster [#1317](https://github.com/opensearch-project/k-NN/pull/1317)

>>>>>>> main
### Infrastructure
* Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289)
### Documentation
Original file line number Diff line number Diff line change
@@ -257,6 +257,6 @@ public String modelIndexMapping(String fieldName, String modelId) throws IOExcep
}

private ModelMetadata getModelMetadata() {
return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", "");
return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", "", "");
}
}
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@ public class KNNConstants {
public static final String MODEL_TIMESTAMP = "timestamp";
public static final String MODEL_DESCRIPTION = "description";
public static final String MODEL_ERROR = "error";
public static final String MODEL_NODE_ASSIGNMENT = "training_node_assignment";
public static final String PARAM_SIZE = "size";
public static final Integer SEARCH_MODEL_MIN_SIZE = 1;
public static final Integer SEARCH_MODEL_MAX_SIZE = 1000;
12 changes: 12 additions & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
@@ -36,10 +36,14 @@

public class IndexUtil {

public static final String MODEL_NODE_ASSIGNMENT_KEY = KNNConstants.MODEL_NODE_ASSIGNMENT;

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0;
private static final Map<String, Version> minimalRequiredVersionMap = new HashMap<String, Version>() {
{
put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED);
put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT);
}
};

@@ -251,4 +255,12 @@ public static boolean isClusterOnOrAfterMinRequiredVersion(String key) {
}
return KNNClusterUtil.instance().getClusterMinVersion().onOrAfter(minimalRequiredVersion);
}

public static boolean isVersionOnOrAfterMinRequiredVersion(Version version, String key) {
Version minimalRequiredVersion = minimalRequiredVersionMap.get(key);
if (minimalRequiredVersion == null) {
return false;
}
return version.onOrAfter(minimalRequiredVersion);
}
}
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
@@ -287,6 +287,7 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
put(KNNConstants.MODEL_TIMESTAMP, modelMetadata.getTimestamp());
put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription());
put(KNNConstants.MODEL_ERROR, modelMetadata.getError());
put(KNNConstants.MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment());
}
};

80 changes: 65 additions & 15 deletions src/main/java/org/opensearch/knn/indices/ModelMetadata.java
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@

package org.opensearch.knn.indices;

import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.opensearch.core.common.io.stream.StreamInput;
@@ -19,6 +20,7 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.util.KNNEngine;

@@ -34,7 +36,9 @@
import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR;
import static org.opensearch.knn.common.KNNConstants.MODEL_STATE;
import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP;
import static org.opensearch.knn.common.KNNConstants.MODEL_NODE_ASSIGNMENT;

@Log4j2
public class ModelMetadata implements Writeable, ToXContentObject {

private static final String DELIMITER = ",";
@@ -46,6 +50,7 @@ public class ModelMetadata implements Writeable, ToXContentObject {
private AtomicReference<ModelState> state;
final private String timestamp;
final private String description;
final private String trainingNodeAssignment;
private String error;

/**
@@ -54,6 +59,7 @@ public class ModelMetadata implements Writeable, ToXContentObject {
* @param in Stream input
*/
public ModelMetadata(StreamInput in) throws IOException {
String tempTrainingNodeAssignment;
this.knnEngine = KNNEngine.getEngine(in.readString());
this.spaceType = SpaceType.getSpace(in.readString());
this.dimension = in.readInt();
@@ -64,6 +70,12 @@ public ModelMetadata(StreamInput in) throws IOException {
// which is checked in constructor and setters
this.description = in.readString();
this.error = in.readString();

if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), IndexUtil.MODEL_NODE_ASSIGNMENT_KEY)) {
this.trainingNodeAssignment = in.readString();
} else {
this.trainingNodeAssignment = "";
}
}

/**
@@ -84,7 +96,8 @@ public ModelMetadata(
ModelState modelState,
String timestamp,
String description,
String error
String error,
String trainingNodeAssignment
) {
this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null");
this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null");
@@ -104,6 +117,7 @@ public ModelMetadata(
this.timestamp = Objects.requireNonNull(timestamp, "timestamp must not be null");
this.description = Objects.requireNonNull(description, "description must not be null");
this.error = Objects.requireNonNull(error, "error must not be null");
this.trainingNodeAssignment = Objects.requireNonNull(trainingNodeAssignment, "node assignment must not be null");
}

/**
@@ -169,6 +183,15 @@ public String getError() {
return error;
}

/**
* getter for model's node assignment
*
* @return trainingNodeAssignment
*/
public String getNodeAssignment() {
return trainingNodeAssignment;
}

/**
* setter for model's state
*
@@ -197,7 +220,8 @@ public String toString() {
getState().toString(),
timestamp,
description,
error
error,
trainingNodeAssignment
);
}

@@ -240,22 +264,36 @@ public int hashCode() {
public static ModelMetadata fromString(String modelMetadataString) {
String[] modelMetadataArray = modelMetadataString.split(DELIMITER, -1);

if (modelMetadataArray.length != 7) {
// Training node assignment was added as a field in Version 2.12.0
// Because models can be created on older versions and the cluster can be upgraded after,
// we need to accept model metadata arrays both with and without the training node assignment.
if (modelMetadataArray.length == 7) {
log.debug("Model metadata array does not contain training node assignment. Assuming empty string.");
KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]);
SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]);
int dimension = Integer.parseInt(modelMetadataArray[2]);
ModelState modelState = ModelState.getModelState(modelMetadataArray[3]);
String timestamp = modelMetadataArray[4];
String description = modelMetadataArray[5];
String error = modelMetadataArray[6];
return new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error, "");
} else if (modelMetadataArray.length == 8) {
log.debug("Model metadata contains training node assignment");
KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]);
SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]);
int dimension = Integer.parseInt(modelMetadataArray[2]);
ModelState modelState = ModelState.getModelState(modelMetadataArray[3]);
String timestamp = modelMetadataArray[4];
String description = modelMetadataArray[5];
String error = modelMetadataArray[6];
String trainingNodeAssignment = modelMetadataArray[7];
return new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error, trainingNodeAssignment);
} else {
throw new IllegalArgumentException(
"Illegal format for model metadata. Must be of the form "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>\"."
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>\" or \"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>\"."
);
}

KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]);
SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]);
int dimension = Integer.parseInt(modelMetadataArray[2]);
ModelState modelState = ModelState.getModelState(modelMetadataArray[3]);
String timestamp = modelMetadataArray[4];
String description = modelMetadataArray[5];
String error = modelMetadataArray[6];

return new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error);
}

private static String objectToString(Object value) {
@@ -282,6 +320,11 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
Object timestamp = modelSourceMap.get(KNNConstants.MODEL_TIMESTAMP);
Object description = modelSourceMap.get(KNNConstants.MODEL_DESCRIPTION);
Object error = modelSourceMap.get(KNNConstants.MODEL_ERROR);
Object trainingNodeAssignment = modelSourceMap.get(KNNConstants.MODEL_NODE_ASSIGNMENT);

if (trainingNodeAssignment == null) {
trainingNodeAssignment = "";
}

ModelMetadata modelMetadata = new ModelMetadata(
KNNEngine.getEngine(objectToString(engine)),
@@ -290,7 +333,8 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
ModelState.getModelState(objectToString(state)),
objectToString(timestamp),
objectToString(description),
objectToString(error)
objectToString(error),
objectToString(trainingNodeAssignment)
);
return modelMetadata;
}
@@ -304,6 +348,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(getTimestamp());
out.writeString(getDescription());
out.writeString(getError());
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_NODE_ASSIGNMENT_KEY)) {
out.writeString(getNodeAssignment());
}
}

@Override
@@ -316,6 +363,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(METHOD_PARAMETER_SPACE_TYPE, getSpaceType().getValue());
builder.field(DIMENSION, getDimension());
builder.field(KNN_ENGINE, getKnnEngine().getName());
if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(IndexUtil.MODEL_NODE_ASSIGNMENT_KEY)) {
builder.field(MODEL_NODE_ASSIGNMENT, getNodeAssignment());
}
return builder;
}
}
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
@@ -78,6 +78,7 @@
import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardTransportAction;
import org.opensearch.knn.training.TrainingJobClusterStateListener;
import org.opensearch.knn.training.TrainingJobRunner;
import org.opensearch.knn.training.VectorReader;
import org.opensearch.plugins.ActionPlugin;
@@ -200,10 +201,14 @@ public Collection<Object> createComponents(
ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings());
ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance());
TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client);
KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);

clusterService.addListener(TrainingJobClusterStateListener.getInstance());

knnStats = new KNNStats();
return ImmutableList.of(knnStats);
}
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
import org.opensearch.transport.TransportService;

import java.io.IOException;
import java.util.concurrent.ExecutionException;

/**
* Transport action that trains a model and serializes it to model system index
@@ -66,7 +67,8 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener
trainingDataEntryContext,
modelAnonymousEntryContext,
request.getDimension(),
request.getDescription()
request.getDescription(),
clusterService.localNode().getEphemeralId()
);

KNNCounter.TRAINING_REQUESTS.increment();
@@ -84,7 +86,7 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener
wrappedListener::onFailure
)
);
} catch (IOException e) {
} catch (IOException | ExecutionException | InterruptedException e) {
wrappedListener.onFailure(e);
}
}
6 changes: 4 additions & 2 deletions src/main/java/org/opensearch/knn/training/TrainingJob.java
Original file line number Diff line number Diff line change
@@ -65,7 +65,8 @@ public TrainingJob(
NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext,
NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext,
int dimension,
String description
String description,
String nodeAssignment
) {
// Generate random base64 string if one is not provided
this.modelId = StringUtils.isNotBlank(modelId) ? modelId : UUIDs.randomBase64UUID();
@@ -81,7 +82,8 @@ public TrainingJob(
ModelState.TRAINING,
ZonedDateTime.now(ZoneOffset.UTC).toString(),
description,
""
"",
nodeAssignment
),
null,
this.modelId
Loading

0 comments on commit 33da521

Please sign in to comment.