diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index d5f6c39d0..a80464cdb 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -11,7 +11,15 @@ package org.opensearch.knn.index; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.common.ValidationException; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; + import java.io.File; +import java.util.Map; import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES; @@ -34,4 +42,118 @@ public static int getFileSizeInKB(String filePath) { return Math.toIntExact((file.length() / BYTES_PER_KILOBYTES) + 1L); // Add one so that integer division rounds up } + + /** + * Validate that a field is a k-NN vector field and has the expected dimension + * + * @param indexMetadata metadata for index to validate + * @param field field name to validate + * @param expectedDimension expected dimension of the field. If this value is negative, dimension will not be + * checked + * @param modelDao used to look up dimension if field uses a model for initialization. Can be null if + * expectedDimension is negative + * @return ValidationException exception produced by field validation + */ + @SuppressWarnings("unchecked") + public static ValidationException validateKnnField(IndexMetadata indexMetadata, String field, int expectedDimension, + ModelDao modelDao) { + // Index metadata should not be null + if (indexMetadata == null) { + throw new IllegalArgumentException("IndexMetadata should not be null"); + } + + ValidationException exception = new ValidationException(); + + // Check the mapping + MappingMetadata mappingMetadata = indexMetadata.mapping(); + if (mappingMetadata == null) { + exception.addValidationError("Invalid index. Index does not contain a mapping"); + return exception; + } + + // The mapping output *should* look like this: + // "{properties={field={type=knn_vector, dimension=8}}}" + Map properties = (Map)mappingMetadata.getSourceAsMap().get("properties"); + + if (properties == null) { + exception.addValidationError("Properties in map does not exists. This is unexpected"); + return exception; + } + + Object fieldMapping = properties.get(field); + + // Check field existence + if (fieldMapping == null) { + exception.addValidationError(String.format("Field \"%s\" does not exist.", field)); + return exception; + } + + // Check if field is a map. If not, that is a problem + if (!(fieldMapping instanceof Map)) { + exception.addValidationError(String.format("Field info for \"%s\" is not a map.", field)); + return exception; + } + + Map fieldMap = (Map) fieldMapping; + + // Check fields type is knn_vector + Object type = fieldMap.get("type"); + + if (!(type instanceof String) || !KNNVectorFieldMapper.CONTENT_TYPE.equals(type)) { + exception.addValidationError(String.format("Field \"%s\" is not of type %s.", field, + KNNVectorFieldMapper.CONTENT_TYPE)); + return exception; + } + + // Return if dimension does not need to be checked + if (expectedDimension < 0) { + return null; + } + + // Check that the dimension of the method passed in matches that of the model + Object dimension = fieldMap.get(KNNConstants.DIMENSION); + + // If dimension is null, the training index/field could use a model. In this case, we need to get the model id + // for the index and then fetch its dimension from the models metadata + if (dimension == null) { + + String modelId = (String) fieldMap.get(KNNConstants.MODEL_ID); + + if (modelId == null) { + exception.addValidationError(String.format("Field \"%s\" does not have a dimension set.", field)); + return exception; + } + + if (modelDao == null) { + throw new IllegalArgumentException(String.format("Field \"%s\" uses model. modelDao cannot be null.", + field)); + } + + ModelMetadata modelMetadata = modelDao.getMetadata(modelId); + if (modelMetadata == null) { + exception.addValidationError(String.format("Model \"%s\" for field \"%s\" does not exist.", modelId, + field)); + return exception; + } + + dimension = modelMetadata.getDimension(); + if ((Integer) dimension != expectedDimension) { + exception.addValidationError(String.format("Field \"%s\" has dimension %d, which is different from " + + "dimension specified in the training request: %d", field, dimension, + expectedDimension)); + return exception; + } + + return null; + } + + // If the dimension was found in training fields mapping, check that it equals the models proposed dimension. + if ((Integer) dimension != expectedDimension) { + exception.addValidationError(String.format("Field \"%s\" has dimension %d, which is different from " + + "dimension specified in the training request: %d", field, dimension, expectedDimension)); + return exception; + } + + return null; + } } diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java index 763e6f782..68d285c6d 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java @@ -11,7 +11,7 @@ package org.opensearch.knn.index.memory; -import org.opensearch.indices.IndicesService; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.knn.index.IndexUtil; import java.io.IOException; @@ -129,7 +129,7 @@ public static class TrainingDataEntryContext extends NativeMemoryEntryContext properties = (Map)indexMetadata.mapping().getSourceAsMap() - .get("properties"); - - if (properties == null) { - exception = exception == null ? new ActionRequestValidationException() : exception; - exception.addValidationError("Properties in map does not exists. This is unexpected"); - return exception; - } - - Object trainingFieldMapping = properties.get(trainingField); - - // Check field existence - if (trainingFieldMapping == null) { - exception = exception == null ? new ActionRequestValidationException() : exception; - exception.addValidationError(String.format("Field \"%s\" does not exist.", this.trainingField)); - return exception; - } - - // Check if field is a map. If not, that is a problem - if (!(trainingFieldMapping instanceof Map)) { - exception = exception == null ? new ActionRequestValidationException() : exception; - exception.addValidationError(String.format("Field info for \"%s\" is not a map.", this.trainingField)); - return exception; - } - - Map trainingFieldMap = (Map) trainingFieldMapping; - - // Check fields type is knn_vector - Object type = trainingFieldMap.get("type"); - - if (!(type instanceof String) || !KNNVectorFieldMapper.CONTENT_TYPE.equals(type)) { - exception = exception == null ? new ActionRequestValidationException() : exception; - exception.addValidationError(String.format("Field \"%s\" is not of type %s.", this.trainingField, - KNNVectorFieldMapper.CONTENT_TYPE)); - return exception; - } - - // Check that the dimension of the method passed in matches that of the model - Object dimension = trainingFieldMap.get(KNNConstants.DIMENSION); - - // If dimension is null, the training index/field could use a model. In this case, we need to get the model id - // for the index and then fetch its dimension from the models metadata - if (dimension == null) { - String modelId = (String) trainingFieldMap.get(KNNConstants.MODEL_ID); - - if (modelId == null) { - exception = exception == null ? new ActionRequestValidationException() : exception; - exception.addValidationError(String.format("Field \"%s\" does not have a dimension set.", - this.trainingField)); - return exception; - } - - ModelMetadata modelMetadata = modelDao.getMetadata(modelId); - if (modelMetadata == null) { - exception = exception == null ? new ActionRequestValidationException() : exception; - exception.addValidationError(String.format("Model \"%s\" for field \"%s\" does not exist.", modelId, - this.trainingField)); - return exception; - } - - dimension = modelMetadata.getDimension(); - if ((Integer) dimension != this.dimension) { - exception = exception == null ? new ActionRequestValidationException() : exception; - exception.addValidationError(String.format("Field \"%s\" has dimension %d, which is different from " + - "dimension specified in the training request: %d", this.trainingField, dimension, - this.dimension)); - return exception; - } - - return exception; - } - - // If the dimension was found in training fields mapping, check that it equals the models proposed dimension. - if ((Integer) dimension != this.dimension) { - exception = exception == null ? new ActionRequestValidationException() : exception; - exception.addValidationError(String.format("Field \"%s\" has dimension %d, which is different from " + - "dimension specified in the training request: %d", this.trainingField, dimension, this.dimension)); - return exception; - } - - return exception; - } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java index 2645ddba3..2e6b7d701 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java @@ -14,8 +14,8 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.indices.IndicesService; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; @@ -32,14 +32,14 @@ */ public class TrainingModelTransportAction extends HandledTransportAction { - private final IndicesService indicesService; + private final ClusterService clusterService; @Inject public TrainingModelTransportAction(TransportService transportService, ActionFilters actionFilters, - IndicesService indicesService) { + ClusterService clusterService) { super(TrainingModelAction.NAME, transportService, actionFilters, TrainingModelRequest::new); - this.indicesService = indicesService; + this.clusterService = clusterService; } @Override @@ -52,7 +52,7 @@ protected void doExecute(Task task, TrainingModelRequest request, request.getTrainingIndex(), request.getTrainingField(), NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance(), - indicesService, + clusterService, request.getMaximumVectorCount(), request.getSearchSize() ); diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index 0f395312c..288823fb7 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -25,6 +25,7 @@ import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; +import org.opensearch.knn.plugin.stats.KNNCounter; import java.time.ZoneOffset; import java.time.ZonedDateTime; @@ -126,6 +127,8 @@ public void run() { nativeMemoryCacheManager.invalidate(trainingDataEntryContext.getKey()); } + KNNCounter.TRAINING_ERRORS.increment(); + return; } @@ -148,6 +151,8 @@ public void run() { nativeMemoryCacheManager.invalidate(modelAnonymousEntryContext.getKey()); } + KNNCounter.TRAINING_ERRORS.increment(); + return; } @@ -184,6 +189,9 @@ public void run() { modelMetadata.setState(ModelState.FAILED); modelMetadata.setError("Failed to execute training. May be caused by an invalid method definition or " + "not enough memory to perform training."); + + KNNCounter.TRAINING_ERRORS.increment(); + } finally { // Invalidate right away so we dont run into any big memory problems trainingDataAllocation.readUnlock(); diff --git a/src/main/java/org/opensearch/knn/training/VectorReader.java b/src/main/java/org/opensearch/knn/training/VectorReader.java index 5df648076..6392ecabe 100644 --- a/src/main/java/org/opensearch/knn/training/VectorReader.java +++ b/src/main/java/org/opensearch/knn/training/VectorReader.java @@ -19,12 +19,11 @@ import org.opensearch.action.search.SearchScrollRequestBuilder; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.ValidationException; import org.opensearch.common.unit.TimeValue; -import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.query.ExistsQueryBuilder; -import org.opensearch.indices.IndicesService; -import org.opensearch.knn.index.KNNVectorFieldMapper; +import org.opensearch.knn.index.IndexUtil; import org.opensearch.search.SearchHit; import org.opensearch.search.sort.SortOrder; @@ -51,6 +50,7 @@ public VectorReader(Client client) { /** * Read vectors from a provided index/field and pass them to vectorConsumer that will do something with them. * + * @param clusterService cluster service to get information about the index * @param indexName name of index containing vectors * @param fieldName name of field containing vectors * @param maxVectorCount maximum number of vectors to return @@ -58,7 +58,7 @@ public VectorReader(Client client) { * @param vectorConsumer consumer used to do something with the collected vectors after each search * @param listener ActionListener that should be called once all search operations complete */ - public void read(IndicesService indicesService, String indexName, String fieldName, int maxVectorCount, + public void read(ClusterService clusterService, String indexName, String fieldName, int maxVectorCount, int searchSize, Consumer> vectorConsumer, ActionListener listener) { ValidationException validationException = null; @@ -74,19 +74,17 @@ public void read(IndicesService indicesService, String indexName, String fieldNa validationException.addValidationError("searchSize must be > 0 and <= 10000"); } - IndexMetadata indexMetadata = indicesService.clusterService().state().metadata().index(indexName); + IndexMetadata indexMetadata = clusterService.state().metadata().index(indexName); if (indexMetadata == null) { validationException = validationException == null ? new ValidationException() : validationException; validationException.addValidationError("index \"" + indexName + "\" does not exist"); throw validationException; } - MappedFieldType fieldType = indicesService.indexServiceSafe(indexMetadata.getIndex()).mapperService() - .fieldType(fieldName); - - if (!(fieldType instanceof KNNVectorFieldMapper.KNNVectorFieldType)) { + ValidationException fieldValidationException = IndexUtil.validateKnnField(indexMetadata, fieldName, -1, null); + if (fieldValidationException != null) { validationException = validationException == null ? new ValidationException() : validationException; - validationException.addValidationError("field \"" + fieldName + "\" must be of type KNNVectorFieldType"); + validationException.addValidationErrors(validationException.validationErrors()); } if (validationException != null) { diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java index f94903382..b8fd05d1e 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java @@ -12,6 +12,7 @@ package org.opensearch.knn.index.memory; import com.google.common.collect.ImmutableMap; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.indices.IndicesService; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.IndexUtil; @@ -198,18 +199,18 @@ public void testTrainingDataEntryContext_getSearchSize() { } public void testTrainingDataEntryContext_getIndicesService() { - IndicesService indicesService = mock(IndicesService.class); + ClusterService clusterService = mock(ClusterService.class); NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = new NativeMemoryEntryContext.TrainingDataEntryContext( 0, "test", "test", null, - indicesService, + clusterService, 0, 0 ); - assertEquals(indicesService, trainingDataEntryContext.getIndicesService()); + assertEquals(clusterService, trainingDataEntryContext.getClusterService()); } private static class TestNativeMemoryAllocation implements NativeMemoryAllocation { diff --git a/src/test/java/org/opensearch/knn/training/VectorReaderTests.java b/src/test/java/org/opensearch/knn/training/VectorReaderTests.java index afbd818f2..2b88603d9 100644 --- a/src/test/java/org/opensearch/knn/training/VectorReaderTests.java +++ b/src/test/java/org/opensearch/knn/training/VectorReaderTests.java @@ -14,8 +14,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionListener; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.ValidationException; -import org.opensearch.indices.IndicesService; import org.opensearch.knn.KNNSingleNodeTestCase; import java.io.IOException; @@ -61,13 +61,13 @@ public void testRead_valid_completeIndex() throws InterruptedException, Executio } // Configure VectorReader - IndicesService indicesService = node().injector().getInstance(IndicesService.class); + ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Read all vectors and confirm they match vectors TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); final CountDownLatch inProgressLatch1 = new CountDownLatch(1); - vectorReader.read(indicesService, indexName, fieldName, 10000, 10, testVectorConsumer, + vectorReader.read(clusterService, indexName, fieldName, 10000, 10, testVectorConsumer, ActionListener.wrap(response -> inProgressLatch1.countDown(), e -> fail(e.toString()))); assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); @@ -115,13 +115,13 @@ public void testRead_valid_incompleteIndex() throws InterruptedException, Execut } // Configure VectorReader - IndicesService indicesService = node().injector().getInstance(IndicesService.class); + ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Read all vectors and confirm they match vectors TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); final CountDownLatch inProgressLatch1 = new CountDownLatch(1); - vectorReader.read(indicesService, indexName, fieldName, 10000, 10, testVectorConsumer, + vectorReader.read(clusterService, indexName, fieldName, 10000, 10, testVectorConsumer, ActionListener.wrap(response -> inProgressLatch1.countDown(), e -> fail(e.toString()))); assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); @@ -160,13 +160,13 @@ public void testRead_valid_OnlyGetMaxVectors() throws InterruptedException, Exec } // Configure VectorReader - IndicesService indicesService = node().injector().getInstance(IndicesService.class); + ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Read maxNumVectorsRead vectors TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); final CountDownLatch inProgressLatch1 = new CountDownLatch(1); - vectorReader.read(indicesService, indexName, fieldName, maxNumVectorsRead, 10, testVectorConsumer, + vectorReader.read(clusterService, indexName, fieldName, maxNumVectorsRead, 10, testVectorConsumer, ActionListener.wrap(response -> inProgressLatch1.countDown(), e -> fail(e.toString()))); assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); @@ -186,10 +186,10 @@ public void testRead_invalid_maxVectorCount() { createKnnIndexMapping(indexName, fieldName, dim); // Configure VectorReader - IndicesService indicesService = node().injector().getInstance(IndicesService.class); + ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); - expectThrows(ValidationException.class, () -> vectorReader.read(indicesService, indexName, fieldName, -10, 10, null, null)); + expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, -10, 10, null, null)); } public void testRead_invalid_searchSize() { @@ -203,14 +203,14 @@ public void testRead_invalid_searchSize() { createKnnIndexMapping(indexName, fieldName, dim); // Configure VectorReader - IndicesService indicesService = node().injector().getInstance(IndicesService.class); + ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Search size is negative - expectThrows(ValidationException.class, () -> vectorReader.read(indicesService, indexName, fieldName, 100, -10, null, null)); + expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 100, -10, null, null)); // Search size is greater than 10000 - expectThrows(ValidationException.class, () -> vectorReader.read(indicesService, indexName, fieldName, 100, 20000, null, null)); + expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 100, 20000, null, null)); } public void testRead_invalid_indexDoesNotExist() { @@ -219,11 +219,11 @@ public void testRead_invalid_indexDoesNotExist() { String fieldName = "test-field"; // Configure VectorReader - IndicesService indicesService = node().injector().getInstance(IndicesService.class); + ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Should throw a validation exception because index does not exist - expectThrows(ValidationException.class, () -> vectorReader.read(indicesService, indexName, fieldName, 10000, 10, null, null)); + expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 10000, 10, null, null)); } public void testRead_invalid_fieldDoesNotExist() { @@ -233,11 +233,11 @@ public void testRead_invalid_fieldDoesNotExist() { createIndex(indexName); // Configure VectorReader - IndicesService indicesService = node().injector().getInstance(IndicesService.class); + ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Should throw a validation exception because field is not k-NN - expectThrows(ValidationException.class, () -> vectorReader.read(indicesService, indexName, fieldName, 10000, 10, null, null)); + expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 10000, 10, null, null)); } public void testRead_invalid_fieldIsNotKnn() throws InterruptedException, ExecutionException, IOException { @@ -248,11 +248,11 @@ public void testRead_invalid_fieldIsNotKnn() throws InterruptedException, Execut addDoc(indexName, "test-id", fieldName, "dummy"); // Configure VectorReader - IndicesService indicesService = node().injector().getInstance(IndicesService.class); + ClusterService clusterService = node().injector().getInstance(ClusterService.class); VectorReader vectorReader = new VectorReader(client()); // Should throw a validation exception because field does not exist - expectThrows(ValidationException.class, () -> vectorReader.read(indicesService, indexName, fieldName, 10000, 10, null, null)); + expectThrows(ValidationException.class, () -> vectorReader.read(clusterService, indexName, fieldName, 10000, 10, null, null)); } private static class TestVectorConsumer implements Consumer> {