diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a7e1d6ac..ab497ef84 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Add KnnCircuitBreakerException and modify exception message [#1688](https://github.com/opensearch-project/k-NN/pull/1688) * Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684) * Support script score when doc value is disabled and fix misusing DISI [#1696](https://github.com/opensearch-project/k-NN/pull/1696) +* Add validation for pq m parameter before training starts [#1713](https://github.com/opensearch-project/k-NN/pull/1713) ### Bug Fixes * Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692) ### Infrastructure diff --git a/src/main/java/org/opensearch/knn/index/KNNMethod.java b/src/main/java/org/opensearch/knn/index/KNNMethod.java index 2d3672d87..7abd2ce39 100644 --- a/src/main/java/org/opensearch/knn/index/KNNMethod.java +++ b/src/main/java/org/opensearch/knn/index/KNNMethod.java @@ -15,6 +15,7 @@ import lombok.Getter; import org.opensearch.common.ValidationException; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.training.VectorSpaceInfo; import java.util.ArrayList; import java.util.Arrays; @@ -41,7 +42,7 @@ public class KNNMethod { * @param space to be checked * @return true if the space is supported; false otherwise */ - public boolean containsSpace(SpaceType space) { + public boolean isSpaceTypeSupported(SpaceType space) { return spaces.contains(space); } @@ -53,7 +54,7 @@ public boolean containsSpace(SpaceType space) { */ public ValidationException validate(KNNMethodContext knnMethodContext) { List errorMessages = new ArrayList<>(); - if (!containsSpace(knnMethodContext.getSpaceType())) { + if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) { errorMessages.add( String.format( "\"%s\" configuration does not support space type: " + "\"%s\".", @@ -77,6 +78,42 @@ public ValidationException validate(KNNMethodContext knnMethodContext) { return validationException; } + /** + * Validate that the configured KNNMethodContext is valid for this method, using additional data not present in the method context + * + * @param knnMethodContext to be validated + * @param vectorSpaceInfo additional data not present in the method context + * @return ValidationException produced by validation errors; null if no validations errors. + */ + public ValidationException validateWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) { + List errorMessages = new ArrayList<>(); + if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) { + errorMessages.add( + String.format( + "\"%s\" configuration does not support space type: " + "\"%s\".", + this.methodComponent.getName(), + knnMethodContext.getSpaceType().getValue() + ) + ); + } + + ValidationException methodValidation = methodComponent.validateWithData( + knnMethodContext.getMethodComponentContext(), + vectorSpaceInfo + ); + if (methodValidation != null) { + errorMessages.addAll(methodValidation.validationErrors()); + } + + if (errorMessages.isEmpty()) { + return null; + } + + ValidationException validationException = new ValidationException(); + validationException.addValidationErrors(errorMessages); + return validationException; + } + /** * returns whether training is required or not * diff --git a/src/main/java/org/opensearch/knn/index/KNNMethodContext.java b/src/main/java/org/opensearch/knn/index/KNNMethodContext.java index d4df713c2..ce48b06be 100644 --- a/src/main/java/org/opensearch/knn/index/KNNMethodContext.java +++ b/src/main/java/org/opensearch/knn/index/KNNMethodContext.java @@ -30,6 +30,7 @@ import java.util.stream.Collectors; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; +import org.opensearch.knn.training.VectorSpaceInfo; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; @@ -86,6 +87,16 @@ public ValidationException validate() { return knnEngine.validateMethod(this); } + /** + * This method uses the knnEngine to validate that the method is compatible with the engine, using additional data not present in the method context + * + * @param vectorSpaceInfo additional data not present in the method context + * @return ValidationException produced by validation errors; null if no validations errors. + */ + public ValidationException validateWithData(VectorSpaceInfo vectorSpaceInfo) { + return knnEngine.validateMethodWithData(this, vectorSpaceInfo); + } + /** * This method returns whether training is requires or not from knnEngine * diff --git a/src/main/java/org/opensearch/knn/index/MethodComponent.java b/src/main/java/org/opensearch/knn/index/MethodComponent.java index f2e2d878e..256d55ee5 100644 --- a/src/main/java/org/opensearch/knn/index/MethodComponent.java +++ b/src/main/java/org/opensearch/knn/index/MethodComponent.java @@ -17,6 +17,7 @@ import org.opensearch.common.ValidationException; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.util.IndexHyperParametersUtil; +import org.opensearch.knn.training.VectorSpaceInfo; import java.util.ArrayList; import java.util.HashMap; @@ -102,6 +103,43 @@ public ValidationException validate(MethodComponentContext methodComponentContex return validationException; } + /** + * Validate that the methodComponentContext is a valid configuration for this methodComponent, using additional data not present in the method component context + * + * @param methodComponentContext to be validated + * @param vectorSpaceInfo additional data not present in the method component context + * @return ValidationException produced by validation errors; null if no validations errors. + */ + public ValidationException validateWithData(MethodComponentContext methodComponentContext, VectorSpaceInfo vectorSpaceInfo) { + Map providedParameters = methodComponentContext.getParameters(); + List errorMessages = new ArrayList<>(); + + if (providedParameters == null) { + return null; + } + + ValidationException parameterValidation; + for (Map.Entry parameter : providedParameters.entrySet()) { + if (!parameters.containsKey(parameter.getKey())) { + errorMessages.add(String.format("Invalid parameter for method \"%s\".", getName())); + continue; + } + + parameterValidation = parameters.get(parameter.getKey()).validateWithData(parameter.getValue(), vectorSpaceInfo); + if (parameterValidation != null) { + errorMessages.addAll(parameterValidation.validationErrors()); + } + } + + if (errorMessages.isEmpty()) { + return null; + } + + ValidationException validationException = new ValidationException(); + validationException.addValidationErrors(errorMessages); + return validationException; + } + /** * gets requiresTraining value * diff --git a/src/main/java/org/opensearch/knn/index/Parameter.java b/src/main/java/org/opensearch/knn/index/Parameter.java index e223909d5..a4520636e 100644 --- a/src/main/java/org/opensearch/knn/index/Parameter.java +++ b/src/main/java/org/opensearch/knn/index/Parameter.java @@ -12,8 +12,10 @@ package org.opensearch.knn.index; import org.opensearch.common.ValidationException; +import org.opensearch.knn.training.VectorSpaceInfo; import java.util.Map; +import java.util.function.BiFunction; import java.util.function.Predicate; /** @@ -26,6 +28,7 @@ public abstract class Parameter { private String name; private T defaultValue; protected Predicate validator; + protected BiFunction validatorWithData; /** * Constructor @@ -38,6 +41,14 @@ public Parameter(String name, T defaultValue, Predicate validator) { this.name = name; this.defaultValue = defaultValue; this.validator = validator; + this.validatorWithData = null; + } + + public Parameter(String name, T defaultValue, Predicate validator, BiFunction validatorWithData) { + this.name = name; + this.defaultValue = defaultValue; + this.validator = validator; + this.validatorWithData = validatorWithData; } /** @@ -66,6 +77,15 @@ public T getDefaultValue() { */ public abstract ValidationException validate(Object value); + /** + * Check if the value passed in is valid, using additional data not present in the value + * + * @param value to be checked + * @param vectorSpaceInfo additional data not present in the value + * @return ValidationException produced by validation errors; null if no validations errors. + */ + public abstract ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo); + /** * Boolean method parameter */ @@ -74,12 +94,23 @@ public BooleanParameter(String name, Boolean defaultValue, Predicate va super(name, defaultValue, validator); } + public BooleanParameter( + String name, + Boolean defaultValue, + Predicate validator, + BiFunction validatorWithData + ) { + super(name, defaultValue, validator, validatorWithData); + } + @Override public ValidationException validate(Object value) { ValidationException validationException = null; if (!(value instanceof Boolean)) { validationException = new ValidationException(); - validationException.addValidationError(String.format("value not of type Boolean for Boolean parameter [%s].", getName())); + validationException.addValidationError( + String.format("value is not an instance of Boolean for Boolean parameter [%s].", getName()) + ); return validationException; } @@ -89,6 +120,27 @@ public ValidationException validate(Object value) { } return validationException; } + + @Override + public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { + ValidationException validationException = null; + if (!(value instanceof Boolean)) { + validationException = new ValidationException(); + validationException.addValidationError(String.format("value not of type Boolean for Boolean parameter [%s].", getName())); + return validationException; + } + + if (validatorWithData == null) { + return null; + } + + if (!validatorWithData.apply((Boolean) value, vectorSpaceInfo)) { + validationException = new ValidationException(); + validationException.addValidationError(String.format("parameter validation failed for Boolean parameter [%s].", getName())); + } + + return validationException; + } } /** @@ -99,6 +151,15 @@ public IntegerParameter(String name, Integer defaultValue, Predicate va super(name, defaultValue, validator); } + public IntegerParameter( + String name, + Integer defaultValue, + Predicate validator, + BiFunction validatorWithData + ) { + super(name, defaultValue, validator, validatorWithData); + } + @Override public ValidationException validate(Object value) { ValidationException validationException = null; @@ -118,6 +179,29 @@ public ValidationException validate(Object value) { } return validationException; } + + @Override + public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { + ValidationException validationException = null; + if (!(value instanceof Integer)) { + validationException = new ValidationException(); + validationException.addValidationError( + String.format("value is not an instance of Integer for Integer parameter [%s].", getName()) + ); + return validationException; + } + + if (validatorWithData == null) { + return null; + } + + if (!validatorWithData.apply((Integer) value, vectorSpaceInfo)) { + validationException = new ValidationException(); + validationException.addValidationError(String.format("parameter validation failed for Integer parameter [%s].", getName())); + } + + return validationException; + } } /** @@ -136,6 +220,15 @@ public StringParameter(String name, String defaultValue, Predicate valid super(name, defaultValue, validator); } + public StringParameter( + String name, + String defaultValue, + Predicate validator, + BiFunction validatorWithData + ) { + super(name, defaultValue, validator, validatorWithData); + } + /** * Check if the value passed in is valid * @@ -161,6 +254,29 @@ public ValidationException validate(Object value) { } return validationException; } + + @Override + public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { + ValidationException validationException = null; + if (!(value instanceof String)) { + validationException = new ValidationException(); + validationException.addValidationError( + String.format("value is not an instance of String for String parameter [%s].", getName()) + ); + return validationException; + } + + if (validatorWithData == null) { + return null; + } + + if (!validatorWithData.apply((String) value, vectorSpaceInfo)) { + validationException = new ValidationException(); + validationException.addValidationError(String.format("parameter validation failed for String parameter [%s].", getName())); + } + + return validationException; + } } /** @@ -190,6 +306,12 @@ public MethodComponentContextParameter( } return methodComponents.get(methodComponentContext.getName()).validate(methodComponentContext) == null; + }, (methodComponentContext, vectorSpaceInfo) -> { + if (!methodComponents.containsKey(methodComponentContext.getName())) { + return false; + } + return methodComponents.get(methodComponentContext.getName()) + .validateWithData(methodComponentContext, vectorSpaceInfo) == null; }); this.methodComponents = methodComponents; } @@ -216,6 +338,31 @@ public ValidationException validate(Object value) { return validationException; } + @Override + public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { + ValidationException validationException = null; + if (!(value instanceof MethodComponentContext)) { + validationException = new ValidationException(); + validationException.addValidationError( + String.format("value is not an instance of for MethodComponentContext parameter [%s].", getName()) + ); + return validationException; + } + + if (validatorWithData == null) { + return null; + } + + if (!validatorWithData.apply((MethodComponentContext) value, vectorSpaceInfo)) { + validationException = new ValidationException(); + validationException.addValidationError( + String.format("parameter validation failed for MethodComponentContext parameter [%s].", getName()) + ); + } + + return validationException; + } + /** * Get method component by name * diff --git a/src/main/java/org/opensearch/knn/index/util/AbstractKNNLibrary.java b/src/main/java/org/opensearch/knn/index/util/AbstractKNNLibrary.java index f97d18810..0fe311094 100644 --- a/src/main/java/org/opensearch/knn/index/util/AbstractKNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/AbstractKNNLibrary.java @@ -11,6 +11,7 @@ import org.opensearch.common.ValidationException; import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.training.VectorSpaceInfo; import java.util.Map; @@ -39,6 +40,12 @@ public ValidationException validateMethod(KNNMethodContext knnMethodContext) { return getMethod(methodName).validate(knnMethodContext); } + @Override + public ValidationException validateMethodWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) { + String methodName = knnMethodContext.getMethodComponentContext().getName(); + return getMethod(methodName).validateWithData(knnMethodContext, vectorSpaceInfo); + } + @Override public boolean isTrainingRequired(KNNMethodContext knnMethodContext) { String methodName = knnMethodContext.getMethodComponentContext().getName(); diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index efd8a637c..bbb58bf1e 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -109,9 +109,6 @@ class Faiss extends NativeLibrary { .build() ); - // TODO: To think about in future: for PQ, if dimension is not divisible by code count, PQ will fail. Right now, - // we do not have a way to base validation off of dimension. Failure will happen during training in JNI. - // Define methods supported by faiss. See issue here: https://github.com/opensearch-project/k-NN/issues/1075 private final static Map HNSW_ENCODERS = ImmutableMap.builder() .putAll( ImmutableMap.of( @@ -122,7 +119,8 @@ class Faiss extends NativeLibrary { new Parameter.IntegerParameter( ENCODER_PARAMETER_PQ_M, ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT, - v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT + v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT, + (v, vectorSpaceInfo) -> vectorSpaceInfo.getDimension() % v == 0 ) ) .addParameter( @@ -161,7 +159,8 @@ class Faiss extends NativeLibrary { new Parameter.IntegerParameter( ENCODER_PARAMETER_PQ_M, ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT, - v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT + v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT, + (v, vectorSpaceInfo) -> vectorSpaceInfo.getDimension() % v == 0 ) ) .addParameter( diff --git a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java index e282c69db..556785783 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java @@ -10,6 +10,7 @@ import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.training.VectorSpaceInfo; import java.util.List; import java.util.Map; @@ -168,6 +169,11 @@ public ValidationException validateMethod(KNNMethodContext knnMethodContext) { return knnLibrary.validateMethod(knnMethodContext); } + @Override + public ValidationException validateMethodWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) { + return knnLibrary.validateMethodWithData(knnMethodContext, vectorSpaceInfo); + } + @Override public boolean isTrainingRequired(KNNMethodContext knnMethodContext) { return knnLibrary.isTrainingRequired(knnMethodContext); diff --git a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java index f837566b8..cac5af2bb 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java @@ -15,6 +15,7 @@ import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.training.VectorSpaceInfo; import java.util.Collections; import java.util.List; @@ -97,6 +98,16 @@ public interface KNNLibrary { */ ValidationException validateMethod(KNNMethodContext knnMethodContext); + /** + * Validate the knnMethodContext for the given library, using additional data not present in the method context. A ValidationException should be thrown if the method is + * deemed invalid. + * + * @param knnMethodContext to be validated + * @param vectorSpaceInfo additional data not present in the method context + * @return ValidationException produced by validation errors; null if no validations errors. + */ + ValidationException validateMethodWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo); + /** * Returns whether training is required or not from knnMethodContext for the given library. * diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index 9035a8e84..5f3913ac5 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -22,6 +22,7 @@ import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.training.VectorSpaceInfo; import java.io.IOException; @@ -281,6 +282,12 @@ public ActionRequestValidationException validate() { exception.addValidationErrors(validationException.validationErrors()); } + validationException = this.knnMethodContext.validateWithData(new VectorSpaceInfo(dimension)); + if (validationException != null) { + exception = new ActionRequestValidationException(); + exception.addValidationErrors(validationException.validationErrors()); + } + if (!this.knnMethodContext.isTrainingRequired()) { exception = exception == null ? new ActionRequestValidationException() : exception; exception.addValidationError("Method does not require training."); diff --git a/src/main/java/org/opensearch/knn/training/VectorSpaceInfo.java b/src/main/java/org/opensearch/knn/training/VectorSpaceInfo.java new file mode 100644 index 000000000..13843486d --- /dev/null +++ b/src/main/java/org/opensearch/knn/training/VectorSpaceInfo.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.training; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; + +/** + * A data spec containing relevant information for validation. + */ +@Getter +@Setter +@AllArgsConstructor +public class VectorSpaceInfo { + private int dimension; +} diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 145fa2cff..b018740bc 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -1311,8 +1311,8 @@ public void testSharedIndexState_whenOneIndexDeleted_thenSecondIndexIsStillSearc .startObject(METHOD_ENCODER_PARAMETER) .field(NAME, ENCODER_PQ) .startObject(PARAMETERS) - .field(ENCODER_PARAMETER_PQ_M, pqCodeSize) - .field(ENCODER_PARAMETER_PQ_CODE_SIZE, pqM) + .field(ENCODER_PARAMETER_PQ_M, pqM) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, pqCodeSize) .endObject() .endObject() .endObject() @@ -1648,6 +1648,138 @@ public void testFiltering_whenUsingFaissExactSearchWithIP_thenMatchExpectedScore } } + @SneakyThrows + public void testHNSW_InvalidPQM_thenFail() { + String trainingIndexName = "training-index"; + String trainingFieldName = "training-field"; + + String modelId = "test-model"; + String modelDescription = "test model"; + + List mValues = ImmutableList.of(16, 32, 64, 128); + int invalidPQM = 3; + + // training data needs to be at least equal to the number of centroids for PQ + // which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ + int trainingDataCount = 256; + + SpaceType spaceType = SpaceType.L2; + + Integer dimension = testData.indexData.vectors[0].length; + + /* + * Builds the below json: + * { + * "name": "hnsw", + * "engine": "faiss", + * "space_type": "l2", + * "parameters": { + * "encoder": { + * "name": "pq", + * "parameters": { + * "m": 3 + * } + * } + * } + * } + */ + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(PARAMETERS) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_M, invalidPQM) + .endObject() + .endObject() + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + ResponseException re = expectThrows( + ResponseException.class, + () -> ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, in, trainingDataCount) + ); + assertTrue( + re.getMessage().contains("Validation Failed: 1: parameter validation failed for MethodComponentContext parameter [encoder].;") + ); + } + + @SneakyThrows + public void testIVF_InvalidPQM_thenFail() { + String trainingIndexName = "training-index"; + String trainingFieldName = "training-field"; + + String modelId = "test-model"; + String modelDescription = "test model"; + + List mValues = ImmutableList.of(16, 32, 64, 128); + int invalidPQM = 3; + + // training data needs to be at least equal to the number of centroids for PQ + // which is 2^8 = 256. + int trainingDataCount = 256; + + int dimension = testData.indexData.vectors[0].length; + SpaceType spaceType = SpaceType.L2; + int ivfNlist = 4; + int ivfNprobes = 4; + int pqCodeSize = 8; + + /* + * Builds the below json: + * { + * "name": "ivf", + * "engine": "faiss", + * "space_type": "l2", + * "parameters": { + * "nprobes": 8, + * "nlist": 4, + * "encoder": { + * "name": "pq", + * "parameters": { + * "m": 3, + * "code_size": 8 + * } + * } + * } + * } + */ + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NPROBES, ivfNprobes) + .field(METHOD_PARAMETER_NLIST, ivfNlist) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_M, invalidPQM) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, pqCodeSize) + .endObject() + .endObject() + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + ResponseException re = expectThrows( + ResponseException.class, + () -> ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, in, trainingDataCount) + ); + assertTrue( + re.getMessage().contains("Validation Failed: 1: parameter validation failed for MethodComponentContext parameter [encoder].;") + ); + } + protected void setupKNNIndexForFilterQuery() throws Exception { // Create Mappings XContentBuilder builder = XContentFactory.jsonBuilder() diff --git a/src/test/java/org/opensearch/knn/index/KNNMethodTests.java b/src/test/java/org/opensearch/knn/index/KNNMethodTests.java index d4dd989f7..607ca849e 100644 --- a/src/test/java/org/opensearch/knn/index/KNNMethodTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNMethodTests.java @@ -17,6 +17,7 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.training.VectorSpaceInfo; import java.io.IOException; import java.util.HashMap; @@ -44,9 +45,9 @@ public void testHasSpace() { KNNMethod knnMethod = KNNMethod.Builder.builder(MethodComponent.Builder.builder(name).build()) .addSpaces(SpaceType.L2, SpaceType.COSINESIMIL) .build(); - assertTrue(knnMethod.containsSpace(SpaceType.L2)); - assertTrue(knnMethod.containsSpace(SpaceType.COSINESIMIL)); - assertFalse(knnMethod.containsSpace(SpaceType.INNER_PRODUCT)); + assertTrue(knnMethod.isSpaceTypeSupported(SpaceType.L2)); + assertTrue(knnMethod.isSpaceTypeSupported(SpaceType.COSINESIMIL)); + assertFalse(knnMethod.isSpaceTypeSupported(SpaceType.INNER_PRODUCT)); } /** @@ -93,6 +94,52 @@ public void testValidate() throws IOException { assertNull(knnMethod.validate(knnMethodContext3)); } + /** + * Test KNNMethod validateWithData + */ + public void testValidateWithData() throws IOException { + String methodName = "test-method"; + KNNMethod knnMethod = KNNMethod.Builder.builder(MethodComponent.Builder.builder(methodName).build()) + .addSpaces(SpaceType.L2) + .build(); + + VectorSpaceInfo testVectorSpaceInfo = new VectorSpaceInfo(4); + + // Invalid space + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, methodName) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue()) + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + KNNMethodContext knnMethodContext1 = KNNMethodContext.parse(in); + assertNotNull(knnMethod.validateWithData(knnMethodContext1, testVectorSpaceInfo)); + + // Invalid methodComponent + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, methodName) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .startObject(PARAMETERS) + .field("invalid", "invalid") + .endObject() + .endObject(); + in = xContentBuilderToMap(xContentBuilder); + KNNMethodContext knnMethodContext2 = KNNMethodContext.parse(in); + + assertNotNull(knnMethod.validateWithData(knnMethodContext2, testVectorSpaceInfo)); + + // Valid everything + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, methodName) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .endObject(); + in = xContentBuilderToMap(xContentBuilder); + KNNMethodContext knnMethodContext3 = KNNMethodContext.parse(in); + assertNull(knnMethod.validateWithData(knnMethodContext3, testVectorSpaceInfo)); + } + public void testGetAsMap() { SpaceType spaceType = SpaceType.DEFAULT; String methodName = "test-method"; @@ -122,6 +169,6 @@ public void testBuilder() { builder.addSpaces(SpaceType.L2); knnMethod = builder.build(); - assertTrue(knnMethod.containsSpace(SpaceType.L2)); + assertTrue(knnMethod.isSpaceTypeSupported(SpaceType.L2)); } } diff --git a/src/test/java/org/opensearch/knn/index/ParameterTests.java b/src/test/java/org/opensearch/knn/index/ParameterTests.java index 08decd592..2f3f19727 100644 --- a/src/test/java/org/opensearch/knn/index/ParameterTests.java +++ b/src/test/java/org/opensearch/knn/index/ParameterTests.java @@ -17,6 +17,7 @@ import org.opensearch.knn.index.Parameter.IntegerParameter; import org.opensearch.knn.index.Parameter.StringParameter; import org.opensearch.knn.index.Parameter.MethodComponentContextParameter; +import org.opensearch.knn.training.VectorSpaceInfo; import java.util.Map; @@ -31,6 +32,12 @@ public void testGetDefaultValue() { public ValidationException validate(Object value) { return null; } + + @Override + public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { + return null; + } + }; assertEquals(defaultValue, parameter.getDefaultValue()); @@ -52,6 +59,29 @@ public void testIntegerParameter_validate() { assertNull(parameter.validate(12)); } + /** + * Test integer parameter validate + */ + public void testIntegerParameter_validateWithData() { + final IntegerParameter parameter = new IntegerParameter( + "test", + 1, + v -> v > 0, + (v, vectorSpaceInfo) -> v > vectorSpaceInfo.getDimension() + ); + + VectorSpaceInfo testVectorSpaceInfo = new VectorSpaceInfo(0); + + // Invalid type + assertNotNull(parameter.validateWithData("String", testVectorSpaceInfo)); + + // Invalid value + assertNotNull(parameter.validateWithData(-1, testVectorSpaceInfo)); + + // valid value + assertNull(parameter.validateWithData(12, testVectorSpaceInfo)); + } + public void testStringParameter_validate() { final StringParameter parameter = new StringParameter("test_parameter", "default_value", v -> "test".equals(v)); @@ -65,6 +95,36 @@ public void testStringParameter_validate() { assertNull(parameter.validate("test")); } + public void testStringParameter_validateWithData() { + final StringParameter parameter = new StringParameter( + "test_parameter", + "default_value", + v -> "test".equals(v), + (v, vectorSpaceInfo) -> { + if (vectorSpaceInfo.getDimension() > 0) { + return "test".equals(v); + } + return false; + } + ); + + VectorSpaceInfo testVectorSpaceInfo = new VectorSpaceInfo(1); + + // Invalid type + assertNotNull(parameter.validateWithData(5, testVectorSpaceInfo)); + + // null + assertNotNull(parameter.validateWithData(null, testVectorSpaceInfo)); + + // valid value + assertNull(parameter.validateWithData("test", testVectorSpaceInfo)); + + testVectorSpaceInfo.setDimension(0); + + // invalid value + assertNotNull(parameter.validateWithData("test", testVectorSpaceInfo)); + } + public void testMethodComponentContextParameter_validate() { String methodComponentName1 = "method-1"; String parameterKey1 = "parameter_key_1"; @@ -109,6 +169,55 @@ public void testMethodComponentContextParameter_validate() { assertNull(parameter.validate(methodComponentContext)); } + public void testMethodComponentContextParameter_validateWithData() { + String methodComponentName1 = "method-1"; + String parameterKey1 = "parameter_key_1"; + Integer parameterValue1 = 12; + + Map defaultParameterMap = ImmutableMap.of(parameterKey1, parameterValue1); + MethodComponentContext methodComponentContext = new MethodComponentContext(methodComponentName1, defaultParameterMap); + + Map methodComponentMap = ImmutableMap.of( + methodComponentName1, + MethodComponent.Builder.builder(parameterKey1) + .addParameter( + parameterKey1, + new IntegerParameter(parameterKey1, 1, v -> v > 0, (v, vectorSpaceInfo) -> v > vectorSpaceInfo.getDimension()) + ) + .build() + ); + + final MethodComponentContextParameter parameter = new MethodComponentContextParameter( + "test", + methodComponentContext, + methodComponentMap + ); + + VectorSpaceInfo testVectorSpaceInfo = new VectorSpaceInfo(0); + + // Invalid type + assertNotNull(parameter.validateWithData(17, testVectorSpaceInfo)); + assertNotNull(parameter.validateWithData("invalid-value", testVectorSpaceInfo)); + + // Invalid value + String invalidMethodComponentName = "invalid-method"; + MethodComponentContext invalidMethodComponentContext1 = new MethodComponentContext(invalidMethodComponentName, defaultParameterMap); + assertNotNull(parameter.validateWithData(invalidMethodComponentContext1, testVectorSpaceInfo)); + + String invalidParameterKey = "invalid-parameter"; + Map invalidParameterMap1 = ImmutableMap.of(invalidParameterKey, parameterValue1); + MethodComponentContext invalidMethodComponentContext2 = new MethodComponentContext(methodComponentName1, invalidParameterMap1); + assertNotNull(parameter.validateWithData(invalidMethodComponentContext2, testVectorSpaceInfo)); + + String invalidParameterValue = "invalid-value"; + Map invalidParameterMap2 = ImmutableMap.of(parameterKey1, invalidParameterValue); + MethodComponentContext invalidMethodComponentContext3 = new MethodComponentContext(methodComponentName1, invalidParameterMap2); + assertNotNull(parameter.validateWithData(invalidMethodComponentContext3, testVectorSpaceInfo)); + + // valid value + assertNull(parameter.validateWithData(methodComponentContext, testVectorSpaceInfo)); + } + public void testMethodComponentContextParameter_getMethodComponent() { String methodComponentName1 = "method-1"; String parameterKey1 = "parameter_key_1"; diff --git a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java index cff4d5805..46240e830 100644 --- a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java +++ b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java @@ -16,6 +16,7 @@ import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNLibrary; +import org.opensearch.knn.training.VectorSpaceInfo; import org.opensearch.test.OpenSearchTestCase; import java.util.Map; @@ -78,6 +79,11 @@ public ValidationException validateMethod(KNNMethodContext knnMethodContext) { return null; } + @Override + public ValidationException validateMethodWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) { + return null; + } + @Override public boolean isTrainingRequired(KNNMethodContext knnMethodContext) { return false; diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 6ba1eae65..fa6a13f2f 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -1430,7 +1430,7 @@ public void assertTrainingFails(String modelId, int attempts, int delayInMillis) assertNotEquals(ModelState.CREATED, modelState); } - fail("Training did not succeed after " + attempts + " attempts with a delay of " + delayInMillis + " ms."); + fail("Training did not fail after " + attempts + " attempts with a delay of " + delayInMillis + " ms."); } protected boolean systemIndexExists(final String indexName) throws IOException {