diff --git a/CHANGELOG.md b/CHANGELOG.md index 3328dda52..cfd852424 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Bug Fixes * Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305) * Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318) +* Properly designate model state for actively training models when nodes crash or leave cluster [#1317](https://github.com/opensearch-project/k-NN/pull/1317) + +>>>>>>> main ### Infrastructure * Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289) ### Documentation diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java index 0324822f7..7eb75e24c 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java @@ -257,6 +257,6 @@ public String modelIndexMapping(String fieldName, String modelId) throws IOExcep } private ModelMetadata getModelMetadata() { - return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", ""); + return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", "", ""); } } diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 85654efd0..5b968ce31 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -45,6 +45,7 @@ public class KNNConstants { public static final String MODEL_TIMESTAMP = "timestamp"; public static final String MODEL_DESCRIPTION = "description"; public static final String MODEL_ERROR = "error"; + public static final String MODEL_NODE_ASSIGNMENT = "training_node_assignment"; public static final String PARAM_SIZE = "size"; public static final Integer SEARCH_MODEL_MIN_SIZE = 1; public static final Integer SEARCH_MODEL_MAX_SIZE = 1000; diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index d5db06df6..e98c00197 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -36,10 +36,14 @@ public class IndexUtil { + public static final String MODEL_NODE_ASSIGNMENT_KEY = KNNConstants.MODEL_NODE_ASSIGNMENT; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0; private static final Map minimalRequiredVersionMap = new HashMap() { { put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED); + put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT); } }; @@ -251,4 +255,12 @@ public static boolean isClusterOnOrAfterMinRequiredVersion(String key) { } return KNNClusterUtil.instance().getClusterMinVersion().onOrAfter(minimalRequiredVersion); } + + public static boolean isVersionOnOrAfterMinRequiredVersion(Version version, String key) { + Version minimalRequiredVersion = minimalRequiredVersionMap.get(key); + if (minimalRequiredVersion == null) { + return false; + } + return version.onOrAfter(minimalRequiredVersion); + } } diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index eada08b44..1d88c9a00 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -287,6 +287,7 @@ private void putInternal(Model model, ActionListener listener, Do put(KNNConstants.MODEL_TIMESTAMP, modelMetadata.getTimestamp()); put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription()); put(KNNConstants.MODEL_ERROR, modelMetadata.getError()); + put(KNNConstants.MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment()); } }; diff --git a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java index 0d0c79bc3..04836f184 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java +++ b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java @@ -11,6 +11,7 @@ package org.opensearch.knn.indices; +import lombok.extern.log4j.Log4j2; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.opensearch.core.common.io.stream.StreamInput; @@ -19,6 +20,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; @@ -34,7 +36,9 @@ import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR; import static org.opensearch.knn.common.KNNConstants.MODEL_STATE; import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP; +import static org.opensearch.knn.common.KNNConstants.MODEL_NODE_ASSIGNMENT; +@Log4j2 public class ModelMetadata implements Writeable, ToXContentObject { private static final String DELIMITER = ","; @@ -46,6 +50,7 @@ public class ModelMetadata implements Writeable, ToXContentObject { private AtomicReference state; final private String timestamp; final private String description; + final private String trainingNodeAssignment; private String error; /** @@ -54,6 +59,7 @@ public class ModelMetadata implements Writeable, ToXContentObject { * @param in Stream input */ public ModelMetadata(StreamInput in) throws IOException { + String tempTrainingNodeAssignment; this.knnEngine = KNNEngine.getEngine(in.readString()); this.spaceType = SpaceType.getSpace(in.readString()); this.dimension = in.readInt(); @@ -64,6 +70,12 @@ public ModelMetadata(StreamInput in) throws IOException { // which is checked in constructor and setters this.description = in.readString(); this.error = in.readString(); + + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), IndexUtil.MODEL_NODE_ASSIGNMENT_KEY)) { + this.trainingNodeAssignment = in.readString(); + } else { + this.trainingNodeAssignment = ""; + } } /** @@ -84,7 +96,8 @@ public ModelMetadata( ModelState modelState, String timestamp, String description, - String error + String error, + String trainingNodeAssignment ) { this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null"); this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null"); @@ -104,6 +117,7 @@ public ModelMetadata( this.timestamp = Objects.requireNonNull(timestamp, "timestamp must not be null"); this.description = Objects.requireNonNull(description, "description must not be null"); this.error = Objects.requireNonNull(error, "error must not be null"); + this.trainingNodeAssignment = Objects.requireNonNull(trainingNodeAssignment, "node assignment must not be null"); } /** @@ -169,6 +183,15 @@ public String getError() { return error; } + /** + * getter for model's node assignment + * + * @return trainingNodeAssignment + */ + public String getNodeAssignment() { + return trainingNodeAssignment; + } + /** * setter for model's state * @@ -197,7 +220,8 @@ public String toString() { getState().toString(), timestamp, description, - error + error, + trainingNodeAssignment ); } @@ -240,22 +264,36 @@ public int hashCode() { public static ModelMetadata fromString(String modelMetadataString) { String[] modelMetadataArray = modelMetadataString.split(DELIMITER, -1); - if (modelMetadataArray.length != 7) { + // Training node assignment was added as a field in Version 2.12.0 + // Because models can be created on older versions and the cluster can be upgraded after, + // we need to accept model metadata arrays both with and without the training node assignment. + if (modelMetadataArray.length == 7) { + log.debug("Model metadata array does not contain training node assignment. Assuming empty string."); + KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); + SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); + int dimension = Integer.parseInt(modelMetadataArray[2]); + ModelState modelState = ModelState.getModelState(modelMetadataArray[3]); + String timestamp = modelMetadataArray[4]; + String description = modelMetadataArray[5]; + String error = modelMetadataArray[6]; + return new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error, ""); + } else if (modelMetadataArray.length == 8) { + log.debug("Model metadata contains training node assignment"); + KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); + SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); + int dimension = Integer.parseInt(modelMetadataArray[2]); + ModelState modelState = ModelState.getModelState(modelMetadataArray[3]); + String timestamp = modelMetadataArray[4]; + String description = modelMetadataArray[5]; + String error = modelMetadataArray[6]; + String trainingNodeAssignment = modelMetadataArray[7]; + return new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error, trainingNodeAssignment); + } else { throw new IllegalArgumentException( "Illegal format for model metadata. Must be of the form " - + "\",,,,,,\"." + + "\",,,,,,\" or \",,,,,,,\"." ); } - - KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); - SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); - int dimension = Integer.parseInt(modelMetadataArray[2]); - ModelState modelState = ModelState.getModelState(modelMetadataArray[3]); - String timestamp = modelMetadataArray[4]; - String description = modelMetadataArray[5]; - String error = modelMetadataArray[6]; - - return new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error); } private static String objectToString(Object value) { @@ -282,6 +320,11 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m Object timestamp = modelSourceMap.get(KNNConstants.MODEL_TIMESTAMP); Object description = modelSourceMap.get(KNNConstants.MODEL_DESCRIPTION); Object error = modelSourceMap.get(KNNConstants.MODEL_ERROR); + Object trainingNodeAssignment = modelSourceMap.get(KNNConstants.MODEL_NODE_ASSIGNMENT); + + if (trainingNodeAssignment == null) { + trainingNodeAssignment = ""; + } ModelMetadata modelMetadata = new ModelMetadata( KNNEngine.getEngine(objectToString(engine)), @@ -290,7 +333,8 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m ModelState.getModelState(objectToString(state)), objectToString(timestamp), objectToString(description), - objectToString(error) + objectToString(error), + objectToString(trainingNodeAssignment) ); return modelMetadata; } @@ -304,6 +348,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(getTimestamp()); out.writeString(getDescription()); out.writeString(getError()); + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_NODE_ASSIGNMENT_KEY)) { + out.writeString(getNodeAssignment()); + } } @Override @@ -316,6 +363,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(METHOD_PARAMETER_SPACE_TYPE, getSpaceType().getValue()); builder.field(DIMENSION, getDimension()); builder.field(KNN_ENGINE, getKnnEngine().getName()); + if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(IndexUtil.MODEL_NODE_ASSIGNMENT_KEY)) { + builder.field(MODEL_NODE_ASSIGNMENT, getNodeAssignment()); + } return builder; } } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 6e6a2b21c..2e5a55092 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -78,6 +78,7 @@ import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardTransportAction; +import org.opensearch.knn.training.TrainingJobClusterStateListener; import org.opensearch.knn.training.TrainingJobRunner; import org.opensearch.knn.training.VectorReader; import org.opensearch.plugins.ActionPlugin; @@ -200,10 +201,14 @@ public Collection createComponents( ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings()); ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance()); + TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client); KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); + + clusterService.addListener(TrainingJobClusterStateListener.getInstance()); + knnStats = new KNNStats(); return ImmutableList.of(knnStats); } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java index 379d8a809..33b420e2c 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java @@ -26,6 +26,7 @@ import org.opensearch.transport.TransportService; import java.io.IOException; +import java.util.concurrent.ExecutionException; /** * Transport action that trains a model and serializes it to model system index @@ -66,7 +67,8 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener trainingDataEntryContext, modelAnonymousEntryContext, request.getDimension(), - request.getDescription() + request.getDescription(), + clusterService.localNode().getEphemeralId() ); KNNCounter.TRAINING_REQUESTS.increment(); @@ -84,7 +86,7 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener wrappedListener::onFailure ) ); - } catch (IOException e) { + } catch (IOException | ExecutionException | InterruptedException e) { wrappedListener.onFailure(e); } } diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index 27a1c6025..8a2af4319 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -65,7 +65,8 @@ public TrainingJob( NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext, NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext, int dimension, - String description + String description, + String nodeAssignment ) { // Generate random base64 string if one is not provided this.modelId = StringUtils.isNotBlank(modelId) ? modelId : UUIDs.randomBase64UUID(); @@ -81,7 +82,8 @@ public TrainingJob( ModelState.TRAINING, ZonedDateTime.now(ZoneOffset.UTC).toString(), description, - "" + "", + nodeAssignment ), null, this.modelId diff --git a/src/main/java/org/opensearch/knn/training/TrainingJobClusterStateListener.java b/src/main/java/org/opensearch/knn/training/TrainingJobClusterStateListener.java new file mode 100644 index 000000000..45d2197e8 --- /dev/null +++ b/src/main/java/org/opensearch/knn/training/TrainingJobClusterStateListener.java @@ -0,0 +1,176 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.training; + +import lombok.extern.log4j.Log4j2; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.cluster.ClusterChangedEvent; +import org.opensearch.cluster.ClusterStateListener; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; +import org.opensearch.knn.indices.Model; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelState; +import org.opensearch.search.SearchHit; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; + +/** + * TrainingJobClusterStateListener is a ClusterStateListener that is used to update models that are still training when a node leaves or the cluster crashes. + * This class also sets a flag in TrainingJobRunner to block serialization when a node rejoins a cluster. + */ +@Log4j2 +public class TrainingJobClusterStateListener implements ClusterStateListener { + private static TrainingJobClusterStateListener INSTANCE; + + private static ModelDao modelDao; + private static ThreadPool threadPool; + private static ClusterService clusterService; + private String oldClusterManagerNodeId = ""; + private String currentClusterManagerNodeId = ""; + private boolean clusterManagerNodeRemoved = false; + + /** + * Get singleton instance of TrainingJobRunner + * + * @return singleton instance of TrainingJobRunner + */ + public static synchronized TrainingJobClusterStateListener getInstance() { + if (INSTANCE == null) { + INSTANCE = new TrainingJobClusterStateListener(); + } + return INSTANCE; + } + + /** + * Initializes static components. + * + * @param threadPool threadPool to use to schedule update of models + * @param modelDao modelDao used to get modelIds + * @param clusterService clusterService used to add a listener + */ + public static synchronized void initialize(ThreadPool threadPool, ModelDao modelDao, ClusterService clusterService) { + TrainingJobClusterStateListener.threadPool = threadPool; + TrainingJobClusterStateListener.modelDao = modelDao; + TrainingJobClusterStateListener.clusterService = clusterService; + } + + /** + * This method is called whenever the cluster state changes. + * It is used to update models that are still training when a node leaves or the cluster crashes. + * It is also used to cancel training jobs when a node rejoins the cluster. + * @param event the event that changed the cluster change + */ + @Override + public void clusterChanged(ClusterChangedEvent event) { + if (event.localNodeClusterManager()) { + if (event.isNewCluster()) { + // When the cluster is first created, the cluster manager will update models that are still marked as training. + threadPool.schedule(() -> { + try { + updateModelsNewCluster(); + } catch (IOException | InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }, TimeValue.timeValueSeconds(1), ThreadPool.Names.GENERIC); + } else if (event.nodesRemoved()) { + List removedNodes = event.nodesDelta().removedNodes(); + threadPool.schedule(() -> { + try { + updateModelsNodesRemoved(removedNodes); + } catch (IOException | InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }, TimeValue.timeValueSeconds(0), ThreadPool.Names.GENERIC); + } + } + } + + protected void updateModelsNewCluster() throws IOException, InterruptedException, ExecutionException { + if (modelDao.isCreated()) { + List modelIds = searchModelIds(); + for (String modelId : modelIds) { + Model model = modelDao.get(modelId); + ModelMetadata modelMetadata = model.getModelMetadata(); + if (modelMetadata.getState().equals(ModelState.TRAINING)) { + updateModelStateAsFailed(model, "Training failed to complete as cluster crashed"); + } + } + } + } + + protected void updateModelsNodesRemoved(List removedNodes) throws IOException, InterruptedException, ExecutionException { + if (modelDao.isCreated()) { + List modelIds = searchModelIds(); + for (DiscoveryNode removedNode : removedNodes) { + for (String modelId : modelIds) { + Model model = modelDao.get(modelId); + ModelMetadata modelMetadata = model.getModelMetadata(); + if (modelMetadata.getNodeAssignment().equals(removedNode.getEphemeralId()) + && modelMetadata.getState().equals(ModelState.TRAINING)) { + updateModelStateAsFailed(model, "Training failed to complete as node dropped"); + } + } + } + } + } + + private List searchModelIds() throws IOException, InterruptedException { + List modelIds = new ArrayList(); + CountDownLatch latch = new CountDownLatch(1); + modelDao.search(new SearchRequest(), new ActionListener() { + @Override + public void onResponse(SearchResponse searchResponse) { + try { + for (SearchHit searchHit : searchResponse.getHits().getHits()) { + modelIds.add(searchHit.getId()); + } + } finally { + latch.countDown(); + } + } + + @Override + public void onFailure(Exception e) { + latch.countDown(); + } + }); + latch.await(); + return modelIds; + } + + private void updateModelStateAsFailed(Model model, String msg) throws IOException { + model.getModelMetadata().setState(ModelState.FAILED); + model.getModelMetadata().setError(msg); + modelDao.update(model, new ActionListener() { + @Override + public void onResponse(IndexResponse indexResponse) { + log.info("Model {} marked as {}", model.getModelID(), model.getModelMetadata().getState()); + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to update model state", e); + } + }); + } +} diff --git a/src/main/java/org/opensearch/knn/training/TrainingJobRunner.java b/src/main/java/org/opensearch/knn/training/TrainingJobRunner.java index a999735a4..8884f8102 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJobRunner.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJobRunner.java @@ -16,6 +16,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.action.index.IndexResponse; import org.opensearch.common.ValidationException; +import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; @@ -23,6 +24,7 @@ import org.opensearch.threadpool.ThreadPool; import java.io.IOException; +import java.util.concurrent.ExecutionException; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicInteger; @@ -79,7 +81,8 @@ public static void initialize(ThreadPool threadPool, ModelDao modelDao) { * @param trainingJob training job to be executed * @param listener listener to handle final model serialization response (or exception) */ - public void execute(TrainingJob trainingJob, ActionListener listener) throws IOException { + public void execute(TrainingJob trainingJob, ActionListener listener) throws IOException, ExecutionException, + InterruptedException { // If the semaphore cannot be acquired, the node is unable to execute this job. This allows us to limit // the number of training jobs that enter this function. Although the training threadpool size will also prevent // this, we want to prevent this before we perform any serialization. @@ -106,10 +109,10 @@ public void execute(TrainingJob trainingJob, ActionListener liste logger.error("Unable to initialize model serialization: " + exception.getMessage()); listener.onFailure(exception); }), false); - } catch (IOException ioe) { + } catch (IOException | ExecutionException | InterruptedException e) { jobCount.decrementAndGet(); semaphore.release(); - throw ioe; + throw e; } } @@ -130,7 +133,7 @@ private void train(TrainingJob trainingJob) { try { trainingJob.run(); serializeModel(trainingJob, loggingListener, true); - } catch (IOException e) { + } catch (IOException | ExecutionException | InterruptedException e) { logger.error("Unable to serialize model \"" + trainingJob.getModelId() + "\": " + e.getMessage()); KNNCounter.TRAINING_ERRORS.increment(); } catch (Exception e) { @@ -150,8 +153,8 @@ private void train(TrainingJob trainingJob) { try { serializeModel(trainingJob, loggingListener, true); - } catch (IOException ioe) { - logger.error("Unable to serialize the failure for model \"" + trainingJob.getModelId() + "\": " + ioe); + } catch (IOException | ExecutionException | InterruptedException e) { + logger.error("Unable to serialize the failure for model \"{}\": ", trainingJob.getModelId(), e); } finally { jobCount.decrementAndGet(); semaphore.release(); @@ -160,9 +163,15 @@ private void train(TrainingJob trainingJob) { } } - private void serializeModel(TrainingJob trainingJob, ActionListener listener, boolean update) throws IOException { + private void serializeModel(TrainingJob trainingJob, ActionListener listener, boolean update) throws IOException, + ExecutionException, InterruptedException { if (update) { - modelDao.update(trainingJob.getModel(), listener); + Model model = modelDao.get(trainingJob.getModelId()); + if (model.getModelMetadata().getState().equals(ModelState.TRAINING)) { + modelDao.update(trainingJob.getModel(), listener); + } else { + logger.info("Model state is {}. Skipping serialization of trained data", model.getModelMetadata().getState()); + } } else { modelDao.put(trainingJob.getModel(), listener); } diff --git a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java index c1c52e63a..8fdc55766 100644 --- a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java @@ -61,7 +61,8 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", - "" + "", + "test-node" ); Model model = new Model(modelMetadata, modelBlob, modelId); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index 6af83de87..242aeaa17 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -341,7 +341,7 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio byte[] modelBytes = JNIService.trainIndex(parameters, dimension, trainingPtr, knnEngine.getName()); Model model = new Model( - new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED, "timestamp", "Empty description", ""), + new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED, "timestamp", "Empty description", "", ""), modelBytes, modelId ); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 6c7631216..eb9b4fa2d 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -204,6 +204,7 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ); diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 1f3598781..2de98d803 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -163,6 +163,7 @@ public void testBuilder_build_fromModel() { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ); builder.modelId.setValue(modelId); @@ -689,6 +690,7 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ); when(mockModelDao.getMetadata(modelId)).thenReturn(mockModelMetadata); diff --git a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java index fb810d969..3146d898e 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java @@ -42,6 +42,7 @@ public void testGet_normal() throws ExecutionException, InterruptedException { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), "hello".getBytes(), @@ -77,6 +78,7 @@ public void testGet_modelDoesNotFitInCache() throws ExecutionException, Interrup ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), new byte[BYTES_PER_KILOBYTES + 1], @@ -133,6 +135,7 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), new byte[size1], @@ -147,6 +150,7 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), new byte[size2], @@ -189,6 +193,7 @@ public void testRemove_normal() throws ExecutionException, InterruptedException ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), new byte[size1], @@ -203,6 +208,7 @@ public void testRemove_normal() throws ExecutionException, InterruptedException ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), new byte[size2], @@ -250,6 +256,7 @@ public void testRebuild_normal() throws ExecutionException, InterruptedException ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), "hello".getBytes(), @@ -294,6 +301,7 @@ public void testRebuild_afterSettingUpdate() throws ExecutionException, Interrup ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), new byte[modelSize], @@ -361,6 +369,7 @@ public void testContains() throws ExecutionException, InterruptedException { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), new byte[modelSize1], @@ -401,6 +410,7 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), new byte[modelSize1], @@ -417,6 +427,7 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), new byte[modelSize2], @@ -461,6 +472,7 @@ public void testModelCacheEvictionDueToSize() throws ExecutionException, Interru ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), new byte[BYTES_PER_KILOBYTES * 2], diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index cb6628c16..1297dc184 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -136,6 +136,7 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), modelBlob, @@ -154,6 +155,7 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti ModelState.FAILED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), modelBlob, @@ -180,6 +182,7 @@ public void testPut_withId() throws InterruptedException, IOException { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), modelBlob, @@ -239,6 +242,7 @@ public void testPut_withoutModel() throws InterruptedException, IOException { ModelState.TRAINING, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), modelBlob, @@ -299,6 +303,7 @@ public void testPut_invalid_badState() { ModelState.TRAINING, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), modelBlob, @@ -334,6 +339,7 @@ public void testUpdate() throws IOException, InterruptedException { ModelState.TRAINING, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), null, @@ -371,6 +377,7 @@ public void testUpdate() throws IOException, InterruptedException { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), modelBlob, @@ -420,6 +427,7 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), modelBlob, @@ -437,6 +445,7 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti ModelState.TRAINING, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), null, @@ -472,6 +481,7 @@ public void testGetMetadata() throws IOException, InterruptedException { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ); @@ -481,7 +491,6 @@ public void testGetMetadata() throws IOException, InterruptedException { final CountDownLatch inProgressLatch1 = new CountDownLatch(1); ActionListener docCreationListener = ActionListener.wrap(response -> { assertEquals(modelId, response.getId()); - ModelMetadata modelMetadata1 = modelDao.getMetadata(modelId); assertEquals(modelMetadata, modelMetadata1); @@ -548,6 +557,7 @@ public void testDelete() throws IOException, InterruptedException { ModelState.TRAINING, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), modelBlob, @@ -580,6 +590,7 @@ public void testDelete() throws IOException, InterruptedException { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), modelBlob, @@ -646,6 +657,7 @@ public void testDeleteModelInTrainingWithStepListeners() throws IOException, Exe ModelState.TRAINING, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), modelBlob, @@ -686,6 +698,7 @@ public void testDeleteWithStepListeners() throws IOException, InterruptedExcepti ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), modelBlob, diff --git a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java index a2e5c6bbe..219710308 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java @@ -38,6 +38,7 @@ public void testStreams() throws IOException { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ); @@ -58,6 +59,7 @@ public void testGetKnnEngine() { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ); @@ -73,6 +75,7 @@ public void testGetSpaceType() { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ); @@ -88,6 +91,7 @@ public void testGetDimension() { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ); @@ -103,6 +107,7 @@ public void testGetState() { modelState, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ); @@ -111,7 +116,7 @@ public void testGetState() { public void testGetTimestamp() { String timeValue = ZonedDateTime.now(ZoneOffset.UTC).toString(); - ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, ModelState.CREATED, timeValue, "", ""); + ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, ModelState.CREATED, timeValue, "", "", ""); assertEquals(timeValue, modelMetadata.getTimestamp()); } @@ -125,6 +130,7 @@ public void testDescription() { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), description, + "", "" ); @@ -140,7 +146,8 @@ public void testGetError() { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", - error + error, + "" ); assertEquals(error, modelMetadata.getError()); @@ -155,6 +162,7 @@ public void testSetState() { modelState, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ); @@ -174,7 +182,8 @@ public void testSetError() { ModelState.TRAINING, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", - error + error, + "" ); assertEquals(error, modelMetadata.getError()); @@ -192,6 +201,7 @@ public void testToString() { String timestamp = ZonedDateTime.now(ZoneOffset.UTC).toString(); String description = "test-description"; String error = "test-error"; + String nodeAssignment = ""; String expected = knnEngine.getName() + "," @@ -205,9 +215,20 @@ public void testToString() { + "," + description + "," - + error; + + error + + "," + + nodeAssignment; - ModelMetadata modelMetadata = new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error); + ModelMetadata modelMetadata = new ModelMetadata( + knnEngine, + spaceType, + dimension, + modelState, + timestamp, + description, + error, + nodeAssignment + ); assertEquals(expected, modelMetadata.toString()); } @@ -217,14 +238,14 @@ public void testEquals() { String time1 = ZonedDateTime.now(ZoneOffset.UTC).toString(); String time2 = ZonedDateTime.of(2021, 9, 30, 12, 20, 45, 1, ZoneId.systemDefault()).toString(); - ModelMetadata modelMetadata1 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", ""); - ModelMetadata modelMetadata2 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", ""); + ModelMetadata modelMetadata1 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", "", ""); + ModelMetadata modelMetadata2 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", "", ""); - ModelMetadata modelMetadata3 = new ModelMetadata(KNNEngine.NMSLIB, SpaceType.L2, 128, ModelState.CREATED, time1, "", ""); - ModelMetadata modelMetadata4 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L1, 128, ModelState.CREATED, time1, "", ""); - ModelMetadata modelMetadata5 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 129, ModelState.CREATED, time1, "", ""); - ModelMetadata modelMetadata6 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.TRAINING, time1, "", ""); - ModelMetadata modelMetadata7 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time2, "", ""); + ModelMetadata modelMetadata3 = new ModelMetadata(KNNEngine.NMSLIB, SpaceType.L2, 128, ModelState.CREATED, time1, "", "", ""); + ModelMetadata modelMetadata4 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L1, 128, ModelState.CREATED, time1, "", "", ""); + ModelMetadata modelMetadata5 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 129, ModelState.CREATED, time1, "", "", ""); + ModelMetadata modelMetadata6 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.TRAINING, time1, "", "", ""); + ModelMetadata modelMetadata7 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time2, "", "", ""); ModelMetadata modelMetadata8 = new ModelMetadata( KNNEngine.FAISS, SpaceType.L2, @@ -232,9 +253,19 @@ public void testEquals() { ModelState.CREATED, time1, "diff descript", + "", + "" + ); + ModelMetadata modelMetadata9 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 128, + ModelState.CREATED, + time1, + "", + "diff error", "" ); - ModelMetadata modelMetadata9 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", "diff error"); assertEquals(modelMetadata1, modelMetadata1); assertEquals(modelMetadata1, modelMetadata2); @@ -254,14 +285,14 @@ public void testHashCode() { String time1 = ZonedDateTime.now(ZoneOffset.UTC).toString(); String time2 = ZonedDateTime.of(2021, 9, 30, 12, 20, 45, 1, ZoneId.systemDefault()).toString(); - ModelMetadata modelMetadata1 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", ""); - ModelMetadata modelMetadata2 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", ""); + ModelMetadata modelMetadata1 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", "", ""); + ModelMetadata modelMetadata2 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", "", ""); - ModelMetadata modelMetadata3 = new ModelMetadata(KNNEngine.NMSLIB, SpaceType.L2, 128, ModelState.CREATED, time1, "", ""); - ModelMetadata modelMetadata4 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L1, 128, ModelState.CREATED, time1, "", ""); - ModelMetadata modelMetadata5 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 129, ModelState.CREATED, time1, "", ""); - ModelMetadata modelMetadata6 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.TRAINING, time1, "", ""); - ModelMetadata modelMetadata7 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time2, "", ""); + ModelMetadata modelMetadata3 = new ModelMetadata(KNNEngine.NMSLIB, SpaceType.L2, 128, ModelState.CREATED, time1, "", "", ""); + ModelMetadata modelMetadata4 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L1, 128, ModelState.CREATED, time1, "", "", ""); + ModelMetadata modelMetadata5 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 129, ModelState.CREATED, time1, "", "", ""); + ModelMetadata modelMetadata6 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.TRAINING, time1, "", "", ""); + ModelMetadata modelMetadata7 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time2, "", "", ""); ModelMetadata modelMetadata8 = new ModelMetadata( KNNEngine.FAISS, SpaceType.L2, @@ -269,9 +300,19 @@ public void testHashCode() { ModelState.CREATED, time1, "diff descript", + "", + "" + ); + ModelMetadata modelMetadata9 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 128, + ModelState.CREATED, + time1, + "", + "diff error", "" ); - ModelMetadata modelMetadata9 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", "diff error"); assertEquals(modelMetadata1.hashCode(), modelMetadata1.hashCode()); assertEquals(modelMetadata1.hashCode(), modelMetadata2.hashCode()); @@ -294,8 +335,25 @@ public void testFromString() { String timestamp = ZonedDateTime.now(ZoneOffset.UTC).toString(); String description = "test-description"; String error = "test-error"; + String nodeAssignment = "test-node"; String stringRep1 = knnEngine.getName() + + "," + + spaceType.getValue() + + "," + + dimension + + "," + + modelState.getName() + + "," + + timestamp + + "," + + description + + "," + + error + + "," + + nodeAssignment; + + String stringRep2 = knnEngine.getName() + "," + spaceType.getValue() + "," @@ -309,10 +367,24 @@ public void testFromString() { + "," + error; - ModelMetadata expected = new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error); + ModelMetadata expected1 = new ModelMetadata( + knnEngine, + spaceType, + dimension, + modelState, + timestamp, + description, + error, + nodeAssignment + ); + + ModelMetadata expected2 = new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error, ""); + ModelMetadata fromString1 = ModelMetadata.fromString(stringRep1); + ModelMetadata fromString2 = ModelMetadata.fromString(stringRep2); - assertEquals(expected, fromString1); + assertEquals(expected1, fromString1); + assertEquals(expected2, fromString2); expectThrows(IllegalArgumentException.class, () -> ModelMetadata.fromString("invalid")); } @@ -325,8 +397,19 @@ public void testFromResponseMap() { String timestamp = ZonedDateTime.now(ZoneOffset.UTC).toString(); String description = "test-description"; String error = "test-error"; + String nodeAssignment = "test-node"; - ModelMetadata expected = new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error); + ModelMetadata expected = new ModelMetadata( + knnEngine, + spaceType, + dimension, + modelState, + timestamp, + description, + error, + nodeAssignment + ); + ModelMetadata expected2 = new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error, ""); Map metadataAsMap = new HashMap<>(); metadataAsMap.put(KNNConstants.KNN_ENGINE, knnEngine.getName()); metadataAsMap.put(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()); @@ -335,8 +418,13 @@ public void testFromResponseMap() { metadataAsMap.put(KNNConstants.MODEL_TIMESTAMP, timestamp); metadataAsMap.put(KNNConstants.MODEL_DESCRIPTION, description); metadataAsMap.put(KNNConstants.MODEL_ERROR, error); + metadataAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, nodeAssignment); ModelMetadata fromMap = ModelMetadata.getMetadataFromSourceMap(metadataAsMap); assertEquals(expected, fromMap); + + metadataAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, null); + assertEquals(expected2, fromMap); + } } diff --git a/src/test/java/org/opensearch/knn/indices/ModelStateTests.java b/src/test/java/org/opensearch/knn/indices/ModelStateTests.java index 1cc8a0e82..1527de539 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelStateTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelStateTests.java @@ -31,5 +31,8 @@ public void testStreams() throws IOException { public void testGetModelState() { assertEquals(ModelState.CREATED, ModelState.getModelState(ModelState.CREATED.getName())); + assertEquals(ModelState.TRAINING, ModelState.getModelState(ModelState.TRAINING.getName())); + assertEquals(ModelState.FAILED, ModelState.getModelState(ModelState.FAILED.getName())); + expectThrows(IllegalArgumentException.class, () -> ModelState.getModelState("throw-exception")); } } diff --git a/src/test/java/org/opensearch/knn/indices/ModelTests.java b/src/test/java/org/opensearch/knn/indices/ModelTests.java index fd3173431..c015e8d62 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelTests.java @@ -38,6 +38,7 @@ public void testInvalidConstructor() { ModelState.FAILED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), null, @@ -57,6 +58,7 @@ public void testInvalidDimension() { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), new byte[16], @@ -73,6 +75,7 @@ public void testInvalidDimension() { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), new byte[16], @@ -89,6 +92,7 @@ public void testInvalidDimension() { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), new byte[16], @@ -106,6 +110,7 @@ public void testGetModelMetadata() { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ); Model model = new Model(modelMetadata, new byte[16], "test-model"); @@ -122,6 +127,7 @@ public void testGetModelBlob() { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), modelBlob, @@ -140,6 +146,7 @@ public void testGetLength() { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), new byte[size], @@ -155,6 +162,7 @@ public void testGetLength() { ModelState.TRAINING, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ), null, @@ -166,7 +174,16 @@ public void testGetLength() { public void testSetModelBlob() { byte[] blob1 = "Hello blob 1".getBytes(); Model model = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L1, + 2, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "", + "" + ), blob1, "test-model" ); @@ -182,17 +199,17 @@ public void testEquals() { String time = ZonedDateTime.now(ZoneOffset.UTC).toString(); Model model1 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", ""), + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", ""), new byte[16], "test-model-1" ); Model model2 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", ""), + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", ""), new byte[16], "test-model-1" ); Model model3 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 2, ModelState.CREATED, time, "", ""), + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 2, ModelState.CREATED, time, "", "", ""), new byte[16], "test-model-2" ); @@ -207,17 +224,17 @@ public void testHashCode() { String time = ZonedDateTime.now(ZoneOffset.UTC).toString(); Model model1 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", ""), + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", ""), new byte[16], "test-model-1" ); Model model2 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", ""), + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", ""), new byte[16], "test-model-1" ); Model model3 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", ""), + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", ""), new byte[16], "test-model-2" ); @@ -236,8 +253,18 @@ public void testModelFromSourceMap() { String timestamp = ZonedDateTime.now(ZoneOffset.UTC).toString(); String description = "test-description"; String error = "test-error"; + String nodeAssignment = "test-node"; - ModelMetadata metadata = new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error); + ModelMetadata metadata = new ModelMetadata( + knnEngine, + spaceType, + dimension, + modelState, + timestamp, + description, + error, + nodeAssignment + ); Map modelAsMap = new HashMap<>(); modelAsMap.put(KNNConstants.MODEL_ID, modelID); modelAsMap.put(KNNConstants.KNN_ENGINE, knnEngine.getName()); @@ -247,6 +274,7 @@ public void testModelFromSourceMap() { modelAsMap.put(KNNConstants.MODEL_TIMESTAMP, timestamp); modelAsMap.put(KNNConstants.MODEL_DESCRIPTION, description); modelAsMap.put(KNNConstants.MODEL_ERROR, error); + modelAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, nodeAssignment); modelAsMap.put(KNNConstants.MODEL_BLOB_PARAMETER, "aGVsbG8="); byte[] blob1 = "hello".getBytes(); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java index 7dbe8d950..ae3e71528 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java @@ -48,7 +48,7 @@ public class RestSearchModelHandlerIT extends KNNRestTestCase { private ModelMetadata getModelMetadata() { - return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", ""); + return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", "", ""); } public void testNotSupportedParams() throws IOException { diff --git a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java index 56788840f..05adb1cf4 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java @@ -26,7 +26,7 @@ public class GetModelResponseTests extends KNNTestCase { private ModelMetadata getModelMetadata(ModelState state) { - return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, state, "2021-03-27 10:15:30 AM +05:30", "test model", ""); + return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, state, "2021-03-27 10:15:30 AM +05:30", "test model", "", ""); } public void testStreams() throws IOException { @@ -46,7 +46,7 @@ public void testXContent() throws IOException { Model model = new Model(getModelMetadata(ModelState.CREATED), testModelBlob, modelId); GetModelResponse getModelResponse = new GetModelResponse(model); String expectedResponseString = - "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\"}"; + "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\"}"; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); getModelResponse.toXContent(xContentBuilder, null); assertEquals(expectedResponseString, xContentBuilder.toString()); @@ -57,7 +57,7 @@ public void testXContentWithNoModelBlob() throws IOException { Model model = new Model(getModelMetadata(ModelState.FAILED), null, modelId); GetModelResponse getModelResponse = new GetModelResponse(model); String expectedResponseString = - "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\"}"; + "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\"}"; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); getModelResponse.toXContent(xContentBuilder, null); assertEquals(expectedResponseString, xContentBuilder.toString()); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java index ae89d83e1..5d30f54bb 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java @@ -68,7 +68,7 @@ public void testNodeOperation_modelInCache() throws ExecutionException, Interrup ModelDao modelDao = mock(ModelDao.class); String modelId = "test-model-id"; Model model = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 16, ModelState.CREATED, "timestamp", "description", ""), + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 16, ModelState.CREATED, "timestamp", "description", "", ""), new byte[128], modelId ); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportActionTests.java index 15d84638d..9f64afebb 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportActionTests.java @@ -19,15 +19,14 @@ import org.opensearch.knn.KNNSingleNodeTestCase; import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelState; import org.opensearch.knn.training.TrainingJob; import org.opensearch.knn.training.TrainingJobRunner; import org.opensearch.threadpool.ThreadPool; import java.io.IOException; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.*; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; @@ -51,7 +50,7 @@ public void teardown() { } @SuppressWarnings("unchecked") - public void testNodeOperation() throws IOException, InterruptedException { + public void testNodeOperation() throws IOException, InterruptedException, ExecutionException { // Ensure initial value of train job count is 0 TrainingJobRouteDecisionInfoTransportAction action = node().injector() .getInstance(TrainingJobRouteDecisionInfoTransportAction.class); @@ -64,12 +63,16 @@ public void testNodeOperation() throws IOException, InterruptedException { // Setup mocked training job String modelId = "model-id"; Model model = mock(Model.class); + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getState()).thenReturn(ModelState.TRAINING); + when(model.getModelMetadata()).thenReturn(modelMetadata); TrainingJob trainingJob = mock(TrainingJob.class); when(trainingJob.getModelId()).thenReturn(modelId); when(trainingJob.getModel()).thenReturn(model); doAnswer(invocationOnMock -> null).when(trainingJob).run(); ModelDao modelDao = mock(ModelDao.class); + when(modelDao.get(modelId)).thenReturn(model); // Here we check to make sure there is a running job doAnswer(invocationOnMock -> { diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index 09ca6c73b..7465ccc58 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -167,6 +167,7 @@ public void testValidation_invalid_modelIdAlreadyExists() { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java index f38273f15..a41ca900a 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java @@ -39,6 +39,7 @@ public void testStreams() throws IOException { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ); UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest(modelId, isRemoveRequest, modelMetadata); @@ -62,6 +63,7 @@ public void testValidate() { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ); @@ -100,6 +102,7 @@ public void testGetModelMetadata() { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ); UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest("test", true, modelMetadata); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java index eb3ecf168..11961f6f5 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java @@ -65,6 +65,7 @@ public void testClusterManagerOperation() throws InterruptedException { ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", + "", "" ); diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobClusterStateListenerTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobClusterStateListenerTests.java new file mode 100644 index 000000000..7994e73d2 --- /dev/null +++ b/src/test/java/org/opensearch/knn/training/TrainingJobClusterStateListenerTests.java @@ -0,0 +1,179 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.training; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.cluster.ClusterChangedEvent; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.indices.Model; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelState; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; +import static org.opensearch.knn.common.KNNConstants.TRAIN_THREAD_POOL; + +public class TrainingJobClusterStateListenerTests extends KNNTestCase { + public void testClusterChanged() throws InterruptedException { + ExecutorService executorService = Executors.newSingleThreadExecutor(); + + TrainingJobClusterStateListener trainingJobClusterStateListener = TrainingJobClusterStateListener.getInstance(); + + ThreadPool threadPool = mock(ThreadPool.class); + when(threadPool.executor(TRAIN_THREAD_POOL)).thenReturn(executorService); + doAnswer(invocationOnMock -> { return null; }).when(threadPool) + .schedule(any(Runnable.class), any(TimeValue.class), any(String.class)); + + ModelDao modelDao = mock(ModelDao.class); + ClusterChangedEvent clusterChangedEvent = mock(ClusterChangedEvent.class); + when(clusterChangedEvent.localNodeClusterManager()).thenReturn(true); + when(clusterChangedEvent.isNewCluster()).thenReturn(true); + + TrainingJobClusterStateListener.initialize(threadPool, modelDao, clusterService); + + trainingJobClusterStateListener.clusterChanged(clusterChangedEvent); + + verify(threadPool, times(1)).schedule(any(Runnable.class), any(TimeValue.class), any(String.class)); + + when(clusterChangedEvent.isNewCluster()).thenReturn(false); + when(clusterChangedEvent.nodesRemoved()).thenReturn(true); + DiscoveryNodes.Delta delta = mock(DiscoveryNodes.Delta.class); + List nodes = new ArrayList<>(); + when(clusterChangedEvent.nodesDelta()).thenReturn(delta); + when(delta.removedNodes()).thenReturn(nodes); + + trainingJobClusterStateListener.clusterChanged(clusterChangedEvent); + + verify(threadPool, times(2)).schedule(any(Runnable.class), any(TimeValue.class), any(String.class)); + verify(clusterChangedEvent, times(1)).nodesDelta(); + + when(clusterChangedEvent.nodesRemoved()).thenReturn(false); + trainingJobClusterStateListener.clusterChanged(clusterChangedEvent); + verify(threadPool, times(2)).schedule(any(Runnable.class), any(TimeValue.class), any(String.class)); + + when(clusterChangedEvent.localNodeClusterManager()).thenReturn(false); + trainingJobClusterStateListener.clusterChanged(clusterChangedEvent); + verify(threadPool, times(2)).schedule(any(Runnable.class), any(TimeValue.class), any(String.class)); + + executorService.shutdown(); + executorService.awaitTermination(10, TimeUnit.SECONDS); + } + + public void testUpdateModelsNewCluster() throws IOException, InterruptedException, ExecutionException { + ExecutorService executorService = Executors.newSingleThreadExecutor(); + + TrainingJobClusterStateListener trainingJobClusterStateListener = TrainingJobClusterStateListener.getInstance(); + + ThreadPool threadPool = mock(ThreadPool.class); + when(threadPool.executor(TRAIN_THREAD_POOL)).thenReturn(executorService); + + String modelId = "test-model-id"; + Model model = mock(Model.class); + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getState()).thenReturn(ModelState.TRAINING); + when(model.getModelMetadata()).thenReturn(modelMetadata); + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.isCreated()).thenReturn(true); + when(modelDao.get(modelId)).thenReturn(model); + doAnswer(invocationOnMock -> { + SearchResponse searchResponse = mock(SearchResponse.class); + SearchHits searchHits = mock(SearchHits.class); + when(searchResponse.getHits()).thenReturn(searchHits); + SearchHit searchHit = mock(SearchHit.class); + when(searchHit.getId()).thenReturn(modelId); + SearchHit[] searchHitArray = new SearchHit[1]; + searchHitArray[0] = searchHit; + when(searchHits.getHits()).thenReturn(searchHitArray); + ((ActionListener) invocationOnMock.getArguments()[1]).onResponse(searchResponse); + return null; + }).when(modelDao).search(any(SearchRequest.class), any(ActionListener.class)); + doAnswer(invocationOnMock -> { return null; }).when(modelDao).update(any(Model.class), any(ActionListener.class)); + + TrainingJobClusterStateListener.initialize(threadPool, modelDao, clusterService); + + trainingJobClusterStateListener.updateModelsNewCluster(); + + executorService.shutdown(); + executorService.awaitTermination(10, TimeUnit.SECONDS); + + verify(modelMetadata, times(1)).setState(ModelState.FAILED); + verify(modelMetadata, times(1)).setError("Training failed to complete as cluster crashed"); + verify(modelDao, times(1)).update(any(Model.class), any(ActionListener.class)); + } + + public void testUpdateModelsNodesRemoved() throws IOException, InterruptedException, ExecutionException { + ExecutorService executorService = Executors.newSingleThreadExecutor(); + + TrainingJobClusterStateListener trainingJobClusterStateListener = TrainingJobClusterStateListener.getInstance(); + + ThreadPool threadPool = mock(ThreadPool.class); + when(threadPool.executor(TRAIN_THREAD_POOL)).thenReturn(executorService); + + String modelId = "test-model-id"; + Model model = mock(Model.class); + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getState()).thenReturn(ModelState.TRAINING); + when(modelMetadata.getNodeAssignment()).thenReturn("test-node-model-match"); + when(model.getModelMetadata()).thenReturn(modelMetadata); + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.isCreated()).thenReturn(true); + when(modelDao.get(modelId)).thenReturn(model); + DiscoveryNode node1 = mock(DiscoveryNode.class); + when(node1.getEphemeralId()).thenReturn("test-node-model-match"); + DiscoveryNode node2 = mock(DiscoveryNode.class); + when(node2.getEphemeralId()).thenReturn("test-node-not-model-match"); + List nodes = new ArrayList(); + nodes.add(node1); + nodes.add(node2); + doAnswer(invocationOnMock -> { + SearchResponse searchResponse = mock(SearchResponse.class); + SearchHits searchHits = mock(SearchHits.class); + when(searchResponse.getHits()).thenReturn(searchHits); + SearchHit searchHit = mock(SearchHit.class); + when(searchHit.getId()).thenReturn(modelId); + SearchHit[] searchHitArray = new SearchHit[1]; + searchHitArray[0] = searchHit; + when(searchHits.getHits()).thenReturn(searchHitArray); + ((ActionListener) invocationOnMock.getArguments()[1]).onResponse(searchResponse); + return null; + }).when(modelDao).search(any(SearchRequest.class), any(ActionListener.class)); + doAnswer(invocationOnMock -> { return null; }).when(modelDao).update(any(Model.class), any(ActionListener.class)); + + TrainingJobClusterStateListener.initialize(threadPool, modelDao, clusterService); + + trainingJobClusterStateListener.updateModelsNodesRemoved(nodes); + + executorService.shutdown(); + executorService.awaitTermination(10, TimeUnit.SECONDS); + + verify(modelMetadata, times(1)).setState(ModelState.FAILED); + verify(modelMetadata, times(1)).setError("Training failed to complete as node dropped"); + verify(modelDao, times(1)).update(any(Model.class), any(ActionListener.class)); + } +} diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobRunnerTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobRunnerTests.java index 9acdc7b36..4876b1562 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobRunnerTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobRunnerTests.java @@ -17,13 +17,12 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelState; import org.opensearch.threadpool.ThreadPool; import java.io.IOException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.RejectedExecutionException; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.*; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; @@ -37,7 +36,7 @@ public class TrainingJobRunnerTests extends KNNTestCase { @SuppressWarnings("unchecked") - public void testExecute_success() throws IOException, InterruptedException { + public void testExecute_success() throws IOException, InterruptedException, ExecutionException { // Test makes sure the correct execution logic follows on successful run ExecutorService executorService = Executors.newSingleThreadExecutor(); @@ -48,6 +47,9 @@ public void testExecute_success() throws IOException, InterruptedException { String modelId = "test-model-id"; Model model = mock(Model.class); + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getState()).thenReturn(ModelState.TRAINING); + when(model.getModelMetadata()).thenReturn(modelMetadata); TrainingJob trainingJob = mock(TrainingJob.class); when(trainingJob.getModelId()).thenReturn(modelId); when(trainingJob.getModel()).thenReturn(model); @@ -63,6 +65,7 @@ public void testExecute_success() throws IOException, InterruptedException { // After put finishes, it should call the onResponse function that will call responseListener and then kickoff // training. ModelDao modelDao = mock(ModelDao.class); + when(modelDao.get(modelId)).thenReturn(model); doAnswer(invocationOnMock -> { assertEquals(1, trainingJobRunner.getJobCount()); // Make sure job count is correct IndexResponse indexResponse = new IndexResponse(new ShardId(MODEL_INDEX_NAME, "uuid", 0), modelId, 0, 0, 0, true); @@ -88,7 +91,7 @@ public void testExecute_success() throws IOException, InterruptedException { } @SuppressWarnings("unchecked") - public void testExecute_failure_rejected() throws IOException, InterruptedException { + public void testExecute_failure_rejected() throws IOException, InterruptedException, ExecutionException { // This test makes sure we reject another request when one is ongoing. To do this, we call // trainingJobRunner.execute(trainingJob, responseListener) in the mocked modeldao.update. At this point, // the call should produce a failure because a training job is already ongoing. @@ -100,6 +103,9 @@ public void testExecute_failure_rejected() throws IOException, InterruptedExcept String modelId = "test-model-id"; Model model = mock(Model.class); + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getState()).thenReturn(ModelState.TRAINING); + when(model.getModelMetadata()).thenReturn(modelMetadata); TrainingJob trainingJob = mock(TrainingJob.class); when(trainingJob.getModelId()).thenReturn(modelId); when(trainingJob.getModel()).thenReturn(model); @@ -115,6 +121,7 @@ public void testExecute_failure_rejected() throws IOException, InterruptedExcept // After put finishes, it should call the onResponse function that will call responseListener and then kickoff // training. ModelDao modelDao = mock(ModelDao.class); + when(modelDao.get(modelId)).thenReturn(model); doAnswer(invocationOnMock -> { IndexResponse indexResponse = new IndexResponse(new ShardId(MODEL_INDEX_NAME, "uuid", 0), modelId, 0, 0, 0, true); ((ActionListener) invocationOnMock.getArguments()[1]).onResponse(indexResponse); diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java index b15bf8207..daf9645fa 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java @@ -52,7 +52,8 @@ public void testGetModelId() { mock(NativeMemoryEntryContext.TrainingDataEntryContext.class), mock(NativeMemoryEntryContext.AnonymousEntryContext.class), 10, - "" + "", + "test-node" ); assertEquals(modelId, trainingJob.getModelId()); @@ -62,8 +63,9 @@ public void testGetModel() { SpaceType spaceType = SpaceType.INNER_PRODUCT; KNNEngine knnEngine = KNNEngine.DEFAULT; int dimension = 10; - String desciption = "test description"; + String description = "test description"; String error = ""; + String nodeAssignment = "test-node"; KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); when(knnMethodContext.getKnnEngine()).thenReturn(knnEngine); @@ -77,7 +79,8 @@ public void testGetModel() { mock(NativeMemoryEntryContext.TrainingDataEntryContext.class), mock(NativeMemoryEntryContext.AnonymousEntryContext.class), dimension, - desciption + description, + nodeAssignment ); Model model = new Model( @@ -87,8 +90,9 @@ public void testGetModel() { dimension, ModelState.TRAINING, trainingJob.getModel().getModelMetadata().getTimestamp(), - desciption, - error + description, + error, + nodeAssignment ), null, modelID @@ -159,7 +163,9 @@ public void testRun_success() throws IOException, ExecutionException { trainingDataEntryContext, modelContext, dimension, - "" + "", + "test-node" + ); trainingJob.run(); @@ -235,7 +241,9 @@ public void testRun_failure_onGetTrainingDataAllocation() throws ExecutionExcept trainingDataEntryContext, modelContext, dimension, - "" + "", + + "test-node" ); trainingJob.run(); @@ -301,7 +309,9 @@ public void testRun_failure_onGetModelAnonymousAllocation() throws ExecutionExce trainingDataEntryContext, modelContext, dimension, - "" + "", + + "test-node" ); trainingJob.run(); @@ -366,7 +376,8 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep trainingDataEntryContext, mock(NativeMemoryEntryContext.AnonymousEntryContext.class), dimension, - "" + "", + "test-node" ); trainingJob.run(); @@ -438,7 +449,8 @@ public void testRun_failure_notEnoughTrainingData() throws ExecutionException { trainingDataEntryContext, modelContext, dimension, - "" + "", + "test-node" ); trainingJob.run(); diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index e88620a82..93f6885ce 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -77,6 +77,7 @@ import static org.opensearch.knn.common.KNNConstants.MODEL_BLOB_PARAMETER; import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR; +import static org.opensearch.knn.common.KNNConstants.MODEL_NODE_ASSIGNMENT; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; @@ -763,6 +764,7 @@ protected void addModelToSystemIndex(String modelId, ModelMetadata modelMetadata .field(MODEL_TIMESTAMP, modelMetadata.getTimestamp()) .field(MODEL_DESCRIPTION, modelMetadata.getDescription()) .field(MODEL_ERROR, modelMetadata.getError()) + .field(MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment()) .endObject(); request.setJsonEntity(builder.toString());