diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 3dc04ba45..fdbe0bb20 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -9,7 +9,7 @@ public class KNNConstants { // shared across library constants public static final String DIMENSION = "dimension"; public static final String KNN_ENGINE = "engine"; - public static final String KNN_METHOD= "method"; + public static final String KNN_METHOD = "method"; public static final String NAME = "name"; public static final String PARAMETERS = "parameters"; public static final String METHOD_HNSW = "hnsw"; diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index 1782d3384..f0d9ee944 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -61,8 +61,12 @@ public static int getFileSizeInKB(String filePath) { * @return ValidationException exception produced by field validation */ @SuppressWarnings("unchecked") - public static ValidationException validateKnnField(IndexMetadata indexMetadata, String field, int expectedDimension, - ModelDao modelDao) { + public static ValidationException validateKnnField( + IndexMetadata indexMetadata, + String field, + int expectedDimension, + ModelDao modelDao + ) { // Index metadata should not be null if (indexMetadata == null) { throw new IllegalArgumentException("IndexMetadata should not be null"); @@ -78,8 +82,8 @@ public static ValidationException validateKnnField(IndexMetadata indexMetadata, } // The mapping output *should* look like this: - // "{properties={field={type=knn_vector, dimension=8}}}" - Map properties = (Map)mappingMetadata.getSourceAsMap().get("properties"); + // "{properties={field={type=knn_vector, dimension=8}}}" + Map properties = (Map) mappingMetadata.getSourceAsMap().get("properties"); if (properties == null) { exception.addValidationError("Properties in map does not exists. This is unexpected"); @@ -106,8 +110,7 @@ public static ValidationException validateKnnField(IndexMetadata indexMetadata, Object type = fieldMap.get("type"); if (!(type instanceof String) || !KNNVectorFieldMapper.CONTENT_TYPE.equals(type)) { - exception.addValidationError(String.format("Field \"%s\" is not of type %s.", field, - KNNVectorFieldMapper.CONTENT_TYPE)); + exception.addValidationError(String.format("Field \"%s\" is not of type %s.", field, KNNVectorFieldMapper.CONTENT_TYPE)); return exception; } @@ -131,22 +134,25 @@ public static ValidationException validateKnnField(IndexMetadata indexMetadata, } if (modelDao == null) { - throw new IllegalArgumentException(String.format("Field \"%s\" uses model. modelDao cannot be null.", - field)); + throw new IllegalArgumentException(String.format("Field \"%s\" uses model. modelDao cannot be null.", field)); } ModelMetadata modelMetadata = modelDao.getMetadata(modelId); if (modelMetadata == null) { - exception.addValidationError(String.format("Model \"%s\" for field \"%s\" does not exist.", modelId, - field)); + exception.addValidationError(String.format("Model \"%s\" for field \"%s\" does not exist.", modelId, field)); return exception; } dimension = modelMetadata.getDimension(); if ((Integer) dimension != expectedDimension) { - exception.addValidationError(String.format("Field \"%s\" has dimension %d, which is different from " + - "dimension specified in the training request: %d", field, dimension, - expectedDimension)); + exception.addValidationError( + String.format( + "Field \"%s\" has dimension %d, which is different from " + "dimension specified in the training request: %d", + field, + dimension, + expectedDimension + ) + ); return exception; } @@ -155,8 +161,14 @@ public static ValidationException validateKnnField(IndexMetadata indexMetadata, // If the dimension was found in training fields mapping, check that it equals the models proposed dimension. if ((Integer) dimension != expectedDimension) { - exception.addValidationError(String.format("Field \"%s\" has dimension %d, which is different from " + - "dimension specified in the training request: %d", field, dimension, expectedDimension)); + exception.addValidationError( + String.format( + "Field \"%s\" has dimension %d, which is different from " + "dimension specified in the training request: %d", + field, + dimension, + expectedDimension + ) + ); return exception; } @@ -172,9 +184,7 @@ public static ValidationException validateKnnField(IndexMetadata indexMetadata, * @return load parameters that will be passed to the JNI. */ public static Map getParametersAtLoading(SpaceType spaceType, KNNEngine knnEngine, String indexName) { - Map loadParameters = Maps.newHashMap(ImmutableMap.of( - SPACE_TYPE, spaceType.getValue() - )); + Map loadParameters = Maps.newHashMap(ImmutableMap.of(SPACE_TYPE, spaceType.getValue())); // For nmslib, we need to add the dynamic ef_search parameter that needs to be passed in when the // hnsw graphs are loaded into memory diff --git a/src/main/java/org/opensearch/knn/index/KNNCircuitBreaker.java b/src/main/java/org/opensearch/knn/index/KNNCircuitBreaker.java index 3ab19139a..5375beddd 100644 --- a/src/main/java/org/opensearch/knn/index/KNNCircuitBreaker.java +++ b/src/main/java/org/opensearch/knn/index/KNNCircuitBreaker.java @@ -27,15 +27,14 @@ */ public class KNNCircuitBreaker { private static Logger logger = LogManager.getLogger(KNNCircuitBreaker.class); - public static int CB_TIME_INTERVAL = 2*60; // seconds + public static int CB_TIME_INTERVAL = 2 * 60; // seconds private static KNNCircuitBreaker INSTANCE; private ThreadPool threadPool; private ClusterService clusterService; private Client client; - private KNNCircuitBreaker() { - } + private KNNCircuitBreaker() {} public static synchronized KNNCircuitBreaker getInstance() { if (INSTANCE == null) { @@ -60,10 +59,10 @@ public void initialize(ThreadPool threadPool, ClusterService clusterService, Cli NativeMemoryCacheManager nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); Runnable runnable = () -> { if (nativeMemoryCacheManager.isCacheCapacityReached() && clusterService.localNode().isDataNode()) { - long currentSizeKiloBytes = nativeMemoryCacheManager.getCacheSizeInKilobytes(); + long currentSizeKiloBytes = nativeMemoryCacheManager.getCacheSizeInKilobytes(); long circuitBreakerLimitSizeKiloBytes = KNNSettings.getCircuitBreakerLimit().getKb(); - long circuitBreakerUnsetSizeKiloBytes = (long) ((KNNSettings.getCircuitBreakerUnsetPercentage()/100) - * circuitBreakerLimitSizeKiloBytes); + long circuitBreakerUnsetSizeKiloBytes = (long) ((KNNSettings.getCircuitBreakerUnsetPercentage() / 100) + * circuitBreakerLimitSizeKiloBytes); /** * Unset capacityReached flag if currentSizeBytes is less than circuitBreakerUnsetSizeBytes */ @@ -76,7 +75,7 @@ public void initialize(ThreadPool threadPool, ClusterService clusterService, Cli if (KNNSettings.isCircuitBreakerTriggered() && clusterService.state().nodes().isLocalNodeElectedMaster()) { KNNStatsRequest knnStatsRequest = new KNNStatsRequest(KNNStatsConfig.KNN_STATS.keySet()); knnStatsRequest.addStat(StatNames.CACHE_CAPACITY_REACHED.getName()); - knnStatsRequest.timeout(new TimeValue(1000*10)); // 10 second timeout + knnStatsRequest.timeout(new TimeValue(1000 * 10)); // 10 second timeout try { KNNStatsResponse knnStatsResponse = client.execute(KNNStatsAction.INSTANCE, knnStatsRequest).get(); @@ -90,11 +89,16 @@ public void initialize(ThreadPool threadPool, ClusterService clusterService, Cli } if (!nodesAtMaxCapacity.isEmpty()) { - logger.info("[KNN] knn.circuit_breaker.triggered stays set. Nodes at max cache capacity: " - + String.join(",", nodesAtMaxCapacity) + "."); + logger.info( + "[KNN] knn.circuit_breaker.triggered stays set. Nodes at max cache capacity: " + + String.join(",", nodesAtMaxCapacity) + + "." + ); } else { - logger.info("[KNN] Cache capacity below 75% of the circuit breaker limit for all nodes." + - " Unsetting knn.circuit_breaker.triggered flag."); + logger.info( + "[KNN] Cache capacity below 75% of the circuit breaker limit for all nodes." + + " Unsetting knn.circuit_breaker.triggered flag." + ); KNNSettings.state().updateCircuitBreakerSettings(false); } } catch (Exception e) { diff --git a/src/main/java/org/opensearch/knn/index/KNNMethod.java b/src/main/java/org/opensearch/knn/index/KNNMethod.java index 79ae9c19c..da2d9c455 100644 --- a/src/main/java/org/opensearch/knn/index/KNNMethod.java +++ b/src/main/java/org/opensearch/knn/index/KNNMethod.java @@ -79,8 +79,13 @@ public Set getSpaces() { public ValidationException validate(KNNMethodContext knnMethodContext) { List errorMessages = new ArrayList<>(); if (!containsSpace(knnMethodContext.getSpaceType())) { - errorMessages.add(String.format("\"%s\" configuration does not support space type: " + - "\"%s\".", this.methodComponent.getName(), knnMethodContext.getSpaceType().getValue())); + errorMessages.add( + String.format( + "\"%s\" configuration does not support space type: " + "\"%s\".", + this.methodComponent.getName(), + knnMethodContext.getSpaceType().getValue() + ) + ); } ValidationException methodValidation = methodComponent.validate(knnMethodContext.getMethodComponent()); @@ -88,7 +93,7 @@ public ValidationException validate(KNNMethodContext knnMethodContext) { errorMessages.addAll(methodValidation.validationErrors()); } - if(errorMessages.isEmpty()) { + if (errorMessages.isEmpty()) { return null; } @@ -130,7 +135,6 @@ public Map getAsMap(KNNMethodContext knnMethodContext) { return parameterMap; } - /** * Builder for KNNMethod */ @@ -160,7 +164,7 @@ private Builder(MethodComponent methodComponent) { * @param spaceTypes to be added * @return Builder */ - public Builder addSpaces(SpaceType ...spaceTypes) { + public Builder addSpaces(SpaceType... spaceTypes) { spaces.addAll(Arrays.asList(spaceTypes)); return this; } diff --git a/src/main/java/org/opensearch/knn/index/KNNQuery.java b/src/main/java/org/opensearch/knn/index/KNNQuery.java index f3bf23aee..7306fc3a1 100644 --- a/src/main/java/org/opensearch/knn/index/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/KNNQuery.java @@ -41,7 +41,9 @@ public int getK() { return this.k; } - public String getIndexName() { return this.indexName; } + public String getIndexName() { + return this.indexName; + } /** * Constructs Weight implementation for this query @@ -71,8 +73,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - return sameClassAs(other) && - equalsTo(getClass().cast(other)); + return sameClassAs(other) && equalsTo(getClass().cast(other)); } private boolean equalsTo(KNNQuery other) { diff --git a/src/main/java/org/opensearch/knn/index/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/KNNQueryBuilder.java index 750f87e3a..d573b9194 100644 --- a/src/main/java/org/opensearch/knn/index/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/KNNQueryBuilder.java @@ -132,12 +132,16 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { queryName = parser.text(); } else { - throw new ParsingException(parser.getTokenLocation(), - "[" + NAME + "] query does not support [" + currentFieldName + "]"); + throw new ParsingException( + parser.getTokenLocation(), + "[" + NAME + "] query does not support [" + currentFieldName + "]" + ); } } else { - throw new ParsingException(parser.getTokenLocation(), - "[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]"); + throw new ParsingException( + parser.getTokenLocation(), + "[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]" + ); } } } else { @@ -218,8 +222,9 @@ protected Query doToQuery(QueryShardContext context) throws IOException { } if (dimension != vector.length) { - throw new IllegalArgumentException("Query vector has invalid dimension: " + vector.length + - ". Dimension should be: " + dimension); + throw new IllegalArgumentException( + "Query vector has invalid dimension: " + vector.length + ". Dimension should be: " + dimension + ); } return new KNNQuery(this.fieldName, vector, k, context.index().getName()); @@ -227,9 +232,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException { @Override protected boolean doEquals(KNNQueryBuilder other) { - return Objects.equals(fieldName, other.fieldName) && - Objects.equals(vector, other.vector) && - Objects.equals(k, other.k); + return Objects.equals(fieldName, other.fieldName) && Objects.equals(vector, other.vector) && Objects.equals(k, other.k); } @Override diff --git a/src/main/java/org/opensearch/knn/index/KNNScorer.java b/src/main/java/org/opensearch/knn/index/KNNScorer.java index e40b5596b..edef5fdd4 100644 --- a/src/main/java/org/opensearch/knn/index/KNNScorer.java +++ b/src/main/java/org/opensearch/knn/index/KNNScorer.java @@ -48,8 +48,7 @@ public float getMaxScore(int upTo) throws IOException { public float score() { assert docID() != DocIdSetIterator.NO_MORE_DOCS; Float score = scores.get(docID()); - if (score == null) - throw new RuntimeException("Null score for the docID: " + docID()); + if (score == null) throw new RuntimeException("Null score for the docID: " + docID()); return score; } @@ -58,4 +57,3 @@ public int docID() { return docIdsIter.docID(); } } - diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 2f3ec91e3..e0e70310f 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -90,11 +90,13 @@ public class KNNSettings { * Settings Definition */ - public static final Setting INDEX_KNN_SPACE_TYPE = Setting.simpleString(KNN_SPACE_TYPE, - INDEX_KNN_DEFAULT_SPACE_TYPE, - new SpaceTypeValidator(), - IndexScope, - Setting.Property.Deprecated); + public static final Setting INDEX_KNN_SPACE_TYPE = Setting.simpleString( + KNN_SPACE_TYPE, + INDEX_KNN_DEFAULT_SPACE_TYPE, + new SpaceTypeValidator(), + IndexScope, + Setting.Property.Deprecated + ); /** * M - the number of bi-directional links created for every new element during construction. @@ -102,76 +104,87 @@ public class KNNSettings { * dimensionality and/or high recall, while low M work better for datasets with low intrinsic dimensionality and/or low recalls. * The parameter also determines the algorithm's memory consumption, which is roughly M * 8-10 bytes per stored element. */ - public static final Setting INDEX_KNN_ALGO_PARAM_M_SETTING = Setting.intSetting(KNN_ALGO_PARAM_M, - INDEX_KNN_DEFAULT_ALGO_PARAM_M, - 2, - IndexScope, - Setting.Property.Deprecated); + public static final Setting INDEX_KNN_ALGO_PARAM_M_SETTING = Setting.intSetting( + KNN_ALGO_PARAM_M, + INDEX_KNN_DEFAULT_ALGO_PARAM_M, + 2, + IndexScope, + Setting.Property.Deprecated + ); /** * ef or efSearch - the size of the dynamic list for the nearest neighbors (used during the search). * Higher ef leads to more accurate but slower search. ef cannot be set lower than the number of queried nearest neighbors k. * The value ef can be anything between k and the size of the dataset. */ - public static final Setting INDEX_KNN_ALGO_PARAM_EF_SEARCH_SETTING = Setting.intSetting(KNN_ALGO_PARAM_EF_SEARCH, - INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH, - 2, - IndexScope, - Dynamic); + public static final Setting INDEX_KNN_ALGO_PARAM_EF_SEARCH_SETTING = Setting.intSetting( + KNN_ALGO_PARAM_EF_SEARCH, + INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH, + 2, + IndexScope, + Dynamic + ); /** * ef_constrution - the parameter has the same meaning as ef, but controls the index_time/index_accuracy. * Bigger ef_construction leads to longer construction(more indexing time), but better index quality. */ - public static final Setting INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING = Setting.intSetting(KNN_ALGO_PARAM_EF_CONSTRUCTION, - INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION, - 2, - IndexScope, - Setting.Property.Deprecated); + public static final Setting INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING = Setting.intSetting( + KNN_ALGO_PARAM_EF_CONSTRUCTION, + INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION, + 2, + IndexScope, + Setting.Property.Deprecated + ); public static final Setting MODEL_INDEX_NUMBER_OF_SHARDS_SETTING = Setting.intSetting( - MODEL_INDEX_NUMBER_OF_SHARDS, - 1, - 1, - Setting.Property.NodeScope, - Setting.Property.Dynamic); + MODEL_INDEX_NUMBER_OF_SHARDS, + 1, + 1, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); public static final Setting MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING = Setting.intSetting( - MODEL_INDEX_NUMBER_OF_REPLICAS, - 1, - 0, - Setting.Property.NodeScope, - Setting.Property.Dynamic); + MODEL_INDEX_NUMBER_OF_REPLICAS, + 1, + 0, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); public static final Setting MODEL_CACHE_SIZE_LIMIT_SETTING = new Setting<>( - MODEL_CACHE_SIZE_LIMIT, - percentageAsString(KNN_DEFAULT_MODEL_CACHE_SIZE_LIMIT_PERCENTAGE), - (s) -> { - ByteSizeValue userDefinedLimit = parseBytesSizeValueOrHeapRatio(s, MODEL_CACHE_SIZE_LIMIT); - - // parseBytesSizeValueOrHeapRatio will make sure that the value entered falls between 0 and 100% of the - // JVM heap. However, we want the maximum percentage of the heap to be much smaller. So, we add - // some additional validation here before returning - ByteSizeValue jvmHeapSize = JvmInfo.jvmInfo().getMem().getHeapMax(); - if ((userDefinedLimit.getKbFrac() / jvmHeapSize.getKbFrac()) > percentageAsFraction(KNN_MAX_MODEL_CACHE_SIZE_LIMIT_PERCENTAGE)) { - throw new OpenSearchParseException("{} ({} KB) cannot exceed {}% of the heap ({} KB).", - MODEL_CACHE_SIZE_LIMIT, - userDefinedLimit.getKb(), - KNN_MAX_MODEL_CACHE_SIZE_LIMIT_PERCENTAGE, - jvmHeapSize.getKb()); - } + MODEL_CACHE_SIZE_LIMIT, + percentageAsString(KNN_DEFAULT_MODEL_CACHE_SIZE_LIMIT_PERCENTAGE), + (s) -> { + ByteSizeValue userDefinedLimit = parseBytesSizeValueOrHeapRatio(s, MODEL_CACHE_SIZE_LIMIT); + + // parseBytesSizeValueOrHeapRatio will make sure that the value entered falls between 0 and 100% of the + // JVM heap. However, we want the maximum percentage of the heap to be much smaller. So, we add + // some additional validation here before returning + ByteSizeValue jvmHeapSize = JvmInfo.jvmInfo().getMem().getHeapMax(); + if ((userDefinedLimit.getKbFrac() / jvmHeapSize.getKbFrac()) > percentageAsFraction( + KNN_MAX_MODEL_CACHE_SIZE_LIMIT_PERCENTAGE + )) { + throw new OpenSearchParseException( + "{} ({} KB) cannot exceed {}% of the heap ({} KB).", + MODEL_CACHE_SIZE_LIMIT, + userDefinedLimit.getKb(), + KNN_MAX_MODEL_CACHE_SIZE_LIMIT_PERCENTAGE, + jvmHeapSize.getKb() + ); + } - return userDefinedLimit; - }, - Setting.Property.NodeScope, - Setting.Property.Dynamic + return userDefinedLimit; + }, + Setting.Property.NodeScope, + Setting.Property.Dynamic ); /** * This setting identifies KNN index. */ - public static final Setting IS_KNN_INDEX_SETTING = Setting.boolSetting(KNN_INDEX, false, IndexScope); - + public static final Setting IS_KNN_INDEX_SETTING = Setting.boolSetting(KNN_INDEX, false, IndexScope); /** * index_thread_quantity - the parameter specifies how many threads the nms library should use to create the graph. @@ -180,29 +193,34 @@ public class KNNSettings { * this could lead to NUM_CORES^2 threads running and could lead to 100% CPU utilization. This setting allows users to * configure number of threads for graph construction. */ - public static final Setting KNN_ALGO_PARAM_INDEX_THREAD_QTY_SETTING = Setting.intSetting(KNN_ALGO_PARAM_INDEX_THREAD_QTY, - KNN_DEFAULT_ALGO_PARAM_INDEX_THREAD_QTY, - 1, - INDEX_THREAD_QTY_MAX, - NodeScope, - Dynamic); - - public static final Setting KNN_CIRCUIT_BREAKER_TRIGGERED_SETTING = Setting.boolSetting(KNN_CIRCUIT_BREAKER_TRIGGERED, - false, - NodeScope, - Dynamic); - - public static final Setting KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE_SETTING = Setting.doubleSetting( - KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE, - KNN_DEFAULT_CIRCUIT_BREAKER_UNSET_PERCENTAGE, - 0, - 100, - NodeScope, - Dynamic); + public static final Setting KNN_ALGO_PARAM_INDEX_THREAD_QTY_SETTING = Setting.intSetting( + KNN_ALGO_PARAM_INDEX_THREAD_QTY, + KNN_DEFAULT_ALGO_PARAM_INDEX_THREAD_QTY, + 1, + INDEX_THREAD_QTY_MAX, + NodeScope, + Dynamic + ); + + public static final Setting KNN_CIRCUIT_BREAKER_TRIGGERED_SETTING = Setting.boolSetting( + KNN_CIRCUIT_BREAKER_TRIGGERED, + false, + NodeScope, + Dynamic + ); + + public static final Setting KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE_SETTING = Setting.doubleSetting( + KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE, + KNN_DEFAULT_CIRCUIT_BREAKER_UNSET_PERCENTAGE, + 0, + 100, + NodeScope, + Dynamic + ); /** * Dynamic settings */ - public static Map> dynamicCacheSettings = new HashMap>() { + public static Map> dynamicCacheSettings = new HashMap>() { { /** * KNN plugin enable/disable setting @@ -212,17 +230,20 @@ public class KNNSettings { /** * Weight circuit breaker settings */ - put(KNN_MEMORY_CIRCUIT_BREAKER_ENABLED, Setting.boolSetting(KNN_MEMORY_CIRCUIT_BREAKER_ENABLED,true, - NodeScope, Dynamic)); - put(KNN_MEMORY_CIRCUIT_BREAKER_LIMIT, knnMemoryCircuitBreakerSetting(KNN_MEMORY_CIRCUIT_BREAKER_LIMIT, "50%", - NodeScope, Dynamic)); + put(KNN_MEMORY_CIRCUIT_BREAKER_ENABLED, Setting.boolSetting(KNN_MEMORY_CIRCUIT_BREAKER_ENABLED, true, NodeScope, Dynamic)); + put( + KNN_MEMORY_CIRCUIT_BREAKER_LIMIT, + knnMemoryCircuitBreakerSetting(KNN_MEMORY_CIRCUIT_BREAKER_LIMIT, "50%", NodeScope, Dynamic) + ); /** * Cache expiry time settings */ put(KNN_CACHE_ITEM_EXPIRY_ENABLED, Setting.boolSetting(KNN_CACHE_ITEM_EXPIRY_ENABLED, false, NodeScope, Dynamic)); - put(KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES, Setting.positiveTimeSetting(KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES, - TimeValue.timeValueHours(3), NodeScope, Dynamic)); + put( + KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES, + Setting.positiveTimeSetting(KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES, TimeValue.timeValueHours(3), NodeScope, Dynamic) + ); } }; @@ -243,38 +264,33 @@ public static synchronized KNNSettings state() { public void setSettingsUpdateConsumers() { for (Setting setting : dynamicCacheSettings.values()) { - clusterService.getClusterSettings().addSettingsUpdateConsumer( - setting, - newVal -> { - logger.debug("The value of setting [{}] changed to [{}]", setting.getKey(), newVal); - latestSettings.put(setting.getKey(), newVal); - - // Rebuild the cache with updated limit - NativeMemoryCacheManager.getInstance().rebuildCache(); - }); + clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, newVal -> { + logger.debug("The value of setting [{}] changed to [{}]", setting.getKey(), newVal); + latestSettings.put(setting.getKey(), newVal); + + // Rebuild the cache with updated limit + NativeMemoryCacheManager.getInstance().rebuildCache(); + }); } /** * We do not have to rebuild the cache for below settings */ - clusterService.getClusterSettings().addSettingsUpdateConsumer( + clusterService.getClusterSettings() + .addSettingsUpdateConsumer( KNN_CIRCUIT_BREAKER_TRIGGERED_SETTING, - newVal -> { - latestSettings.put(KNN_CIRCUIT_BREAKER_TRIGGERED, newVal); - } - ); - clusterService.getClusterSettings().addSettingsUpdateConsumer( + newVal -> { latestSettings.put(KNN_CIRCUIT_BREAKER_TRIGGERED, newVal); } + ); + clusterService.getClusterSettings() + .addSettingsUpdateConsumer( KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE_SETTING, - newVal -> { - latestSettings.put(KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE, newVal); - } - ); - clusterService.getClusterSettings().addSettingsUpdateConsumer( + newVal -> { latestSettings.put(KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE, newVal); } + ); + clusterService.getClusterSettings() + .addSettingsUpdateConsumer( KNN_ALGO_PARAM_INDEX_THREAD_QTY_SETTING, - newVal -> { - latestSettings.put(KNN_ALGO_PARAM_INDEX_THREAD_QTY, newVal); - } - ); + newVal -> { latestSettings.put(KNN_ALGO_PARAM_INDEX_THREAD_QTY, newVal); } + ); } /** @@ -310,19 +326,20 @@ public Setting getSetting(String key) { } public List> getSettings() { - List> settings = Arrays.asList(INDEX_KNN_SPACE_TYPE, - INDEX_KNN_ALGO_PARAM_M_SETTING, - INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING, - INDEX_KNN_ALGO_PARAM_EF_SEARCH_SETTING, - KNN_ALGO_PARAM_INDEX_THREAD_QTY_SETTING, - KNN_CIRCUIT_BREAKER_TRIGGERED_SETTING, - KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE_SETTING, - IS_KNN_INDEX_SETTING, - MODEL_INDEX_NUMBER_OF_SHARDS_SETTING, - MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING, - MODEL_CACHE_SIZE_LIMIT_SETTING); - return Stream.concat(settings.stream(), dynamicCacheSettings.values().stream()) - .collect(Collectors.toList()); + List> settings = Arrays.asList( + INDEX_KNN_SPACE_TYPE, + INDEX_KNN_ALGO_PARAM_M_SETTING, + INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING, + INDEX_KNN_ALGO_PARAM_EF_SEARCH_SETTING, + KNN_ALGO_PARAM_INDEX_THREAD_QTY_SETTING, + KNN_CIRCUIT_BREAKER_TRIGGERED_SETTING, + KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE_SETTING, + IS_KNN_INDEX_SETTING, + MODEL_INDEX_NUMBER_OF_SHARDS_SETTING, + MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING, + MODEL_CACHE_SIZE_LIMIT_SETTING + ); + return Stream.concat(settings.stream(), dynamicCacheSettings.values().stream()).collect(Collectors.toList()); } public static boolean isKNNPluginEnabled() { @@ -390,22 +407,25 @@ public static ByteSizeValue parseknnMemoryCircuitBreakerValue(String sValue, Str */ public synchronized void updateCircuitBreakerSettings(boolean flag) { ClusterUpdateSettingsRequest clusterUpdateSettingsRequest = new ClusterUpdateSettingsRequest(); - Settings circuitBreakerSettings = Settings.builder() - .put(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED, flag) - .build(); + Settings circuitBreakerSettings = Settings.builder().put(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED, flag).build(); clusterUpdateSettingsRequest.persistentSettings(circuitBreakerSettings); - client.admin().cluster().updateSettings(clusterUpdateSettingsRequest, - new ActionListener() { + client.admin().cluster().updateSettings(clusterUpdateSettingsRequest, new ActionListener() { @Override public void onResponse(ClusterUpdateSettingsResponse clusterUpdateSettingsResponse) { - logger.debug("Cluster setting {}, acknowledged: {} ", - clusterUpdateSettingsRequest.persistentSettings(), - clusterUpdateSettingsResponse.isAcknowledged()); + logger.debug( + "Cluster setting {}, acknowledged: {} ", + clusterUpdateSettingsRequest.persistentSettings(), + clusterUpdateSettingsResponse.isAcknowledged() + ); } + @Override public void onFailure(Exception e) { - logger.info("Exception while updating circuit breaker setting {} to {}", - clusterUpdateSettingsRequest.persistentSettings(), e.getMessage()); + logger.info( + "Exception while updating circuit breaker setting {} to {}", + clusterUpdateSettingsRequest.persistentSettings(), + e.getMessage() + ); } }); } @@ -425,14 +445,15 @@ public static int getEfSearchParam(String index) { * @return spaceType name in KNN plugin */ public static String getSpaceType(String index) { - return KNNSettings.state().clusterService.state().getMetadata() - .index(index).getSettings().get(KNN_SPACE_TYPE, SpaceType.DEFAULT.getValue()); + return KNNSettings.state().clusterService.state() + .getMetadata() + .index(index) + .getSettings() + .get(KNN_SPACE_TYPE, SpaceType.DEFAULT.getValue()); } public static int getIndexSettingValue(String index, String settingName, int defaultValue) { - return KNNSettings.state().clusterService.state().getMetadata() - .index(index).getSettings() - .getAsInt(settingName, defaultValue); + return KNNSettings.state().clusterService.state().getMetadata().index(index).getSettings().getAsInt(settingName, defaultValue); } public void setClusterService(ClusterService clusterService) { @@ -441,7 +462,8 @@ public void setClusterService(ClusterService clusterService) { static class SpaceTypeValidator implements Setting.Validator { - @Override public void validate(String value) { + @Override + public void validate(String value) { try { SpaceType.getSpace(value); } catch (IllegalArgumentException ex) { @@ -451,14 +473,12 @@ static class SpaceTypeValidator implements Setting.Validator { } public void onIndexModule(IndexModule module) { - module.addSettingsUpdateConsumer( - INDEX_KNN_ALGO_PARAM_EF_SEARCH_SETTING, - newVal -> { - logger.debug("The value of [KNN] setting [{}] changed to [{}]", KNN_ALGO_PARAM_EF_SEARCH, newVal); - latestSettings.put(KNN_ALGO_PARAM_EF_SEARCH, newVal); - // TODO: replace cache-rebuild with index reload into the cache - NativeMemoryCacheManager.getInstance().rebuildCache(); - }); + module.addSettingsUpdateConsumer(INDEX_KNN_ALGO_PARAM_EF_SEARCH_SETTING, newVal -> { + logger.debug("The value of [KNN] setting [{}] changed to [{}]", KNN_ALGO_PARAM_EF_SEARCH, newVal); + latestSettings.put(KNN_ALGO_PARAM_EF_SEARCH, newVal); + // TODO: replace cache-rebuild with index reload into the cache + NativeMemoryCacheManager.getInstance().rebuildCache(); + }); } private static String percentageAsString(Integer percentage) { diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java index 7b0cc229e..5f522e3de 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java @@ -40,12 +40,12 @@ public ScriptDocValues getScriptValues() { BinaryDocValues values = DocValues.getBinary(reader, fieldName); return new KNNVectorScriptDocValues(values, fieldName); } catch (IOException e) { - throw new IllegalStateException("Cannot load doc values for knn vector field: "+fieldName, e); + throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e); } } @Override public SortedBinaryDocValues getBytesValues() { - throw new UnsupportedOperationException("knn vector field '"+ fieldName + "' doesn't support sorting"); + throw new UnsupportedOperationException("knn vector field '" + fieldName + "' doesn't support sorting"); } } diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/KNNVectorFieldMapper.java index f6af53e1a..d3af2101c 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorFieldMapper.java @@ -91,69 +91,63 @@ private static KNNVectorFieldMapper toType(FieldMapper in) { public static class Builder extends ParametrizedFieldMapper.Builder { protected Boolean ignoreMalformed; - protected final Parameter stored = Parameter.boolParam("store", false, - m -> toType(m).stored, false); - protected final Parameter hasDocValues = Parameter.boolParam("doc_values", false, - m -> toType(m).hasDocValues, true); - protected final Parameter dimension = new Parameter<>(KNNConstants.DIMENSION, false, - () -> -1, - (n, c, o) -> { - if (o == null) { - throw new IllegalArgumentException("Dimension cannot be null"); - } - int value = XContentMapValues.nodeIntegerValue(o); - if (value > MAX_DIMENSION) { - throw new IllegalArgumentException("Dimension value cannot be greater than " + - MAX_DIMENSION + " for vector: " + name); - } - - if (value <= 0) { - throw new IllegalArgumentException("Dimension value must be greater than 0 " + - "for vector: " + name); - } - return value; - }, m -> toType(m).dimension); + protected final Parameter stored = Parameter.boolParam("store", false, m -> toType(m).stored, false); + protected final Parameter hasDocValues = Parameter.boolParam("doc_values", false, m -> toType(m).hasDocValues, true); + protected final Parameter dimension = new Parameter<>(KNNConstants.DIMENSION, false, () -> -1, (n, c, o) -> { + if (o == null) { + throw new IllegalArgumentException("Dimension cannot be null"); + } + int value = XContentMapValues.nodeIntegerValue(o); + if (value > MAX_DIMENSION) { + throw new IllegalArgumentException("Dimension value cannot be greater than " + MAX_DIMENSION + " for vector: " + name); + } + + if (value <= 0) { + throw new IllegalArgumentException("Dimension value must be greater than 0 " + "for vector: " + name); + } + return value; + }, m -> toType(m).dimension); /** * modelId provides a way for a user to generate the underlying library indices from an already serialized * model template index. If this parameter is set, it will take precedence. This parameter is only relevant for * library indices that require training. */ - protected final Parameter modelId = Parameter.stringParam(KNNConstants.MODEL_ID, false, - m -> toType(m).modelId, null); + protected final Parameter modelId = Parameter.stringParam(KNNConstants.MODEL_ID, false, m -> toType(m).modelId, null); /** * knnMethodContext parameter allows a user to define their k-NN library index configuration. Defaults to an L2 * hnsw default engine index without any parameters set */ - protected final Parameter knnMethodContext = new Parameter<>(KNN_METHOD, false, - () -> null, - (n, c, o) -> KNNMethodContext.parse(o), m -> toType(m).knnMethod) - .setSerializer(((b, n, v) ->{ - b.startObject(n); - v.toXContent(b, ToXContent.EMPTY_PARAMS); - b.endObject(); - }), m -> m.getMethodComponent().getName()) - .setValidator(v -> { - if (v == null) - return; - - ValidationException validationException = null; - if (v.isTrainingRequired()){ - validationException = new ValidationException(); - validationException.addValidationError(String.format("\"%s\" requires training.", KNN_METHOD)); - } - - ValidationException methodValidation = v.validate(); - if (methodValidation != null) { - validationException = validationException == null ? new ValidationException() : validationException; - validationException.addValidationErrors(methodValidation.validationErrors()); - } - - if (validationException != null) { - throw validationException; - } - }); + protected final Parameter knnMethodContext = new Parameter<>( + KNN_METHOD, + false, + () -> null, + (n, c, o) -> KNNMethodContext.parse(o), + m -> toType(m).knnMethod + ).setSerializer(((b, n, v) -> { + b.startObject(n); + v.toXContent(b, ToXContent.EMPTY_PARAMS); + b.endObject(); + }), m -> m.getMethodComponent().getName()).setValidator(v -> { + if (v == null) return; + + ValidationException validationException = null; + if (v.isTrainingRequired()) { + validationException = new ValidationException(); + validationException.addValidationError(String.format("\"%s\" requires training.", KNN_METHOD)); + } + + ValidationException methodValidation = v.validate(); + if (methodValidation != null) { + validationException = validationException == null ? new ValidationException() : validationException; + validationException.addValidationErrors(methodValidation.validationErrors()); + } + + if (validationException != null) { + throw validationException; + } + }); protected final Parameter> meta = Parameter.metaParam(); @@ -211,14 +205,16 @@ public KNNVectorFieldMapper build(BuilderContext context) { KNNMethodContext knnMethodContext = this.knnMethodContext.getValue(); if (knnMethodContext != null) { - return new MethodFieldMapper(name, - new KNNVectorFieldType(buildFullName(context), meta.getValue(), dimension.getValue()), - multiFieldsBuilder.build(this, context), - copyTo.build(), - ignoreMalformed(context), - stored.get(), - hasDocValues.get(), - knnMethodContext); + return new MethodFieldMapper( + name, + new KNNVectorFieldType(buildFullName(context), meta.getValue(), dimension.getValue()), + multiFieldsBuilder.build(this, context), + copyTo.build(), + ignoreMalformed(context), + stored.get(), + hasDocValues.get(), + knnMethodContext + ); } String modelIdAsString = this.modelId.get(); @@ -229,15 +225,16 @@ public KNNVectorFieldMapper build(BuilderContext context) { // safely. So, we are unable to validate the model. The model gets validated during ingestion. return new ModelFieldMapper( - name, - new KNNVectorFieldType(buildFullName(context), meta.getValue(), -1, modelIdAsString), - multiFieldsBuilder.build(this, context), - copyTo.build(), - ignoreMalformed(context), - stored.get(), - hasDocValues.get(), - modelDao, - modelIdAsString); + name, + new KNNVectorFieldType(buildFullName(context), meta.getValue(), -1, modelIdAsString), + multiFieldsBuilder.build(this, context), + copyTo.build(), + ignoreMalformed(context), + stored.get(), + hasDocValues.get(), + modelDao, + modelIdAsString + ); } // Build legacy @@ -253,16 +250,18 @@ public KNNVectorFieldMapper build(BuilderContext context) { this.efConstruction = LegacyFieldMapper.getEfConstruction(context.indexSettings()); } - return new LegacyFieldMapper(name, - new KNNVectorFieldType(buildFullName(context), meta.getValue(), dimension.getValue()), - multiFieldsBuilder.build(this, context), - copyTo.build(), - ignoreMalformed(context), - stored.get(), - hasDocValues.get(), - spaceType, - m, - efConstruction); + return new LegacyFieldMapper( + name, + new KNNVectorFieldType(buildFullName(context), meta.getValue(), dimension.getValue()), + multiFieldsBuilder.build(this, context), + copyTo.build(), + ignoreMalformed(context), + stored.get(), + hasDocValues.get(), + spaceType, + m, + efConstruction + ); } } @@ -277,17 +276,16 @@ public TypeParser(Supplier modelDaoSupplier) { } @Override - public Mapper.Builder parse(String name, Map node, ParserContext parserContext) - throws MapperParsingException { + public Mapper.Builder parse(String name, Map node, ParserContext parserContext) throws MapperParsingException { Builder builder = new KNNVectorFieldMapper.Builder(name, modelDaoSupplier.get()); builder.parse(name, parserContext, node); - // All parsing + // All parsing // is done before any mappers are built. Therefore, validation should be done during parsing // so that it can fail early. if (builder.knnMethodContext.get() != null && builder.modelId.get() != null) { - throw new IllegalArgumentException("Method and model can not be both specified in the mapping: " - + name); + throw new IllegalArgumentException("Method and model can not be both specified in the mapping: " + name); } // Dimension should not be null unless modelId is used @@ -331,8 +329,10 @@ public Query existsQuery(QueryShardContext context) { @Override public Query termQuery(Object value, QueryShardContext context) { - throw new QueryShardException(context, "KNN vector do not support exact searching, use KNN queries " + - "instead: [" + name() + "]"); + throw new QueryShardException( + context, + "KNN vector do not support exact searching, use KNN queries " + "instead: [" + name() + "]" + ); } public int getDimension() { @@ -362,10 +362,16 @@ public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, S protected KNNMethodContext knnMethod; protected String modelId; - public KNNVectorFieldMapper(String simpleName, KNNVectorFieldType mappedFieldType, MultiFields multiFields, - CopyTo copyTo, Explicit ignoreMalformed, boolean stored, - boolean hasDocValues) { - super(simpleName, mappedFieldType, multiFields, copyTo); + public KNNVectorFieldMapper( + String simpleName, + KNNVectorFieldType mappedFieldType, + MultiFields multiFields, + CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues + ) { + super(simpleName, mappedFieldType, multiFields, copyTo); this.ignoreMalformed = ignoreMalformed; this.stored = stored; this.hasDocValues = hasDocValues; @@ -389,13 +395,13 @@ protected void parseCreateField(ParseContext context) throws IOException { protected void parseCreateField(ParseContext context, int dimension) throws IOException { if (!KNNSettings.isKNNPluginEnabled()) { - throw new IllegalStateException("KNN plugin is disabled. To enable " + - "update knn.plugin.enabled setting to true"); + throw new IllegalStateException("KNN plugin is disabled. To enable " + "update knn.plugin.enabled setting to true"); } if (KNNSettings.isCircuitBreakerTriggered()) { - throw new IllegalStateException("Indexing knn vector fields is rejected as circuit breaker triggered." + - " Check _opendistro/_knn/stats for detailed state"); + throw new IllegalStateException( + "Indexing knn vector fields is rejected as circuit breaker triggered." + " Check _opendistro/_knn/stats for detailed state" + ); } context.path().add(simpleName()); @@ -438,8 +444,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx } if (dimension != vector.size()) { - String errorMessage = String.format("Vector dimension mismatch. Expected: %d, Given: %d", dimension, - vector.size()); + String errorMessage = String.format("Vector dimension mismatch. Expected: %d, Given: %d", dimension, vector.size()); throw new IllegalArgumentException(errorMessage); } @@ -498,7 +503,7 @@ public static class Defaults { FIELD_TYPE.setTokenized(false); FIELD_TYPE.setIndexOptions(IndexOptions.NONE); FIELD_TYPE.setDocValuesType(DocValuesType.BINARY); - FIELD_TYPE.putAttribute(KNN_FIELD, "true"); //This attribute helps to determine knn field type + FIELD_TYPE.putAttribute(KNN_FIELD, "true"); // This attribute helps to determine knn field type FIELD_TYPE.freeze(); } } @@ -512,9 +517,18 @@ protected static class LegacyFieldMapper extends KNNVectorFieldMapper { protected String m; protected String efConstruction; - private LegacyFieldMapper(String simpleName, KNNVectorFieldType mappedFieldType, MultiFields multiFields, - CopyTo copyTo, Explicit ignoreMalformed, boolean stored, - boolean hasDocValues, String spaceType, String m, String efConstruction) { + private LegacyFieldMapper( + String simpleName, + KNNVectorFieldType mappedFieldType, + MultiFields multiFields, + CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues, + String spaceType, + String m, + String efConstruction + ) { super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues); this.spaceType = spaceType; @@ -540,33 +554,45 @@ public ParametrizedFieldMapper.Builder getMergeBuilder() { } static String getSpaceType(Settings indexSettings) { - String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey()); + String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey()); if (spaceType == null) { - logger.info("[KNN] The setting \"" + METHOD_PARAMETER_SPACE_TYPE + "\" was not set for the index. " + - "Likely caused by recent version upgrade. Setting the setting to the default value=" - + KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE); + logger.info( + "[KNN] The setting \"" + + METHOD_PARAMETER_SPACE_TYPE + + "\" was not set for the index. " + + "Likely caused by recent version upgrade. Setting the setting to the default value=" + + KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE + ); return KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE; } return spaceType; } static String getM(Settings indexSettings) { - String m = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_M_SETTING.getKey()); + String m = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_M_SETTING.getKey()); if (m == null) { - logger.info("[KNN] The setting \"" + HNSW_ALGO_M + "\" was not set for the index. " + - "Likely caused by recent version upgrade. Setting the setting to the default value=" - + KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M); + logger.info( + "[KNN] The setting \"" + + HNSW_ALGO_M + + "\" was not set for the index. " + + "Likely caused by recent version upgrade. Setting the setting to the default value=" + + KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M + ); return String.valueOf(KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M); } return m; } static String getEfConstruction(Settings indexSettings) { - String efConstruction = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING.getKey()); + String efConstruction = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING.getKey()); if (efConstruction == null) { - logger.info("[KNN] The setting \"" + HNSW_ALGO_EF_CONSTRUCTION + "\" was not set for" + - " the index. Likely caused by recent version upgrade. Setting the setting to the default value=" - + KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION); + logger.info( + "[KNN] The setting \"" + + HNSW_ALGO_EF_CONSTRUCTION + + "\" was not set for" + + " the index. Likely caused by recent version upgrade. Setting the setting to the default value=" + + KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION + ); return String.valueOf(KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION); } return efConstruction; @@ -578,9 +604,16 @@ static String getEfConstruction(Settings indexSettings) { */ protected static class MethodFieldMapper extends KNNVectorFieldMapper { - private MethodFieldMapper(String simpleName, KNNVectorFieldType mappedFieldType, MultiFields multiFields, - CopyTo copyTo, Explicit ignoreMalformed, boolean stored, - boolean hasDocValues, KNNMethodContext knnMethodContext) { + private MethodFieldMapper( + String simpleName, + KNNVectorFieldType mappedFieldType, + MultiFields multiFields, + CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues, + KNNMethodContext knnMethodContext + ) { super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues); @@ -595,8 +628,10 @@ private MethodFieldMapper(String simpleName, KNNVectorFieldType mappedFieldType, this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName()); try { - this.fieldType.putAttribute(PARAMETERS, Strings.toString(XContentFactory.jsonBuilder() - .map(knnEngine.getMethodAsMap(knnMethodContext)))); + this.fieldType.putAttribute( + PARAMETERS, + Strings.toString(XContentFactory.jsonBuilder().map(knnEngine.getMethodAsMap(knnMethodContext))) + ); } catch (IOException ioe) { throw new RuntimeException("Unable to create KNNVectorFieldMapper: " + ioe); } @@ -610,9 +645,17 @@ private MethodFieldMapper(String simpleName, KNNVectorFieldType mappedFieldType, */ protected static class ModelFieldMapper extends KNNVectorFieldMapper { - private ModelFieldMapper(String simpleName, KNNVectorFieldType mappedFieldType, MultiFields multiFields, - CopyTo copyTo, Explicit ignoreMalformed, boolean stored, - boolean hasDocValues, ModelDao modelDao, String modelId) { + private ModelFieldMapper( + String simpleName, + KNNVectorFieldType mappedFieldType, + MultiFields multiFields, + CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues, + ModelDao modelDao, + String modelId + ) { super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues); this.modelId = modelId; @@ -631,10 +674,17 @@ protected void parseCreateField(ParseContext context) throws IOException { ModelMetadata modelMetadata = this.modelDao.getMetadata(modelId); if (modelMetadata == null) { - throw new IllegalStateException("Model \"" + modelId + "\" from " + - context.mapperService().index().getName() + "'s mapping does not exist. Because the " + - "\"" + MODEL_ID + "\" parameter is not updateable, this index will need to " + - "be recreated with a valid model."); + throw new IllegalStateException( + "Model \"" + + modelId + + "\" from " + + context.mapperService().index().getName() + + "'s mapping does not exist. Because the " + + "\"" + + MODEL_ID + + "\" parameter is not updateable, this index will need to " + + "be recreated with a valid model." + ); } parseCreateField(context, modelMetadata.getDimension()); diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorIndexFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorIndexFieldData.java index 20fe1add1..367cfae53 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorIndexFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorIndexFieldData.java @@ -54,19 +54,23 @@ public SortField sortField(Object missingValue, MultiValueMode sortMode, XFieldC @Override public BucketedSort newBucketedSort( - BigArrays bigArrays, Object missingValue, - MultiValueMode sortMode, XFieldComparatorSource.Nested nested, - SortOrder sortOrder, DocValueFormat format, int bucketSize, BucketedSort.ExtraData extra) { + BigArrays bigArrays, + Object missingValue, + MultiValueMode sortMode, + XFieldComparatorSource.Nested nested, + SortOrder sortOrder, + DocValueFormat format, + int bucketSize, + BucketedSort.ExtraData extra + ) { throw new UnsupportedOperationException("knn vector field doesn't support this operation"); } - public static class Builder implements IndexFieldData.Builder { private final String name; private final ValuesSourceType valuesSourceType; - public Builder(String name, ValuesSourceType valuesSourceType) { this.name = name; this.valuesSourceType = valuesSourceType; diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 69563b06f..0c8240dd4 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -38,9 +38,12 @@ public void setNextDocId(int docId) throws IOException { public float[] getValue() { if (!docExists) { String errorMessage = String.format( - "One of the document doesn't have a value for field '%s'. " + - "This can be avoided by checking if a document has a value for the field or not " + - "by doc['%s'].size() == 0 ? 0 : {your script}",fieldName,fieldName); + "One of the document doesn't have a value for field '%s'. " + + "This can be avoided by checking if a document has a value for the field or not " + + "by doc['%s'].size() == 0 ? 0 : {your script}", + fieldName, + fieldName + ); throw new IllegalStateException(errorMessage); } try { diff --git a/src/main/java/org/opensearch/knn/index/KNNWeight.java b/src/main/java/org/opensearch/knn/index/KNNWeight.java index 6c8c41cb9..909e7222f 100644 --- a/src/main/java/org/opensearch/knn/index/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/KNNWeight.java @@ -5,7 +5,6 @@ package org.opensearch.knn.index; -import com.google.common.collect.ImmutableMap; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.memory.NativeMemoryAllocation; @@ -36,7 +35,6 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -78,113 +76,120 @@ public Explanation explain(LeafReaderContext context, int doc) { } @Override - public void extractTerms(Set terms) { - } + public void extractTerms(Set terms) {} @Override public Scorer scorer(LeafReaderContext context) throws IOException { - SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); - String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); - - FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); - - if (fieldInfo == null) { - logger.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), - reader.getSegmentName()); - return null; - } - - KNNEngine knnEngine; - SpaceType spaceType; - - // Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's - // metadata. - String modelId = fieldInfo.getAttribute(MODEL_ID); - if (modelId != null) { - ModelMetadata modelMetadata = modelDao.getMetadata(modelId); - if (modelMetadata == null) { - throw new RuntimeException("Model \"" + modelId + "\" does not exist."); - } - - knnEngine = modelMetadata.getKnnEngine(); - spaceType = modelMetadata.getSpaceType(); - } else { - String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.NMSLIB.getName()); - knnEngine = KNNEngine.getEngine(engineName); - String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); - spaceType = SpaceType.getSpace(spaceTypeName); - } - - /* - * In case of compound file, extension would be + c otherwise - */ - String engineExtension = reader.getSegmentInfo().info.getUseCompoundFile() - ? knnEngine.getExtension() + KNNConstants.COMPOUND_EXTENSION : knnEngine.getExtension(); - String engineSuffix = knnQuery.getField() + engineExtension; - List engineFiles = reader.getSegmentInfo().files().stream() - .filter(fileName -> fileName.endsWith(engineSuffix)) - .collect(Collectors.toList()); - - if(engineFiles.isEmpty()) { - logger.debug("[KNN] No engine index found for field {} for segment {}", - knnQuery.getField(), reader.getSegmentName()); - return null; + SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); + String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); + + FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + + if (fieldInfo == null) { + logger.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName()); + return null; + } + + KNNEngine knnEngine; + SpaceType spaceType; + + // Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's + // metadata. + String modelId = fieldInfo.getAttribute(MODEL_ID); + if (modelId != null) { + ModelMetadata modelMetadata = modelDao.getMetadata(modelId); + if (modelMetadata == null) { + throw new RuntimeException("Model \"" + modelId + "\" does not exist."); } - Path indexPath = PathUtils.get(directory, engineFiles.get(0)); - final KNNQueryResult[] results; - KNNCounter.GRAPH_QUERY_REQUESTS.increment(); - - // We need to first get index allocation - NativeMemoryAllocation indexAllocation; - try { - indexAllocation = nativeMemoryCacheManager.get( - new NativeMemoryEntryContext.IndexEntryContext( - indexPath.toString(), - NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), - getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName()), - knnQuery.getIndexName() - ), true); - } catch (ExecutionException e) { - GRAPH_QUERY_ERRORS.increment(); - throw new RuntimeException(e); + knnEngine = modelMetadata.getKnnEngine(); + spaceType = modelMetadata.getSpaceType(); + } else { + String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.NMSLIB.getName()); + knnEngine = KNNEngine.getEngine(engineName); + String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); + spaceType = SpaceType.getSpace(spaceTypeName); + } + + /* + * In case of compound file, extension would be + c otherwise + */ + String engineExtension = reader.getSegmentInfo().info.getUseCompoundFile() + ? knnEngine.getExtension() + KNNConstants.COMPOUND_EXTENSION + : knnEngine.getExtension(); + String engineSuffix = knnQuery.getField() + engineExtension; + List engineFiles = reader.getSegmentInfo() + .files() + .stream() + .filter(fileName -> fileName.endsWith(engineSuffix)) + .collect(Collectors.toList()); + + if (engineFiles.isEmpty()) { + logger.debug("[KNN] No engine index found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName()); + return null; + } + + Path indexPath = PathUtils.get(directory, engineFiles.get(0)); + final KNNQueryResult[] results; + KNNCounter.GRAPH_QUERY_REQUESTS.increment(); + + // We need to first get index allocation + NativeMemoryAllocation indexAllocation; + try { + indexAllocation = nativeMemoryCacheManager.get( + new NativeMemoryEntryContext.IndexEntryContext( + indexPath.toString(), + NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), + getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName()), + knnQuery.getIndexName() + ), + true + ); + } catch (ExecutionException e) { + GRAPH_QUERY_ERRORS.increment(); + throw new RuntimeException(e); + } + + // Now that we have the allocation, we need to readLock it + indexAllocation.readLock(); + + try { + if (indexAllocation.isClosed()) { + throw new RuntimeException("Index has already been closed"); } - // Now that we have the allocation, we need to readLock it - indexAllocation.readLock(); - - try { - if (indexAllocation.isClosed()) { - throw new RuntimeException("Index has already been closed"); - } - - results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), knnQuery.getQueryVector(), knnQuery.getK(), knnEngine.getName()); - } catch (Exception e) { - GRAPH_QUERY_ERRORS.increment(); - throw new RuntimeException(e); - } finally { - indexAllocation.readUnlock(); - } - - /* - * Scores represent the distance of the documents with respect to given query vector. - * Lesser the score, the closer the document is to the query vector. - * Since by default results are retrieved in the descending order of scores, to get the nearest - * neighbors we are inverting the scores. - */ - if (results.length == 0) { - logger.debug("[KNN] Query yielded 0 results"); - return null; - } - - Map scores = Arrays.stream(results).collect( - Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); - int maxDoc = Collections.max(scores.keySet()) + 1; - DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc); - DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(maxDoc); - Arrays.stream(results).forEach(result -> setAdder.add(result.getId())); - DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator(); - return new KNNScorer(this, docIdSetIter, scores, boost); + results = JNIService.queryIndex( + indexAllocation.getMemoryAddress(), + knnQuery.getQueryVector(), + knnQuery.getK(), + knnEngine.getName() + ); + } catch (Exception e) { + GRAPH_QUERY_ERRORS.increment(); + throw new RuntimeException(e); + } finally { + indexAllocation.readUnlock(); + } + + /* + * Scores represent the distance of the documents with respect to given query vector. + * Lesser the score, the closer the document is to the query vector. + * Since by default results are retrieved in the descending order of scores, to get the nearest + * neighbors we are inverting the scores. + */ + if (results.length == 0) { + logger.debug("[KNN] Query yielded 0 results"); + return null; + } + + Map scores = Arrays.stream(results) + .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); + int maxDoc = Collections.max(scores.keySet()) + 1; + DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc); + DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(maxDoc); + Arrays.stream(results).forEach(result -> setAdder.add(result.getId())); + DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator(); + return new KNNScorer(this, docIdSetIter, scores, boost); } @Override @@ -193,9 +198,7 @@ public boolean isCacheable(LeafReaderContext context) { } public static float normalizeScore(float score) { - if (score >= 0) - return 1 / (1 + score); + if (score >= 0) return 1 / (1 + score); return -score + 1; } } - diff --git a/src/main/java/org/opensearch/knn/index/MethodComponent.java b/src/main/java/org/opensearch/knn/index/MethodComponent.java index 432f8dfee..d7957d74f 100644 --- a/src/main/java/org/opensearch/knn/index/MethodComponent.java +++ b/src/main/java/org/opensearch/knn/index/MethodComponent.java @@ -63,7 +63,6 @@ public Map> getParameters() { return parameters; } - /** * Parse methodComponentContext into a map that the library can use to configure the method * @@ -107,7 +106,7 @@ public ValidationException validate(MethodComponentContext methodComponentContex } } - if(errorMessages.isEmpty()) { + if (errorMessages.isEmpty()) { return null; } @@ -116,7 +115,6 @@ public ValidationException validate(MethodComponentContext methodComponentContex return validationException; } - /** * gets requiresTraining value * @@ -170,23 +168,23 @@ public boolean isTrainingRequired(MethodComponentContext methodComponentContext) public int estimateOverheadInKB(MethodComponentContext methodComponentContext, int dimension) { // Assume we have the following KNNMethodContext: // "method": { - // "name":"METHOD_1", - // "engine":"faiss", - // "space_type": "l2", - // "parameters":{ - // "P1":1, - // "P2":{ - // "name":"METHOD_2", - // "parameters":{ - // "P3":2 - // } - // } - // } + // "name":"METHOD_1", + // "engine":"faiss", + // "space_type": "l2", + // "parameters":{ + // "P1":1, + // "P2":{ + // "name":"METHOD_2", + // "parameters":{ + // "P3":2 + // } + // } + // } // } // // First, we get the overhead estimate of METHOD_1. Then, we add the overhead // estimate for METHOD_2 by looping over parameters of METHOD_1. - + long size = overheadInKBEstimator.apply(this, methodComponentContext, dimension); // Check if any of the parameters add overhead @@ -274,7 +272,7 @@ public Builder setMapGenerator(BiFunction getParameterMapWithDefaultsAdded(MethodComponentContext methodComponentContext, - MethodComponent methodComponent) { + public static Map getParameterMapWithDefaultsAdded( + MethodComponentContext methodComponentContext, + MethodComponent methodComponent + ) { Map parametersWithDefaultsMap = new HashMap<>(); Map userProvidedParametersMap = methodComponentContext.getParameters(); for (Parameter parameter : methodComponent.getParameters().values()) { diff --git a/src/main/java/org/opensearch/knn/index/Parameter.java b/src/main/java/org/opensearch/knn/index/Parameter.java index 1bb06fd98..4d69e7838 100644 --- a/src/main/java/org/opensearch/knn/index/Parameter.java +++ b/src/main/java/org/opensearch/knn/index/Parameter.java @@ -70,8 +70,7 @@ public T getDefaultValue() { * Integer method parameter */ public static class IntegerParameter extends Parameter { - public IntegerParameter(String name, Integer defaultValue, Predicate validator) - { + public IntegerParameter(String name, Integer defaultValue, Predicate validator) { super(name, defaultValue, validator); } @@ -80,21 +79,22 @@ public ValidationException validate(Object value) { ValidationException validationException = null; if (!(value instanceof Integer)) { validationException = new ValidationException(); - validationException.addValidationError(String.format("Value not of type Integer for Integer " + - "parameter \"%s\".", getName())); + validationException.addValidationError( + String.format("Value not of type Integer for Integer " + "parameter \"%s\".", getName()) + ); return validationException; } if (!validator.test((Integer) value)) { validationException = new ValidationException(); - validationException.addValidationError(String.format("Parameter validation failed for Integer " + - "parameter \"%s\".", getName())); + validationException.addValidationError( + String.format("Parameter validation failed for Integer " + "parameter \"%s\".", getName()) + ); } return validationException; } } - /** * MethodContext parameter. Some methods require sub-methods in order to implement some kind of functionality. For * instance, faiss methods can contain an encoder along side the approximate nearest neighbor function to compress @@ -111,9 +111,11 @@ public static class MethodComponentContextParameter extends Parameter methodComponents) { + public MethodComponentContextParameter( + String name, + MethodComponentContext defaultValue, + Map methodComponents + ) { super(name, defaultValue, methodComponentContext -> { if (!methodComponents.containsKey(methodComponentContext.getName())) { return false; @@ -129,16 +131,18 @@ public ValidationException validate(Object value) { ValidationException validationException = null; if (!(value instanceof MethodComponentContext)) { validationException = new ValidationException(); - validationException.addValidationError(String.format("Value not of type MethodComponentContext for" + - " MethodComponentContext parameter \"%s\".", getName())); + validationException.addValidationError( + String.format("Value not of type MethodComponentContext for" + " MethodComponentContext parameter \"%s\".", getName()) + ); return validationException; } if (!validator.test((MethodComponentContext) value)) { validationException = new ValidationException(); validationException.addValidationError("Parameter validation failed."); - validationException.addValidationError(String.format("Parameter validation failed for " + - "MethodComponentContext parameter \"%s\".", getName())); + validationException.addValidationError( + String.format("Parameter validation failed for " + "MethodComponentContext parameter \"%s\".", getName()) + ); } return validationException; diff --git a/src/main/java/org/opensearch/knn/index/codec/BinaryDocValuesSub.java b/src/main/java/org/opensearch/knn/index/codec/BinaryDocValuesSub.java index 67a476c8d..73fef0349 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BinaryDocValuesSub.java +++ b/src/main/java/org/opensearch/knn/index/codec/BinaryDocValuesSub.java @@ -35,4 +35,4 @@ public BinaryDocValuesSub(MergeState.DocMap docMap, BinaryDocValues values) { public int nextDoc() throws IOException { return values.nextDoc(); } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80Codec.java index 59655762e..de9affe86 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80Codec.java @@ -52,8 +52,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) { * This function returns the Lucene80 Codec. */ public Codec getDelegatee() { - if (lucene80Codec == null) - lucene80Codec = Codec.forName(LUCENE_80); + if (lucene80Codec == null) lucene80Codec = Codec.forName(LUCENE_80); return lucene80Codec; } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java index ca581b6ae..18e55bf1a 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java @@ -28,8 +28,7 @@ public class KNN80CompoundFormat extends CompoundFormat { private final Logger logger = LogManager.getLogger(KNN80CompoundFormat.class); - public KNN80CompoundFormat() { - } + public KNN80CompoundFormat() {} @Override public CompoundDirectory getCompoundReader(Directory dir, SegmentInfo si, IOContext context) throws IOException { @@ -44,14 +43,12 @@ public void write(Directory dir, SegmentInfo si, IOContext context) throws IOExc Codec.getDefault().compoundFormat().write(dir, si, context); } - private void writeEngineFiles(Directory dir, SegmentInfo si, IOContext context, String engineExtension) - throws IOException { + private void writeEngineFiles(Directory dir, SegmentInfo si, IOContext context, String engineExtension) throws IOException { /* * If engine file present, remove it from the compounding file list to avoid header/footer checks * and create a new compounding file format with extension engine + c. */ - Set engineFiles = si.files().stream().filter(file -> file.endsWith(engineExtension)) - .collect(Collectors.toSet()); + Set engineFiles = si.files().stream().filter(file -> file.endsWith(engineExtension)).collect(Collectors.toSet()); Set segmentFiles = new HashSet<>(si.files()); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index d4bc662c5..807e7c9da 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -94,10 +94,14 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) KNNEngine knnEngine = model.getModelMetadata().getKnnEngine(); - engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(), - field.name, knnEngine.getExtension()); - indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), - engineFileName).toString(); + engineFileName = buildEngineFileName( + state.segmentInfo.name, + knnEngine.getLatestBuildVersion(), + field.name, + knnEngine.getExtension() + ); + indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName) + .toString(); tmpEngineFileName = engineFileName + TEMP_SUFFIX; String tempIndexPath = indexPath + TEMP_SUFFIX; @@ -112,10 +116,14 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName()); KNNEngine knnEngine = KNNEngine.getEngine(engineName); - engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getLatestBuildVersion(), - field.name, knnEngine.getExtension()); - indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), - engineFileName).toString(); + engineFileName = buildEngineFileName( + state.segmentInfo.name, + knnEngine.getLatestBuildVersion(), + field.name, + knnEngine.getExtension() + ); + indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName) + .toString(); tmpEngineFileName = engineFileName + TEMP_SUFFIX; String tempIndexPath = indexPath + TEMP_SUFFIX; @@ -131,10 +139,12 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) * existing file will miss calculating checksum for the serialized graph * bytes and result in index corruption issues. */ - //TODO: I think this can be refactored to avoid this copy and then write + // TODO: I think this can be refactored to avoid this copy and then write // https://github.com/opendistro-for-elasticsearch/k-NN/issues/330 - try (IndexInput is = state.directory.openInput(tmpEngineFileName, state.context); - IndexOutput os = state.directory.createOutput(engineFileName, state.context)) { + try ( + IndexInput is = state.directory.openInput(tmpEngineFileName, state.context); + IndexOutput os = state.directory.createOutput(engineFileName, state.context) + ) { os.copyBytes(is, is.length()); CodecUtil.writeFooter(os); } catch (Exception ex) { @@ -146,29 +156,26 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer) } } - private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, - String indexPath) { - Map parameters = ImmutableMap.of(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue( - KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); - AccessController.doPrivileged( - (PrivilegedAction) () -> { - JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters, - knnEngine.getName()); - return null; - } + private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { + Map parameters = ImmutableMap.of( + KNNConstants.INDEX_THREAD_QTY, + KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) ); + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters, knnEngine.getName()); + return null; + }); } - private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, - String indexPath) throws IOException { + private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) + throws IOException { Map parameters = new HashMap<>(); Map fieldAttributes = fieldInfo.attributes(); String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); // parametersString will be null when legacy mapper is used if (parametersString == null) { - parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, - SpaceType.DEFAULT.getValue())); + parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); Map algoParams = new HashMap<>(); @@ -183,22 +190,20 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa parameters.put(PARAMETERS, algoParams); } else { parameters.putAll( - XContentFactory.xContent(XContentType.JSON).createParser(NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, parametersString).map() + XContentFactory.xContent(XContentType.JSON) + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, parametersString) + .map() ); } // Used to determine how many threads to use when indexing - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue( - KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); // Pass the path for the nms library to save the file - AccessController.doPrivileged( - (PrivilegedAction) () -> { - JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine.getName()); - return null; - } - ); + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine.getName()); + return null; + }); } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN84Codec/KNN84Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN84Codec/KNN84Codec.java index a50f396a4..97a5fe029 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN84Codec/KNN84Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN84Codec/KNN84Codec.java @@ -42,7 +42,7 @@ public KNN84Codec() { super(KNN_84); // Note that DocValuesFormat can use old Codec's DocValuesFormat. For instance Lucene84 uses Lucene80 // DocValuesFormat. Refer to defaultDVFormat in LuceneXXCodec.java to find out which version it uses - this.docValuesFormat = new KNN80DocValuesFormat(); + this.docValuesFormat = new KNN80DocValuesFormat(); this.perFieldDocValuesFormat = new PerFieldDocValuesFormat() { @Override public DocValuesFormat getDocValuesFormatForField(String field) { @@ -56,8 +56,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) { * This function returns the Lucene84 Codec. */ public Codec getDelegatee() { - if (lucene84Codec == null) - lucene84Codec = Codec.forName(LUCENE_84); + if (lucene84Codec == null) lucene84Codec = Codec.forName(LUCENE_84); return lucene84Codec; } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN86Codec/KNN86Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN86Codec/KNN86Codec.java index 70c75e09b..ad0fbff06 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN86Codec/KNN86Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN86Codec/KNN86Codec.java @@ -43,7 +43,7 @@ public KNN86Codec() { super(KNN_86); // Note that DocValuesFormat can use old Codec's DocValuesFormat. For instance Lucene84 uses Lucene80 // DocValuesFormat. Refer to defaultDVFormat in LuceneXXCodec.java to find out which version it uses - this.docValuesFormat = new KNN80DocValuesFormat(); + this.docValuesFormat = new KNN80DocValuesFormat(); this.perFieldDocValuesFormat = new PerFieldDocValuesFormat() { @Override public DocValuesFormat getDocValuesFormatForField(String field) { @@ -57,8 +57,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) { * This function returns the Lucene84 Codec. */ public Codec getDelegatee() { - if (lucene86Codec == null) - lucene86Codec = Codec.forName(LUCENE_86); + if (lucene86Codec == null) lucene86Codec = Codec.forName(LUCENE_86); return lucene86Codec; } @@ -73,7 +72,6 @@ public DocValuesFormat docValuesFormat() { * approach of manually overriding. */ - public void setPostingsFormat(PostingsFormat postingsFormat) { this.postingsFormat = postingsFormat; } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java index 6e2e897e0..3001acb70 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN87Codec/KNN87Codec.java @@ -43,7 +43,7 @@ public KNN87Codec() { super(KNN_87); // Note that DocValuesFormat can use old Codec's DocValuesFormat. For instance Lucene84 uses Lucene80 // DocValuesFormat. Refer to defaultDVFormat in LuceneXXCodec.java to find out which version it uses - this.docValuesFormat = new KNN80DocValuesFormat(); + this.docValuesFormat = new KNN80DocValuesFormat(); this.perFieldDocValuesFormat = new PerFieldDocValuesFormat() { @Override public DocValuesFormat getDocValuesFormatForField(String field) { @@ -57,8 +57,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) { * This function returns the Codec. */ public Codec getDelegatee() { - if (lucene87Codec == null) - lucene87Codec = Codec.forName(LUCENE_87); + if (lucene87Codec == null) lucene87Codec = Codec.forName(LUCENE_87); return lucene87Codec; } @@ -73,7 +72,6 @@ public DocValuesFormat docValuesFormat() { * approach of manually overriding. */ - public void setPostingsFormat(PostingsFormat postingsFormat) { this.postingsFormat = postingsFormat; } diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorAsArraySerializer.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorAsArraySerializer.java index f3a0803cb..751a229db 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorAsArraySerializer.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorAsArraySerializer.java @@ -18,8 +18,10 @@ public class KNNVectorAsArraySerializer implements KNNVectorSerializer { @Override public byte[] floatToByteArray(float[] input) { byte[] bytes; - try (ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); - ObjectOutputStream objectStream = new ObjectOutputStream(byteStream);) { + try ( + ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); + ObjectOutputStream objectStream = new ObjectOutputStream(byteStream); + ) { objectStream.writeObject(input); bytes = byteStream.toByteArray(); } catch (IOException e) { diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializer.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializer.java index 75fb4f4a4..35f1ff5be 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializer.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializer.java @@ -6,7 +6,6 @@ package org.opensearch.knn.index.codec.util; import java.io.ByteArrayInputStream; -import java.io.IOException; /** * Interface abstracts the vector serializer object that is responsible for serialization and de-serialization of k-NN vector diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java index bdb131f60..f02da0949 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java @@ -21,8 +21,10 @@ */ public class KNNVectorSerializerFactory { private static Map VECTOR_SERIALIZER_BY_TYPE = ImmutableMap.of( - ARRAY, new KNNVectorAsArraySerializer(), - COLLECTION_OF_FLOATS, new KNNVectorAsCollectionOfFloatsSerializer() + ARRAY, + new KNNVectorAsArraySerializer(), + COLLECTION_OF_FLOATS, + new KNNVectorAsCollectionOfFloatsSerializer() ); private static final int ARRAY_HEADER_OFFSET = 27; @@ -34,13 +36,12 @@ COLLECTION_OF_FLOATS, new KNNVectorAsCollectionOfFloatsSerializer() * here. */ private static final byte[] SERIALIZATION_PROTOCOL_HEADER_PREFIX = new byte[] { - highByte(ObjectStreamConstants.STREAM_MAGIC), - lowByte(ObjectStreamConstants.STREAM_MAGIC), - highByte(ObjectStreamConstants.STREAM_VERSION), - lowByte(ObjectStreamConstants.STREAM_VERSION), - ObjectStreamConstants.TC_ARRAY, - ObjectStreamConstants.TC_CLASSDESC - }; + highByte(ObjectStreamConstants.STREAM_MAGIC), + lowByte(ObjectStreamConstants.STREAM_MAGIC), + highByte(ObjectStreamConstants.STREAM_VERSION), + lowByte(ObjectStreamConstants.STREAM_VERSION), + ObjectStreamConstants.TC_ARRAY, + ObjectStreamConstants.TC_CLASSDESC }; public static KNNVectorSerializer getSerializerBySerializationMode(final SerializationMode serializationMode) { return VECTOR_SERIALIZER_BY_TYPE.getOrDefault(serializationMode, new KNNVectorAsCollectionOfFloatsSerializer()); @@ -63,7 +64,7 @@ private static SerializationMode serializerModeFromStream(ByteArrayInputStream b final byte[] byteArray = new byte[SERIALIZATION_PROTOCOL_HEADER_PREFIX.length]; byteStream.read(byteArray, 0, SERIALIZATION_PROTOCOL_HEADER_PREFIX.length); byteStream.reset(); - //checking if stream protocol grammar in header is valid for serialized array + // checking if stream protocol grammar in header is valid for serialized array if (Arrays.equals(SERIALIZATION_PROTOCOL_HEADER_PREFIX, byteArray)) { int numberOfAvailableBytesAfterHeader = numberOfAvailableBytesInStream - ARRAY_HEADER_OFFSET; return getSerializerOrThrowError(numberOfAvailableBytesAfterHeader, ARRAY); @@ -75,16 +76,17 @@ private static SerializationMode getSerializerOrThrowError(int numberOfRemaining if (numberOfRemainingBytes % BYTES_IN_FLOAT == 0) { return serializationMode; } - throw new IllegalArgumentException(String.format("Byte stream cannot be deserialized to array of floats due to invalid length %d", numberOfRemainingBytes)); + throw new IllegalArgumentException( + String.format("Byte stream cannot be deserialized to array of floats due to invalid length %d", numberOfRemainingBytes) + ); } private static byte highByte(short shortValue) { - return (byte) (shortValue>> BITS_IN_ONE_BYTE); + return (byte) (shortValue >> BITS_IN_ONE_BYTE); } private static byte lowByte(short shortValue) { return (byte) shortValue; } - } diff --git a/src/main/java/org/opensearch/knn/index/codec/util/SerializationMode.java b/src/main/java/org/opensearch/knn/index/codec/util/SerializationMode.java index 1a68393a7..1fb82cbfe 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/SerializationMode.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/SerializationMode.java @@ -6,5 +6,6 @@ package org.opensearch.knn.index.codec.util; public enum SerializationMode { - ARRAY, COLLECTION_OF_FLOATS + ARRAY, + COLLECTION_OF_FLOATS } diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java index fa6044f46..9279c816f 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java @@ -102,8 +102,15 @@ class IndexAllocation implements NativeMemoryAllocation { * @param openSearchIndexName Name of OpenSearch index this index is associated with * @param watcherHandle Handle for watching index file */ - IndexAllocation(ExecutorService executorService, long memoryAddress, int size, KNNEngine knnEngine, - String indexPath, String openSearchIndexName, WatcherHandle watcherHandle) { + IndexAllocation( + ExecutorService executorService, + long memoryAddress, + int size, + KNNEngine knnEngine, + String indexPath, + String openSearchIndexName, + WatcherHandle watcherHandle + ) { this.executor = executorService; this.closed = false; this.knnEngine = knnEngine; diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java index 5672b4a94..efdc4fd31 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java @@ -69,18 +69,18 @@ public static synchronized NativeMemoryCacheManager getInstance() { private void initialize() { CacheBuilder cacheBuilder = CacheBuilder.newBuilder() - .recordStats() - .concurrencyLevel(1) - .removalListener(this::onRemoval); + .recordStats() + .concurrencyLevel(1) + .removalListener(this::onRemoval); - if(KNNSettings.state().getSettingValue(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED)) { + if (KNNSettings.state().getSettingValue(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED)) { maxWeight = KNNSettings.getCircuitBreakerLimit().getKb(); cacheBuilder.maximumWeight(maxWeight).weigher((k, v) -> v.getSizeInKB()); } - if(KNNSettings.state().getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_ENABLED)) { - long expiryTime = ((TimeValue) KNNSettings.state() - .getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES)).getMinutes(); + if (KNNSettings.state().getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_ENABLED)) { + long expiryTime = ((TimeValue) KNNSettings.state().getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES)) + .getMinutes(); cacheBuilder.expireAfterAccess(expiryTime, TimeUnit.MINUTES); } @@ -95,12 +95,12 @@ private void initialize() { public synchronized void rebuildCache() { logger.info("KNN Cache rebuilding."); - //TODO: Does this really need to be executed with an executor? Also, does invalidateAll really need to be + // TODO: Does this really need to be executed with an executor? Also, does invalidateAll really need to be // called? executor.execute(() -> { cache.invalidateAll(); - initialize(); } - ); + initialize(); + }); } @Override @@ -132,10 +132,12 @@ public Float getCacheSizeAsPercentage() { * @return current size of the cache */ public long getIndicesSizeInKilobytes() { - return cache.asMap().values().stream() - .filter(nativeMemoryAllocation -> nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation) - .mapToLong(NativeMemoryAllocation::getSizeInKB) - .sum(); + return cache.asMap() + .values() + .stream() + .filter(nativeMemoryAllocation -> nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation) + .mapToLong(NativeMemoryAllocation::getSizeInKB) + .sum(); } /** @@ -155,11 +157,15 @@ public Float getIndicesSizeAsPercentage() { */ public Long getIndexSizeInKilobytes(final String indexName) { Validate.notNull(indexName, "Index name cannot be null"); - return cache.asMap().values().stream() - .filter(nativeMemoryAllocation -> nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation) - .filter(indexAllocation -> indexName.equals(((NativeMemoryAllocation.IndexAllocation) indexAllocation).getOpenSearchIndexName())) - .mapToLong(NativeMemoryAllocation::getSizeInKB) - .sum(); + return cache.asMap() + .values() + .stream() + .filter(nativeMemoryAllocation -> nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation) + .filter( + indexAllocation -> indexName.equals(((NativeMemoryAllocation.IndexAllocation) indexAllocation).getOpenSearchIndexName()) + ) + .mapToLong(NativeMemoryAllocation::getSizeInKB) + .sum(); } /** @@ -180,12 +186,15 @@ public Float getIndexSizeAsPercentage(final String indexName) { */ public long getTrainingSizeInKilobytes() { // Currently, all allocations that are not index allocations will be for training. - return cache.asMap().values().stream() - .filter(nativeMemoryAllocation -> - nativeMemoryAllocation instanceof NativeMemoryAllocation.TrainingDataAllocation || - nativeMemoryAllocation instanceof NativeMemoryAllocation.AnonymousAllocation) - .mapToLong(NativeMemoryAllocation::getSizeInKB) - .sum(); + return cache.asMap() + .values() + .stream() + .filter( + nativeMemoryAllocation -> nativeMemoryAllocation instanceof NativeMemoryAllocation.TrainingDataAllocation + || nativeMemoryAllocation instanceof NativeMemoryAllocation.AnonymousAllocation + ) + .mapToLong(NativeMemoryAllocation::getSizeInKB) + .sum(); } /** @@ -214,12 +223,16 @@ public long getMaxCacheSizeInKilobytes() { */ public int getIndexGraphCount(String indexName) { Validate.notNull(indexName, "Index name cannot be null"); - return Long.valueOf(cache.asMap().values().stream() - .filter(nativeMemoryAllocation -> - nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation) - .filter(indexAllocation -> indexName.equals(((NativeMemoryAllocation.IndexAllocation) indexAllocation) - .getOpenSearchIndexName())) - .count()).intValue(); + return Long.valueOf( + cache.asMap() + .values() + .stream() + .filter(nativeMemoryAllocation -> nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation) + .filter( + indexAllocation -> indexName.equals(((NativeMemoryAllocation.IndexAllocation) indexAllocation).getOpenSearchIndexName()) + ) + .count() + ).intValue(); } /** @@ -239,17 +252,22 @@ public CacheStats getCacheStats() { * @return NativeMemoryAllocation associated with nativeMemoryEntryContext * @throws ExecutionException if there is an exception when loading from the cache */ - public NativeMemoryAllocation get(NativeMemoryEntryContext nativeMemoryEntryContext, - boolean isAbleToTriggerEviction) throws ExecutionException { - if (!isAbleToTriggerEviction && - !cache.asMap().containsKey(nativeMemoryEntryContext.getKey()) && - maxWeight - getCacheSizeInKilobytes() - nativeMemoryEntryContext.calculateSizeInKB() <= 0 - ) { + public NativeMemoryAllocation get(NativeMemoryEntryContext nativeMemoryEntryContext, boolean isAbleToTriggerEviction) + throws ExecutionException { + if (!isAbleToTriggerEviction + && !cache.asMap().containsKey(nativeMemoryEntryContext.getKey()) + && maxWeight - getCacheSizeInKilobytes() - nativeMemoryEntryContext.calculateSizeInKB() <= 0) { throw new OutOfNativeMemoryException( - "Entry cannot be loaded into cache because it would not fit. " + - "Entry size: " + nativeMemoryEntryContext.calculateSizeInKB() + " KB " + - "Current Cache Size: " + getCacheSizeInKilobytes() + " KB " + - "Max Cache Size: " + maxWeight); + "Entry cannot be loaded into cache because it would not fit. " + + "Entry size: " + + nativeMemoryEntryContext.calculateSizeInKB() + + " KB " + + "Current Cache Size: " + + getCacheSizeInKilobytes() + + " KB " + + "Max Cache Size: " + + maxWeight + ); } return cache.get(nativeMemoryEntryContext.getKey(), nativeMemoryEntryContext::load); @@ -306,17 +324,14 @@ public Map> getIndicesCacheStats() { Map indexMap = statValues.computeIfAbsent(indexName, name -> new HashMap<>()); indexMap.computeIfAbsent(GRAPH_COUNT, key -> getIndexGraphCount(indexName)); - indexMap.computeIfAbsent(StatNames.GRAPH_MEMORY_USAGE.getName(), key -> - getIndexSizeInKilobytes(indexName)); - indexMap.computeIfAbsent(StatNames.GRAPH_MEMORY_USAGE_PERCENTAGE.getName(), key -> - getIndexSizeAsPercentage(indexName)); + indexMap.computeIfAbsent(StatNames.GRAPH_MEMORY_USAGE.getName(), key -> getIndexSizeInKilobytes(indexName)); + indexMap.computeIfAbsent(StatNames.GRAPH_MEMORY_USAGE_PERCENTAGE.getName(), key -> getIndexSizeAsPercentage(indexName)); } } return statValues; } - private void onRemoval(RemovalNotification removalNotification) { NativeMemoryAllocation nativeMemoryAllocation = removalNotification.getValue(); nativeMemoryAllocation.close(); @@ -326,8 +341,7 @@ private void onRemoval(RemovalNotification remov setCacheCapacityReached(true); } - logger.debug("[KNN] Cache evicted. Key {}, Reason: {}", removalNotification.getKey(), - removalNotification.getCause()); + logger.debug("[KNN] Cache evicted. Key {}, Reason: {}", removalNotification.getKey(), removalNotification.getCause()); } private Float getSizeAsPercentage(long size) { diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java index 68d285c6d..13f8dae10 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java @@ -25,6 +25,7 @@ public abstract class NativeMemoryEntryContext { protected final String key; + /** * Constructor * @@ -71,10 +72,12 @@ public static class IndexEntryContext extends NativeMemoryEntryContext parameters, - String openSearchIndexName) { + public IndexEntryContext( + String indexPath, + NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy, + Map parameters, + String openSearchIndexName + ) { super(indexPath); this.indexLoadStrategy = indexLoadStrategy; this.openSearchIndexName = openSearchIndexName; @@ -146,13 +149,15 @@ public static class TrainingDataEntryContext extends NativeMemoryEntryContext, Closeable { + class IndexLoadStrategy + implements + NativeMemoryLoadStrategy, + Closeable { private static IndexLoadStrategy INSTANCE; @@ -82,26 +84,26 @@ public void onFileDeleted(Path indexFilePath) { } @Override - public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.IndexEntryContext - indexEntryContext) throws IOException { + public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.IndexEntryContext indexEntryContext) + throws IOException { Path indexPath = Paths.get(indexEntryContext.getKey()); FileWatcher fileWatcher = new FileWatcher(indexPath); fileWatcher.addListener(indexFileOnDeleteListener); fileWatcher.init(); KNNEngine knnEngine = KNNEngine.getEngineNameFromPath(indexPath.toString()); - long memoryAddress = JNIService.loadIndex(indexPath.toString(), indexEntryContext.getParameters(), - knnEngine.getName()); + long memoryAddress = JNIService.loadIndex(indexPath.toString(), indexEntryContext.getParameters(), knnEngine.getName()); final WatcherHandle watcherHandle = resourceWatcherService.add(fileWatcher); return new NativeMemoryAllocation.IndexAllocation( - executor, - memoryAddress, - indexEntryContext.calculateSizeInKB(), - knnEngine, - indexPath.toString(), - indexEntryContext.getOpenSearchIndexName(), - watcherHandle); + executor, + memoryAddress, + indexEntryContext.calculateSizeInKB(), + knnEngine, + indexPath.toString(), + indexEntryContext.getOpenSearchIndexName(), + watcherHandle + ); } @Override @@ -110,8 +112,10 @@ public void close() { } } - class TrainingLoadStrategy implements NativeMemoryLoadStrategy, Closeable { + class TrainingLoadStrategy + implements + NativeMemoryLoadStrategy, + Closeable { private static TrainingLoadStrategy INSTANCE; @@ -144,11 +148,15 @@ private TrainingLoadStrategy() { } @Override - public NativeMemoryAllocation.TrainingDataAllocation load(NativeMemoryEntryContext.TrainingDataEntryContext - nativeMemoryEntryContext) { + public NativeMemoryAllocation.TrainingDataAllocation load( + NativeMemoryEntryContext.TrainingDataEntryContext nativeMemoryEntryContext + ) { // Generate an empty training data allocation with the appropriate size - NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation - .TrainingDataAllocation(executor, 0, nativeMemoryEntryContext.calculateSizeInKB()); + NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( + executor, + 0, + nativeMemoryEntryContext.calculateSizeInKB() + ); // Start loading all training data. Once the data has been loaded, release the lock TrainingDataConsumer trainingDataConsumer = new TrainingDataConsumer(trainingDataAllocation); @@ -156,24 +164,20 @@ public NativeMemoryAllocation.TrainingDataAllocation load(NativeMemoryEntryConte trainingDataAllocation.writeLock(); vectorReader.read( - nativeMemoryEntryContext.getClusterService(), - nativeMemoryEntryContext.getTrainIndexName(), - nativeMemoryEntryContext.getTrainFieldName(), - nativeMemoryEntryContext.getMaxVectorCount(), - nativeMemoryEntryContext.getSearchSize(), - trainingDataConsumer, - ActionListener.wrap( - response -> trainingDataAllocation.writeUnlock(), - ex -> { - // Close unsafe will assume that the caller passes control of the writelock to it. It - // will then handle releasing the write lock once the close operations finish. - trainingDataAllocation.closeUnsafe(); - throw new RuntimeException(ex); - } - ) + nativeMemoryEntryContext.getClusterService(), + nativeMemoryEntryContext.getTrainIndexName(), + nativeMemoryEntryContext.getTrainFieldName(), + nativeMemoryEntryContext.getMaxVectorCount(), + nativeMemoryEntryContext.getSearchSize(), + trainingDataConsumer, + ActionListener.wrap(response -> trainingDataAllocation.writeUnlock(), ex -> { + // Close unsafe will assume that the caller passes control of the writelock to it. It + // will then handle releasing the write lock once the close operations finish. + trainingDataAllocation.closeUnsafe(); + throw new RuntimeException(ex); + }) ); - // The write lock is acquired before the trainingDataAllocation is returned and not released until the // loading has completed. The calling thread will need to obtain a read lock in order to proceed, which // will not be possible until the write lock is released. @@ -186,8 +190,10 @@ public void close() throws IOException { } } - class AnonymousLoadStrategy implements NativeMemoryLoadStrategy, Closeable { + class AnonymousLoadStrategy + implements + NativeMemoryLoadStrategy, + Closeable { private static AnonymousLoadStrategy INSTANCE; diff --git a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java index 8c7023542..365226a01 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java @@ -21,7 +21,6 @@ import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; - /** * KNNEngine provides the functionality to validate and transform user defined indices into information that can be * passed to the respective k-NN library's JNI layer. @@ -53,7 +52,7 @@ public enum KNNEngine implements KNNLibrary { * @return KNNEngine corresponding to name */ public static KNNEngine getEngine(String name) { - if (NMSLIB.getName().equalsIgnoreCase(name)){ + if (NMSLIB.getName().equalsIgnoreCase(name)) { return NMSLIB; } @@ -71,13 +70,11 @@ public static KNNEngine getEngine(String name) { * @return KNNEngine corresponding to path */ public static KNNEngine getEngineNameFromPath(String path) { - if (path.endsWith(KNNEngine.NMSLIB.getExtension()) - || path.endsWith(KNNEngine.NMSLIB.getCompoundExtension())) { + if (path.endsWith(KNNEngine.NMSLIB.getExtension()) || path.endsWith(KNNEngine.NMSLIB.getCompoundExtension())) { return KNNEngine.NMSLIB; } - if (path.endsWith(KNNEngine.FAISS.getExtension()) - || path.endsWith(KNNEngine.FAISS.getCompoundExtension())) { + if (path.endsWith(KNNEngine.FAISS.getExtension()) || path.endsWith(KNNEngine.FAISS.getCompoundExtension())) { return KNNEngine.FAISS; } @@ -133,7 +130,6 @@ public boolean isTrainingRequired(KNNMethodContext knnMethodContext) { return knnLibrary.isTrainingRequired(knnMethodContext); } - @Override public Map getMethodAsMap(KNNMethodContext knnMethodContext) { return knnLibrary.getMethodAsMap(knnMethodContext); diff --git a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java index b39ab7aa6..767a9ad61 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java @@ -9,7 +9,6 @@ * GitHub history for details. */ - package org.opensearch.knn.index.util; import org.opensearch.common.ValidationException; @@ -131,7 +130,7 @@ public interface KNNLibrary { * @param dimension to estimate size for * @return size overhead estimate in KB */ - int estimateOverheadInKB (KNNMethodContext knnMethodContext, int dimension); + int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension); /** * Generate method as map that can be used to configure the knn index from the jni @@ -176,9 +175,13 @@ abstract class NativeLibrary implements KNNLibrary { * @param latestLibraryVersion String representation of latest version of the library * @param extension String representing the extension that library files should use */ - public NativeLibrary(Map methods, Map> scoreTranslation, - String latestLibraryBuildVersion, String latestLibraryVersion, String extension) - { + public NativeLibrary( + Map methods, + Map> scoreTranslation, + String latestLibraryBuildVersion, + String latestLibraryVersion, + String extension + ) { this.methods = methods; this.scoreTranslation = scoreTranslation; this.latestLibraryBuildVersion = latestLibraryBuildVersion; @@ -248,8 +251,7 @@ public Map getMethodAsMap(KNNMethodContext knnMethodContext) { KNNMethod knnMethod = methods.get(knnMethodContext.getMethodComponent().getName()); if (knnMethod == null) { - throw new IllegalArgumentException("Invalid method name: " - + knnMethodContext.getMethodComponent().getName()); + throw new IllegalArgumentException("Invalid method name: " + knnMethodContext.getMethodComponent().getName()); } return knnMethod.getAsMap(knnMethodContext); @@ -277,23 +279,32 @@ class Nmslib extends NativeLibrary { public final static String EXTENSION = ".hnsw"; public final static Map METHODS = ImmutableMap.of( - METHOD_HNSW, - KNNMethod.Builder.builder( - MethodComponent.Builder.builder(HNSW_LIB_NAME) - .addParameter(METHOD_PARAMETER_M, new Parameter.IntegerParameter( - METHOD_PARAMETER_M, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M, v -> v > 0)) - .addParameter(METHOD_PARAMETER_EF_CONSTRUCTION, new Parameter.IntegerParameter( - METHOD_PARAMETER_EF_CONSTRUCTION, - KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION, v -> v > 0)) - .build()) - .addSpaces(SpaceType.L2, SpaceType.L1, SpaceType.LINF, SpaceType.COSINESIMIL, - SpaceType.INNER_PRODUCT) - .build() + METHOD_HNSW, + KNNMethod.Builder.builder( + MethodComponent.Builder.builder(HNSW_LIB_NAME) + .addParameter( + METHOD_PARAMETER_M, + new Parameter.IntegerParameter(METHOD_PARAMETER_M, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M, v -> v > 0) + ) + .addParameter( + METHOD_PARAMETER_EF_CONSTRUCTION, + new Parameter.IntegerParameter( + METHOD_PARAMETER_EF_CONSTRUCTION, + KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION, + v -> v > 0 + ) + ) + .build() + ).addSpaces(SpaceType.L2, SpaceType.L1, SpaceType.LINF, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT).build() ); - - public final static Nmslib INSTANCE = new Nmslib(METHODS, Collections.emptyMap(), - Version.LATEST.getBuildVersion(), Version.LATEST.indexLibraryVersion(), EXTENSION); + public final static Nmslib INSTANCE = new Nmslib( + METHODS, + Collections.emptyMap(), + Version.LATEST.getBuildVersion(), + Version.LATEST.indexLibraryVersion(), + EXTENSION + ); /** * Constructor for Nmslib @@ -304,8 +315,13 @@ class Nmslib extends NativeLibrary { * @param latestLibraryVersion String representation of latest version of the library * @param extension String representing the extension that library files should use */ - private Nmslib(Map methods, Map> scoreTranslation, - String latestLibraryBuildVersion, String latestLibraryVersion, String extension) { + private Nmslib( + Map methods, + Map> scoreTranslation, + String latestLibraryBuildVersion, + String latestLibraryVersion, + String extension + ) { super(methods, scoreTranslation, latestLibraryBuildVersion, latestLibraryVersion, extension); } @@ -313,7 +329,7 @@ public enum Version { /** * Latest available nmslib version */ - V2011("2011"){ + V2011("2011") { @Override public String indexLibraryVersion() { return KNNConstants.NMSLIB_JNI_LIBRARY_NAME; @@ -334,7 +350,9 @@ public String indexLibraryVersion() { */ public abstract String indexLibraryVersion(); - public String getBuildVersion() { return buildVersion; } + public String getBuildVersion() { + return buildVersion; + } } } @@ -345,134 +363,186 @@ class Faiss extends NativeLibrary { // Map that overrides OpenSearch score translation by space type of scores returned by faiss public final static Map> SCORE_TRANSLATIONS = ImmutableMap.of( - SpaceType.INNER_PRODUCT, rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1*rawScore) + SpaceType.INNER_PRODUCT, + rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore) ); // Define encoders supported by faiss public final static MethodComponentContext ENCODER_DEFAULT = new MethodComponentContext( - KNNConstants.ENCODER_FLAT, Collections.emptyMap()); + KNNConstants.ENCODER_FLAT, + Collections.emptyMap() + ); - //TODO: To think about in future: for PQ, if dimension is not divisible by code count, PQ will fail. Right now, + // TODO: To think about in future: for PQ, if dimension is not divisible by code count, PQ will fail. Right now, // we do not have a way to base validation off of dimension. Failure will happen during training in JNI. public final static Map encoderComponents = ImmutableMap.of( - KNNConstants.ENCODER_FLAT, MethodComponent.Builder.builder(KNNConstants.ENCODER_FLAT) - .setMapGenerator(((methodComponent, methodComponentContext) -> - MethodAsMapBuilder.builder(KNNConstants.FAISS_FLAT_DESCRIPTION, methodComponent, - methodComponentContext).build())).build(), - KNNConstants.ENCODER_PQ, MethodComponent.Builder.builder(KNNConstants.ENCODER_PQ) - .addParameter(ENCODER_PARAMETER_PQ_M, - new Parameter.IntegerParameter(ENCODER_PARAMETER_PQ_M, - ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT, v -> v > 0 - && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT)) - .addParameter(ENCODER_PARAMETER_PQ_CODE_SIZE, - new Parameter.IntegerParameter(ENCODER_PARAMETER_PQ_CODE_SIZE, - ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT, v -> v > 0 - && v < ENCODER_PARAMETER_PQ_CODE_SIZE_LIMIT)) - .setRequiresTraining(true) - .setMapGenerator(((methodComponent, methodComponentContext) -> - MethodAsMapBuilder.builder(FAISS_PQ_DESCRIPTION, methodComponent, methodComponentContext) - .addParameter(ENCODER_PARAMETER_PQ_M, "", "") - .addParameter(ENCODER_PARAMETER_PQ_CODE_SIZE, "x", "") - .build())) - .setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> { - // Size estimate formula: (4 * d * 2^code_size) / 1024 + 1 - - // Get value of code size passed in by user - Object codeSizeObject = methodComponentContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE); - - // If not specified, get default value of code size - if (codeSizeObject == null) { - Object codeSizeParameter = methodComponent.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE); - if (codeSizeParameter == null) { - throw new IllegalStateException(ENCODER_PARAMETER_PQ_CODE_SIZE + " is not a valid " + - " parameter. This is a bug."); - } - - codeSizeObject = ((Parameter) codeSizeParameter).getDefaultValue(); - } - - if (!(codeSizeObject instanceof Integer)) { - throw new IllegalStateException(ENCODER_PARAMETER_PQ_CODE_SIZE + " must be " + - "an integer."); - } - - int codeSize = (Integer) codeSizeObject; - return ((4L * (1 << codeSize) * dimension) / BYTES_PER_KILOBYTES) + 1; - }) - .build() + KNNConstants.ENCODER_FLAT, + MethodComponent.Builder.builder(KNNConstants.ENCODER_FLAT) + .setMapGenerator( + ((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder( + KNNConstants.FAISS_FLAT_DESCRIPTION, + methodComponent, + methodComponentContext + ).build()) + ) + .build(), + KNNConstants.ENCODER_PQ, + MethodComponent.Builder.builder(KNNConstants.ENCODER_PQ) + .addParameter( + ENCODER_PARAMETER_PQ_M, + new Parameter.IntegerParameter( + ENCODER_PARAMETER_PQ_M, + ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT, + v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT + ) + ) + .addParameter( + ENCODER_PARAMETER_PQ_CODE_SIZE, + new Parameter.IntegerParameter( + ENCODER_PARAMETER_PQ_CODE_SIZE, + ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT, + v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_SIZE_LIMIT + ) + ) + .setRequiresTraining(true) + .setMapGenerator( + ((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder( + FAISS_PQ_DESCRIPTION, + methodComponent, + methodComponentContext + ).addParameter(ENCODER_PARAMETER_PQ_M, "", "").addParameter(ENCODER_PARAMETER_PQ_CODE_SIZE, "x", "").build()) + ) + .setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> { + // Size estimate formula: (4 * d * 2^code_size) / 1024 + 1 + + // Get value of code size passed in by user + Object codeSizeObject = methodComponentContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE); + + // If not specified, get default value of code size + if (codeSizeObject == null) { + Object codeSizeParameter = methodComponent.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE); + if (codeSizeParameter == null) { + throw new IllegalStateException( + ENCODER_PARAMETER_PQ_CODE_SIZE + " is not a valid " + " parameter. This is a bug." + ); + } + + codeSizeObject = ((Parameter) codeSizeParameter).getDefaultValue(); + } + + if (!(codeSizeObject instanceof Integer)) { + throw new IllegalStateException(ENCODER_PARAMETER_PQ_CODE_SIZE + " must be " + "an integer."); + } + + int codeSize = (Integer) codeSizeObject; + return ((4L * (1 << codeSize) * dimension) / BYTES_PER_KILOBYTES) + 1; + }) + .build() ); // Define methods supported by faiss public final static Map METHODS = ImmutableMap.of( - METHOD_HNSW, KNNMethod.Builder.builder(MethodComponent.Builder.builder(METHOD_HNSW) - .addParameter(METHOD_PARAMETER_M, - new Parameter.IntegerParameter(METHOD_PARAMETER_M, - KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M, v -> v > 0)) - .addParameter(METHOD_PARAMETER_EF_CONSTRUCTION, - new Parameter.IntegerParameter(METHOD_PARAMETER_EF_CONSTRUCTION, - KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION, v -> v > 0)) - .addParameter(METHOD_PARAMETER_EF_SEARCH, - new Parameter.IntegerParameter(METHOD_PARAMETER_EF_SEARCH, - KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH, v -> v > 0)) - .addParameter(METHOD_ENCODER_PARAMETER, - new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, - ENCODER_DEFAULT, encoderComponents)) - .setMapGenerator(((methodComponent, methodComponentContext) -> - MethodAsMapBuilder.builder(FAISS_HNSW_DESCRIPTION, methodComponent, methodComponentContext) - .addParameter(METHOD_PARAMETER_M, "", "") - .addParameter(METHOD_ENCODER_PARAMETER, ",", "") - .build())) - .build()) - .addSpaces(SpaceType.L2, SpaceType.INNER_PRODUCT).build(), - METHOD_IVF, KNNMethod.Builder.builder(MethodComponent.Builder.builder(METHOD_IVF) - .addParameter(METHOD_PARAMETER_NPROBES, - new Parameter.IntegerParameter(METHOD_PARAMETER_NPROBES, - METHOD_PARAMETER_NPROBES_DEFAULT, v -> v > 0 - && v < METHOD_PARAMETER_NPROBES_LIMIT)) - .addParameter(METHOD_PARAMETER_NLIST, - new Parameter.IntegerParameter(METHOD_PARAMETER_NLIST, METHOD_PARAMETER_NLIST_DEFAULT, - v -> v > 0 && v < METHOD_PARAMETER_NLIST_LIMIT)) - .addParameter(METHOD_ENCODER_PARAMETER, - new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, - ENCODER_DEFAULT, encoderComponents)) - .setRequiresTraining(true) - .setMapGenerator(((methodComponent, methodComponentContext) -> - MethodAsMapBuilder.builder(FAISS_IVF_DESCRIPTION, methodComponent, methodComponentContext) - .addParameter(METHOD_PARAMETER_NLIST, "", "") - .addParameter(METHOD_ENCODER_PARAMETER, ",", "") - .build())) - .setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> { - // Size estimate formula: (4 * nlists * d) / 1024 + 1 - - // Get value of nlists passed in by user - Object nlistObject = methodComponentContext.getParameters().get(METHOD_PARAMETER_NLIST); - - // If not specified, get default value of nlist - if (nlistObject == null) { - Object nlistParameter = methodComponent.getParameters().get(METHOD_PARAMETER_NLIST); - if (nlistParameter == null) { - throw new IllegalStateException(METHOD_PARAMETER_NLIST + " is not a valid " + - " parameter. This is a bug."); - } - - nlistObject = ((Parameter) nlistParameter).getDefaultValue(); + METHOD_HNSW, + KNNMethod.Builder.builder( + MethodComponent.Builder.builder(METHOD_HNSW) + .addParameter( + METHOD_PARAMETER_M, + new Parameter.IntegerParameter(METHOD_PARAMETER_M, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M, v -> v > 0) + ) + .addParameter( + METHOD_PARAMETER_EF_CONSTRUCTION, + new Parameter.IntegerParameter( + METHOD_PARAMETER_EF_CONSTRUCTION, + KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION, + v -> v > 0 + ) + ) + .addParameter( + METHOD_PARAMETER_EF_SEARCH, + new Parameter.IntegerParameter( + METHOD_PARAMETER_EF_SEARCH, + KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH, + v -> v > 0 + ) + ) + .addParameter( + METHOD_ENCODER_PARAMETER, + new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, ENCODER_DEFAULT, encoderComponents) + ) + .setMapGenerator( + ((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder( + FAISS_HNSW_DESCRIPTION, + methodComponent, + methodComponentContext + ).addParameter(METHOD_PARAMETER_M, "", "").addParameter(METHOD_ENCODER_PARAMETER, ",", "").build()) + ) + .build() + ).addSpaces(SpaceType.L2, SpaceType.INNER_PRODUCT).build(), + METHOD_IVF, + KNNMethod.Builder.builder( + MethodComponent.Builder.builder(METHOD_IVF) + .addParameter( + METHOD_PARAMETER_NPROBES, + new Parameter.IntegerParameter( + METHOD_PARAMETER_NPROBES, + METHOD_PARAMETER_NPROBES_DEFAULT, + v -> v > 0 && v < METHOD_PARAMETER_NPROBES_LIMIT + ) + ) + .addParameter( + METHOD_PARAMETER_NLIST, + new Parameter.IntegerParameter( + METHOD_PARAMETER_NLIST, + METHOD_PARAMETER_NLIST_DEFAULT, + v -> v > 0 && v < METHOD_PARAMETER_NLIST_LIMIT + ) + ) + .addParameter( + METHOD_ENCODER_PARAMETER, + new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, ENCODER_DEFAULT, encoderComponents) + ) + .setRequiresTraining(true) + .setMapGenerator( + ((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder( + FAISS_IVF_DESCRIPTION, + methodComponent, + methodComponentContext + ).addParameter(METHOD_PARAMETER_NLIST, "", "").addParameter(METHOD_ENCODER_PARAMETER, ",", "").build()) + ) + .setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> { + // Size estimate formula: (4 * nlists * d) / 1024 + 1 + + // Get value of nlists passed in by user + Object nlistObject = methodComponentContext.getParameters().get(METHOD_PARAMETER_NLIST); + + // If not specified, get default value of nlist + if (nlistObject == null) { + Object nlistParameter = methodComponent.getParameters().get(METHOD_PARAMETER_NLIST); + if (nlistParameter == null) { + throw new IllegalStateException(METHOD_PARAMETER_NLIST + " is not a valid " + " parameter. This is a bug."); } - if (!(nlistObject instanceof Integer)) { - throw new IllegalStateException(METHOD_PARAMETER_NLIST + " must be " + - "an integer."); - } + nlistObject = ((Parameter) nlistParameter).getDefaultValue(); + } - int centroids = (Integer) nlistObject; - return ((4L * centroids * dimension) / BYTES_PER_KILOBYTES) + 1; - }) - .build()) - .addSpaces(SpaceType.L2, SpaceType.INNER_PRODUCT).build() + if (!(nlistObject instanceof Integer)) { + throw new IllegalStateException(METHOD_PARAMETER_NLIST + " must be " + "an integer."); + } + + int centroids = (Integer) nlistObject; + return ((4L * centroids * dimension) / BYTES_PER_KILOBYTES) + 1; + }) + .build() + ).addSpaces(SpaceType.L2, SpaceType.INNER_PRODUCT).build() ); - public final static Faiss INSTANCE = new Faiss(METHODS, SCORE_TRANSLATIONS, - Version.LATEST.getBuildVersion(), Version.LATEST.indexLibraryVersion(), - KNNConstants.FAISS_EXTENSION); + public final static Faiss INSTANCE = new Faiss( + METHODS, + SCORE_TRANSLATIONS, + Version.LATEST.getBuildVersion(), + Version.LATEST.indexLibraryVersion(), + KNNConstants.FAISS_EXTENSION + ); /** * Constructor for Faiss @@ -483,9 +553,13 @@ class Faiss extends NativeLibrary { * @param latestLibraryVersion String representation of latest version of the library * @param extension String representing the extension that library files should use */ - private Faiss(Map methods, Map> scoreTranslation, String latestLibraryBuildVersion, String latestLibraryVersion, - String extension) { + private Faiss( + Map methods, + Map> scoreTranslation, + String latestLibraryBuildVersion, + String latestLibraryVersion, + String extension + ) { super(methods, scoreTranslation, latestLibraryBuildVersion, latestLibraryVersion, extension); } @@ -508,8 +582,7 @@ protected static class MethodAsMapBuilder { * @param methodComponent the method component that maps to this builder * @param initialMap the initial parameter map that will be modified */ - MethodAsMapBuilder(String baseDescription, MethodComponent methodComponent, - Map initialMap) { + MethodAsMapBuilder(String baseDescription, MethodComponent methodComponent, Map initialMap) { this.indexDescription = baseDescription; this.methodComponent = methodComponent; this.methodAsMap = initialMap; @@ -531,13 +604,16 @@ MethodAsMapBuilder addParameter(String parameterName, String prefix, String suff // into the index description string faiss uses to create the index. Map methodParameters = (Map) methodAsMap.get(PARAMETERS); Parameter parameter = methodComponent.getParameters().get(parameterName); - Object value = methodParameters.containsKey(parameterName) ? methodParameters.get(parameterName) : parameter.getDefaultValue(); + Object value = methodParameters.containsKey(parameterName) + ? methodParameters.get(parameterName) + : parameter.getDefaultValue(); // Recursion is needed if the parameter is a method component context itself. if (parameter instanceof Parameter.MethodComponentContextParameter) { MethodComponentContext subMethodComponentContext = (MethodComponentContext) value; - MethodComponent subMethodComponent = ((Parameter.MethodComponentContextParameter) parameter) - .getMethodComponent(subMethodComponentContext.getName()); + MethodComponent subMethodComponent = ((Parameter.MethodComponentContextParameter) parameter).getMethodComponent( + subMethodComponentContext.getName() + ); Map subMethodAsMap = subMethodComponent.getAsMap(subMethodComponentContext); indexDescription += subMethodAsMap.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER); @@ -566,8 +642,11 @@ Map build() { return methodAsMap; } - static MethodAsMapBuilder builder(String baseDescription, MethodComponent methodComponent, - MethodComponentContext methodComponentContext) { + static MethodAsMapBuilder builder( + String baseDescription, + MethodComponent methodComponent, + MethodComponentContext methodComponentContext + ) { Map initialMap = new HashMap<>(); initialMap.put(NAME, methodComponent.getName()); initialMap.put(PARAMETERS, MethodComponent.getParameterMapWithDefaultsAdded(methodComponentContext, methodComponent)); @@ -583,7 +662,7 @@ private enum Version { /** * Latest available nmslib version */ - V165("165"){ + V165("165") { @Override public String indexLibraryVersion() { return KNNConstants.FAISS_JNI_LIBRARY_NAME; @@ -604,7 +683,9 @@ public String indexLibraryVersion() { */ abstract String indexLibraryVersion(); - String getBuildVersion() { return buildVersion; } + String getBuildVersion() { + return buildVersion; + } } } } diff --git a/src/main/java/org/opensearch/knn/indices/Model.java b/src/main/java/org/opensearch/knn/indices/Model.java index 03195f00b..ec486bf7a 100644 --- a/src/main/java/org/opensearch/knn/indices/Model.java +++ b/src/main/java/org/opensearch/knn/indices/Model.java @@ -46,8 +46,9 @@ public Model(ModelMetadata modelMetadata, @Nullable byte[] modelBlob, @NonNull S this.modelMetadata = Objects.requireNonNull(modelMetadata, "modelMetadata must not be null"); if (ModelState.CREATED.equals(this.modelMetadata.getState()) && modelBlob == null) { - throw new IllegalArgumentException("Cannot construct model in state CREATED when model binary is null. " + - "State must be either TRAINING or FAILED"); + throw new IllegalArgumentException( + "Cannot construct model in state CREATED when model binary is null. " + "State must be either TRAINING or FAILED" + ); } this.modelBlob = new AtomicReference<>(modelBlob); @@ -55,7 +56,7 @@ public Model(ModelMetadata modelMetadata, @Nullable byte[] modelBlob, @NonNull S } private byte[] readOptionalModelBlob(StreamInput in) throws IOException { - return in.readBoolean() ? in.readByteArray(): null; + return in.readBoolean() ? in.readByteArray() : null; } /** @@ -69,7 +70,6 @@ public Model(StreamInput in) throws IOException { this.modelID = in.readString(); } - /** * getter for model's metadata * @@ -120,10 +120,8 @@ public synchronized void setModelBlob(byte[] modelBlob) { @Override public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null || getClass() != obj.getClass()) - return false; + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; Model other = (Model) obj; return other.getModelID().equals(this.getModelID()); } @@ -147,7 +145,7 @@ public static Model getModelFromSourceMap(Map sourceMap) { } private void writeOptionalModelBlob(StreamOutput output) throws IOException { - if(getModelBlob() == null){ + if (getModelBlob() == null) { output.writeBoolean(false); return; } @@ -175,11 +173,11 @@ private static String getModelIDFromResponse(Map responseMap) { return (String) modelId; } - private static byte[] getModelBlobFromResponse(Map responseMap){ + private static byte[] getModelBlobFromResponse(Map responseMap) { Object blob = responseMap.get(KNNConstants.MODEL_BLOB_PARAMETER); // If byte blob is not there, it means that the state has not yet been updated to CREATED. - if(blob == null){ + if (blob == null) { return null; } return Base64.getDecoder().decode((String) blob); diff --git a/src/main/java/org/opensearch/knn/indices/ModelCache.java b/src/main/java/org/opensearch/knn/indices/ModelCache.java index 2e56c613e..bcc835490 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelCache.java +++ b/src/main/java/org/opensearch/knn/indices/ModelCache.java @@ -27,7 +27,6 @@ import static org.opensearch.knn.common.KNNConstants.MODEL_CACHE_EXPIRE_AFTER_ACCESS_TIME_MINUTES; import static org.opensearch.knn.index.KNNSettings.MODEL_CACHE_SIZE_LIMIT_SETTING; - public final class ModelCache { private static Logger logger = LogManager.getLogger(ModelCache.class); @@ -82,12 +81,12 @@ protected ModelCache() { private void initCache() { CacheBuilder cacheBuilder = CacheBuilder.newBuilder() - .recordStats() - .concurrencyLevel(1) - .removalListener(this::onRemoval) - .maximumWeight(cacheSizeInKB) - .expireAfterAccess(MODEL_CACHE_EXPIRE_AFTER_ACCESS_TIME_MINUTES, TimeUnit.MINUTES) - .weigher((k, v) -> Math.toIntExact(getModelLengthInKB(v))); + .recordStats() + .concurrencyLevel(1) + .removalListener(this::onRemoval) + .maximumWeight(cacheSizeInKB) + .expireAfterAccess(MODEL_CACHE_EXPIRE_AFTER_ACCESS_TIME_MINUTES, TimeUnit.MINUTES) + .weigher((k, v) -> Math.toIntExact(getModelLengthInKB(v))); cache = cacheBuilder.build(); } @@ -97,8 +96,7 @@ private void onRemoval(RemovalNotification removalNotification) { updateEvictedDueToSizeAt(); } - logger.info("[KNN] Model Cache evicted. Key {}, Reason: {}", removalNotification.getKey(), - removalNotification.getCause()); + logger.info("[KNN] Model Cache evicted. Key {}, Reason: {}", removalNotification.getKey(), removalNotification.getCause()); } public Instant getEvictedDueToSizeAt() { @@ -129,8 +127,7 @@ public Model get(String modelId) { * @return total weight */ public long getTotalWeightInKB() { - return cache.asMap().values().stream().map(this::getModelLengthInKB) - .reduce(0L, Long::sum); + return cache.asMap().values().stream().map(this::getModelLengthInKB).reduce(0L, Long::sum); } /** diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 4d193a560..8ad8c60f1 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -55,8 +55,6 @@ import java.net.URL; import java.util.Base64; import java.util.HashMap; -import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -191,10 +189,9 @@ private OpenSearchKNNModelDao() { numberOfShards = MODEL_INDEX_NUMBER_OF_SHARDS_SETTING.get(settings); numberOfReplicas = MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(MODEL_INDEX_NUMBER_OF_SHARDS_SETTING, - it -> numberOfShards = it); - clusterService.getClusterSettings().addSettingsUpdateConsumer(MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING, - it -> numberOfReplicas = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MODEL_INDEX_NUMBER_OF_SHARDS_SETTING, it -> numberOfShards = it); + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING, it -> numberOfReplicas = it); } @Override @@ -203,13 +200,13 @@ public void create(ActionListener actionListener) throws IO return; } - CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME) - .mapping("_doc", getMapping(), XContentType.JSON) - .settings(Settings.builder() - .put("index.hidden", true) - .put("index.number_of_shards", this.numberOfShards) - .put("index.number_of_replicas", this.numberOfReplicas) - ); + CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping("_doc", getMapping(), XContentType.JSON) + .settings( + Settings.builder() + .put("index.hidden", true) + .put("index.number_of_shards", this.numberOfShards) + .put("index.number_of_replicas", this.numberOfReplicas) + ); client.admin().indices().create(request, actionListener); } @@ -224,7 +221,7 @@ public boolean isCreated() { * @return ClusterHealthStatus of model index */ @Override - public ClusterHealthStatus getHealthStatus() throws IndexNotFoundException{ + public ClusterHealthStatus getHealthStatus() throws IndexNotFoundException { if (!isCreated()) { throw new IndexNotFoundException(MODEL_INDEX_NAME); } @@ -242,13 +239,12 @@ public void put(Model model, ActionListener listener) throws IOEx } @Override - public void update(Model model, ActionListener listener) - throws IOException { + public void update(Model model, ActionListener listener) throws IOException { putInternal(model, listener, DocWriteRequest.OpType.INDEX); } - private void putInternal(Model model, ActionListener listener, - DocWriteRequest.OpType requestOpType) throws IOException { + private void putInternal(Model model, ActionListener listener, DocWriteRequest.OpType requestOpType) + throws IOException { if (model == null) { throw new IllegalArgumentException("Model cannot be null"); @@ -256,16 +252,18 @@ private void putInternal(Model model, ActionListener listener, ModelMetadata modelMetadata = model.getModelMetadata(); - Map parameters = new HashMap() {{ - put(KNNConstants.MODEL_ID, model.getModelID()); - put(KNNConstants.KNN_ENGINE, modelMetadata.getKnnEngine().getName()); - put(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, modelMetadata.getSpaceType().getValue()); - put(KNNConstants.DIMENSION, modelMetadata.getDimension()); - put(KNNConstants.MODEL_STATE, modelMetadata.getState().getName()); - put(KNNConstants.MODEL_TIMESTAMP, modelMetadata.getTimestamp()); - put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription()); - put(KNNConstants.MODEL_ERROR, modelMetadata.getError()); - }}; + Map parameters = new HashMap() { + { + put(KNNConstants.MODEL_ID, model.getModelID()); + put(KNNConstants.KNN_ENGINE, modelMetadata.getKnnEngine().getName()); + put(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, modelMetadata.getSpaceType().getValue()); + put(KNNConstants.DIMENSION, modelMetadata.getDimension()); + put(KNNConstants.MODEL_STATE, modelMetadata.getState().getName()); + put(KNNConstants.MODEL_TIMESTAMP, modelMetadata.getTimestamp()); + put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription()); + put(KNNConstants.MODEL_ERROR, modelMetadata.getError()); + } + }; byte[] modelBlob = model.getModelBlob(); @@ -290,23 +288,23 @@ private void putInternal(Model model, ActionListener listener, // After metadata update finishes, remove item from every node's cache if necessary. If no model id is // passed then nothing needs to be removed from the cache ActionListener onMetaListener; - onMetaListener = ActionListener.wrap(indexResponse -> client.execute( + onMetaListener = ActionListener.wrap( + indexResponse -> client.execute( RemoveModelFromCacheAction.INSTANCE, new RemoveModelFromCacheRequest(model.getModelID()), - ActionListener.wrap( - removeModelFromCacheResponse -> { - if (!removeModelFromCacheResponse.hasFailures()) { - listener.onResponse(indexResponse); - return; - } - - String failureMessage = buildRemoveModelErrorMessage(model.getModelID(), - removeModelFromCacheResponse); - - listener.onFailure(new RuntimeException(failureMessage)); - }, listener::onFailure - ) - ), listener::onFailure); + ActionListener.wrap(removeModelFromCacheResponse -> { + if (!removeModelFromCacheResponse.hasFailures()) { + listener.onResponse(indexResponse); + return; + } + + String failureMessage = buildRemoveModelErrorMessage(model.getModelID(), removeModelFromCacheResponse); + + listener.onFailure(new RuntimeException(failureMessage)); + }, listener::onFailure) + ), + listener::onFailure + ); // After the model is indexed, update metadata only if the model is in CREATED state ActionListener onIndexListener; @@ -318,24 +316,29 @@ private void putInternal(Model model, ActionListener listener, // Create the model index if it does not already exist if (!isCreated()) { - create(ActionListener.wrap(createIndexResponse -> indexRequestBuilder.execute(onIndexListener), - onIndexListener::onFailure)); + create( + ActionListener.wrap(createIndexResponse -> indexRequestBuilder.execute(onIndexListener), onIndexListener::onFailure) + ); return; } indexRequestBuilder.execute(onIndexListener); } - private ActionListener getUpdateModelMetadataListener(ModelMetadata modelMetadata, - ActionListener listener) { - return ActionListener.wrap(indexResponse -> client.execute( + private ActionListener getUpdateModelMetadataListener( + ModelMetadata modelMetadata, + ActionListener listener + ) { + return ActionListener.wrap( + indexResponse -> client.execute( UpdateModelMetadataAction.INSTANCE, new UpdateModelMetadataRequest(indexResponse.getId(), false, modelMetadata), // Here we wrap the IndexResponse listener around an AcknowledgedListener. This allows us // to pass the indexResponse back up. - ActionListener.wrap(acknowledgedResponse -> listener.onResponse(indexResponse), - listener::onFailure) - ), listener::onFailure); + ActionListener.wrap(acknowledgedResponse -> listener.onResponse(indexResponse), listener::onFailure) + ), + listener::onFailure + ); } @Override @@ -343,15 +346,13 @@ public Model get(String modelId) throws ExecutionException, InterruptedException /* GET //?_local */ - GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME) - .setId(modelId) - .setPreference("_local"); + GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId) + .setPreference("_local"); GetResponse getResponse = getRequestBuilder.execute().get(); Map responseMap = getResponse.getSourceAsMap(); return Model.getModelFromSourceMap(responseMap); } - /** * Get a model from the system index. Non-blocking. * @@ -364,14 +365,13 @@ public void get(String modelId, ActionListener actionListener) /* GET //?_local */ - GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME) - .setId(modelId) + GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId) .setPreference("_local"); getRequestBuilder.execute(ActionListener.wrap(response -> { - if(response.isSourceEmpty()){ + if (response.isSourceEmpty()) { String errorMessage = String.format("Model \" %s \" does not exist", modelId); - actionListener.onFailure(new ResourceNotFoundException(modelId,errorMessage)); + actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage)); return; } final Map responseMap = response.getSourceAsMap(); @@ -398,23 +398,22 @@ public ModelMetadata getMetadata(String modelId) { IndexMetadata indexMetadata = clusterService.state().metadata().index(MODEL_INDEX_NAME); if (indexMetadata == null) { - logger.debug("ModelMetadata for model \"" + modelId + "\" is null. " + MODEL_INDEX_NAME + - " index does not exist."); + logger.debug("ModelMetadata for model \"" + modelId + "\" is null. " + MODEL_INDEX_NAME + " index does not exist."); return null; } Map models = indexMetadata.getCustomData(MODEL_METADATA_FIELD); if (models == null) { - logger.debug("ModelMetadata for model \"" + modelId + "\" is null. " + MODEL_INDEX_NAME + - "'s custom metadata does not exist."); + logger.debug( + "ModelMetadata for model \"" + modelId + "\" is null. " + MODEL_INDEX_NAME + "'s custom metadata does not exist." + ); return null; } String modelMetadata = models.get(modelId); if (modelMetadata == null) { - logger.debug("ModelMetadata for model \"" + modelId + "\" is null. Model \"" + modelId + "\" does " + - "not exist."); + logger.debug("ModelMetadata for model \"" + modelId + "\" is null. Model \"" + modelId + "\" does " + "not exist."); return null; } @@ -434,69 +433,57 @@ private String getMapping() throws IOException { public void delete(String modelId, ActionListener listener) { // If the index is not created, there is no need to delete the model if (!isCreated()) { - logger.error("Cannot delete model \"" + modelId + "\". Model index "+ MODEL_INDEX_NAME + "does not exist."); + logger.error("Cannot delete model \"" + modelId + "\". Model index " + MODEL_INDEX_NAME + "does not exist."); String errorMessage = String.format("Cannot delete model \"%s\". Model index does not exist", modelId); listener.onResponse(new DeleteModelResponse(modelId, "failed", errorMessage)); return; } // Setup delete model request - DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, - MODEL_INDEX_NAME); + DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME); deleteRequestBuilder.setId(modelId); deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); // On model deletion from the index, remove the model from all nodes' model cache ActionListener onModelDeleteListener = ActionListener.wrap(deleteResponse -> { // If model is not deleted, return with error message - if(deleteResponse.getResult() != DocWriteResponse.Result.DELETED) { + if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) { String errorMessage = String.format("Model \" %s \" does not exist", modelId); - listener.onResponse(new DeleteModelResponse(modelId, deleteResponse.getResult().getLowercase(), - errorMessage)); + listener.onResponse(new DeleteModelResponse(modelId, deleteResponse.getResult().getLowercase(), errorMessage)); return; } // After model is deleted from the index, make sure the model is evicted from every cache in the // cluster client.execute( - RemoveModelFromCacheAction.INSTANCE, - new RemoveModelFromCacheRequest(modelId), - ActionListener.wrap( - removeModelFromCacheResponse -> { - - if (!removeModelFromCacheResponse.hasFailures()) { - listener.onResponse( - new DeleteModelResponse( - modelId, - deleteResponse.getResult().getLowercase(), - null - ) - ); - return; - } - - String failureMessage = buildRemoveModelErrorMessage(modelId, - removeModelFromCacheResponse); - - listener.onResponse(new DeleteModelResponse(modelId, "failed", - failureMessage)); - - }, e -> listener.onResponse( - new DeleteModelResponse(modelId, "failed", e.getMessage()) - ) - ) + RemoveModelFromCacheAction.INSTANCE, + new RemoveModelFromCacheRequest(modelId), + ActionListener.wrap(removeModelFromCacheResponse -> { + + if (!removeModelFromCacheResponse.hasFailures()) { + listener.onResponse(new DeleteModelResponse(modelId, deleteResponse.getResult().getLowercase(), null)); + return; + } + + String failureMessage = buildRemoveModelErrorMessage(modelId, removeModelFromCacheResponse); + + listener.onResponse(new DeleteModelResponse(modelId, "failed", failureMessage)); + + }, e -> listener.onResponse(new DeleteModelResponse(modelId, "failed", e.getMessage()))) ); }, e -> listener.onResponse(new DeleteModelResponse(modelId, "failed", e.getMessage()))); // On model metadata removal, delete the model from the index - ActionListener onMetadataUpdateListener = ActionListener.wrap(acknowledgedResponse -> - deleteRequestBuilder.execute(onModelDeleteListener), listener::onFailure); + ActionListener onMetadataUpdateListener = ActionListener.wrap( + acknowledgedResponse -> deleteRequestBuilder.execute(onModelDeleteListener), + listener::onFailure + ); // Remove the metadata asynchronously client.execute( - UpdateModelMetadataAction.INSTANCE, - new UpdateModelMetadataRequest(modelId, true, null), - onMetadataUpdateListener + UpdateModelMetadataAction.INSTANCE, + new UpdateModelMetadataRequest(modelId, true, null), + onMetadataUpdateListener ); } @@ -505,12 +492,11 @@ private String buildRemoveModelErrorMessage(String modelId, RemoveModelFromCache StringBuilder stringBuilder = new StringBuilder(failureMessage); for (FailedNodeException nodeException : response.failures()) { - stringBuilder - .append("Node \"") - .append(nodeException.nodeId()) - .append("\" ") - .append(nodeException.getMessage()) - .append("; "); + stringBuilder.append("Node \"") + .append(nodeException.nodeId()) + .append("\" ") + .append(nodeException.getMessage()) + .append("; "); } return stringBuilder.toString(); diff --git a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java index b2c81aef8..9aa0d133b 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java +++ b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java @@ -23,7 +23,6 @@ import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; -import java.util.Base64; import java.util.Map; import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; @@ -31,7 +30,6 @@ import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; -import static org.opensearch.knn.common.KNNConstants.MODEL_BLOB_PARAMETER; import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR; import static org.opensearch.knn.common.KNNConstants.MODEL_STATE; @@ -80,13 +78,21 @@ public ModelMetadata(StreamInput in) throws IOException { * @param description information about the model * @param error error message associated with model */ - public ModelMetadata(KNNEngine knnEngine, SpaceType spaceType, int dimension, ModelState modelState, - String timestamp, String description, String error) { + public ModelMetadata( + KNNEngine knnEngine, + SpaceType spaceType, + int dimension, + ModelState modelState, + String timestamp, + String description, + String error + ) { this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null"); this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null"); if (dimension <= 0 || dimension >= MAX_DIMENSION) { - throw new IllegalArgumentException("Dimension \"" + dimension + "\" is invalid. Value must be greater " + - "than 0 and less than " + MAX_DIMENSION); + throw new IllegalArgumentException( + "Dimension \"" + dimension + "\" is invalid. Value must be greater " + "than 0 and less than " + MAX_DIMENSION + ); } this.dimension = dimension; @@ -179,16 +185,22 @@ public synchronized void setError(String error) { @Override public String toString() { - return String.join(DELIMITER, knnEngine.getName(), spaceType.getValue(), Integer.toString(dimension), - getState().toString(), timestamp, description, error); + return String.join( + DELIMITER, + knnEngine.getName(), + spaceType.getValue(), + Integer.toString(dimension), + getState().toString(), + timestamp, + description, + error + ); } @Override public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null || getClass() != obj.getClass()) - return false; + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; ModelMetadata other = (ModelMetadata) obj; EqualsBuilder equalsBuilder = new EqualsBuilder(); @@ -205,8 +217,14 @@ public boolean equals(Object obj) { @Override public int hashCode() { - return new HashCodeBuilder().append(getKnnEngine()).append(getSpaceType()).append(getDimension()) - .append(getState()).append(getTimestamp()).append(getDescription()).append(getError()).toHashCode(); + return new HashCodeBuilder().append(getKnnEngine()) + .append(getSpaceType()) + .append(getDimension()) + .append(getState()) + .append(getTimestamp()) + .append(getDescription()) + .append(getError()) + .toHashCode(); } /** @@ -219,8 +237,10 @@ public static ModelMetadata fromString(String modelMetadataString) { String[] modelMetadataArray = modelMetadataString.split(DELIMITER, -1); if (modelMetadataArray.length != 7) { - throw new IllegalArgumentException("Illegal format for model metadata. Must be of the form " + - "\",,,,,,\"."); + throw new IllegalArgumentException( + "Illegal format for model metadata. Must be of the form " + + "\",,,,,,\"." + ); } KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); @@ -235,15 +255,13 @@ public static ModelMetadata fromString(String modelMetadataString) { } private static String objectToString(Object value) { - if(value == null) - return null; - return (String)value; + if (value == null) return null; + return (String) value; } private static Integer objectToInteger(Object value) { - if(value == null) - return null; - return (Integer)value; + if (value == null) return null; + return (Integer) value; } /** @@ -252,18 +270,24 @@ private static Integer objectToInteger(Object value) { * @param modelSourceMap Map to be parsed * @return ModelMetadata instance */ - public static ModelMetadata getMetadataFromSourceMap(final Map modelSourceMap){ + public static ModelMetadata getMetadataFromSourceMap(final Map modelSourceMap) { Object engine = modelSourceMap.get(KNNConstants.KNN_ENGINE); Object space = modelSourceMap.get(KNNConstants.METHOD_PARAMETER_SPACE_TYPE); Object dimension = modelSourceMap.get(KNNConstants.DIMENSION); Object state = modelSourceMap.get(KNNConstants.MODEL_STATE); - Object timestamp = modelSourceMap.get(KNNConstants.MODEL_TIMESTAMP); + Object timestamp = modelSourceMap.get(KNNConstants.MODEL_TIMESTAMP); Object description = modelSourceMap.get(KNNConstants.MODEL_DESCRIPTION); Object error = modelSourceMap.get(KNNConstants.MODEL_ERROR); - ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.getEngine(objectToString(engine)), - SpaceType.getSpace(objectToString( space)), objectToInteger(dimension), ModelState.getModelState(objectToString(state)), - objectToString(timestamp), objectToString(description), objectToString( error)); + ModelMetadata modelMetadata = new ModelMetadata( + KNNEngine.getEngine(objectToString(engine)), + SpaceType.getSpace(objectToString(space)), + objectToInteger(dimension), + ModelState.getModelState(objectToString(state)), + objectToString(timestamp), + objectToString(description), + objectToString(error) + ); return modelMetadata; } diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 09be29a6a..3f90e33d1 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -57,8 +57,13 @@ class FaissService { * @param templateIndex empty template index * @param parameters additional build time parameters */ - public static native void createIndexFromTemplate(int[] ids, float[][] data, String indexPath, byte[] templateIndex, - Map parameters); + public static native void createIndexFromTemplate( + int[] ids, + float[][] data, + String indexPath, + byte[] templateIndex, + Map parameters + ); /** * Load an index into memory @@ -97,8 +102,7 @@ public static native void createIndexFromTemplate(int[] ids, float[][] data, Str * @param trainVectorsPointer pointer to where training vectors are stored in native memory * @return bytes array of trained template index */ - public static native byte[] trainIndex(Map indexParameters, int dimension, - long trainVectorsPointer); + public static native byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer); /** * Transfer vectors from Java to native diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 60ecb1830..7afc312ac 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -30,8 +30,7 @@ public class JNIService { * @param parameters parameters to build index * @param engineName name of engine to build index for */ - public static void createIndex(int[] ids, float[][] data, String indexPath, Map parameters, - String engineName) { + public static void createIndex(int[] ids, float[][] data, String indexPath, Map parameters, String engineName) { if (KNNEngine.NMSLIB.getName().equals(engineName)) { NmslibService.createIndex(ids, data, indexPath, parameters); return; @@ -55,8 +54,14 @@ public static void createIndex(int[] ids, float[][] data, String indexPath, Map< * @param parameters parameters to build index * @param engineName name of engine to build index for */ - public static void createIndexFromTemplate(int[] ids, float[][] data, String indexPath, byte[] templateIndex, - Map parameters, String engineName) { + public static void createIndexFromTemplate( + int[] ids, + float[][] data, + String indexPath, + byte[] templateIndex, + Map parameters, + String engineName + ) { if (KNNEngine.FAISS.getName().equals(engineName)) { FaissService.createIndexFromTemplate(ids, data, indexPath, templateIndex, parameters); return; @@ -135,8 +140,7 @@ public static void free(long indexPointer, String engineName) { * @param engineName engine to perform the training * @return bytes array of trained template index */ - public static byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer, - String engineName) { + public static byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer, String engineName) { if (KNNEngine.FAISS.getName().equals(engineName)) { return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer); } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNCodecService.java b/src/main/java/org/opensearch/knn/plugin/KNNCodecService.java index c67f27923..e5b06112d 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNCodecService.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNCodecService.java @@ -36,6 +36,6 @@ public Codec codec(String name) { } public void setPostingsFormat(PostingsFormat postingsFormat) { - ((KNN87Codec)codec("")).setPostingsFormat(postingsFormat); + ((KNN87Codec) codec("")).setPostingsFormat(postingsFormat); } } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNEngineFactory.java b/src/main/java/org/opensearch/knn/plugin/KNNEngineFactory.java index c5da36864..03b3f2a4d 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNEngineFactory.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNEngineFactory.java @@ -20,14 +20,30 @@ class KNNEngineFactory implements EngineFactory { @Override public Engine newReadWriteEngine(EngineConfig config) { codecService.setPostingsFormat(config.getCodec().postingsFormat()); - EngineConfig engineConfig = new EngineConfig(config.getShardId(), - config.getThreadPool(), config.getIndexSettings(), config.getWarmer(), config.getStore(), - config.getMergePolicy(), config.getAnalyzer(), config.getSimilarity(), codecService, - config.getEventListener(), config.getQueryCache(), config.getQueryCachingPolicy(), - config.getTranslogConfig(), config.getFlushMergesAfter(), config.getExternalRefreshListener(), - config.getInternalRefreshListener(), config.getIndexSort(), config.getCircuitBreakerService(), - config.getGlobalCheckpointSupplier(), config.retentionLeasesSupplier(), config.getPrimaryTermSupplier(), - config.getTombstoneDocSupplier()); + EngineConfig engineConfig = new EngineConfig( + config.getShardId(), + config.getThreadPool(), + config.getIndexSettings(), + config.getWarmer(), + config.getStore(), + config.getMergePolicy(), + config.getAnalyzer(), + config.getSimilarity(), + codecService, + config.getEventListener(), + config.getQueryCache(), + config.getQueryCachingPolicy(), + config.getTranslogConfig(), + config.getFlushMergesAfter(), + config.getExternalRefreshListener(), + config.getInternalRefreshListener(), + config.getIndexSort(), + config.getCircuitBreakerService(), + config.getGlobalCheckpointSupplier(), + config.retentionLeasesSupplier(), + config.getPrimaryTermSupplier(), + config.getTombstoneDocSupplier() + ); return new InternalEngine(engineConfig); } } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 122b77929..42fa46e10 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -138,8 +138,10 @@ public class KNNPlugin extends Plugin implements MapperPlugin, SearchPlugin, Act @Override public Map getMappers() { - return Collections.singletonMap(KNNVectorFieldMapper.CONTENT_TYPE, new KNNVectorFieldMapper.TypeParser( - ModelDao.OpenSearchKNNModelDao::getInstance)); + return Collections.singletonMap( + KNNVectorFieldMapper.CONTENT_TYPE, + new KNNVectorFieldMapper.TypeParser(ModelDao.OpenSearchKNNModelDao::getInstance) + ); } @Override @@ -148,12 +150,19 @@ public List> getQueries() { } @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, - ResourceWatcherService resourceWatcherService, ScriptService scriptService, - NamedXContentRegistry xContentRegistry, Environment environment, - NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, - IndexNameExpressionResolver indexNameExpressionResolver, - Supplier repositoriesServiceSupplier) { + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier + ) { this.clusterService = clusterService; // Initialize Native Memory loading strategies @@ -178,25 +187,35 @@ public List> getSettings() { return KNNSettings.state().getSettings(); } - public List getRestHandlers(Settings settings, - RestController restController, - ClusterSettings clusterSettings, - IndexScopedSettings indexScopedSettings, - SettingsFilter settingsFilter, - IndexNameExpressionResolver indexNameExpressionResolver, - Supplier nodesInCluster) { + public List getRestHandlers( + Settings settings, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster + ) { RestKNNStatsHandler restKNNStatsHandler = new RestKNNStatsHandler(settings, restController, knnStats); - RestKNNWarmupHandler restKNNWarmupHandler = new RestKNNWarmupHandler(settings, restController, clusterService, - indexNameExpressionResolver); + RestKNNWarmupHandler restKNNWarmupHandler = new RestKNNWarmupHandler( + settings, + restController, + clusterService, + indexNameExpressionResolver + ); RestGetModelHandler restGetModelHandler = new RestGetModelHandler(); RestDeleteModelHandler restDeleteModelHandler = new RestDeleteModelHandler(); RestTrainModelHandler restTrainModelHandler = new RestTrainModelHandler(); RestSearchModelHandler restSearchModelHandler = new RestSearchModelHandler(); return ImmutableList.of( - restKNNStatsHandler, restKNNWarmupHandler, restGetModelHandler, restDeleteModelHandler, - restTrainModelHandler, restSearchModelHandler + restKNNStatsHandler, + restKNNWarmupHandler, + restGetModelHandler, + restDeleteModelHandler, + restTrainModelHandler, + restSearchModelHandler ); } @@ -206,17 +225,16 @@ public List getRestHandlers(Settings settings, @Override public List> getActions() { return Arrays.asList( - new ActionHandler<>(KNNStatsAction.INSTANCE, KNNStatsTransportAction.class), - new ActionHandler<>(KNNWarmupAction.INSTANCE, KNNWarmupTransportAction.class), - new ActionHandler<>(UpdateModelMetadataAction.INSTANCE, UpdateModelMetadataTransportAction.class), - new ActionHandler<>(TrainingJobRouteDecisionInfoAction.INSTANCE, - TrainingJobRouteDecisionInfoTransportAction.class), - new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class), - new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class), - new ActionHandler<>(TrainingJobRouterAction.INSTANCE, TrainingJobRouterTransportAction.class), - new ActionHandler<>(TrainingModelAction.INSTANCE, TrainingModelTransportAction.class), - new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class), - new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class) + new ActionHandler<>(KNNStatsAction.INSTANCE, KNNStatsTransportAction.class), + new ActionHandler<>(KNNWarmupAction.INSTANCE, KNNWarmupTransportAction.class), + new ActionHandler<>(UpdateModelMetadataAction.INSTANCE, UpdateModelMetadataTransportAction.class), + new ActionHandler<>(TrainingJobRouteDecisionInfoAction.INSTANCE, TrainingJobRouteDecisionInfoTransportAction.class), + new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class), + new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class), + new ActionHandler<>(TrainingJobRouterAction.INSTANCE, TrainingJobRouterTransportAction.class), + new ActionHandler<>(TrainingModelAction.INSTANCE, TrainingModelTransportAction.class), + new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class), + new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class) ); } @@ -267,15 +285,6 @@ public ScriptEngine getScriptEngine(Settings settings, Collection> getExecutorBuilders(Settings settings) { - return ImmutableList.of( - new FixedExecutorBuilder( - settings, - TRAIN_THREAD_POOL, - 1, - 1, - KNN_THREAD_POOL_PREFIX, - false - ) - ); + return ImmutableList.of(new FixedExecutorBuilder(settings, TRAIN_THREAD_POOL, 1, 1, KNN_THREAD_POOL_PREFIX, false)); } } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestDeleteModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestDeleteModelHandler.java index d4f55941c..37074d128 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestDeleteModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestDeleteModelHandler.java @@ -43,13 +43,9 @@ public String getName() { @Override public List routes() { - return ImmutableList - .of( - new Route( - RestRequest.Method.DELETE, - String.format(Locale.ROOT, "%s/%s/{%s}", KNNPlugin.KNN_BASE_URI, MODELS, MODEL_ID) - ) - ); + return ImmutableList.of( + new Route(RestRequest.Method.DELETE, String.format(Locale.ROOT, "%s/%s/{%s}", KNNPlugin.KNN_BASE_URI, MODELS, MODEL_ID)) + ); } /** diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestGetModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestGetModelHandler.java index 342c4b350..09f2daab2 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestGetModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestGetModelHandler.java @@ -27,6 +27,7 @@ import static org.opensearch.knn.common.KNNConstants.MODELS; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; + /** * Rest Handler for get model api endpoint. */ @@ -41,13 +42,9 @@ public String getName() { @Override public List routes() { - return ImmutableList - .of( - new Route( - RestRequest.Method.GET, - String.format(Locale.ROOT, "%s/%s/{%s}", KNNPlugin.KNN_BASE_URI, MODELS, MODEL_ID) - ) - ); + return ImmutableList.of( + new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/%s/{%s}", KNNPlugin.KNN_BASE_URI, MODELS, MODEL_ID)) + ); } @Override @@ -58,7 +55,6 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient } GetModelRequest getModelRequest = new GetModelRequest(modelID); - return channel -> client - .execute(GetModelAction.INSTANCE, getModelRequest, new RestToXContentListener<>(channel)); + return channel -> client.execute(GetModelAction.INSTANCE, getModelRequest, new RestToXContentListener<>(channel)); } } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestKNNStatsHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestKNNStatsHandler.java index efb20692d..d5e991858 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestKNNStatsHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestKNNStatsHandler.java @@ -55,7 +55,6 @@ public String getName() { return NAME; } - private List getStatsPath() { List statsPath = new ArrayList<>(); statsPath.add("/{nodeId}/stats/"); @@ -66,24 +65,24 @@ private List getStatsPath() { } private Map getUrlPathByLegacyUrlPathMap() { - return getStatsPath().stream().collect( - Collectors.toMap(path -> KNNPlugin.LEGACY_KNN_BASE_URI + path, path -> KNNPlugin.KNN_BASE_URI + path) - ); + return getStatsPath().stream() + .collect(Collectors.toMap(path -> KNNPlugin.LEGACY_KNN_BASE_URI + path, path -> KNNPlugin.KNN_BASE_URI + path)); } @Override public List routes() { - return ImmutableList.of(); + return ImmutableList.of(); } @Override public List replacedRoutes() { - return getUrlPathByLegacyUrlPathMap().entrySet().stream().map( - e -> new ReplacedRoute(RestRequest.Method.GET, e.getValue(), RestRequest.Method.GET, e.getKey()) - ).collect(Collectors.toList()); + return getUrlPathByLegacyUrlPathMap().entrySet() + .stream() + .map(e -> new ReplacedRoute(RestRequest.Method.GET, e.getValue(), RestRequest.Method.GET, e.getKey())) + .collect(Collectors.toList()); } - @Override + @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { // From restrequest, create a knnStatsRequest KNNStatsRequest knnStatsRequest = getRequest(request); @@ -130,8 +129,7 @@ private KNNStatsRequest getRequest(RestRequest request) { } if (!invalidStats.isEmpty()) { - throw new IllegalArgumentException(unrecognized(request, invalidStats, - knnStatsRequest.getStatsToBeRetrieved(), "stat")); + throw new IllegalArgumentException(unrecognized(request, invalidStats, knnStatsRequest.getStatsToBeRetrieved(), "stat")); } } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestKNNWarmupHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestKNNWarmupHandler.java index 0f563d3d9..f457d6782 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestKNNWarmupHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestKNNWarmupHandler.java @@ -41,13 +41,16 @@ public class RestKNNWarmupHandler extends BaseRestHandler { private IndexNameExpressionResolver indexNameExpressionResolver; private ClusterService clusterService; - public RestKNNWarmupHandler(Settings settings, RestController controller, ClusterService clusterService, - IndexNameExpressionResolver indexNameExpressionResolver) { + public RestKNNWarmupHandler( + Settings settings, + RestController controller, + ClusterService clusterService, + IndexNameExpressionResolver indexNameExpressionResolver + ) { this.clusterService = clusterService; this.indexNameExpressionResolver = indexNameExpressionResolver; } - @Override public String getName() { return NAME; @@ -62,23 +65,24 @@ public List routes() { public List replacedRoutes() { return ImmutableList.of( new ReplacedRoute( - RestRequest.Method.GET, KNNPlugin.KNN_BASE_URI + URL_PATH, - RestRequest.Method.GET, KNNPlugin.LEGACY_KNN_BASE_URI + URL_PATH) + RestRequest.Method.GET, + KNNPlugin.KNN_BASE_URI + URL_PATH, + RestRequest.Method.GET, + KNNPlugin.LEGACY_KNN_BASE_URI + URL_PATH + ) ); } @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { KNNWarmupRequest knnWarmupRequest = createKNNWarmupRequest(request); - logger.info("[KNN] Warmup started for the following indices: " - + String.join(",", knnWarmupRequest.indices())); + logger.info("[KNN] Warmup started for the following indices: " + String.join(",", knnWarmupRequest.indices())); return channel -> client.execute(KNNWarmupAction.INSTANCE, knnWarmupRequest, new RestToXContentListener<>(channel)); } private KNNWarmupRequest createKNNWarmupRequest(RestRequest request) { String[] indexNames = Strings.splitStringByCommaToArray(request.param("index")); - Index[] indices = indexNameExpressionResolver.concreteIndices(clusterService.state(), strictExpandOpen(), - indexNames); + Index[] indices = indexNameExpressionResolver.concreteIndices(clusterService.state(), strictExpandOpen(), indexNames); List invalidIndexNames = new ArrayList<>(); Arrays.stream(indices).forEach(index -> { @@ -88,8 +92,10 @@ private KNNWarmupRequest createKNNWarmupRequest(RestRequest request) { }); if (invalidIndexNames.size() != 0) { - throw new KNNInvalidIndicesException(invalidIndexNames, - "Warm up request rejected. One or more indices have 'index.knn' set to false."); + throw new KNNInvalidIndicesException( + invalidIndexNames, + "Warm up request rejected. One or more indices have 'index.knn' set to false." + ); } return new KNNWarmupRequest(indexNames); diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestSearchModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestSearchModelHandler.java index 915f5717d..eeb510b9d 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestSearchModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestSearchModelHandler.java @@ -12,27 +12,15 @@ package org.opensearch.knn.plugin.rest; import com.google.common.collect.ImmutableList; -import org.opensearch.action.search.SearchAction; import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.client.node.NodeClient; -import org.opensearch.common.bytes.BytesReference; -import org.opensearch.common.xcontent.ToXContentObject; -import org.opensearch.common.xcontent.XContentBuilder; -import org.opensearch.knn.indices.Model; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.SearchModelAction; import org.opensearch.rest.BaseRestHandler; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; -import org.opensearch.rest.RestResponse; -import org.opensearch.rest.RestStatus; import org.opensearch.rest.action.RestCancellableNodeClient; -import org.opensearch.rest.action.RestResponseListener; import org.opensearch.rest.action.RestToXContentListener; import org.opensearch.rest.action.search.RestSearchAction; -import org.opensearch.search.SearchHit; import java.io.IOException; import java.util.ArrayList; @@ -41,10 +29,7 @@ import java.util.Locale; import java.util.function.IntConsumer; -import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS; -import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; import static org.opensearch.knn.common.KNNConstants.MODELS; -import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; /** * Rest Handler for search model api endpoint. @@ -53,9 +38,9 @@ public class RestSearchModelHandler extends BaseRestHandler { private final static String NAME = "knn_search_model_action"; private static final String SEARCH = "_search"; - //Add params that are not fit to be part of model search + // Add params that are not fit to be part of model search public List UNSUPPORTED_PARAM_LIST = Arrays.asList( - "index" //we don't want to search across all indices + "index" // we don't want to search across all indices ); @Override @@ -65,17 +50,10 @@ public String getName() { @Override public List routes() { - return ImmutableList - .of( - new Route( - RestRequest.Method.GET, - String.format(Locale.ROOT, "%s/%s/%s", KNNPlugin.KNN_BASE_URI, MODELS, SEARCH) - ), - new Route( - RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/%s/%s", KNNPlugin.KNN_BASE_URI, MODELS, SEARCH) - ) - ); + return ImmutableList.of( + new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/%s/%s", KNNPlugin.KNN_BASE_URI, MODELS, SEARCH)), + new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/%s/%s", KNNPlugin.KNN_BASE_URI, MODELS, SEARCH)) + ); } private void checkUnSupportedParamsExists(RestRequest request) { @@ -86,8 +64,7 @@ private void checkUnSupportedParamsExists(RestRequest request) { invalidParam.add(param); } }); - if (invalidParam.isEmpty()) - return; + if (invalidParam.isEmpty()) return; String errorMessage = "request contains an unrecognized parameter: [ " + String.join(",", invalidParam) + " ]"; throw new IllegalArgumentException(errorMessage); } @@ -97,8 +74,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli checkUnSupportedParamsExists(request); SearchRequest searchRequest = new SearchRequest(); IntConsumer setSize = size -> searchRequest.source().size(size); - request.withContentOrSourceParamParserOrNull(parser -> - RestSearchAction.parseSearchRequest(searchRequest, request, parser, client.getNamedWriteableRegistry(), setSize)); + request.withContentOrSourceParamParserOrNull( + parser -> RestSearchAction.parseSearchRequest(searchRequest, request, parser, client.getNamedWriteableRegistry(), setSize) + ); return channel -> { RestCancellableNodeClient cancelClient = new RestCancellableNodeClient(client, request.getHttpChannel()); diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java index a5b2f13e4..9ddf7410b 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java @@ -55,26 +55,17 @@ public String getName() { @Override public List routes() { - return ImmutableList - .of( - new Route( - RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/%s/{%s}/_train", KNNPlugin.KNN_BASE_URI, MODELS, - MODEL_ID) - ), - new Route( - RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/%s/_train", KNNPlugin.KNN_BASE_URI, MODELS) - ) - ); + return ImmutableList.of( + new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/%s/{%s}/_train", KNNPlugin.KNN_BASE_URI, MODELS, MODEL_ID)), + new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/%s/_train", KNNPlugin.KNN_BASE_URI, MODELS)) + ); } @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { TrainingModelRequest trainingModelRequest = createTransportRequest(restRequest); - return channel -> client.execute(TrainingJobRouterAction.INSTANCE, trainingModelRequest, - new RestToXContentListener<>(channel)); + return channel -> client.execute(TrainingJobRouterAction.INSTANCE, trainingModelRequest, new RestToXContentListener<>(channel)); } private TrainingModelRequest createTransportRequest(RestRequest restRequest) throws IOException { @@ -114,8 +105,7 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr } else if (MODEL_DESCRIPTION.equals(fieldName) && ensureNotSet(fieldName, description)) { description = parser.textOrNull(); } else { - throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is not a valid " + - "parameter."); + throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is not a valid " + "parameter."); } } @@ -130,8 +120,15 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr description = ""; } - TrainingModelRequest trainingModelRequest = new TrainingModelRequest(modelId, knnMethodContext, dimension, - trainingIndex, trainingField, preferredNodeId, description); + TrainingModelRequest trainingModelRequest = new TrainingModelRequest( + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + preferredNodeId, + description + ); if (maximumVectorCount != DEFAULT_NOT_SET_INT_VALUE) { trainingModelRequest.setMaximumVectorCount(maximumVectorCount); diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java index 6c8ba20c1..f190a3e1d 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java @@ -26,8 +26,14 @@ public abstract class KNNScoreScript extends ScoreScript { protected final String field; protected final BiFunction scoringMethod; - public KNNScoreScript(Map params, T queryValue, String field, - BiFunction scoringMethod, SearchLookup lookup, LeafReaderContext leafContext) { + public KNNScoreScript( + Map params, + T queryValue, + String field, + BiFunction scoringMethod, + SearchLookup lookup, + LeafReaderContext leafContext + ) { super(params, lookup, leafContext); this.queryValue = queryValue; this.field = field; @@ -39,9 +45,14 @@ public KNNScoreScript(Map params, T queryValue, String field, * expected to be Longs. */ public static class LongType extends KNNScoreScript { - public LongType(Map params, Long queryValue, String field, - BiFunction scoringMethod, SearchLookup lookup, - LeafReaderContext leafContext) { + public LongType( + Map params, + Long queryValue, + String field, + BiFunction scoringMethod, + SearchLookup lookup, + LeafReaderContext leafContext + ) { super(params, queryValue, field, scoringMethod, lookup, leafContext); } @@ -67,9 +78,14 @@ public double execute(ScoreScript.ExplanationHolder explanationHolder) { * are expected to be BigInteger. */ public static class BigIntegerType extends KNNScoreScript { - public BigIntegerType(Map params, BigInteger queryValue, String field, - BiFunction scoringMethod, SearchLookup lookup, - LeafReaderContext leafContext) { + public BigIntegerType( + Map params, + BigInteger queryValue, + String field, + BiFunction scoringMethod, + SearchLookup lookup, + LeafReaderContext leafContext + ) { super(params, queryValue, field, scoringMethod, lookup, leafContext); } @@ -96,9 +112,14 @@ public double execute(ScoreScript.ExplanationHolder explanationHolder) { */ public static class KNNVectorType extends KNNScoreScript { - public KNNVectorType(Map params, float[] queryValue, String field, - BiFunction scoringMethod, SearchLookup lookup, - LeafReaderContext leafContext) throws IOException { + public KNNVectorType( + Map params, + float[] queryValue, + String field, + BiFunction scoringMethod, + SearchLookup lookup, + LeafReaderContext leafContext + ) throws IOException { super(params, queryValue, field, scoringMethod, lookup, leafContext); } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptFactory.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptFactory.java index 84cb62754..b686a20f0 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptFactory.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptFactory.java @@ -29,17 +29,19 @@ public KNNScoreScriptFactory(Map params, SearchLookup lookup) { this.similaritySpace = getValue(params, "space_type").toString(); this.query = getValue(params, "query_value"); - this.knnScoringSpace = KNNScoringSpaceFactory.create(this.similaritySpace, this.query, - lookup.doc().mapperService().fieldType(this.field)); + this.knnScoringSpace = KNNScoringSpaceFactory.create( + this.similaritySpace, + this.query, + lookup.doc().mapperService().fieldType(this.field) + ); } private Object getValue(Map params, String fieldName) { final Object value = params.get(fieldName); - if (value != null) - return value; + if (value != null) return value; KNNCounter.SCRIPT_QUERY_ERRORS.increment(); - throw new IllegalArgumentException("Missing parameter ["+ fieldName +"]"); + throw new IllegalArgumentException("Missing parameter [" + fieldName + "]"); } @Override diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringScriptEngine.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringScriptEngine.java index ab118111e..2d6b5f69c 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringScriptEngine.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringScriptEngine.java @@ -31,8 +31,7 @@ public FactoryType compile(String name, String code, ScriptContext KNNCounter.SCRIPT_COMPILATIONS.increment(); if (!ScoreScript.CONTEXT.equals(context)) { KNNCounter.SCRIPT_COMPILATION_ERRORS.increment(); - throw new IllegalArgumentException(getType() + " KNN scoring scripts cannot be used for context [" - + context.name + "]"); + throw new IllegalArgumentException(getType() + " KNN scoring scripts cannot be used for context [" + context.name + "]"); } // we use the script "source" as the script identifier if (!SCRIPT_SOURCE.equals(code)) { diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index 653bb5cef..3e066b7c9 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -25,7 +25,6 @@ import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToFloatArray; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToLong; - public interface KNNScoringSpace { /** * Return the correct scoring script for a given query. The scoring script @@ -37,8 +36,7 @@ public interface KNNScoringSpace { * @return ScoreScript for this query * @throws IOException throws IOException if ScoreScript cannot be constructed */ - ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) - throws IOException; + ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) throws IOException; class L2 implements KNNScoringSpace { @@ -53,19 +51,16 @@ class L2 implements KNNScoringSpace { */ public L2(Object query, MappedFieldType fieldType) { if (!isKNNVectorFieldType(fieldType)) { - throw new IllegalArgumentException("Incompatible field_type for l2 space. The field type must " + - "be knn_vector."); + throw new IllegalArgumentException("Incompatible field_type for l2 space. The field type must " + "be knn_vector."); } - this.processedQuery = parseToFloatArray(query, - ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); + this.processedQuery = parseToFloatArray(query, ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v)); } - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, - LeafReaderContext ctx) throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, - ctx); + public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) + throws IOException { + return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx); } } @@ -83,21 +78,17 @@ class CosineSimilarity implements KNNScoringSpace { */ public CosineSimilarity(Object query, MappedFieldType fieldType) { if (!isKNNVectorFieldType(fieldType)) { - throw new IllegalArgumentException("Incompatible field_type for cosine space. The field type must " + - "be knn_vector."); + throw new IllegalArgumentException("Incompatible field_type for cosine space. The field type must " + "be knn_vector."); } - this.processedQuery = parseToFloatArray(query, - ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); + this.processedQuery = parseToFloatArray(query, ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); float qVectorSquaredMagnitude = getVectorMagnitudeSquared(this.processedQuery); - this.scoringMethod = (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, - qVectorSquaredMagnitude); + this.scoringMethod = (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude); } - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, - LeafReaderContext ctx) throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, - ctx); + public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) + throws IOException { + return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx); } } @@ -119,24 +110,36 @@ public HammingBit(Object query, MappedFieldType fieldType) { this.scoringMethod = (Long q, Long v) -> 1.0f / (1 + KNNScoringUtil.calculateHammingBit(q, v)); } else if (isBinaryFieldType(fieldType)) { this.processedQuery = parseToBigInteger(query); - this.scoringMethod = (BigInteger q, BigInteger v) -> - 1.0f / (1 + KNNScoringUtil.calculateHammingBit(q, v)); + this.scoringMethod = (BigInteger q, BigInteger v) -> 1.0f / (1 + KNNScoringUtil.calculateHammingBit(q, v)); } else { - throw new IllegalArgumentException("Incompatible field_type for hamming space. The field type must " + - "of type long or binary."); + throw new IllegalArgumentException( + "Incompatible field_type for hamming space. The field type must " + "of type long or binary." + ); } } @SuppressWarnings("unchecked") - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, - LeafReaderContext ctx) throws IOException { + public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) + throws IOException { if (this.processedQuery instanceof Long) { - return new KNNScoreScript.LongType(params, (Long) this.processedQuery, field, - (BiFunction) this.scoringMethod, lookup, ctx); + return new KNNScoreScript.LongType( + params, + (Long) this.processedQuery, + field, + (BiFunction) this.scoringMethod, + lookup, + ctx + ); } - return new KNNScoreScript.BigIntegerType(params, (BigInteger) this.processedQuery, field, - (BiFunction) this.scoringMethod, lookup, ctx); + return new KNNScoreScript.BigIntegerType( + params, + (BigInteger) this.processedQuery, + field, + (BiFunction) this.scoringMethod, + lookup, + ctx + ); } } @@ -153,19 +156,16 @@ class L1 implements KNNScoringSpace { */ public L1(Object query, MappedFieldType fieldType) { if (!isKNNVectorFieldType(fieldType)) { - throw new IllegalArgumentException("Incompatible field_type for l1 space. The field type must " + - "be knn_vector."); + throw new IllegalArgumentException("Incompatible field_type for l1 space. The field type must " + "be knn_vector."); } - this.processedQuery = parseToFloatArray(query, - ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); + this.processedQuery = parseToFloatArray(query, ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v)); } - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, - LeafReaderContext ctx) throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, - ctx); + public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) + throws IOException { + return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx); } } @@ -182,19 +182,16 @@ class LInf implements KNNScoringSpace { */ public LInf(Object query, MappedFieldType fieldType) { if (!isKNNVectorFieldType(fieldType)) { - throw new IllegalArgumentException("Incompatible field_type for l-inf space. The field type must " + - "be knn_vector."); + throw new IllegalArgumentException("Incompatible field_type for l-inf space. The field type must " + "be knn_vector."); } - this.processedQuery = parseToFloatArray(query, - ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); + this.processedQuery = parseToFloatArray(query, ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v)); } - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, - LeafReaderContext ctx) throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, - ctx); + public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) + throws IOException { + return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx); } } @@ -211,19 +208,19 @@ class InnerProd implements KNNScoringSpace { */ public InnerProd(Object query, MappedFieldType fieldType) { if (!isKNNVectorFieldType(fieldType)) { - throw new IllegalArgumentException("Incompatible field_type for innerproduct space. The field type must " + - "be knn_vector."); + throw new IllegalArgumentException( + "Incompatible field_type for innerproduct space. The field type must " + "be knn_vector." + ); } - this.processedQuery = parseToFloatArray(query, - ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); + this.processedQuery = parseToFloatArray(query, ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); this.scoringMethod = (float[] q, float[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v)); } @Override - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, - ctx); + public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) + throws IOException { + return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx); } } } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java index 47d10f84b..4d64b5b96 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java @@ -27,7 +27,7 @@ public class KNNScoringSpaceUtil { */ public static boolean isLongFieldType(MappedFieldType fieldType) { return fieldType instanceof NumberFieldMapper.NumberFieldType - && ((NumberFieldMapper.NumberFieldType) fieldType).numericType() == LONG.numericType(); + && ((NumberFieldMapper.NumberFieldType) fieldType).numericType() == LONG.numericType(); } /** @@ -89,8 +89,9 @@ public static float[] parseToFloatArray(Object object, int expectedDimensions) { float[] floatArray = convertVectorToPrimitive(object); if (expectedDimensions != floatArray.length) { KNNCounter.SCRIPT_QUERY_ERRORS.increment(); - throw new IllegalStateException("Object's dimension=" + floatArray.length + " does not match the " + - "expected dimension=" + expectedDimensions + "."); + throw new IllegalStateException( + "Object's dimension=" + floatArray.length + " does not match the " + "expected dimension=" + expectedDimensions + "." + ); } return floatArray; } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java index 7904fdad0..5ec462933 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.Logger; import java.math.BigInteger; -import java.lang.Math; import java.util.List; import java.util.Objects; @@ -28,14 +27,15 @@ private static void requireEqualDimension(final float[] queryVector, final float Objects.requireNonNull(queryVector); Objects.requireNonNull(inputVector); if (queryVector.length != inputVector.length) { - String errorMessage = String.format("query vector dimension mismatch. Expected: %d, Given: %d", - inputVector.length, queryVector.length); + String errorMessage = String.format( + "query vector dimension mismatch. Expected: %d, Given: %d", + inputVector.length, + queryVector.length + ); throw new IllegalArgumentException(errorMessage); } } - - /** * This method calculates L2 squared distance between query vector * and input vector @@ -126,8 +126,7 @@ public static float cosinesimilOptimized(float[] queryVector, float[] inputVecto * @param queryVectorMagnitude the magnitude of the query vector. * @return cosine score */ - public static float cosineSimilarity( - List queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) { + public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) { return cosinesimilOptimized(toFloat(queryVector), docValues.getValue(), queryVectorMagnitude.floatValue()); } @@ -176,7 +175,6 @@ public static float cosineSimilarity(List queryVector, KNNVectorScriptDo return cosinesimil(toFloat(queryVector), docValues.getValue()); } - /** * This method calculates hamming distance on 2 BigIntegers * diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNWhitelistExtension.java b/src/main/java/org/opensearch/knn/plugin/script/KNNWhitelistExtension.java index ccf9d8274..52fedac6e 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNWhitelistExtension.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNWhitelistExtension.java @@ -17,8 +17,7 @@ public class KNNWhitelistExtension implements PainlessExtension { - private static final Whitelist WHITELIST = - WhitelistLoader.loadFromResourceFiles(KNNWhitelistExtension.class, "knn_whitelist.txt"); + private static final Whitelist WHITELIST = WhitelistLoader.loadFromResourceFiles(KNNWhitelistExtension.class, "knn_whitelist.txt"); @Override public Map, List> getContextWhitelists() { diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java index 47b1d3414..d933ce66d 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNCounter.java @@ -68,4 +68,4 @@ public void increment() { public void set(long value) { count.set(value); } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNStat.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNStat.java index b8b2df179..230b55881 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNStat.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNStat.java @@ -30,7 +30,9 @@ public KNNStat(Boolean clusterLevel, Supplier supplier) { * * @return boolean that is true if the stat is clusterLevel; false otherwise */ - public Boolean isClusterLevel() { return clusterLevel; } + public Boolean isClusterLevel() { + return clusterLevel; + } /** * Get the value of the statistic @@ -40,4 +42,4 @@ public KNNStat(Boolean clusterLevel, Supplier supplier) { public T getValue() { return supplier.get(); } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java index 32ea5e816..a2b35083a 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java @@ -75,4 +75,4 @@ private Map> getClusterOrNodeStats(Boolean getClusterStats) { } return statsMap; } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNStatsConfig.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNStatsConfig.java index bf1b60d9f..56089ed84 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNStatsConfig.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNStatsConfig.java @@ -25,65 +25,65 @@ import java.util.Map; public class KNNStatsConfig { - public static Map> KNN_STATS = ImmutableMap.>builder() - .put(StatNames.HIT_COUNT.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::hitCount))) - .put(StatNames.MISS_COUNT.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::missCount))) - .put(StatNames.LOAD_SUCCESS_COUNT.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::loadSuccessCount))) - .put(StatNames.LOAD_EXCEPTION_COUNT.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::loadExceptionCount))) - .put(StatNames.TOTAL_LOAD_TIME.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::totalLoadTime))) - .put(StatNames.EVICTION_COUNT.getName(), new KNNStat<>(false, - new KNNInnerCacheStatsSupplier(CacheStats::evictionCount))) - .put(StatNames.GRAPH_MEMORY_USAGE.getName(), new KNNStat<>(false, - new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::getIndicesSizeInKilobytes))) - .put(StatNames.GRAPH_MEMORY_USAGE_PERCENTAGE.getName(), new KNNStat<>(false, - new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::getIndicesSizeAsPercentage))) - .put(StatNames.INDICES_IN_CACHE.getName(), new KNNStat<>(false, - new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::getIndicesCacheStats))) - .put(StatNames.CACHE_CAPACITY_REACHED.getName(), new KNNStat<>(false, - new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::isCacheCapacityReached))) - .put(StatNames.GRAPH_QUERY_ERRORS.getName(), new KNNStat<>(false, - new KNNCounterSupplier(KNNCounter.GRAPH_QUERY_ERRORS))) - .put(StatNames.GRAPH_QUERY_REQUESTS.getName(), new KNNStat<>(false, - new KNNCounterSupplier(KNNCounter.GRAPH_QUERY_REQUESTS))) - .put(StatNames.GRAPH_INDEX_ERRORS.getName(), new KNNStat<>(false, - new KNNCounterSupplier(KNNCounter.GRAPH_INDEX_ERRORS))) - .put(StatNames.GRAPH_INDEX_REQUESTS.getName(), new KNNStat<>(false, - new KNNCounterSupplier(KNNCounter.GRAPH_INDEX_REQUESTS))) - .put(StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), new KNNStat<>(true, - new KNNCircuitBreakerSupplier())) - .put(StatNames.MODEL_INDEX_STATUS.getName(), new KNNStat<>(true, - new ModelIndexStatusSupplier<>(ModelDao::getHealthStatus))) - .put(StatNames.KNN_QUERY_REQUESTS.getName(), new KNNStat<>(false, - new KNNCounterSupplier(KNNCounter.KNN_QUERY_REQUESTS))) - .put(StatNames.SCRIPT_COMPILATIONS.getName(), new KNNStat<>(false, - new KNNCounterSupplier(KNNCounter.SCRIPT_COMPILATIONS))) - .put(StatNames.SCRIPT_COMPILATION_ERRORS.getName(), new KNNStat<>(false, - new KNNCounterSupplier(KNNCounter.SCRIPT_COMPILATION_ERRORS))) - .put(StatNames.SCRIPT_QUERY_REQUESTS.getName(), new KNNStat<>(false, - new KNNCounterSupplier(KNNCounter.SCRIPT_QUERY_REQUESTS))) - .put(StatNames.SCRIPT_QUERY_ERRORS.getName(), new KNNStat<>(false, - new KNNCounterSupplier(KNNCounter.SCRIPT_QUERY_ERRORS))) - .put(StatNames.INDEXING_FROM_MODEL_DEGRADED.getName(), new KNNStat<>(false, - new EventOccurredWithinThresholdSupplier( - new ModelIndexingDegradingSupplier(ModelCache::getEvictedDueToSizeAt), - KNNConstants.MODEL_CACHE_CAPACITY_ATROPHY_THRESHOLD_IN_MINUTES, - ChronoUnit.MINUTES))) - .put(StatNames.FAISS_LOADED.getName(), new KNNStat<>(false, - new LibraryInitializedSupplier(KNNEngine.FAISS))) - .put(StatNames.NMSLIB_LOADED.getName(), new KNNStat<>(false, - new LibraryInitializedSupplier(KNNEngine.NMSLIB))) - .put(StatNames.TRAINING_REQUESTS.getName(), new KNNStat<>(false, - new KNNCounterSupplier(KNNCounter.TRAINING_REQUESTS))) - .put(StatNames.TRAINING_ERRORS.getName(), new KNNStat<>(false, - new KNNCounterSupplier(KNNCounter.TRAINING_ERRORS))) - .put(StatNames.TRAINING_MEMORY_USAGE.getName(), new KNNStat<>(false, - new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::getTrainingSizeInKilobytes))) - .put(StatNames.TRAINING_MEMORY_USAGE_PERCENTAGE.getName(), new KNNStat<>(false, - new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::getTrainingSizeAsPercentage))) - .build(); + public static Map> KNN_STATS = ImmutableMap.>builder() + .put(StatNames.HIT_COUNT.getName(), new KNNStat<>(false, new KNNInnerCacheStatsSupplier(CacheStats::hitCount))) + .put(StatNames.MISS_COUNT.getName(), new KNNStat<>(false, new KNNInnerCacheStatsSupplier(CacheStats::missCount))) + .put(StatNames.LOAD_SUCCESS_COUNT.getName(), new KNNStat<>(false, new KNNInnerCacheStatsSupplier(CacheStats::loadSuccessCount))) + .put(StatNames.LOAD_EXCEPTION_COUNT.getName(), new KNNStat<>(false, new KNNInnerCacheStatsSupplier(CacheStats::loadExceptionCount))) + .put(StatNames.TOTAL_LOAD_TIME.getName(), new KNNStat<>(false, new KNNInnerCacheStatsSupplier(CacheStats::totalLoadTime))) + .put(StatNames.EVICTION_COUNT.getName(), new KNNStat<>(false, new KNNInnerCacheStatsSupplier(CacheStats::evictionCount))) + .put( + StatNames.GRAPH_MEMORY_USAGE.getName(), + new KNNStat<>(false, new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::getIndicesSizeInKilobytes)) + ) + .put( + StatNames.GRAPH_MEMORY_USAGE_PERCENTAGE.getName(), + new KNNStat<>(false, new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::getIndicesSizeAsPercentage)) + ) + .put( + StatNames.INDICES_IN_CACHE.getName(), + new KNNStat<>(false, new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::getIndicesCacheStats)) + ) + .put( + StatNames.CACHE_CAPACITY_REACHED.getName(), + new KNNStat<>(false, new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::isCacheCapacityReached)) + ) + .put(StatNames.GRAPH_QUERY_ERRORS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_QUERY_ERRORS))) + .put(StatNames.GRAPH_QUERY_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_QUERY_REQUESTS))) + .put(StatNames.GRAPH_INDEX_ERRORS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_INDEX_ERRORS))) + .put(StatNames.GRAPH_INDEX_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_INDEX_REQUESTS))) + .put(StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), new KNNStat<>(true, new KNNCircuitBreakerSupplier())) + .put(StatNames.MODEL_INDEX_STATUS.getName(), new KNNStat<>(true, new ModelIndexStatusSupplier<>(ModelDao::getHealthStatus))) + .put(StatNames.KNN_QUERY_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.KNN_QUERY_REQUESTS))) + .put(StatNames.SCRIPT_COMPILATIONS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.SCRIPT_COMPILATIONS))) + .put( + StatNames.SCRIPT_COMPILATION_ERRORS.getName(), + new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.SCRIPT_COMPILATION_ERRORS)) + ) + .put(StatNames.SCRIPT_QUERY_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.SCRIPT_QUERY_REQUESTS))) + .put(StatNames.SCRIPT_QUERY_ERRORS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.SCRIPT_QUERY_ERRORS))) + .put( + StatNames.INDEXING_FROM_MODEL_DEGRADED.getName(), + new KNNStat<>( + false, + new EventOccurredWithinThresholdSupplier( + new ModelIndexingDegradingSupplier(ModelCache::getEvictedDueToSizeAt), + KNNConstants.MODEL_CACHE_CAPACITY_ATROPHY_THRESHOLD_IN_MINUTES, + ChronoUnit.MINUTES + ) + ) + ) + .put(StatNames.FAISS_LOADED.getName(), new KNNStat<>(false, new LibraryInitializedSupplier(KNNEngine.FAISS))) + .put(StatNames.NMSLIB_LOADED.getName(), new KNNStat<>(false, new LibraryInitializedSupplier(KNNEngine.NMSLIB))) + .put(StatNames.TRAINING_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.TRAINING_REQUESTS))) + .put(StatNames.TRAINING_ERRORS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.TRAINING_ERRORS))) + .put( + StatNames.TRAINING_MEMORY_USAGE.getName(), + new KNNStat<>(false, new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::getTrainingSizeInKilobytes)) + ) + .put( + StatNames.TRAINING_MEMORY_USAGE_PERCENTAGE.getName(), + new KNNStat<>(false, new NativeMemoryCacheManagerSupplier<>(NativeMemoryCacheManager::getTrainingSizeAsPercentage)) + ) + .build(); } diff --git a/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java b/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java index 0f9e97e61..f807c3eaa 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java @@ -43,14 +43,18 @@ public enum StatNames { private String name; - StatNames(String name) { this.name = name; } + StatNames(String name) { + this.name = name; + } /** * Get stat name * * @return name */ - public String getName() { return name; } + public String getName() { + return name; + } /** * Get all stat names @@ -65,4 +69,4 @@ public static Set getNames() { } return names; } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/EventOccurredWithinThresholdSupplier.java b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/EventOccurredWithinThresholdSupplier.java index 16e5cdc25..d6440f013 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/EventOccurredWithinThresholdSupplier.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/EventOccurredWithinThresholdSupplier.java @@ -40,10 +40,10 @@ public EventOccurredWithinThresholdSupplier(Supplier supplier, long thr public Boolean get() { Instant lastSeenAt = supplier.get(); - if (lastSeenAt == null) //Event never happened + if (lastSeenAt == null) // Event never happened return false; Instant expiringAt = lastSeenAt.plus(threshold, unit); - //if expiration is greater than current instant, then event occurred + // if expiration is greater than current instant, then event occurred if (expiringAt.compareTo(Instant.now()) > 0) { return true; } diff --git a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNCircuitBreakerSupplier.java b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNCircuitBreakerSupplier.java index 170b84613..32b78e7cc 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNCircuitBreakerSupplier.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNCircuitBreakerSupplier.java @@ -4,6 +4,7 @@ */ package org.opensearch.knn.plugin.stats.suppliers; + import org.opensearch.knn.index.KNNSettings; import java.util.function.Supplier; @@ -22,4 +23,4 @@ public KNNCircuitBreakerSupplier() {} public Boolean get() { return KNNSettings.isCircuitBreakerTriggered(); } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNCounterSupplier.java b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNCounterSupplier.java index 5a8128273..d0c66e41d 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNCounterSupplier.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNCounterSupplier.java @@ -28,4 +28,4 @@ public KNNCounterSupplier(KNNCounter knnCounter) { public Long get() { return knnCounter.getCount(); } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNInnerCacheStatsSupplier.java b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNInnerCacheStatsSupplier.java index b0bce3498..60a61f5f8 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNInnerCacheStatsSupplier.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/KNNInnerCacheStatsSupplier.java @@ -30,4 +30,4 @@ public KNNInnerCacheStatsSupplier(Function getter) { public Long get() { return getter.apply(NativeMemoryCacheManager.getInstance().getCacheStats()); } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/ModelIndexStatusSupplier.java b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/ModelIndexStatusSupplier.java index 01f8bec39..e958dcb64 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/ModelIndexStatusSupplier.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/ModelIndexStatusSupplier.java @@ -14,11 +14,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ResourceNotFoundException; -import org.opensearch.cluster.health.ClusterHealthStatus; -import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.indices.ModelDao; -import java.util.Locale; import java.util.function.Function; import java.util.function.Supplier; @@ -37,9 +34,9 @@ public ModelIndexStatusSupplier(Function getter) { @Override public T get() { - try{ + try { return getter.apply(ModelDao.OpenSearchKNNModelDao.getInstance()); - } catch (ResourceNotFoundException e) { //catch to prevent exception to be raised. + } catch (ResourceNotFoundException e) { // catch to prevent exception to be raised. logger.info(e.getMessage()); return null; // to let consumer knows that no value is available for getter. } diff --git a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/ModelIndexingDegradingSupplier.java b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/ModelIndexingDegradingSupplier.java index f53b06c27..f099d9f59 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/ModelIndexingDegradingSupplier.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/ModelIndexingDegradingSupplier.java @@ -24,6 +24,7 @@ public class ModelIndexingDegradingSupplier implements Supplier { public ModelIndexingDegradingSupplier(Function getter) { this.getter = getter; } + @Override public Instant get() { return getter.apply(ModelCache.getInstance()); diff --git a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/NativeMemoryCacheManagerSupplier.java b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/NativeMemoryCacheManagerSupplier.java index 460b79ce5..52d52b399 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/NativeMemoryCacheManagerSupplier.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/NativeMemoryCacheManagerSupplier.java @@ -35,4 +35,4 @@ public NativeMemoryCacheManagerSupplier(Function ge public T get() { return getter.apply(NativeMemoryCacheManager.getInstance()); } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelAction.java b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelAction.java index fc880fef0..f58728368 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelAction.java @@ -12,12 +12,10 @@ package org.opensearch.knn.plugin.transport; import org.opensearch.action.ActionType; -import org.opensearch.action.delete.DeleteResponse; import org.opensearch.common.io.stream.Writeable; public class DeleteModelAction extends ActionType { - public static final DeleteModelAction INSTANCE = new DeleteModelAction(); public static final String NAME = "cluster:admin/knn_delete_model_action"; diff --git a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelRequest.java index 1f1c5315c..792ccc543 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelRequest.java @@ -43,7 +43,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override public ActionRequestValidationException validate() { - if(Strings.hasText(modelID)) { + if (Strings.hasText(modelID)) { return null; } return addValidationError("Model id cannot be empty ", null); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelResponse.java b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelResponse.java index 6510d3e2e..b6f330d83 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelResponse.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelResponse.java @@ -11,7 +11,6 @@ package org.opensearch.knn.plugin.transport; import org.opensearch.action.ActionResponse; -import org.opensearch.action.DocWriteResponse; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; @@ -73,7 +72,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field(MODEL_ID, getModelID()); builder.field(RESULT, getResult()); - if (Strings.hasText(errorMessage)){ + if (Strings.hasText(errorMessage)) { builder.field(ERROR_MSG, getErrorMessage()); } builder.endObject(); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelTransportAction.java index d61b7b055..ee7f9e939 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelTransportAction.java @@ -12,7 +12,6 @@ package org.opensearch.knn.plugin.transport; import org.opensearch.action.ActionListener; -import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; @@ -22,7 +21,6 @@ public class DeleteModelTransportAction extends HandledTransportAction { - private final ModelDao modelDao; @Inject diff --git a/src/main/java/org/opensearch/knn/plugin/transport/GetModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/GetModelRequest.java index 050664d53..e2e03d041 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/GetModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/GetModelRequest.java @@ -26,6 +26,7 @@ public class GetModelRequest extends ActionRequest { private String modelID; + /** * Constructor * diff --git a/src/main/java/org/opensearch/knn/plugin/transport/GetModelResponse.java b/src/main/java/org/opensearch/knn/plugin/transport/GetModelResponse.java index 14a9d84dd..6c61befef 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/GetModelResponse.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/GetModelResponse.java @@ -57,7 +57,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return model.toXContent(builder, params); } - @Override public void writeTo(StreamOutput output) throws IOException { model.writeTo(output); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/GetModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/GetModelTransportAction.java index 7066f6b81..05cb12742 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/GetModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/GetModelTransportAction.java @@ -35,7 +35,6 @@ public GetModelTransportAction(TransportService transportService, ActionFilters this.modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); } - @Override protected void doExecute(Task task, GetModelRequest request, ActionListener actionListener) { String modelID = request.getModelID(); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsAction.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsAction.java index cb2e21c82..ccafb00d5 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsAction.java @@ -27,4 +27,4 @@ private KNNStatsAction() { public Writeable.Reader getResponseReader() { return KNNStatsResponse::new; } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsNodeRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsNodeRequest.java index d675cc9dc..ba66784f8 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsNodeRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsNodeRequest.java @@ -58,4 +58,4 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); request.writeTo(out); } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsNodeResponse.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsNodeResponse.java index 8a23614c4..5bdf03ea7 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsNodeResponse.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsNodeResponse.java @@ -30,7 +30,7 @@ public class KNNStatsNodeResponse extends BaseNodeResponse implements ToXContent */ public KNNStatsNodeResponse(StreamInput in) throws IOException { super(in); - this.statsMap = in.readMap(StreamInput::readString, StreamInput::readGenericValue); + this.statsMap = in.readMap(StreamInput::readString, StreamInput::readGenericValue); } /** @@ -86,4 +86,4 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsRequest.java index a841e47bd..500c36203 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsRequest.java @@ -100,4 +100,4 @@ public void writeTo(StreamOutput out) throws IOException { out.writeStringCollection(validStats); out.writeStringCollection(statsToBeRetrieved); } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsResponse.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsResponse.java index 8df9e0f3d..0679eefaa 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsResponse.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsResponse.java @@ -45,8 +45,12 @@ public KNNStatsResponse(StreamInput in) throws IOException { * @param failures List of failures from nodes * @param clusterStats Cluster level stats only obtained from a single node */ - public KNNStatsResponse(ClusterName clusterName, List nodes, List failures, - Map clusterStats) { + public KNNStatsResponse( + ClusterName clusterName, + List nodes, + List failures, + Map clusterStats + ) { super(clusterName, nodes, failures); this.clusterStats = clusterStats; } @@ -88,4 +92,4 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsTransportAction.java index 965307680..0189bf88d 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsTransportAction.java @@ -25,8 +25,11 @@ /** * KNNStatsTransportAction contains the logic to extract the stats from the nodes */ -public class KNNStatsTransportAction extends TransportNodesAction { +public class KNNStatsTransportAction extends TransportNodesAction< + KNNStatsRequest, + KNNStatsResponse, + KNNStatsNodeRequest, + KNNStatsNodeResponse> { private KNNStats knnStats; @@ -41,20 +44,32 @@ public class KNNStatsTransportAction extends TransportNodesAction responses, - List failures) { + protected KNNStatsResponse newResponse( + KNNStatsRequest request, + List responses, + List failures + ) { Map clusterStats = new HashMap<>(); Set statsToBeRetrieved = request.getStatsToBeRetrieved(); @@ -65,12 +80,7 @@ protected KNNStatsResponse newResponse(KNNStatsRequest request, List shardFailures) { + public KNNWarmupResponse( + int totalShards, + int successfulShards, + int failedShards, + List shardFailures + ) { super(totalShards, successfulShards, failedShards, shardFailures); } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNWarmupTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNWarmupTransportAction.java index 689e4682b..79c0315fd 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/KNNWarmupTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNWarmupTransportAction.java @@ -32,18 +32,32 @@ * all shards across the cluster for the given indices. For each shard, shardOperation will be called and the * warmup will take place. */ -public class KNNWarmupTransportAction extends TransportBroadcastByNodeAction { +public class KNNWarmupTransportAction extends TransportBroadcastByNodeAction< + KNNWarmupRequest, + KNNWarmupResponse, + TransportBroadcastByNodeAction.EmptyResult> { public static Logger logger = LogManager.getLogger(KNNWarmupTransportAction.class); private IndicesService indicesService; @Inject - public KNNWarmupTransportAction(ClusterService clusterService, TransportService transportService, IndicesService indicesService, - ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver) { - super(KNNWarmupAction.NAME, clusterService, transportService, actionFilters, indexNameExpressionResolver, - KNNWarmupRequest::new, ThreadPool.Names.SEARCH); + public KNNWarmupTransportAction( + ClusterService clusterService, + TransportService transportService, + IndicesService indicesService, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver + ) { + super( + KNNWarmupAction.NAME, + clusterService, + transportService, + actionFilters, + indexNameExpressionResolver, + KNNWarmupRequest::new, + ThreadPool.Names.SEARCH + ); this.indicesService = indicesService; } @@ -53,10 +67,15 @@ protected EmptyResult readShardResult(StreamInput in) throws IOException { } @Override - protected KNNWarmupResponse newResponse(KNNWarmupRequest request, int totalShards, int successfulShards, - int failedShards, List emptyResults, - List shardFailures, - ClusterState clusterState) { + protected KNNWarmupResponse newResponse( + KNNWarmupRequest request, + int totalShards, + int successfulShards, + int failedShards, + List emptyResults, + List shardFailures, + ClusterState clusterState + ) { return new KNNWarmupResponse(totalShards, successfulShards, failedShards, shardFailures); } @@ -67,8 +86,9 @@ protected KNNWarmupRequest readRequestFrom(StreamInput in) throws IOException { @Override protected EmptyResult shardOperation(KNNWarmupRequest request, ShardRouting shardRouting) throws IOException { - KNNIndexShard knnIndexShard = new KNNIndexShard(indicesService.indexServiceSafe(shardRouting.shardId() - .getIndex()).getShard(shardRouting.shardId().id())); + KNNIndexShard knnIndexShard = new KNNIndexShard( + indicesService.indexServiceSafe(shardRouting.shardId().getIndex()).getShard(shardRouting.shardId().id()) + ); knnIndexShard.warmup(); return EmptyResult.INSTANCE; } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheAction.java b/src/main/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheAction.java index 3b8a0e1e3..c2ce43069 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheAction.java @@ -20,8 +20,7 @@ public class RemoveModelFromCacheAction extends ActionType { public static final String NAME = "cluster:admin/knn_remove_model_from_cache_action"; - public static final RemoveModelFromCacheAction INSTANCE = new RemoveModelFromCacheAction(NAME, - RemoveModelFromCacheResponse::new); + public static final RemoveModelFromCacheAction INSTANCE = new RemoveModelFromCacheAction(NAME, RemoveModelFromCacheResponse::new); /** * Constructor @@ -29,8 +28,7 @@ public class RemoveModelFromCacheAction extends ActionType responseReader) { + public RemoveModelFromCacheAction(String name, Writeable.Reader responseReader) { super(name, responseReader); } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheResponse.java b/src/main/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheResponse.java index e3b531438..3293573ce 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheResponse.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheResponse.java @@ -32,9 +32,11 @@ public class RemoveModelFromCacheResponse extends BaseNodesResponse nodes, - List failures) { + public RemoveModelFromCacheResponse( + ClusterName clusterName, + List nodes, + List failures + ) { super(clusterName, nodes, failures); } @@ -45,8 +47,7 @@ public RemoveModelFromCacheResponse(ClusterName clusterName, * @throws IOException thrown when input stream cannot be read */ public RemoveModelFromCacheResponse(StreamInput in) throws IOException { - super(new ClusterName(in), in.readList(RemoveModelFromCacheNodeResponse::new), - in.readList(FailedNodeException::new)); + super(new ClusterName(in), in.readList(RemoveModelFromCacheNodeResponse::new), in.readList(FailedNodeException::new)); } @Override diff --git a/src/main/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportAction.java index ff5dd8c7c..92938ed3c 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportAction.java @@ -29,26 +29,40 @@ /** * Transport action to remove models from some or all nodes in the clusters caches */ -public class RemoveModelFromCacheTransportAction extends - TransportNodesAction { +public class RemoveModelFromCacheTransportAction extends TransportNodesAction< + RemoveModelFromCacheRequest, + RemoveModelFromCacheResponse, + RemoveModelFromCacheNodeRequest, + RemoveModelFromCacheNodeResponse> { private static Logger logger = LogManager.getLogger(RemoveModelFromCacheTransportAction.class); @Inject - public RemoveModelFromCacheTransportAction(ThreadPool threadPool, - ClusterService clusterService, - TransportService transportService, - ActionFilters actionFilters) { - super(RemoveModelFromCacheAction.NAME, threadPool, clusterService, transportService, actionFilters, - RemoveModelFromCacheRequest::new, RemoveModelFromCacheNodeRequest::new, - ThreadPool.Names.SAME, RemoveModelFromCacheNodeResponse.class); + public RemoveModelFromCacheTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters + ) { + super( + RemoveModelFromCacheAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + RemoveModelFromCacheRequest::new, + RemoveModelFromCacheNodeRequest::new, + ThreadPool.Names.SAME, + RemoveModelFromCacheNodeResponse.class + ); } @Override - protected RemoveModelFromCacheResponse newResponse(RemoveModelFromCacheRequest nodesRequest, - List responses, - List failures) { + protected RemoveModelFromCacheResponse newResponse( + RemoveModelFromCacheRequest nodesRequest, + List responses, + List failures + ) { return new RemoveModelFromCacheResponse(clusterService.getClusterName(), responses, failures); } @@ -64,8 +78,7 @@ protected RemoveModelFromCacheNodeResponse newNodeResponse(StreamInput in) throw @Override protected RemoveModelFromCacheNodeResponse nodeOperation(RemoveModelFromCacheNodeRequest nodeRequest) { - logger.debug("[KNN] Removing model \"" + nodeRequest.getModelId() + "\" on node \"" + - clusterService.localNode().getId() + "."); + logger.debug("[KNN] Removing model \"" + nodeRequest.getModelId() + "\" on node \"" + clusterService.localNode().getId() + "."); ModelCache.getInstance().remove(nodeRequest.getModelId()); return new RemoveModelFromCacheNodeResponse(clusterService.localNode()); } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/SearchModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/SearchModelTransportAction.java index ed5411b3a..4d9f67059 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/SearchModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/SearchModelTransportAction.java @@ -27,10 +27,7 @@ public class SearchModelTransportAction extends HandledTransportAction { public static final String NAME = "cluster:admin/knn_training_job_route_decision_info_action"; - public static final TrainingJobRouteDecisionInfoAction INSTANCE = new TrainingJobRouteDecisionInfoAction(NAME, - TrainingJobRouteDecisionInfoResponse::new); + public static final TrainingJobRouteDecisionInfoAction INSTANCE = new TrainingJobRouteDecisionInfoAction( + NAME, + TrainingJobRouteDecisionInfoResponse::new + ); /** * Constructor. @@ -30,8 +32,7 @@ public class TrainingJobRouteDecisionInfoAction extends ActionType responseReader) { + public TrainingJobRouteDecisionInfoAction(String name, Writeable.Reader responseReader) { super(name, responseReader); } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoResponse.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoResponse.java index 2fae3f2a0..4fe50410c 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoResponse.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoResponse.java @@ -29,7 +29,8 @@ * Aggregated response for training job route decision info */ public class TrainingJobRouteDecisionInfoResponse extends BaseNodesResponse - implements ToXContentObject { + implements + ToXContentObject { /** * Constructor @@ -48,9 +49,11 @@ public TrainingJobRouteDecisionInfoResponse(StreamInput in) throws IOException { * @param nodes List of KNNStatsNodeResponses * @param failures List of failures from nodes */ - public TrainingJobRouteDecisionInfoResponse(ClusterName clusterName, - List nodes, - List failures) { + public TrainingJobRouteDecisionInfoResponse( + ClusterName clusterName, + List nodes, + List failures + ) { super(clusterName, nodes, failures); } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportAction.java index 7f29a5f52..407036c86 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportAction.java @@ -28,9 +28,11 @@ * Broadcasts request to collect training job route decision info from all nodes and aggregates it into a single * response. */ -public class TrainingJobRouteDecisionInfoTransportAction extends - TransportNodesAction { +public class TrainingJobRouteDecisionInfoTransportAction extends TransportNodesAction< + TrainingJobRouteDecisionInfoRequest, + TrainingJobRouteDecisionInfoResponse, + TrainingJobRouteDecisionInfoNodeRequest, + TrainingJobRouteDecisionInfoNodeResponse> { /** * Constructor * @@ -41,25 +43,31 @@ public class TrainingJobRouteDecisionInfoTransportAction extends */ @Inject public TrainingJobRouteDecisionInfoTransportAction( - ThreadPool threadPool, - ClusterService clusterService, - TransportService transportService, - ActionFilters actionFilters + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters ) { - super(TrainingJobRouteDecisionInfoAction.NAME, threadPool, clusterService, transportService, actionFilters, - TrainingJobRouteDecisionInfoRequest::new, TrainingJobRouteDecisionInfoNodeRequest::new, - ThreadPool.Names.MANAGEMENT, TrainingJobRouteDecisionInfoNodeResponse.class); + super( + TrainingJobRouteDecisionInfoAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + TrainingJobRouteDecisionInfoRequest::new, + TrainingJobRouteDecisionInfoNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + TrainingJobRouteDecisionInfoNodeResponse.class + ); } @Override - protected TrainingJobRouteDecisionInfoResponse newResponse(TrainingJobRouteDecisionInfoRequest request, - List responses, - List failures) { - return new TrainingJobRouteDecisionInfoResponse( - clusterService.getClusterName(), - responses, - failures - ); + protected TrainingJobRouteDecisionInfoResponse newResponse( + TrainingJobRouteDecisionInfoRequest request, + List responses, + List failures + ) { + return new TrainingJobRouteDecisionInfoResponse(clusterService.getClusterName(), responses, failures); } @Override @@ -74,7 +82,6 @@ protected TrainingJobRouteDecisionInfoNodeResponse newNodeResponse(StreamInput i @Override protected TrainingJobRouteDecisionInfoNodeResponse nodeOperation(TrainingJobRouteDecisionInfoNodeRequest request) { - return new TrainingJobRouteDecisionInfoNodeResponse(clusterService.localNode(), - TrainingJobRunner.getInstance().getJobCount()); + return new TrainingJobRouteDecisionInfoNodeResponse(clusterService.localNode(), TrainingJobRunner.getInstance().getJobCount()); } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterAction.java index 401f924aa..d6a263a4b 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterAction.java @@ -20,8 +20,7 @@ public class TrainingJobRouterAction extends ActionType { public static final String NAME = "cluster:admin/knn_training_job_router_action"; - public static final TrainingJobRouterAction INSTANCE = new TrainingJobRouterAction(NAME, - TrainingModelResponse::new); + public static final TrainingJobRouterAction INSTANCE = new TrainingJobRouterAction(NAME, TrainingModelResponse::new); private TrainingJobRouterAction(String name, Writeable.Reader trainingModelResponseReader) { super(name, trainingModelResponseReader); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java index 0a24c0c47..774029c58 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java @@ -28,8 +28,6 @@ import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportService; -import java.util.concurrent.RejectedExecutionException; - import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES; import static org.opensearch.search.internal.SearchContext.DEFAULT_TERMINATE_AFTER; @@ -43,9 +41,12 @@ public class TrainingJobRouterTransportAction extends HandledTransportAction listener) { + protected void doExecute(Task task, TrainingModelRequest request, ActionListener listener) { // Get the size of the training request and then route the request. We get/set this here, as opposed to in // TrainingModelTransportAction, because in the future, we may want to use size to factor into our routing // decision. @@ -66,20 +66,27 @@ protected void doExecute(Task task, TrainingModelRequest request, protected void routeRequest(TrainingModelRequest request, ActionListener listener) { // Pick a node and then use the transport service to forward the request - client.execute(TrainingJobRouteDecisionInfoAction.INSTANCE, new TrainingJobRouteDecisionInfoRequest(), - ActionListener.wrap(response -> { - DiscoveryNode node = selectNode(request.getPreferredNodeId(), response); - - if (node == null) { - ValidationException exception = new ValidationException(); - exception.addValidationError("Cluster does not have capacity to train"); - listener.onFailure(exception); - return; - } - - transportService.sendRequest(node, TrainingModelAction.NAME, request, TransportRequestOptions.EMPTY, - new ActionListenerResponseHandler<>(listener, TrainingModelResponse::new)); - }, listener::onFailure) + client.execute( + TrainingJobRouteDecisionInfoAction.INSTANCE, + new TrainingJobRouteDecisionInfoRequest(), + ActionListener.wrap(response -> { + DiscoveryNode node = selectNode(request.getPreferredNodeId(), response); + + if (node == null) { + ValidationException exception = new ValidationException(); + exception.addValidationError("Cluster does not have capacity to train"); + listener.onFailure(exception); + return; + } + + transportService.sendRequest( + node, + TrainingModelAction.NAME, + request, + TransportRequestOptions.EMPTY, + new ActionListenerResponseHandler<>(listener, TrainingModelResponse::new) + ); + }, listener::onFailure) ); } @@ -93,7 +100,7 @@ protected DiscoveryNode selectNode(String preferredNode, TrainingJobRouteDecisio for (TrainingJobRouteDecisionInfoNodeResponse response : jobInfo.getNodes()) { currentNode = response.getNode(); - if(!eligibleNodes.containsKey(currentNode.getId())) { + if (!eligibleNodes.containsKey(currentNode.getId())) { continue; } @@ -138,6 +145,6 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques */ public static int estimateVectorSetSizeInKB(long vectorCount, int dimension) { // Ensure we do not overflow the int on estimate - return Math.toIntExact(((Float.BYTES * dimension * vectorCount) / BYTES_PER_KILOBYTES ) + 1L); + return Math.toIntExact(((Float.BYTES * dimension * vectorCount) / BYTES_PER_KILOBYTES) + 1L); } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelAction.java index 80bac9d32..b82107407 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelAction.java @@ -17,8 +17,7 @@ public class TrainingModelAction extends ActionType { public static final String NAME = "cluster:admin/knn_training_model_action"; - public static final TrainingModelAction INSTANCE = new TrainingModelAction(NAME, - TrainingModelResponse::new); + public static final TrainingModelAction INSTANCE = new TrainingModelAction(NAME, TrainingModelResponse::new); private TrainingModelAction(String name, Writeable.Reader trainingModelResponseReader) { super(name, trainingModelResponseReader); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index 1958dbcf8..ce9397905 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -57,8 +57,15 @@ public class TrainingModelRequest extends ActionRequest { * @param preferredNodeId Preferred node to execute training on. If null, the plugin will select the node. * @param description User provided description of their model */ - public TrainingModelRequest(String modelId, KNNMethodContext knnMethodContext, int dimension, String trainingIndex, - String trainingField, String preferredNodeId, String description) { + public TrainingModelRequest( + String modelId, + KNNMethodContext knnMethodContext, + int dimension, + String trainingIndex, + String trainingField, + String preferredNodeId, + String description + ) { super(); this.modelId = modelId; this.knnMethodContext = knnMethodContext; @@ -188,8 +195,9 @@ public int getMaximumVectorCount() { */ public void setMaximumVectorCount(int maximumVectorCount) { if (maximumVectorCount <= 0) { - throw new IllegalArgumentException(String.format("Maximum vector count %d is invalid. Maximum vector " + - "count must be greater than 0", maximumVectorCount)); + throw new IllegalArgumentException( + String.format("Maximum vector count %d is invalid. Maximum vector " + "count must be greater than 0", maximumVectorCount) + ); } this.maximumVectorCount = maximumVectorCount; } @@ -211,8 +219,9 @@ public int getSearchSize() { */ public void setSearchSize(int searchSize) { if (searchSize <= 0 || searchSize > 10000) { - throw new IllegalArgumentException(String.format("Search size %d is invalid. Search size must be " + - "between 0 and 10,000", searchSize)); + throw new IllegalArgumentException( + String.format("Search size %d is invalid. Search size must be " + "between 0 and 10,000", searchSize) + ); } this.searchSize = searchSize; } @@ -233,8 +242,9 @@ public int getTrainingDataSizeInKB() { */ void setTrainingDataSizeInKB(int trainingDataSizeInKB) { if (trainingDataSizeInKB <= 0) { - throw new IllegalArgumentException(String.format("Training data size %d is invalid. Training data size " + - "must be greater than 0", trainingDataSizeInKB)); + throw new IllegalArgumentException( + String.format("Training data size %d is invalid. Training data size " + "must be greater than 0", trainingDataSizeInKB) + ); } this.trainingDataSizeInKB = trainingDataSizeInKB; } @@ -271,8 +281,7 @@ public ActionRequestValidationException validate() { // Check if description is too long if (description != null && description.length() > KNNConstants.MAX_MODEL_DESCRIPTION_LENGTH) { exception = exception == null ? new ActionRequestValidationException() : exception; - exception.addValidationError("Description exceeds limit of " + KNNConstants.MAX_MODEL_DESCRIPTION_LENGTH + - " characters"); + exception.addValidationError("Description exceeds limit of " + KNNConstants.MAX_MODEL_DESCRIPTION_LENGTH + " characters"); } // Validate training index exists @@ -284,8 +293,7 @@ public ActionRequestValidationException validate() { } // Validate the training field - ValidationException fieldValidation = IndexUtil.validateKnnField(indexMetadata, this.trainingField, - this.dimension, modelDao); + ValidationException fieldValidation = IndexUtil.validateKnnField(indexMetadata, this.trainingField, this.dimension, modelDao); if (fieldValidation != null) { exception = exception == null ? new ActionRequestValidationException() : exception; exception.addValidationErrors(fieldValidation.validationErrors()); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java index 2e6b7d701..a3c4be16e 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java @@ -35,43 +35,38 @@ public class TrainingModelTransportAction extends HandledTransportAction listener) { + protected void doExecute(Task task, TrainingModelRequest request, ActionListener listener) { - NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = - new NativeMemoryEntryContext.TrainingDataEntryContext( - request.getTrainingDataSizeInKB(), - request.getTrainingIndex(), - request.getTrainingField(), - NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance(), - clusterService, - request.getMaximumVectorCount(), - request.getSearchSize() - ); + NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = new NativeMemoryEntryContext.TrainingDataEntryContext( + request.getTrainingDataSizeInKB(), + request.getTrainingIndex(), + request.getTrainingField(), + NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance(), + clusterService, + request.getMaximumVectorCount(), + request.getSearchSize() + ); // Allocation representing size model will occupy in memory during training - NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext = - new NativeMemoryEntryContext.AnonymousEntryContext( - request.getKnnMethodContext().estimateOverheadInKB(request.getDimension()), - NativeMemoryLoadStrategy.AnonymousLoadStrategy.getInstance() - ); + NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext = new NativeMemoryEntryContext.AnonymousEntryContext( + request.getKnnMethodContext().estimateOverheadInKB(request.getDimension()), + NativeMemoryLoadStrategy.AnonymousLoadStrategy.getInstance() + ); TrainingJob trainingJob = new TrainingJob( - request.getModelId(), - request.getKnnMethodContext(), - NativeMemoryCacheManager.getInstance(), - trainingDataEntryContext, - modelAnonymousEntryContext, - request.getDimension(), - request.getDescription() + request.getModelId(), + request.getKnnMethodContext(), + NativeMemoryCacheManager.getInstance(), + trainingDataEntryContext, + modelAnonymousEntryContext, + request.getDimension(), + request.getDescription() ); KNNCounter.TRAINING_REQUESTS.increment(); @@ -81,10 +76,14 @@ protected void doExecute(Task task, TrainingModelRequest request, }); try { - TrainingJobRunner.getInstance().execute(trainingJob, ActionListener.wrap( - indexResponse -> wrappedListener.onResponse(new TrainingModelResponse(indexResponse.getId())), - wrappedListener::onFailure) - ); + TrainingJobRunner.getInstance() + .execute( + trainingJob, + ActionListener.wrap( + indexResponse -> wrappedListener.onResponse(new TrainingModelResponse(indexResponse.getId())), + wrappedListener::onFailure + ) + ); } catch (IOException e) { wrappedListener.onFailure(e); } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequest.java index 7a49c4e93..d0628a519 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequest.java @@ -104,7 +104,6 @@ public ModelMetadata getModelMetadata() { return modelMetadata; } - @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportAction.java index 54d437869..b0deb93d9 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportAction.java @@ -52,13 +52,22 @@ public class UpdateModelMetadataTransportAction extends TransportMasterNodeActio private UpdateModelMetadataExecutor updateModelMetadataExecutor; @Inject - public UpdateModelMetadataTransportAction(TransportService transportService, - ClusterService clusterService, - ThreadPool threadPool, - ActionFilters actionFilters, - IndexNameExpressionResolver indexNameExpressionResolver) { - super(UpdateModelMetadataAction.NAME, transportService, clusterService, threadPool, actionFilters, - UpdateModelMetadataRequest::new, indexNameExpressionResolver); + public UpdateModelMetadataTransportAction( + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver + ) { + super( + UpdateModelMetadataAction.NAME, + transportService, + clusterService, + threadPool, + actionFilters, + UpdateModelMetadataRequest::new, + indexNameExpressionResolver + ); this.updateModelMetadataExecutor = new UpdateModelMetadataExecutor(); } @@ -73,25 +82,29 @@ protected AcknowledgedResponse read(StreamInput streamInput) throws IOException } @Override - protected void masterOperation(UpdateModelMetadataRequest request, ClusterState clusterState, - ActionListener actionListener) { + protected void masterOperation( + UpdateModelMetadataRequest request, + ClusterState clusterState, + ActionListener actionListener + ) { // Master updates model metadata based on request parameters clusterService.submitStateUpdateTask( - PLUGIN_NAME, - new UpdateModelMetaDataTask(request.getModelId(), request.getModelMetadata(), request.isRemoveRequest()), - ClusterStateTaskConfig.build(Priority.NORMAL), - updateModelMetadataExecutor, - new ClusterStateTaskListener() { - @Override - public void onFailure(String s, Exception e) { - actionListener.onFailure(e); - } - - @Override - public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { - actionListener.onResponse(new AcknowledgedResponse(true)); - } - }); + PLUGIN_NAME, + new UpdateModelMetaDataTask(request.getModelId(), request.getModelMetadata(), request.isRemoveRequest()), + ClusterStateTaskConfig.build(Priority.NORMAL), + updateModelMetadataExecutor, + new ClusterStateTaskListener() { + @Override + public void onFailure(String s, Exception e) { + actionListener.onFailure(e); + } + + @Override + public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { + actionListener.onResponse(new AcknowledgedResponse(true)); + } + } + ); } @Override @@ -125,8 +138,7 @@ private static class UpdateModelMetaDataTask { private static class UpdateModelMetadataExecutor implements ClusterStateTaskExecutor { @Override - public ClusterTasksResult execute(ClusterState clusterState, - List list) { + public ClusterTasksResult execute(ClusterState clusterState, List list) { // Get the map of the models metadata IndexMetadata indexMetadata = clusterState.metadata().index(MODEL_INDEX_NAME); diff --git a/src/main/java/org/opensearch/knn/training/TrainingDataConsumer.java b/src/main/java/org/opensearch/knn/training/TrainingDataConsumer.java index 163abc30c..6732bd3f4 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingDataConsumer.java +++ b/src/main/java/org/opensearch/knn/training/TrainingDataConsumer.java @@ -36,7 +36,11 @@ public TrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation traini @Override public void accept(List floats) { - trainingDataAllocation.setMemoryAddress(JNIService.transferVectors(trainingDataAllocation.getMemoryAddress(), - floats.stream().map(ArrayUtils::toPrimitive).toArray(float[][]::new))); + trainingDataAllocation.setMemoryAddress( + JNIService.transferVectors( + trainingDataAllocation.getMemoryAddress(), + floats.stream().map(ArrayUtils::toPrimitive).toArray(float[][]::new) + ) + ); } } diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index 288823fb7..c83c69831 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -58,33 +58,34 @@ public class TrainingJob implements Runnable { * @param dimension model's dimension * @param description user provided description of the model. */ - public TrainingJob(String modelId, KNNMethodContext knnMethodContext, - NativeMemoryCacheManager nativeMemoryCacheManager, - NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext, - NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext, - int dimension, String description) { + public TrainingJob( + String modelId, + KNNMethodContext knnMethodContext, + NativeMemoryCacheManager nativeMemoryCacheManager, + NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext, + NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext, + int dimension, + String description + ) { // Generate random base64 string if one is not provided this.modelId = Strings.hasText(modelId) ? modelId : UUIDs.randomBase64UUID(); this.knnMethodContext = Objects.requireNonNull(knnMethodContext, "MethodContext cannot be null."); - this.nativeMemoryCacheManager = Objects.requireNonNull(nativeMemoryCacheManager, - "NativeMemoryCacheManager cannot be null."); - this.trainingDataEntryContext = Objects.requireNonNull(trainingDataEntryContext, - "TrainingDataEntryContext cannot be null."); - this.modelAnonymousEntryContext = Objects.requireNonNull(modelAnonymousEntryContext, - "AnonymousEntryContext cannot be null."); + this.nativeMemoryCacheManager = Objects.requireNonNull(nativeMemoryCacheManager, "NativeMemoryCacheManager cannot be null."); + this.trainingDataEntryContext = Objects.requireNonNull(trainingDataEntryContext, "TrainingDataEntryContext cannot be null."); + this.modelAnonymousEntryContext = Objects.requireNonNull(modelAnonymousEntryContext, "AnonymousEntryContext cannot be null."); this.model = new Model( - new ModelMetadata( - knnMethodContext.getEngine(), - knnMethodContext.getSpaceType(), - dimension, - ModelState.TRAINING, - ZonedDateTime.now(ZoneOffset.UTC).toString(), - description, - "" - ), - null, - this.modelId - ); + new ModelMetadata( + knnMethodContext.getEngine(), + knnMethodContext.getSpaceType(), + dimension, + ModelState.TRAINING, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + description, + "" + ), + null, + this.modelId + ); } /** @@ -120,8 +121,9 @@ public void run() { } catch (Exception e) { logger.error("Failed to get training data for model \"" + modelId + "\": " + e.getMessage()); modelMetadata.setState(ModelState.FAILED); - modelMetadata.setError("Failed to load training data into memory. " + - "Check if there is enough memory to perform the request."); + modelMetadata.setError( + "Failed to load training data into memory. " + "Check if there is enough memory to perform the request." + ); if (trainingDataAllocation != null) { nativeMemoryCacheManager.invalidate(trainingDataEntryContext.getKey()); @@ -141,8 +143,9 @@ public void run() { } catch (Exception e) { logger.error("Failed to allocate space in native memory for model \"" + modelId + "\": " + e.getMessage()); modelMetadata.setState(ModelState.FAILED); - modelMetadata.setError("Failed to allocate space in native memory for the model. " + - "Check if there is enough memory to perform the request."); + modelMetadata.setError( + "Failed to allocate space in native memory for the model. " + "Check if there is enough memory to perform the request." + ); trainingDataAllocation.readUnlock(); nativeMemoryCacheManager.invalidate(trainingDataEntryContext.getKey()); @@ -171,14 +174,16 @@ public void run() { } Map trainParameters = model.getModelMetadata().getKnnEngine().getMethodAsMap(knnMethodContext); - trainParameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue( - KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + trainParameters.put( + KNNConstants.INDEX_THREAD_QTY, + KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) + ); byte[] modelBlob = JNIService.trainIndex( - trainParameters, - model.getModelMetadata().getDimension(), - trainingDataAllocation.getMemoryAddress(), - model.getModelMetadata().getKnnEngine().getName() + trainParameters, + model.getModelMetadata().getDimension(), + trainingDataAllocation.getMemoryAddress(), + model.getModelMetadata().getKnnEngine().getName() ); // Once training finishes, update model @@ -187,8 +192,9 @@ public void run() { } catch (Exception e) { logger.error("Failed to run training job for model \"" + modelId + "\": " + e.getMessage()); modelMetadata.setState(ModelState.FAILED); - modelMetadata.setError("Failed to execute training. May be caused by an invalid method definition or " + - "not enough memory to perform training."); + modelMetadata.setError( + "Failed to execute training. May be caused by an invalid method definition or " + "not enough memory to perform training." + ); KNNCounter.TRAINING_ERRORS.increment(); diff --git a/src/main/java/org/opensearch/knn/training/TrainingJobRunner.java b/src/main/java/org/opensearch/knn/training/TrainingJobRunner.java index 77fdd2dc1..774500311 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJobRunner.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJobRunner.java @@ -95,24 +95,17 @@ public void execute(TrainingJob trainingJob, ActionListener liste // Serialize model before training. The model should be in the training state and the model binary should be // null. This notifies users that their model is training, but not yet ready for use. try { - serializeModel( - trainingJob, - ActionListener.wrap( - indexResponse -> { - // Respond to the request with the initial index response - listener.onResponse(indexResponse); - train(trainingJob); - }, - exception -> { - // Serialization failed. Let listener handle the exception, but free up resources. - jobCount.decrementAndGet(); - semaphore.release(); - logger.error("Unable to initialize model serialization: " + exception.getMessage()); - listener.onFailure(exception); - } - ), - false - ); + serializeModel(trainingJob, ActionListener.wrap(indexResponse -> { + // Respond to the request with the initial index response + listener.onResponse(indexResponse); + train(trainingJob); + }, exception -> { + // Serialization failed. Let listener handle the exception, but free up resources. + jobCount.decrementAndGet(); + semaphore.release(); + logger.error("Unable to initialize model serialization: " + exception.getMessage()); + listener.onFailure(exception); + }), false); } catch (IOException ioe) { jobCount.decrementAndGet(); semaphore.release(); @@ -125,13 +118,11 @@ private void train(TrainingJob trainingJob) { // Listener for update model after training index action ActionListener loggingListener = ActionListener.wrap( - indexResponse -> logger.debug("[KNN] Model serialization update for \"" + - trainingJob.getModelId() + "\" was successful"), - e -> { - logger.error("[KNN] Model serialization update for \"" + trainingJob.getModelId() + - "\" failed: " + e.getMessage()); - KNNCounter.TRAINING_ERRORS.increment(); - } + indexResponse -> logger.debug("[KNN] Model serialization update for \"" + trainingJob.getModelId() + "\" was successful"), + e -> { + logger.error("[KNN] Model serialization update for \"" + trainingJob.getModelId() + "\" failed: " + e.getMessage()); + KNNCounter.TRAINING_ERRORS.increment(); + } ); try { @@ -143,8 +134,7 @@ private void train(TrainingJob trainingJob) { logger.error("Unable to serialize model \"" + trainingJob.getModelId() + "\": " + e.getMessage()); KNNCounter.TRAINING_ERRORS.increment(); } catch (Exception e) { - logger.error("Unable to complete training for \"" + trainingJob.getModelId() + "\": " - + e.getMessage()); + logger.error("Unable to complete training for \"" + trainingJob.getModelId() + "\": " + e.getMessage()); KNNCounter.TRAINING_ERRORS.increment(); } finally { jobCount.decrementAndGet(); @@ -170,8 +160,7 @@ private void train(TrainingJob trainingJob) { } } - private void serializeModel(TrainingJob trainingJob, ActionListener listener, boolean update) - throws IOException { + private void serializeModel(TrainingJob trainingJob, ActionListener listener, boolean update) throws IOException { if (update) { modelDao.update(trainingJob.getModel(), listener); } else { diff --git a/src/test/java/org/opensearch/knn/KNNRestTestCase.java b/src/test/java/org/opensearch/knn/KNNRestTestCase.java index 3afac33a8..620fa2ecc 100644 --- a/src/test/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNRestTestCase.java @@ -95,8 +95,11 @@ public static void dumpCoverage() throws IOException, MalformedObjectNameExcepti String serverUrl = "service:jmx:rmi:///jndi/rmi://127.0.0.1:7777/jmxrmi"; try (JMXConnector connector = JMXConnectorFactory.connect(new JMXServiceURL(serverUrl))) { IProxy proxy = MBeanServerInvocationHandler.newProxyInstance( - connector.getMBeanServerConnection(), new ObjectName("org.jacoco:type=Runtime"), IProxy.class, - false); + connector.getMBeanServerConnection(), + new ObjectName("org.jacoco:type=Runtime"), + IProxy.class, + false + ); Path path = Paths.get(jacocoBuildPath + "/integTest.exec"); Files.write(path, proxy.getExecutionData(false)); @@ -127,7 +130,8 @@ protected void createKnnIndex(String index, Settings settings, String mapping) t } protected void createBasicKnnIndex(String index, String fieldName, int dimension) throws IOException { - String mapping = Strings.toString(XContentFactory.jsonBuilder() + String mapping = Strings.toString( + XContentFactory.jsonBuilder() .startObject() .startObject("properties") .startObject(fieldName) @@ -135,7 +139,8 @@ protected void createBasicKnnIndex(String index, String fieldName, int dimension .field("dimension", Integer.toString(dimension)) .endObject() .endObject() - .endObject()); + .endObject() + ); mapping = mapping.substring(1, mapping.length() - 1); createIndex(index, Settings.EMPTY, mapping); @@ -144,16 +149,12 @@ protected void createBasicKnnIndex(String index, String fieldName, int dimension /** * Run KNN Search on Index */ - protected Response searchKNNIndex(String index, KNNQueryBuilder knnQueryBuilder, int resultSize) throws - IOException { + protected Response searchKNNIndex(String index, KNNQueryBuilder knnQueryBuilder, int resultSize) throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); knnQueryBuilder.doXContent(builder, ToXContent.EMPTY_PARAMS); builder.endObject().endObject(); - Request request = new Request( - "POST", - "/" + index + "/_search" - ); + Request request = new Request("POST", "/" + index + "/_search"); request.addParameter("size", Integer.toString(resultSize)); request.addParameter("explain", Boolean.toString(true)); @@ -161,8 +162,7 @@ protected Response searchKNNIndex(String index, KNNQueryBuilder knnQueryBuilder, request.setJsonEntity(Strings.toString(builder)); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); return response; } @@ -170,13 +170,9 @@ protected Response searchKNNIndex(String index, KNNQueryBuilder knnQueryBuilder, /** * Run exists search */ - protected Response searchExists(String index, ExistsQueryBuilder existsQueryBuilder, int resultSize) throws - IOException { + protected Response searchExists(String index, ExistsQueryBuilder existsQueryBuilder, int resultSize) throws IOException { - Request request = new Request( - "POST", - "/" + index + "/_search" - ); + Request request = new Request("POST", "/" + index + "/_search"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); builder = XContentFactory.jsonBuilder().startObject(); @@ -187,8 +183,7 @@ protected Response searchExists(String index, ExistsQueryBuilder existsQueryBuil request.setJsonEntity(Strings.toString(builder)); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); return response; } @@ -198,32 +193,30 @@ protected Response searchExists(String index, ExistsQueryBuilder existsQueryBuil */ protected List parseSearchResponse(String responseBody, String fieldName) throws IOException { @SuppressWarnings("unchecked") - List hits = (List) ((Map) createParser(XContentType.JSON.xContent(), - responseBody).map().get("hits")).get("hits"); + List hits = (List) ((Map) createParser(XContentType.JSON.xContent(), responseBody).map() + .get("hits")).get("hits"); @SuppressWarnings("unchecked") List knnSearchResponses = hits.stream().map(hit -> { - @SuppressWarnings("unchecked") - Float[] vector = Arrays.stream( - ((ArrayList) ((Map) - ((Map) hit).get("_source")).get(fieldName)).toArray()) - .map(Object::toString) - .map(Float::valueOf) - .toArray(Float[]::new); - return new KNNResult((String) ((Map) hit).get("_id"), vector); - } - ).collect(Collectors.toList()); + @SuppressWarnings("unchecked") + Float[] vector = Arrays.stream( + ((ArrayList) ((Map) ((Map) hit).get("_source")).get(fieldName)).toArray() + ).map(Object::toString).map(Float::valueOf).toArray(Float[]::new); + return new KNNResult((String) ((Map) hit).get("_id"), vector); + }).collect(Collectors.toList()); return knnSearchResponses; } + protected List parseSearchResponseScore(String responseBody, String fieldName) throws IOException { @SuppressWarnings("unchecked") - List hits = (List) ((Map) createParser(XContentType.JSON.xContent(), - responseBody).map().get("hits")).get("hits"); + List hits = (List) ((Map) createParser(XContentType.JSON.xContent(), responseBody).map() + .get("hits")).get("hits"); @SuppressWarnings("unchecked") - List knnSearchResponses = hits.stream().map(hit -> - ((Double) ((Map) hit).get("_score")).floatValue()).collect(Collectors.toList()); + List knnSearchResponses = hits.stream() + .map(hit -> ((Double) ((Map) hit).get("_score")).floatValue()) + .collect(Collectors.toList()); return knnSearchResponses; } @@ -236,14 +229,10 @@ protected List parseSearchResponseScore(String responseBody, String field * Delete KNN index */ protected void deleteKNNIndex(String index) throws IOException { - Request request = new Request( - "DELETE", - "/" + index - ); + Request request = new Request("DELETE", "/" + index); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } /** @@ -251,29 +240,28 @@ protected void deleteKNNIndex(String index) throws IOException { */ protected void putMappingRequest(String index, String mapping) throws IOException { // Put KNN mapping - Request request = new Request( - "PUT", - "/" + index + "/_mapping" - ); + Request request = new Request("PUT", "/" + index + "/_mapping"); request.setJsonEntity(mapping); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } /** * Utility to create a Knn Index Mapping */ protected String createKnnIndexMapping(String fieldName, Integer dimensions) throws IOException { - return Strings.toString(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject(fieldName) - .field("type", "knn_vector") - .field("dimension", dimensions.toString()) - .endObject() - .endObject() - .endObject()); + return Strings.toString( + XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimensions.toString()) + .endObject() + .endObject() + .endObject() + ); } /** @@ -295,7 +283,6 @@ protected String createKnnIndexMapping(List fieldNames, List di return Strings.toString(xContentBuilder); } - /** * Get index mapping as map * @@ -304,34 +291,25 @@ protected String createKnnIndexMapping(List fieldNames, List di */ @SuppressWarnings("unchecked") public Map getIndexMappingAsMap(String index) throws IOException { - Request request = new Request( - "GET", - "/" + index + "/_mapping" - ); + Request request = new Request("GET", "/" + index + "/_mapping"); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); String responseBody = EntityUtils.toString(response.getEntity()); Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); - return (Map) ((Map) responseMap.get(index)).get("mappings"); } public int getDocCount(String indexName) throws IOException { - Request request = new Request( - "GET", - "/" + indexName + "/_count" - ); + Request request = new Request("GET", "/" + indexName + "/_count"); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); String responseBody = EntityUtils.toString(response.getEntity()); @@ -343,25 +321,17 @@ public int getDocCount(String indexName) throws IOException { * Force merge KNN index segments */ protected void forceMergeKnnIndex(String index) throws Exception { - Request request = new Request( - "POST", - "/" + index + "/_refresh" - ); + Request request = new Request("POST", "/" + index + "/_refresh"); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - request = new Request( - "POST", - "/" + index + "/_forcemerge" - ); + request = new Request("POST", "/" + index + "/_forcemerge"); request.addParameter("max_num_segments", "1"); request.addParameter("flush", "true"); response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); TimeUnit.SECONDS.sleep(5); // To make sure force merge is completed } @@ -369,34 +339,22 @@ protected void forceMergeKnnIndex(String index) throws Exception { * Add a single KNN Doc to an index */ protected void addKnnDoc(String index, String docId, String fieldName, Object[] vector) throws IOException { - Request request = new Request( - "POST", - "/" + index + "/_doc/" + docId + "?refresh=true" - ); + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .field(fieldName, vector) - .endObject(); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, vector).endObject(); request.setJsonEntity(Strings.toString(builder)); client().performRequest(request); - request = new Request( - "POST", - "/" + index + "/_refresh" - ); + request = new Request("POST", "/" + index + "/_refresh"); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } /** * Add a single KNN Doc to an index with multiple fields */ protected void addKnnDoc(String index, String docId, List fieldNames, List vectors) throws IOException { - Request request = new Request( - "POST", - "/" + index + "/_doc/" + docId + "?refresh=true" - ); + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); for (int i = 0; i < fieldNames.size(); i++) { @@ -406,72 +364,51 @@ protected void addKnnDoc(String index, String docId, List fieldNames, Li request.setJsonEntity(Strings.toString(builder)); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } /** * Add a single numeric field Doc to an index */ protected void addDocWithNumericField(String index, String docId, String fieldName, long value) throws IOException { - Request request = new Request( - "POST", - "/" + index + "/_doc/" + docId + "?refresh=true" - ); + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .field(fieldName, value) - .endObject(); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, value).endObject(); request.setJsonEntity(Strings.toString(builder)); Response response = client().performRequest(request); - - assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } /** * Add a single numeric field Doc to an index */ - protected void addDocWithBinaryField(String index, String docId, String fieldName, String base64String) - throws IOException { - Request request = new Request( - "POST", - "/" + index + "/_doc/" + docId + "?refresh=true" - ); + protected void addDocWithBinaryField(String index, String docId, String fieldName, String base64String) throws IOException { + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .field(fieldName, base64String) - .endObject(); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, base64String).endObject(); request.setJsonEntity(Strings.toString(builder)); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } /** * Update a KNN Doc with a new vector for the given fieldName */ protected void updateKnnDoc(String index, String docId, String fieldName, Object[] vector) throws IOException { - Request request = new Request( - "POST", - "/" + index + "/_doc/" + docId + "?refresh=true" - ); + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .field(fieldName, vector) - .endObject(); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, vector).endObject(); request.setJsonEntity(Strings.toString(builder)); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } /** @@ -479,28 +416,21 @@ protected void updateKnnDoc(String index, String docId, String fieldName, Object */ protected void deleteKnnDoc(String index, String docId) throws IOException { // Put KNN mapping - Request request = new Request( - "DELETE", - "/" + index + "/_doc/" + docId + "?refresh" - ); + Request request = new Request("DELETE", "/" + index + "/_doc/" + docId + "?refresh"); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } /** * Retrieve document by index and document id */ protected Map getKnnDoc(final String index, final String docId) throws IOException { - final Request request = new Request( - "GET", - "/" + index + "/_doc/" + docId - ); + final Request request = new Request("GET", "/" + index + "/_doc/" + docId); final Response response = client().performRequest(request); - final Map responseMap = - createParser(XContentType.JSON.xContent(), EntityUtils.toString(response.getEntity())).map(); + final Map responseMap = createParser(XContentType.JSON.xContent(), EntityUtils.toString(response.getEntity())) + .map(); assertNotNull(responseMap); assertTrue((Boolean) responseMap.get(DOCUMENT_FIELD_FOUND)); @@ -531,11 +461,7 @@ protected void updateClusterSettings(String settingKey, Object value) throws Exc * Return default index settings for index creation */ protected Settings getKNNDefaultIndexSettings() { - return Settings.builder() - .put("number_of_shards", 1) - .put("number_of_replicas", 0) - .put("index.knn", true) - .build(); + return Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).put("index.knn", true).build(); } /** @@ -545,8 +471,7 @@ protected Response getKnnStats(List nodeIds, List stats) throws return executeKnnStatRequest(nodeIds, stats, KNNPlugin.KNN_BASE_URI); } - protected Response executeKnnStatRequest( - List nodeIds, List stats, final String baseURI) throws IOException { + protected Response executeKnnStatRequest(List nodeIds, List stats, final String baseURI) throws IOException { String nodePrefix = ""; if (!nodeIds.isEmpty()) { nodePrefix = "/" + String.join(",", nodeIds); @@ -577,7 +502,6 @@ protected Response executeWarmupRequest(List indices, final String baseU return client().performRequest(request); } - /** * Parse KNN Cluster stats from response */ @@ -594,13 +518,13 @@ protected Map parseClusterStatsResponse(String responseBody) thr */ protected List> parseNodeStatsResponse(String responseBody) throws IOException { @SuppressWarnings("unchecked") - Map responseMap = (Map) createParser(XContentType.JSON.xContent(), - responseBody).map().get("nodes"); + Map responseMap = (Map) createParser(XContentType.JSON.xContent(), responseBody).map().get("nodes"); @SuppressWarnings("unchecked") - List> nodeResponses = responseMap.keySet().stream().map(key -> - (Map) responseMap.get(key) - ).collect(Collectors.toList()); + List> nodeResponses = responseMap.keySet() + .stream() + .map(key -> (Map) responseMap.get(key)) + .collect(Collectors.toList()); return nodeResponses; } @@ -610,10 +534,8 @@ protected List> parseNodeStatsResponse(String responseBody) */ @SuppressWarnings("unchecked") protected int parseTotalSearchHits(String searchResponseBody) throws IOException { - Map responseMap = (Map) createParser( - XContentType.JSON.xContent(), - searchResponseBody - ).map().get("hits"); + Map responseMap = (Map) createParser(XContentType.JSON.xContent(), searchResponseBody).map() + .get("hits"); return (int) ((Map) responseMap.get("total")).get("value"); } @@ -633,8 +555,9 @@ protected int getTotalGraphsInCache() throws IOException { return nodesStats.stream() .filter(nodeStats -> nodeStats.get(INDICES_IN_CACHE.getName()) != null) .map(nodeStats -> nodeStats.get(INDICES_IN_CACHE.getName())) - .mapToInt(nodeIndicesStats -> - ((Map>) nodeIndicesStats).values().stream() + .mapToInt( + nodeIndicesStats -> ((Map>) nodeIndicesStats).values() + .stream() .mapToInt(nodeIndexStats -> (int) nodeIndexStats.get(GRAPH_COUNT)) .sum() ) @@ -646,9 +569,9 @@ protected int getTotalGraphsInCache() throws IOException { */ protected String getIndexSettingByName(String indexName, String settingName) throws IOException { @SuppressWarnings("unchecked") - Map settings = - (Map) ((Map) getIndexSettings(indexName).get(indexName)) - .get("settings"); + Map settings = (Map) ((Map) getIndexSettings(indexName).get(indexName)).get( + "settings" + ); return (String) settings.get(settingName); } @@ -661,39 +584,33 @@ protected void createModelSystemIndex() throws IOException { String mapping = Resources.toString(url, Charsets.UTF_8); mapping = mapping.substring(1, mapping.length() - 1); - createIndex(MODEL_INDEX_NAME, Settings.builder() - .put("number_of_shards", 1) - .put("number_of_replicas", 0).build(), - mapping); + createIndex(MODEL_INDEX_NAME, Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).build(), mapping); } protected void addModelToSystemIndex(String modelId, ModelMetadata modelMetadata, byte[] model) throws IOException { assertFalse(Strings.isNullOrEmpty(modelId)); String modelBase64 = Base64.getEncoder().encodeToString(model); - Request request = new Request( - "POST", - "/" + MODEL_INDEX_NAME + "/_doc/" + modelId + "?refresh=true" - ); + Request request = new Request("POST", "/" + MODEL_INDEX_NAME + "/_doc/" + modelId + "?refresh=true"); - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .field(MODEL_ID, modelId) - .field(MODEL_STATE, modelMetadata.getState().getName()) - .field(KNN_ENGINE, modelMetadata.getKnnEngine().getName()) - .field(METHOD_PARAMETER_SPACE_TYPE, modelMetadata.getSpaceType().getValue()) - .field(DIMENSION, modelMetadata.getDimension()) - .field(MODEL_BLOB_PARAMETER, modelBase64) - .field(MODEL_TIMESTAMP, modelMetadata.getTimestamp()) - .field(MODEL_DESCRIPTION, modelMetadata.getDescription()) - .field(MODEL_ERROR, modelMetadata.getError()) - .endObject(); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(MODEL_ID, modelId) + .field(MODEL_STATE, modelMetadata.getState().getName()) + .field(KNN_ENGINE, modelMetadata.getKnnEngine().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, modelMetadata.getSpaceType().getValue()) + .field(DIMENSION, modelMetadata.getDimension()) + .field(MODEL_BLOB_PARAMETER, modelBase64) + .field(MODEL_TIMESTAMP, modelMetadata.getTimestamp()) + .field(MODEL_DESCRIPTION, modelMetadata.getDescription()) + .field(MODEL_ERROR, modelMetadata.getError()) + .endObject(); request.setJsonEntity(Strings.toString(builder)); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } /** @@ -719,13 +636,16 @@ protected void clearScriptCache() throws Exception { } protected Request constructScriptQueryRequest( - String indexName, QueryBuilder qb, Map params, String language, String source, int size) - throws Exception { + String indexName, + QueryBuilder qb, + Map params, + String language, + String source, + int size + ) throws Exception { Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, language, source, params); ScriptScoreQueryBuilder sc = new ScriptScoreQueryBuilder(qb, script); - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .field("size", size) - .startObject("query"); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("size", size).startObject("query"); builder.startObject("script_score"); builder.field("query"); sc.query().toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -733,16 +653,12 @@ protected Request constructScriptQueryRequest( builder.endObject(); builder.endObject(); builder.endObject(); - Request request = new Request( - "POST", - "/" + indexName + "/_search" - ); + Request request = new Request("POST", "/" + indexName + "/_search"); request.setJsonEntity(Strings.toString(builder)); return request; } - protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder qb, Map params) - throws Exception { + protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder qb, Map params) throws Exception { Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, KNNScoringScriptEngine.NAME, KNNScoringScriptEngine.SCRIPT_SOURCE, params); ScriptScoreQueryBuilder sc = new ScriptScoreQueryBuilder(qb, script); XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); @@ -753,22 +669,18 @@ protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder builder.endObject(); builder.endObject(); builder.endObject(); - Request request = new Request( - "POST", - "/" + indexName + "/_search" - ); + Request request = new Request("POST", "/" + indexName + "/_search"); request.setJsonEntity(Strings.toString(builder)); return request; } - protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder qb, Map params, - int size) throws Exception { + protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder qb, Map params, int size) + throws Exception { return constructScriptQueryRequest(indexName, qb, params, KNNScoringScriptEngine.NAME, KNNScoringScriptEngine.SCRIPT_SOURCE, size); } public Map xContentBuilderToMap(XContentBuilder xContentBuilder) { - return XContentHelper.convertToMap(BytesReference.bytes(xContentBuilder), true, - xContentBuilder.contentType()).v2(); + return XContentHelper.convertToMap(BytesReference.bytes(xContentBuilder), true, xContentBuilder.contentType()).v2(); } public void bulkIngestRandomVectors(String indexName, String fieldName, int numVectors, int dimension) throws IOException { @@ -778,32 +690,29 @@ public void bulkIngestRandomVectors(String indexName, String fieldName, int numV vector[j] = randomFloat(); } - addKnnDoc(indexName, String.valueOf(i+1), fieldName, Floats.asList(vector).toArray()); + addKnnDoc(indexName, String.valueOf(i + 1), fieldName, Floats.asList(vector).toArray()); } } - //Method that adds multiple documents into the index using Bulk API - public void bulkAddKnnDocs(String index, String fieldName, float[][] indexVectors, int docCount) throws IOException { - Request request = new Request( - "POST", - "/_bulk" - ); + // Method that adds multiple documents into the index using Bulk API + public void bulkAddKnnDocs(String index, String fieldName, float[][] indexVectors, int docCount) throws IOException { + Request request = new Request("POST", "/_bulk"); request.addParameter("refresh", "true"); StringBuilder sb = new StringBuilder(); for (int i = 0; i < docCount; i++) { sb.append("{ \"index\" : { \"_index\" : \"") - .append(index) - .append("\", \"_id\" : \"") - .append(i+1) - .append("\" } }\n") - .append("{ \"") - .append(fieldName) - .append("\" : ") - .append(Arrays.toString(indexVectors[i])) - .append(" }\n"); + .append(index) + .append("\", \"_id\" : \"") + .append(i + 1) + .append("\" } }\n") + .append("{ \"") + .append(fieldName) + .append("\" : ") + .append(Arrays.toString(indexVectors[i])) + .append(" }\n"); } request.setJsonEntity(sb.toString()); @@ -812,16 +721,13 @@ public void bulkAddKnnDocs(String index, String fieldName, float[][] indexVecto assertEquals(response.getStatusLine().getStatusCode(), 200); } - //Method that returns index vectors of the documents that were added before into the index + // Method that returns index vectors of the documents that were added before into the index public float[][] getIndexVectorsFromIndex(String testIndex, String testField, int docCount, int dimensions) throws IOException { float[][] vectors = new float[docCount][dimensions]; QueryBuilder qb = new MatchAllQueryBuilder(); - Request request = new Request( - "POST", - "/" + testIndex + "/_search" - ); + Request request = new Request("POST", "/" + testIndex + "/_search"); request.addParameter("size", Integer.toString(docCount)); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); @@ -830,8 +736,7 @@ public float[][] getIndexVectorsFromIndex(String testIndex, String testField, in request.setJsonEntity(Strings.toString(builder)); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), testField); int i = 0; @@ -851,7 +756,7 @@ public List> bulkSearch(String testIndex, String testField, float[] for (int i = 0; i < queryVectors.length; i++) { KNNQueryBuilder knnQueryBuilderRecall = new KNNQueryBuilder(testField, queryVectors[i], k); - Response respRecall = searchKNNIndex(testIndex, knnQueryBuilderRecall,k); + Response respRecall = searchKNNIndex(testIndex, knnQueryBuilderRecall, k); List resultsRecall = parseSearchResponse(EntityUtils.toString(respRecall.getEntity()), testField); assertEquals(resultsRecall.size(), k); @@ -877,16 +782,23 @@ public List> bulkSearch(String testIndex, String testField, float[] * @return Response returned by the cluster * @throws IOException if request cannot be performed */ - public Response trainModel(String modelId, String trainingIndexName, String trainingFieldName, int dimension, - Map method, String description) throws IOException { + public Response trainModel( + String modelId, + String trainingIndexName, + String trainingFieldName, + int dimension, + Map method, + String description + ) throws IOException { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .field(TRAIN_INDEX_PARAMETER, trainingIndexName) - .field(TRAIN_FIELD_PARAMETER, trainingFieldName) - .field(DIMENSION, dimension) - .field(KNN_METHOD, method) - .field(MODEL_DESCRIPTION, description) - .endObject(); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(TRAIN_INDEX_PARAMETER, trainingIndexName) + .field(TRAIN_FIELD_PARAMETER, trainingFieldName) + .field(DIMENSION, dimension) + .field(KNN_METHOD, method) + .field(MODEL_DESCRIPTION, description) + .endObject(); if (modelId == null) { modelId = ""; @@ -894,10 +806,7 @@ public Response trainModel(String modelId, String trainingIndexName, String trai modelId = "/" + modelId; } - Request request = new Request( - "POST", - "/_plugins/_knn/models" + modelId + "/_train" - ); + Request request = new Request("POST", "/_plugins/_knn/models" + modelId + "/_train"); request.setJsonEntity(Strings.toString(builder)); return client().performRequest(request); } @@ -924,16 +833,12 @@ public Response getModel(String modelId, List filters) throws IOExceptio filterString = "&filter_path=" + StringUtils.join(filters, ","); } - Request request = new Request( - "GET", - "/_plugins/_knn/models" + modelId + filterString - ); + Request request = new Request("GET", "/_plugins/_knn/models" + modelId + filterString); return client().performRequest(request); } - public void assertTrainingSucceeds(String modelId, int attempts, int delayInMillis) throws InterruptedException, - IOException { + public void assertTrainingSucceeds(String modelId, int attempts, int delayInMillis) throws InterruptedException, IOException { int attemptNum = 0; Response response; Map responseMap; @@ -944,10 +849,7 @@ public void assertTrainingSucceeds(String modelId, int attempts, int delayInMill response = getModel(modelId, null); - responseMap = createParser( - XContentType.JSON.xContent(), - EntityUtils.toString(response.getEntity()) - ).map(); + responseMap = createParser(XContentType.JSON.xContent(), EntityUtils.toString(response.getEntity())).map(); modelState = ModelState.getModelState((String) responseMap.get(MODEL_STATE)); if (modelState == ModelState.CREATED) { @@ -960,8 +862,7 @@ public void assertTrainingSucceeds(String modelId, int attempts, int delayInMill fail("Training did not succeed after " + attempts + " attempts with a delay of " + delayInMillis + " ms."); } - public void assertTrainingFails(String modelId, int attempts, int delayInMillis) throws InterruptedException, - IOException { + public void assertTrainingFails(String modelId, int attempts, int delayInMillis) throws InterruptedException, IOException { int attemptNum = 0; Response response; Map responseMap; @@ -972,10 +873,7 @@ public void assertTrainingFails(String modelId, int attempts, int delayInMillis) response = getModel(modelId, null); - responseMap = createParser( - XContentType.JSON.xContent(), - EntityUtils.toString(response.getEntity()) - ).map(); + responseMap = createParser(XContentType.JSON.xContent(), EntityUtils.toString(response.getEntity())).map(); modelState = ModelState.getModelState((String) responseMap.get(MODEL_STATE)); if (modelState == ModelState.FAILED) { diff --git a/src/test/java/org/opensearch/knn/KNNResult.java b/src/test/java/org/opensearch/knn/KNNResult.java index 81a878377..803c2ae72 100644 --- a/src/test/java/org/opensearch/knn/KNNResult.java +++ b/src/test/java/org/opensearch/knn/KNNResult.java @@ -18,7 +18,7 @@ public String getDocId() { return docId; } - public Float[] getVector() { + public Float[] getVector() { return vector; } -} \ No newline at end of file +} diff --git a/src/test/java/org/opensearch/knn/KNNTestCase.java b/src/test/java/org/opensearch/knn/KNNTestCase.java index 4ac6e700c..d5ac75287 100644 --- a/src/test/java/org/opensearch/knn/KNNTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNTestCase.java @@ -36,7 +36,6 @@ public static void resetState() { } public Map xContentBuilderToMap(XContentBuilder xContentBuilder) { - return XContentHelper.convertToMap(BytesReference.bytes(xContentBuilder), true, - xContentBuilder.contentType()).v2(); + return XContentHelper.convertToMap(BytesReference.bytes(xContentBuilder), true, xContentBuilder.contentType()).v2(); } -} \ No newline at end of file +} diff --git a/src/test/java/org/opensearch/knn/ODFERestTestCase.java b/src/test/java/org/opensearch/knn/ODFERestTestCase.java index c0b206b53..30dd63df4 100644 --- a/src/test/java/org/opensearch/knn/ODFERestTestCase.java +++ b/src/test/java/org/opensearch/knn/ODFERestTestCase.java @@ -81,28 +81,27 @@ protected static void configureHttpsClient(RestClientBuilder builder, Settings s } builder.setDefaultHeaders(defaultHeaders); builder.setHttpClientConfigCallback(httpClientBuilder -> { - String userName = Optional - .ofNullable(System.getProperty("user")) - .orElseThrow(() -> new RuntimeException("user name is missing")); - String password = Optional - .ofNullable(System.getProperty("password")) - .orElseThrow(() -> new RuntimeException("password is missing")); + String userName = Optional.ofNullable(System.getProperty("user")) + .orElseThrow(() -> new RuntimeException("user name is missing")); + String password = Optional.ofNullable(System.getProperty("password")) + .orElseThrow(() -> new RuntimeException("password is missing")); CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(userName, password)); try { - return httpClientBuilder - .setDefaultCredentialsProvider(credentialsProvider) - // disable the certificate since our testing cluster just uses the default security configuration - .setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE) - .setSSLContext(SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build()); + return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider) + // disable the certificate since our testing cluster just uses the default security configuration + .setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE) + .setSSLContext(SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build()); } catch (Exception e) { throw new RuntimeException(e); } }); final String socketTimeoutString = settings.get(CLIENT_SOCKET_TIMEOUT); - final TimeValue socketTimeout = TimeValue - .parseTimeValue(socketTimeoutString == null ? "60s" : socketTimeoutString, CLIENT_SOCKET_TIMEOUT); + final TimeValue socketTimeout = TimeValue.parseTimeValue( + socketTimeoutString == null ? "60s" : socketTimeoutString, + CLIENT_SOCKET_TIMEOUT + ); builder.setRequestConfigCallback(conf -> conf.setSocketTimeout(Math.toIntExact(socketTimeout.getMillis()))); if (settings.hasValue(CLIENT_PATH_PREFIX)) { builder.setPathPrefix(settings.get(CLIENT_PATH_PREFIX)); @@ -123,13 +122,12 @@ protected void wipeAllODFEIndices() throws IOException { Response response = client().performRequest(new Request("GET", "/_cat/indices?format=json&expand_wildcards=all")); XContentType xContentType = XContentType.fromMediaTypeOrFormat(response.getEntity().getContentType().getValue()); try ( - XContentParser parser = xContentType - .xContent() - .createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - response.getEntity().getContent() - ) + XContentParser parser = xContentType.xContent() + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + response.getEntity().getContent() + ) ) { XContentParser.Token token = parser.nextToken(); List> parserList = null; @@ -148,12 +146,11 @@ protected void wipeAllODFEIndices() throws IOException { } } - private boolean skipDeleteIndex(String indexName){ - if (indexName != null && !OPENDISTRO_SECURITY.equals(indexName) && !indexName.matches(KNN_BWC_PREFIX+"(.*)")){ + private boolean skipDeleteIndex(String indexName) { + if (indexName != null && !OPENDISTRO_SECURITY.equals(indexName) && !indexName.matches(KNN_BWC_PREFIX + "(.*)")) { return false; } return true; } } - diff --git a/src/test/java/org/opensearch/knn/TestUtils.java b/src/test/java/org/opensearch/knn/TestUtils.java index f4968a255..92bf2f740 100644 --- a/src/test/java/org/opensearch/knn/TestUtils.java +++ b/src/test/java/org/opensearch/knn/TestUtils.java @@ -21,7 +21,6 @@ import java.io.IOException; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.plugin.script.KNNScoringUtil; -import org.opensearch.knn.plugin.stats.suppliers.ModelIndexStatusSupplier; import java.util.Comparator; import java.util.Random; import java.util.Set; @@ -36,7 +35,7 @@ class DistVector { public float dist; public String docID; - public DistVector (float dist, String docID) { + public DistVector(float dist, String docID) { this.dist = dist; this.docID = docID; } @@ -117,10 +116,10 @@ public static List> computeGroundTruthValues(float[][] indexVectors, } if (pq.size() < k) { - pq.add(new DistVector(dist, String.valueOf(j+1))); + pq.add(new DistVector(dist, String.valueOf(j + 1))); } else if (pq.peek().getDist() > dist) { pq.poll(); - pq.add(new DistVector(dist, String.valueOf(j+1))); + pq.add(new DistVector(dist, String.valueOf(j + 1))); } } @@ -137,7 +136,7 @@ public static List> computeGroundTruthValues(float[][] indexVectors, public static float[][] getQueryVectors(int queryCount, int dimensions, int docCount, boolean isStandard) { if (isStandard) { - return randomlyGenerateStandardVectors(queryCount, dimensions, docCount+1); + return randomlyGenerateStandardVectors(queryCount, dimensions, docCount + 1); } else { return generateRandomVectors(queryCount, dimensions); } @@ -169,8 +168,8 @@ public static double calculateRecallValue(List> searchResults, List recalls.add(recallVal / k); } - double sum = recalls.stream().reduce((a,b)->a+b).get(); - return sum/recalls.size(); + double sum = recalls.stream().reduce((a, b) -> a + b).get(); + return sum / recalls.size(); } /** @@ -192,14 +191,15 @@ private KNNCodecUtil.Pair readIndexData(String path) throws IOException { BufferedReader reader = new BufferedReader(new FileReader(path)); String line = reader.readLine(); while (line != null) { - Map doc = XContentFactory.xContent(XContentType.JSON).createParser( - NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, line).map(); + Map doc = XContentFactory.xContent(XContentType.JSON) + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, line) + .map(); idsList.add((Integer) doc.get("id")); @SuppressWarnings("unchecked") ArrayList vector = (ArrayList) doc.get("vector"); Float[] floatArray = new Float[vector.size()]; - for (int i =0; i< vector.size(); i++) { + for (int i = 0; i < vector.size(); i++) { floatArray[i] = vector.get(i).floatValue(); } vectorsList.add(floatArray); @@ -208,7 +208,7 @@ private KNNCodecUtil.Pair readIndexData(String path) throws IOException { } reader.close(); - int[] idsArray = new int [idsList.size()]; + int[] idsArray = new int[idsList.size()]; float[][] vectorsArray = new float[vectorsList.size()][vectorsList.get(0).length]; for (int i = 0; i < idsList.size(); i++) { idsArray[i] = idsList.get(i); diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 91f05f5a6..157eb0ce8 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -9,7 +9,6 @@ * GitHub history for details. */ - package org.opensearch.knn.index; import com.google.common.collect.ImmutableList; @@ -70,24 +69,24 @@ public void testEndToEnd_fromMethod() throws IOException, InterruptedException { // Create an index XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(fieldName) - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(KNNConstants.PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) - .endObject() - .endObject() - .endObject() - .endObject() - .endObject(); + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); Map mappingMap = xContentBuilderToMap(builder); String mapping = Strings.toString(builder); @@ -97,8 +96,12 @@ public void testEndToEnd_fromMethod() throws IOException, InterruptedException { // Index the test data for (int i = 0; i < testData.indexData.docs.length; i++) { - addKnnDoc(indexName, Integer.toString(testData.indexData.docs[i]), fieldName, - Floats.asList(testData.indexData.vectors[i]).toArray()); + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + fieldName, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); } // Assert we have the right number of documents in the index @@ -115,8 +118,11 @@ public void testEndToEnd_fromMethod() throws IOException, InterruptedException { List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); - assertEquals(KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), - spaceType), actualScores.get(j), 0.0001); + assertEquals( + KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), + actualScores.get(j), + 0.0001 + ); } } @@ -130,7 +136,7 @@ public void testEndToEnd_fromMethod() throws IOException, InterruptedException { return; } - Thread.sleep(5*1000); + Thread.sleep(5 * 1000); } fail("Graphs are not getting evicted"); @@ -146,28 +152,28 @@ public void testDocUpdate() throws IOException { // Create an index XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(fieldName) - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .endObject() - .endObject() - .endObject() - .endObject(); + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .endObject() + .endObject(); String mapping = Strings.toString(builder); createKnnIndex(indexName, mapping); - Float[] vector = {6.0f, 6.0f}; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); // update - Float[] updatedVector = {8.0f, 8.0f}; + Float[] updatedVector = { 8.0f, 8.0f }; updateKnnDoc(INDEX_NAME, "1", FIELD_NAME, updatedVector); } @@ -182,24 +188,24 @@ public void testDocDeletion() throws IOException { // Create an index XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(fieldName) - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) - .endObject() - .endObject() - .endObject() - .endObject(); + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .endObject() + .endObject(); String mapping = Strings.toString(builder); createKnnIndex(indexName, mapping); - Float[] vector = {6.0f, 6.0f}; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); // delete knn doc @@ -219,14 +225,15 @@ public void testEndToEnd_fromModel() throws IOException, InterruptedException { bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); // Call train API - IVF with nlists = 1 is brute force, but will require training - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .field(NAME, "ivf") - .field(KNN_ENGINE, "faiss") - .field(METHOD_PARAMETER_SPACE_TYPE, "l2") - .startObject(PARAMETERS) - .field(METHOD_PARAMETER_NLIST, 1) - .endObject() - .endObject(); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, "ivf") + .field(KNN_ENGINE, "faiss") + .field(METHOD_PARAMETER_SPACE_TYPE, "l2") + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 1) + .endObject() + .endObject(); Map method = xContentBuilderToMap(builder); trainModel(modelId, trainingIndexName, trainingFieldName, dimension, method, "faiss test description"); @@ -237,14 +244,17 @@ public void testEndToEnd_fromModel() throws IOException, InterruptedException { // Create knn index from model String fieldName = "test-field-name"; String indexName = "test-index-name"; - String indexMapping = Strings.toString(XContentFactory.jsonBuilder().startObject() + String indexMapping = Strings.toString( + XContentFactory.jsonBuilder() + .startObject() .startObject("properties") .startObject(fieldName) .field("type", "knn_vector") .field(MODEL_ID, modelId) .endObject() .endObject() - .endObject()); + .endObject() + ); createKnnIndex(indexName, getKNNDefaultIndexSettings(), indexMapping); diff --git a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java index e90e300f7..7013ef261 100644 --- a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java @@ -47,9 +47,7 @@ public void testGetLoadParameters() { int efSearchValue = 413; // We use the constant for the setting here as opposed to the identifier of efSearch in nmslib jni - Map indexSettings = ImmutableMap.of( - KNN_ALGO_PARAM_EF_SEARCH, efSearchValue - ); + Map indexSettings = ImmutableMap.of(KNN_ALGO_PARAM_EF_SEARCH, efSearchValue); // Because ef search comes from an index setting, we need to mock the long line of calls to get those // index settings diff --git a/src/test/java/org/opensearch/knn/index/KNNCircuitBreakerIT.java b/src/test/java/org/opensearch/knn/index/KNNCircuitBreakerIT.java index 1e0f94c42..86b728ab9 100644 --- a/src/test/java/org/opensearch/knn/index/KNNCircuitBreakerIT.java +++ b/src/test/java/org/opensearch/knn/index/KNNCircuitBreakerIT.java @@ -35,10 +35,10 @@ private void tripCb() throws Exception { // Create index with 1 primary and numNodes-1 replicas so that the data will be on every node in the cluster int numNodes = Integer.parseInt(System.getProperty("cluster.number_of_nodes", "1")); Settings settings = Settings.builder() - .put("number_of_shards", 1) - .put("number_of_replicas", numNodes - 1) - .put("index.knn", true) - .build(); + .put("number_of_shards", 1) + .put("number_of_replicas", numNodes - 1) + .put("index.knn", true) + .build(); String indexName1 = INDEX_NAME + "1"; String indexName2 = INDEX_NAME + "2"; @@ -46,7 +46,7 @@ private void tripCb() throws Exception { createKnnIndex(indexName1, settings, createKnnIndexMapping(FIELD_NAME, 2)); createKnnIndex(indexName2, settings, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] vector = {1.3f, 2.2f}; + Float[] vector = { 1.3f, 2.2f }; int docsInIndex = 5; // through testing, 7 is minimum number of docs to trip circuit breaker at 1kb for (int i = 0; i < docsInIndex; i++) { @@ -58,7 +58,7 @@ private void tripCb() throws Exception { forceMergeKnnIndex(indexName2); // Execute search on both indices - will cause eviction - float[] qvector = {1.9f, 2.4f}; + float[] qvector = { 1.9f, 2.4f }; int k = 10; // Ensure that each shard is searched over so that each Lucene segment gets loaded into memory @@ -68,13 +68,12 @@ private void tripCb() throws Exception { } // Give cluster 5 seconds to update settings and then assert that Cb get triggered - Thread.sleep(5*1000); // seconds + Thread.sleep(5 * 1000); // seconds assertTrue(isCbTripped()); } public boolean isCbTripped() throws Exception { - Response response = getKnnStats(Collections.emptyList(), - Collections.singletonList("circuit_breaker_triggered")); + Response response = getKnnStats(Collections.emptyList(), Collections.singletonList("circuit_breaker_triggered")); String responseBody = EntityUtils.toString(response.getEntity()); Map clusterStats = parseClusterStatsResponse(responseBody); return Boolean.parseBoolean(clusterStats.get("circuit_breaker_triggered").toString()); @@ -89,12 +88,12 @@ public void testCbUntrips() throws Exception { assertTrue(isCbTripped()); int backOffInterval = 5; // seconds - for (int i = 0; i < CB_TIME_INTERVAL; i+=backOffInterval) { + for (int i = 0; i < CB_TIME_INTERVAL; i += backOffInterval) { if (!isCbTripped()) { break; } - Thread.sleep(backOffInterval*1000); + Thread.sleep(backOffInterval * 1000); } assertFalse(isCbTripped()); } -} \ No newline at end of file +} diff --git a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java index e1feb9f18..fa9583e3c 100644 --- a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java @@ -45,14 +45,23 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException // "Train" a faiss flat index - this really just creates an empty index that does brute force k-NN long vectorsPointer = JNIService.transferVectors(0, new float[0][0]); - byte [] modelBlob = JNIService.trainIndex(ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, "Flat", - SPACE_TYPE, spaceType.getValue()), dimension, vectorsPointer, - KNNEngine.FAISS.getName()); + byte[] modelBlob = JNIService.trainIndex( + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, "Flat", SPACE_TYPE, spaceType.getValue()), + dimension, + vectorsPointer, + KNNEngine.FAISS.getName() + ); // Setup model - ModelMetadata modelMetadata = new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata = new ModelMetadata( + knnEngine, + spaceType, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); Model model = new Model(modelMetadata, modelBlob, modelId); @@ -65,33 +74,21 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException String fieldName = "test-field"; modelDao.put(model, ActionListener.wrap(indexResponse -> { - CreateIndexRequestBuilder createIndexRequestBuilder = client().admin().indices().prepareCreate(indexName) - .setSettings(Settings.builder() - .put("number_of_shards", 1) - .put("number_of_replicas", 0) - .put("index.knn", true) - .build() - ).addMapping( - "_doc", ImmutableMap.of( - "properties", ImmutableMap.of( - fieldName, ImmutableMap.of( - "type", "knn_vector", - "model_id", modelId - ) - ) - ) - ); - - client().admin().indices().create(createIndexRequestBuilder.request(), - ActionListener.wrap( - createIndexResponse -> { - assertTrue(createIndexResponse.isAcknowledged()); - inProgressLatch.countDown(); - }, e -> fail("Unable to create index: " + e.getMessage()) - ) - ); - - }, e ->fail("Unable to put model: " + e.getMessage()))); + CreateIndexRequestBuilder createIndexRequestBuilder = client().admin() + .indices() + .prepareCreate(indexName) + .setSettings(Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).put("index.knn", true).build()) + .addMapping( + "_doc", + ImmutableMap.of("properties", ImmutableMap.of(fieldName, ImmutableMap.of("type", "knn_vector", "model_id", modelId))) + ); + + client().admin().indices().create(createIndexRequestBuilder.request(), ActionListener.wrap(createIndexResponse -> { + assertTrue(createIndexResponse.isAcknowledged()); + inProgressLatch.countDown(); + }, e -> fail("Unable to create index: " + e.getMessage()))); + + }, e -> fail("Unable to put model: " + e.getMessage()))); assertTrue(inProgressLatch.await(20, TimeUnit.SECONDS)); } diff --git a/src/test/java/org/opensearch/knn/index/KNNESSettingsTestIT.java b/src/test/java/org/opensearch/knn/index/KNNESSettingsTestIT.java index 5e71ceb2f..1319448c8 100644 --- a/src/test/java/org/opensearch/knn/index/KNNESSettingsTestIT.java +++ b/src/test/java/org/opensearch/knn/index/KNNESSettingsTestIT.java @@ -27,22 +27,21 @@ public class KNNESSettingsTestIT extends KNNRestTestCase { public void testIndexWritesPluginDisabled() throws Exception { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] vector = {6.0f, 6.0f}; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); - float[] qvector = {1.0f, 2.0f}; + float[] qvector = { 1.0f, 2.0f }; Response response = searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1); assertEquals("knn query failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - //disable plugin + // disable plugin updateClusterSettings(KNNSettings.KNN_PLUGIN_ENABLED, false); // indexing should be blocked - Exception ex = expectThrows(ResponseException.class, - () -> addKnnDoc(INDEX_NAME, "2", FIELD_NAME, vector)); + Exception ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, "2", FIELD_NAME, vector)); assertThat(ex.getMessage(), containsString("KNN plugin is disabled")); - //enable plugin + // enable plugin updateClusterSettings(KNNSettings.KNN_PLUGIN_ENABLED, true); addKnnDoc(INDEX_NAME, "3", FIELD_NAME, vector); } @@ -50,21 +49,23 @@ public void testIndexWritesPluginDisabled() throws Exception { public void testQueriesPluginDisabled() throws Exception { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] vector = {6.0f, 6.0f}; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); - float[] qvector = {1.0f, 2.0f}; + float[] qvector = { 1.0f, 2.0f }; Response response = searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1); assertEquals("knn query failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - //update settings + // update settings updateClusterSettings(KNNSettings.KNN_PLUGIN_ENABLED, false); // indexing should be blocked - Exception ex = expectThrows(ResponseException.class, - () -> searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1)); + Exception ex = expectThrows( + ResponseException.class, + () -> searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1) + ); assertThat(ex.getMessage(), containsString("KNN plugin is disabled")); - //enable plugin + // enable plugin updateClusterSettings(KNNSettings.KNN_PLUGIN_ENABLED, true); searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1); } @@ -74,10 +75,10 @@ public void testItemRemovedFromCache_expiration() throws Exception { updateClusterSettings(KNNSettings.KNN_CACHE_ITEM_EXPIRY_ENABLED, true); updateClusterSettings(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES, "1m"); - Float[] vector = {6.0f, 6.0f}; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); - float[] qvector = {1.0f, 2.0f}; + float[] qvector = { 1.0f, 2.0f }; Response response = searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1); assertEquals("knn query failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); assertEquals(1, getTotalGraphsInCache()); @@ -97,36 +98,32 @@ public void testCreateIndexWithInvalidSpaceType() throws IOException { .put("index.knn", true) .put("index.knn.space_type", invalidSpaceType) .build(); - expectThrows(ResponseException.class, - () -> createKnnIndex(INDEX_NAME, invalidSettings, createKnnIndexMapping(FIELD_NAME, 2))); + expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, invalidSettings, createKnnIndexMapping(FIELD_NAME, 2))); } public void testUpdateIndexSetting() throws IOException { - Settings settings = Settings.builder() - .put("index.knn", true) - .put(KNNSettings.KNN_ALGO_PARAM_EF_SEARCH, 512) - .build(); + Settings settings = Settings.builder().put("index.knn", true).put(KNNSettings.KNN_ALGO_PARAM_EF_SEARCH, 512).build(); createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, 2)); assertEquals("512", getIndexSettingByName(INDEX_NAME, KNNSettings.KNN_ALGO_PARAM_EF_SEARCH)); updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.KNN_ALGO_PARAM_EF_SEARCH, 400)); assertEquals("400", getIndexSettingByName(INDEX_NAME, KNNSettings.KNN_ALGO_PARAM_EF_SEARCH)); - Exception ex = expectThrows(ResponseException.class, - () -> updateIndexSettings(INDEX_NAME, - Settings.builder().put(KNNSettings.KNN_ALGO_PARAM_EF_SEARCH, 1))); - assertThat(ex.getMessage(), - containsString("Failed to parse value [1] for setting [index.knn.algo_param.ef_search] must be >= 2")); + Exception ex = expectThrows( + ResponseException.class, + () -> updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.KNN_ALGO_PARAM_EF_SEARCH, 1)) + ); + assertThat(ex.getMessage(), containsString("Failed to parse value [1] for setting [index.knn.algo_param.ef_search] must be >= 2")); } @SuppressWarnings("unchecked") public void testCacheRebuiltAfterUpdateIndexSettings() throws IOException { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] vector = {6.0f, 6.0f}; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); - float[] qvector = {6.0f, 6.0f}; + float[] qvector = { 6.0f, 6.0f }; // First search to load graph into cache searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1); @@ -149,4 +146,3 @@ public void testCacheRebuiltAfterUpdateIndexSettings() throws IOException { assertEquals(0, indicesInCache.size()); } } - diff --git a/src/test/java/org/opensearch/knn/index/KNNMapperSearcherIT.java b/src/test/java/org/opensearch/knn/index/KNNMapperSearcherIT.java index e3c21d3d6..b98b081ab 100644 --- a/src/test/java/org/opensearch/knn/index/KNNMapperSearcherIT.java +++ b/src/test/java/org/opensearch/knn/index/KNNMapperSearcherIT.java @@ -23,16 +23,16 @@ public class KNNMapperSearcherIT extends KNNRestTestCase { * Test Data set */ private void addTestData() throws Exception { - Float[] f1 = {6.0f, 6.0f}; + Float[] f1 = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - Float[] f2 = {2.0f, 2.0f}; + Float[] f2 = { 2.0f, 2.0f }; addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - Float[] f3 = {4.0f, 4.0f}; + Float[] f3 = { 4.0f, 4.0f }; addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - Float[] f4 = {3.0f, 3.0f}; + Float[] f4 = { 3.0f, 3.0f }; addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); } @@ -44,8 +44,8 @@ public void testKNNResultsWithForceMerge() throws Exception { /** * Query params */ - float[] queryVector = {1.0f, 1.0f}; // vector to be queried - int k = 1; // nearest 1 neighbor + float[] queryVector = { 1.0f, 1.0f }; // vector to be queried + int k = 1; // nearest 1 neighbor KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); @@ -53,22 +53,22 @@ public void testKNNResultsWithForceMerge() throws Exception { List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(k, results.size()); - for(KNNResult result : results) { + for (KNNResult result : results) { assertEquals("2", result.getDocId()); } } public void testKNNResultsUpdateDocAndForceMerge() throws Exception { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - addDocWithNumericField(INDEX_NAME, "1", "abc", 100 ); + addDocWithNumericField(INDEX_NAME, "1", "abc", 100); addTestData(); forceMergeKnnIndex(INDEX_NAME); /** * Query params */ - float[] queryVector = {1.0f, 1.0f}; // vector to be queried - int k = 1; // nearest 1 neighbor + float[] queryVector = { 1.0f, 1.0f }; // vector to be queried + int k = 1; // nearest 1 neighbor KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); @@ -76,7 +76,7 @@ public void testKNNResultsUpdateDocAndForceMerge() throws Exception { List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(k, results.size()); - for(KNNResult result : results) { + for (KNNResult result : results) { assertEquals("2", result.getDocId()); } } @@ -88,16 +88,16 @@ public void testKNNResultsWithoutForceMerge() throws Exception { /** * Query params */ - float[] queryVector = {2.0f, 2.0f}; // vector to be queried - int k = 3; //nearest 3 neighbors + float[] queryVector = { 2.0f, 2.0f }; // vector to be queried + int k = 3; // nearest 3 neighbors KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); - Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder,k); + Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); List expectedDocids = Arrays.asList("2", "4", "3"); List actualDocids = new ArrayList<>(); - for(KNNResult result : results) { + for (KNNResult result : results) { actualDocids.add(result.getDocId()); } @@ -109,42 +109,41 @@ public void testKNNResultsWithNewDoc() throws Exception { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); addTestData(); - float[] queryVector = {1.0f, 1.0f}; // vector to be queried - int k = 1; // nearest 1 neighbor + float[] queryVector = { 1.0f, 1.0f }; // vector to be queried + int k = 1; // nearest 1 neighbor KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); - Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder,k); + Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(results.size(), k); - for(KNNResult result : results) { - assertEquals("2", result.getDocId()); //Vector of DocId 2 is closest to the query + for (KNNResult result : results) { + assertEquals("2", result.getDocId()); // Vector of DocId 2 is closest to the query } /** * Add new doc with vector not nearest than doc 2 */ - Float[] newVector = {6.0f, 6.0f}; + Float[] newVector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "6", FIELD_NAME, newVector); - response = searchKNNIndex(INDEX_NAME, knnQueryBuilder,k); + response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(results.size(), k); - for(KNNResult result : results) { + for (KNNResult result : results) { assertEquals("2", result.getDocId()); } - /** * Add new doc with vector nearest than doc 2 to queryVector */ - Float[] newVector1 = {0.5f, 0.5f}; + Float[] newVector1 = { 0.5f, 0.5f }; addKnnDoc(INDEX_NAME, "7", FIELD_NAME, newVector1); - response = searchKNNIndex(INDEX_NAME, knnQueryBuilder,k); + response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(results.size(), k); - for(KNNResult result : results) { + for (KNNResult result : results) { assertEquals("7", result.getDocId()); } } @@ -153,28 +152,28 @@ public void testKNNResultsWithUpdateDoc() throws Exception { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); addTestData(); - float[] queryVector = {1.0f, 1.0f}; // vector to be queried - int k = 1; // nearest 1 neighbor + float[] queryVector = { 1.0f, 1.0f }; // vector to be queried + int k = 1; // nearest 1 neighbor KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); - Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder,k); + Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(results.size(), k); - for(KNNResult result : results) { - assertEquals("2", result.getDocId()); //Vector of DocId 2 is closest to the query + for (KNNResult result : results) { + assertEquals("2", result.getDocId()); // Vector of DocId 2 is closest to the query } /** * update doc 3 to the nearest */ - Float[] updatedVector = {0.1f, 0.1f}; + Float[] updatedVector = { 0.1f, 0.1f }; updateKnnDoc(INDEX_NAME, "3", FIELD_NAME, updatedVector); - response = searchKNNIndex(INDEX_NAME, knnQueryBuilder,k); + response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(results.size(), k); - for(KNNResult result : results) { - assertEquals("3", result.getDocId()); //Vector of DocId 3 is closest to the query + for (KNNResult result : results) { + assertEquals("3", result.getDocId()); // Vector of DocId 3 is closest to the query } } @@ -182,30 +181,29 @@ public void testKNNResultsWithDeleteDoc() throws Exception { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); addTestData(); - float[] queryVector = {1.0f, 1.0f}; // vector to be queried - int k = 1; // nearest 1 neighbor + float[] queryVector = { 1.0f, 1.0f }; // vector to be queried + int k = 1; // nearest 1 neighbor KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(results.size(), k); - for(KNNResult result : results) { - assertEquals("2", result.getDocId()); //Vector of DocId 2 is closest to the query + for (KNNResult result : results) { + assertEquals("2", result.getDocId()); // Vector of DocId 2 is closest to the query } - /** * delete the nearest doc (doc2) */ deleteKnnDoc(INDEX_NAME, "2"); - knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k+1); - response = searchKNNIndex(INDEX_NAME, knnQueryBuilder,k); + knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k + 1); + response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(results.size(), k); - for(KNNResult result : results) { - assertEquals("4", result.getDocId()); //Vector of DocId 4 is closest to the query + for (KNNResult result : results) { + assertEquals("4", result.getDocId()); // Vector of DocId 4 is closest to the query } } @@ -213,7 +211,7 @@ public void testKNNResultsWithDeleteDoc() throws Exception { * For negative K, query builder should throw Exception */ public void testNegativeK() { - float[] vector = {1.0f, 2.0f}; + float[] vector = { 1.0f, 2.0f }; expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, vector, -1)); } @@ -221,7 +219,7 @@ public void testNegativeK() { * For zero K, query builder should throw Exception */ public void testZeroK() { - float[] vector = {1.0f, 2.0f}; + float[] vector = { 1.0f, 2.0f }; expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, vector, 0)); } @@ -232,8 +230,8 @@ public void testLargeK() throws Exception { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); addTestData(); - float[] queryVector = {1.0f, 1.0f}; // vector to be queried - int k = KNNQueryBuilder.K_MAX; // nearest 1 neighbor + float[] queryVector = { 1.0f, 1.0f }; // vector to be queried + int k = KNNQueryBuilder.K_MAX; // nearest 1 neighbor KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); diff --git a/src/test/java/org/opensearch/knn/index/KNNMethodTests.java b/src/test/java/org/opensearch/knn/index/KNNMethodTests.java index ce5ff44aa..18aa90b46 100644 --- a/src/test/java/org/opensearch/knn/index/KNNMethodTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNMethodTests.java @@ -13,7 +13,6 @@ import com.google.common.collect.ImmutableMap; import org.opensearch.knn.KNNTestCase; -import org.opensearch.common.ValidationException; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.common.KNNConstants; @@ -27,15 +26,13 @@ import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; - public class KNNMethodTests extends KNNTestCase { /** * Test KNNMethod method component getter */ public void testGetMethodComponent() { String name = "test"; - KNNMethod knnMethod = KNNMethod.Builder.builder(MethodComponent.Builder.builder(name).build()) - .build(); + KNNMethod knnMethod = KNNMethod.Builder.builder(MethodComponent.Builder.builder(name).build()).build(); assertEquals(name, knnMethod.getMethodComponent().getName()); } @@ -45,8 +42,8 @@ public void testGetMethodComponent() { public void testHasSpace() { String name = "test"; KNNMethod knnMethod = KNNMethod.Builder.builder(MethodComponent.Builder.builder(name).build()) - .addSpaces(SpaceType.L2, SpaceType.COSINESIMIL) - .build(); + .addSpaces(SpaceType.L2, SpaceType.COSINESIMIL) + .build(); assertTrue(knnMethod.containsSpace(SpaceType.L2)); assertTrue(knnMethod.containsSpace(SpaceType.COSINESIMIL)); assertFalse(knnMethod.containsSpace(SpaceType.INNER_PRODUCT)); @@ -58,36 +55,39 @@ public void testHasSpace() { public void testValidate() throws IOException { String methodName = "test-method"; KNNMethod knnMethod = KNNMethod.Builder.builder(MethodComponent.Builder.builder(methodName).build()) - .addSpaces(SpaceType.L2) - .build(); + .addSpaces(SpaceType.L2) + .build(); // Invalid space - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field(NAME, methodName) - .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue()) - .endObject(); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, methodName) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue()) + .endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext1 = KNNMethodContext.parse(in); assertNotNull(knnMethod.validate(knnMethodContext1)); // Invalid methodComponent - xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field(NAME, methodName) - .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) - .startObject(PARAMETERS) - .field("invalid", "invalid") - .endObject() - .endObject(); + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, methodName) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .startObject(PARAMETERS) + .field("invalid", "invalid") + .endObject() + .endObject(); in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext2 = KNNMethodContext.parse(in); assertNotNull(knnMethod.validate(knnMethodContext2)); // Valid everything - xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field(NAME, methodName) - .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) - .endObject(); + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, methodName) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .endObject(); in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext3 = KNNMethodContext.parse(in); assertNull(knnMethod.validate(knnMethodContext3)); @@ -98,8 +98,8 @@ public void testGetAsMap() { String methodName = "test-method"; Map generatedMap = ImmutableMap.of("test-key", "test-value"); MethodComponent methodComponent = MethodComponent.Builder.builder(methodName) - .setMapGenerator(((methodComponent1, methodComponentContext) -> generatedMap)) - .build(); + .setMapGenerator(((methodComponent1, methodComponentContext) -> generatedMap)) + .build(); KNNMethod knnMethod = KNNMethod.Builder.builder(methodComponent).build(); diff --git a/src/test/java/org/opensearch/knn/index/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/KNNQueryBuilderTests.java index 96e23d28d..387f7e05d 100644 --- a/src/test/java/org/opensearch/knn/index/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNQueryBuilderTests.java @@ -24,25 +24,22 @@ public class KNNQueryBuilderTests extends KNNTestCase { public void testInvalidK() { - float[] queryVector = {1.0f, 1.0f}; + float[] queryVector = { 1.0f, 1.0f }; /** * -ve k */ - expectThrows(IllegalArgumentException.class, - () -> new KNNQueryBuilder("myvector", queryVector, -1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, -1)); /** * zero k */ - expectThrows(IllegalArgumentException.class, - () -> new KNNQueryBuilder("myvector", queryVector, 0)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, 0)); /** * k > KNNQueryBuilder.K_MAX */ - expectThrows(IllegalArgumentException.class, - () -> new KNNQueryBuilder("myvector", queryVector, KNNQueryBuilder.K_MAX + 1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, KNNQueryBuilder.K_MAX + 1)); } public void testEmptyVector() { @@ -50,19 +47,17 @@ public void testEmptyVector() { * null query vector */ float[] queryVector = null; - expectThrows(IllegalArgumentException.class, - () -> new KNNQueryBuilder("myvector", queryVector, 1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, 1)); /** * empty query vector */ float[] queryVector1 = {}; - expectThrows(IllegalArgumentException.class, - () -> new KNNQueryBuilder("myvector", queryVector1, 1)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector1, 1)); } public void testFromXcontent() throws Exception { - float[] queryVector = {1.0f, 2.0f, 3.0f, 4.0f}; + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); @@ -78,7 +73,7 @@ public void testFromXcontent() throws Exception { } public void testDoToQuery_Normal() throws Exception { - float[] queryVector = {1.0f, 2.0f, 3.0f, 4.0f}; + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); @@ -86,14 +81,14 @@ public void testDoToQuery_Normal() throws Exception { when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - KNNQuery query = (KNNQuery)knnQueryBuilder.doToQuery(mockQueryShardContext); + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); assertEquals(knnQueryBuilder.getK(), query.getK()); assertEquals(knnQueryBuilder.fieldName(), query.getField()); assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } public void testDoToQuery_FromModel() throws Exception { - float[] queryVector = {1.0f, 2.0f, 3.0f, 4.0f}; + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); @@ -113,14 +108,14 @@ public void testDoToQuery_FromModel() throws Exception { KNNQueryBuilder.initialize(modelDao); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - KNNQuery query = (KNNQuery)knnQueryBuilder.doToQuery(mockQueryShardContext); + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); assertEquals(knnQueryBuilder.getK(), query.getK()); assertEquals(knnQueryBuilder.fieldName(), query.getField()); assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } public void testDoToQuery_InvalidDimensions() { - float[] queryVector = {1.0f, 2.0f, 3.0f, 4.0f}; + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); @@ -134,7 +129,7 @@ public void testDoToQuery_InvalidDimensions() { } public void testDoToQuery_InvalidFieldType() throws IOException { - float[] queryVector = {1.0f, 2.0f, 3.0f, 4.0f}; + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("mynumber", queryVector, 1); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java index 87d00e554..cdd031f4d 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java @@ -43,9 +43,11 @@ private void createKNNVectorDocument(Directory directory) throws IOException { IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); knnDocument.add( - new BinaryDocValuesField( - MOCK_INDEX_FIELD_NAME, - new VectorField(MOCK_INDEX_FIELD_NAME, new float[]{1.0f, 2.0f}, new FieldType()).binaryValue())); + new BinaryDocValuesField( + MOCK_INDEX_FIELD_NAME, + new VectorField(MOCK_INDEX_FIELD_NAME, new float[] { 1.0f, 2.0f }, new FieldType()).binaryValue() + ) + ); knnDocument.add(new NumericDocValuesField(MOCK_NUMERIC_INDEX_FIELD_NAME, 1000)); writer.addDocument(knnDocument); writer.commit(); @@ -67,16 +69,14 @@ public void testGetScriptValues() { } public void testGetScriptValuesWrongFieldName() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData( - leafReaderContext.reader(), "invalid"); + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "invalid"); ScriptDocValues scriptValues = leafFieldData.getScriptValues(); assertNotNull(scriptValues); } public void testGetScriptValuesWrongFieldType() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData( - leafReaderContext.reader(), MOCK_NUMERIC_INDEX_FIELD_NAME); - expectThrows(IllegalStateException.class, ()->leafFieldData.getScriptValues()); + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), MOCK_NUMERIC_INDEX_FIELD_NAME); + expectThrows(IllegalStateException.class, () -> leafFieldData.getScriptValues()); } public void testRamBytesUsed() { @@ -86,7 +86,6 @@ public void testRamBytesUsed() { public void testGetBytesValues() { KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), ""); - expectThrows(UnsupportedOperationException.class, - () -> leafFieldData.getBytesValues()); + expectThrows(UnsupportedOperationException.class, () -> leafFieldData.getBytesValues()); } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorFieldMapperTests.java index 424fff49d..b6a65cce9 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorFieldMapperTests.java @@ -58,15 +58,22 @@ public void testBuilder_build_fromKnnMethodContext() { // Setup settings Settings settings = Settings.builder() - .put(settings(CURRENT).build()) - .put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue()) - .put(KNNSettings.KNN_ALGO_PARAM_M, m) - .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) - .build(); - - builder.knnMethodContext.setValue(new KNNMethodContext(KNNEngine.DEFAULT, spaceType, - new MethodComponentContext(METHOD_HNSW, ImmutableMap.of(METHOD_PARAMETER_M, m, - METHOD_PARAMETER_EF_CONSTRUCTION, efConstruction)))); + .put(settings(CURRENT).build()) + .put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue()) + .put(KNNSettings.KNN_ALGO_PARAM_M, m) + .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) + .build(); + + builder.knnMethodContext.setValue( + new KNNMethodContext( + KNNEngine.DEFAULT, + spaceType, + new MethodComponentContext( + METHOD_HNSW, + ImmutableMap.of(METHOD_PARAMETER_M, m, METHOD_PARAMETER_EF_CONSTRUCTION, efConstruction) + ) + ) + ); builder.modelId.setValue("Random modelId"); @@ -88,15 +95,22 @@ public void testBuilder_build_fromModel() { // Setup settings Settings settings = Settings.builder() - .put(settings(CURRENT).build()) - .put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue()) - .put(KNNSettings.KNN_ALGO_PARAM_M, m) - .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) - .build(); + .put(settings(CURRENT).build()) + .put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue()) + .put(KNNSettings.KNN_ALGO_PARAM_M, m) + .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) + .build(); String modelId = "Random modelId"; - ModelMetadata mockedModelMetadata = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 129, - ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata mockedModelMetadata = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 129, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); builder.modelId.setValue(modelId); Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); @@ -118,11 +132,11 @@ public void testBuilder_build_fromLegacy() { // Setup settings Settings settings = Settings.builder() - .put(settings(CURRENT).build()) - .put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue()) - .put(KNNSettings.KNN_ALGO_PARAM_M, m) - .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) - .build(); + .put(settings(CURRENT).build()) + .put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue()) + .put(KNNSettings.KNN_ALGO_PARAM_M, m) + .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) + .build(); Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); @@ -137,82 +151,96 @@ public void testTypeParser_parse_fromKnnMethodContext() throws IOException { String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder() - .put(settings(CURRENT).build()) - .build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); int efConstruction = 321; int dimension = 133; - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNN_METHOD) - .field(NAME, METHOD_HNSW) - .startObject(PARAMETERS) - .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstruction) - .endObject() - .endObject() - .endObject(); - - KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse(fieldName, - xContentBuilderToMap(xContentBuilder), buildParserContext(indexName, settings)); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstruction) + .endObject() + .endObject() + .endObject(); + + KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + fieldName, + xContentBuilderToMap(xContentBuilder), + buildParserContext(indexName, settings) + ); assertEquals(METHOD_HNSW, builder.knnMethodContext.get().getMethodComponent().getName()); - assertEquals(efConstruction, builder.knnMethodContext.get().getMethodComponent().getParameters() - .get(METHOD_PARAMETER_EF_CONSTRUCTION)); + assertEquals( + efConstruction, + builder.knnMethodContext.get().getMethodComponent().getParameters().get(METHOD_PARAMETER_EF_CONSTRUCTION) + ); // Test invalid parameter - XContentBuilder xContentBuilder2 = XContentFactory.jsonBuilder().startObject() - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNN_METHOD) - .field(NAME, METHOD_HNSW) - .startObject(PARAMETERS) - .field("invalid", "invalid") - .endObject() - .endObject() - .endObject(); - - expectThrows(ValidationException.class, () -> typeParser.parse(fieldName, - xContentBuilderToMap(xContentBuilder2), buildParserContext(indexName, settings))); + XContentBuilder xContentBuilder2 = XContentFactory.jsonBuilder() + .startObject() + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .startObject(PARAMETERS) + .field("invalid", "invalid") + .endObject() + .endObject() + .endObject(); + + expectThrows( + ValidationException.class, + () -> typeParser.parse(fieldName, xContentBuilderToMap(xContentBuilder2), buildParserContext(indexName, settings)) + ); // Test invalid method - XContentBuilder xContentBuilder3 = XContentFactory.jsonBuilder().startObject() - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNN_METHOD) - .field(NAME, "invalid") - .endObject() - .endObject(); - - expectThrows(IllegalArgumentException.class, () -> typeParser.parse(fieldName, - xContentBuilderToMap(xContentBuilder3), buildParserContext(indexName, settings))); + XContentBuilder xContentBuilder3 = XContentFactory.jsonBuilder() + .startObject() + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, "invalid") + .endObject() + .endObject(); + + expectThrows( + IllegalArgumentException.class, + () -> typeParser.parse(fieldName, xContentBuilderToMap(xContentBuilder3), buildParserContext(indexName, settings)) + ); // Test missing required parameter: dimension - XContentBuilder xContentBuilder4 = XContentFactory.jsonBuilder().startObject() - .field("type", "knn_vector").endObject(); + XContentBuilder xContentBuilder4 = XContentFactory.jsonBuilder().startObject().field("type", "knn_vector").endObject(); - expectThrows(IllegalArgumentException.class, () -> typeParser.parse(fieldName, - xContentBuilderToMap(xContentBuilder4), buildParserContext(indexName, settings))); + expectThrows( + IllegalArgumentException.class, + () -> typeParser.parse(fieldName, xContentBuilderToMap(xContentBuilder4), buildParserContext(indexName, settings)) + ); // Check that this fails if model id is also set - XContentBuilder xContentBuilder5 = XContentFactory.jsonBuilder().startObject() - .field("type", "knn_vector") - .field("dimension", dimension) - .field(MODEL_ID, "test-id") - .startObject(KNN_METHOD) - .field(NAME, METHOD_HNSW) - .startObject(PARAMETERS) - .field("invalid", "invalid") - .endObject() - .endObject() - .endObject(); - - expectThrows(IllegalArgumentException.class, () -> typeParser.parse(fieldName, - xContentBuilderToMap(xContentBuilder5), buildParserContext(indexName, settings))); + XContentBuilder xContentBuilder5 = XContentFactory.jsonBuilder() + .startObject() + .field("type", "knn_vector") + .field("dimension", dimension) + .field(MODEL_ID, "test-id") + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .startObject(PARAMETERS) + .field("invalid", "invalid") + .endObject() + .endObject() + .endObject(); + + expectThrows( + IllegalArgumentException.class, + () -> typeParser.parse(fieldName, xContentBuilderToMap(xContentBuilder5), buildParserContext(indexName, settings)) + ); } public void testTypeParser_parse_fromModel() throws IOException { @@ -220,21 +248,23 @@ public void testTypeParser_parse_fromModel() throws IOException { String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder() - .put(settings(CURRENT).build()) - .build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); String modelId = "test-id"; - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field("type", "knn_vector") - .field(MODEL_ID, modelId) - .endObject(); - - KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse(fieldName, - xContentBuilderToMap(xContentBuilder), buildParserContext(indexName, settings)); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field("type", "knn_vector") + .field(MODEL_ID, modelId) + .endObject(); + + KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + fieldName, + xContentBuilderToMap(xContentBuilder), + buildParserContext(indexName, settings) + ); assertEquals(modelId, builder.modelId.get()); } @@ -248,23 +278,27 @@ public void testTypeParser_parse_fromLegacy() throws IOException { int efConstruction = 123; SpaceType spaceType = SpaceType.L2; Settings settings = Settings.builder() - .put(settings(CURRENT).build()) - .put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue()) - .put(KNNSettings.KNN_ALGO_PARAM_M, m) - .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) - .build(); + .put(settings(CURRENT).build()) + .put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue()) + .put(KNNSettings.KNN_ALGO_PARAM_M, m) + .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) + .build(); ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); int dimension = 122; - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field("type", "knn_vector") - .field("dimension", dimension) - .endObject(); - - KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse(fieldName, - xContentBuilderToMap(xContentBuilder), buildParserContext(indexName, settings)); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field("type", "knn_vector") + .field("dimension", dimension) + .endObject(); + + KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + fieldName, + xContentBuilderToMap(xContentBuilder), + buildParserContext(indexName, settings) + ); assertNull(builder.modelId.get()); assertNull(builder.knnMethodContext.get()); @@ -274,33 +308,34 @@ public void testKNNVectorFieldMapper_merge_fromKnnMethodContext() throws IOExcep String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder() - .put(settings(CURRENT).build()) - .build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); int dimension = 133; int efConstruction = 321; - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNN_METHOD) - .field(NAME, METHOD_HNSW) - .startObject(PARAMETERS) - .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstruction) - .endObject() - .endObject() - .endObject(); - - KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse(fieldName, - xContentBuilderToMap(xContentBuilder), buildParserContext(indexName, settings)); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstruction) + .endObject() + .endObject() + .endObject(); + + KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + fieldName, + xContentBuilderToMap(xContentBuilder), + buildParserContext(indexName, settings) + ); Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper1 = builder.build(builderContext); - // merge with itself - should be successful KNNVectorFieldMapper knnVectorFieldMapperMerge1 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper1); assertEquals(knnVectorFieldMapper1.knnMethod, knnVectorFieldMapperMerge1.knnMethod); @@ -311,16 +346,20 @@ public void testKNNVectorFieldMapper_merge_fromKnnMethodContext() throws IOExcep assertEquals(knnVectorFieldMapper1.knnMethod, knnVectorFieldMapperMerge2.knnMethod); // merge with another mapper of the same field with different context - xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNN_METHOD) - .field(NAME, METHOD_HNSW) - .endObject() - .endObject(); - - builder = (KNNVectorFieldMapper.Builder) typeParser.parse(fieldName, xContentBuilderToMap(xContentBuilder), - buildParserContext(indexName, settings)); + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .endObject() + .endObject(); + + builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + fieldName, + xContentBuilderToMap(xContentBuilder), + buildParserContext(indexName, settings) + ); KNNVectorFieldMapper knnVectorFieldMapper3 = builder.build(builderContext); expectThrows(IllegalArgumentException.class, () -> knnVectorFieldMapper1.merge(knnVectorFieldMapper3)); } @@ -329,28 +368,36 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder() - .put(settings(CURRENT).build()) - .build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); String modelId = "test-id"; int dimension = 133; ModelDao mockModelDao = mock(ModelDao.class); - ModelMetadata mockModelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, - ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata mockModelMetadata = new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); when(mockModelDao.getMetadata(modelId)).thenReturn(mockModelMetadata); KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> mockModelDao); - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field("type", "knn_vector") - .field(MODEL_ID, modelId) - .endObject(); - - KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse(fieldName, - xContentBuilderToMap(xContentBuilder), buildParserContext(indexName, settings)); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field("type", "knn_vector") + .field(MODEL_ID, modelId) + .endObject(); + KNNVectorFieldMapper.Builder builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + fieldName, + xContentBuilderToMap(xContentBuilder), + buildParserContext(indexName, settings) + ); Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper1 = builder.build(builderContext); @@ -365,43 +412,58 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { assertEquals(knnVectorFieldMapper1.modelId, knnVectorFieldMapperMerge2.modelId); // merge with another mapper of the same field with different context - xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNN_METHOD) - .field(NAME, METHOD_HNSW) - .endObject() - .endObject(); - - builder = (KNNVectorFieldMapper.Builder) typeParser.parse(fieldName, xContentBuilderToMap(xContentBuilder), - buildParserContext(indexName, settings)); + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .endObject() + .endObject(); + + builder = (KNNVectorFieldMapper.Builder) typeParser.parse( + fieldName, + xContentBuilderToMap(xContentBuilder), + buildParserContext(indexName, settings) + ); KNNVectorFieldMapper knnVectorFieldMapper3 = builder.build(builderContext); expectThrows(IllegalArgumentException.class, () -> knnVectorFieldMapper1.merge(knnVectorFieldMapper3)); } public IndexMetadata buildIndexMetaData(String indexName, Settings settings) { - return IndexMetadata.builder(indexName).settings(settings) - .numberOfShards(1) - .numberOfReplicas(0) - .version(7) - .mappingVersion(0) - .settingsVersion(0) - .aliasesVersion(0) - .creationDate(0) - .build(); + return IndexMetadata.builder(indexName) + .settings(settings) + .numberOfShards(1) + .numberOfReplicas(0) + .version(7) + .mappingVersion(0) + .settingsVersion(0) + .aliasesVersion(0) + .creationDate(0) + .build(); } public Mapper.TypeParser.ParserContext buildParserContext(String indexName, Settings settings) { - IndexSettings indexSettings = new IndexSettings(buildIndexMetaData(indexName, settings), Settings.EMPTY, - new IndexScopedSettings(Settings.EMPTY, new HashSet<>(IndexScopedSettings.BUILT_IN_INDEX_SETTINGS))); + IndexSettings indexSettings = new IndexSettings( + buildIndexMetaData(indexName, settings), + Settings.EMPTY, + new IndexScopedSettings(Settings.EMPTY, new HashSet<>(IndexScopedSettings.BUILT_IN_INDEX_SETTINGS)) + ); MapperService mapperService = mock(MapperService.class); when(mapperService.getIndexSettings()).thenReturn(indexSettings); // Setup blank ModelDao mockModelDao = mock(ModelDao.class); - return new Mapper.TypeParser.ParserContext(null, mapperService, - type -> new KNNVectorFieldMapper.TypeParser(() -> mockModelDao), CURRENT, null, null, null); + return new Mapper.TypeParser.ParserContext( + null, + mapperService, + type -> new KNNVectorFieldMapper.TypeParser(() -> mockModelDao), + CURRENT, + null, + null, + null + ); } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java index 54435db99..028cd4bc0 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java @@ -72,13 +72,14 @@ public void testLoadDirect() throws IOException { public void testSortField() { - expectThrows(UnsupportedOperationException.class, - () -> indexFieldData.sortField(null, null, null, false)); + expectThrows(UnsupportedOperationException.class, () -> indexFieldData.sortField(null, null, null, false)); } public void testNewBucketedSort() { - expectThrows(UnsupportedOperationException.class, - () -> indexFieldData.newBucketedSort(null, null, null, null, null, null, 0, null)); + expectThrows( + UnsupportedOperationException.class, + () -> indexFieldData.newBucketedSort(null, null, null, null, null, null, 0, null) + ); } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index 8883bf4dd..b76184f4b 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -23,7 +23,7 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase { private static final String MOCK_INDEX_FIELD_NAME = "test-index-field-name"; - private static final float[] SAMPLE_VECTOR_DATA = new float[]{1.0f, 2.0f}; + private static final float[] SAMPLE_VECTOR_DATA = new float[] { 1.0f, 2.0f }; private KNNVectorScriptDocValues scriptDocValues; private Directory directory; private DirectoryReader reader; @@ -36,7 +36,9 @@ public void setUp() throws Exception { reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); scriptDocValues = new KNNVectorScriptDocValues( - leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME), MOCK_INDEX_FIELD_NAME); + leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME), + MOCK_INDEX_FIELD_NAME + ); } private void createKNNVectorDocument(Directory directory) throws IOException { @@ -44,9 +46,11 @@ private void createKNNVectorDocument(Directory directory) throws IOException { IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); knnDocument.add( - new BinaryDocValuesField( - MOCK_INDEX_FIELD_NAME, - new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue())); + new BinaryDocValuesField( + MOCK_INDEX_FIELD_NAME, + new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue() + ) + ); writer.addDocument(knnDocument); writer.commit(); writer.close(); @@ -64,8 +68,7 @@ public void testGetValue() throws IOException { Assert.assertArrayEquals(SAMPLE_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); } - - //Test getValue without calling setNextDocId + // Test getValue without calling setNextDocId public void testGetValueFails() throws IOException { expectThrows(IllegalStateException.class, () -> scriptDocValues.getValue()); } diff --git a/src/test/java/org/opensearch/knn/index/MethodComponentTests.java b/src/test/java/org/opensearch/knn/index/MethodComponentTests.java index a2b810e31..b752764c3 100644 --- a/src/test/java/org/opensearch/knn/index/MethodComponentTests.java +++ b/src/test/java/org/opensearch/knn/index/MethodComponentTests.java @@ -39,8 +39,8 @@ public void testGetParameters() { String name = "test"; String paramKey = "key"; MethodComponent methodComponent = MethodComponent.Builder.builder(name) - .addParameter(paramKey, new Parameter.IntegerParameter(paramKey, 1, v -> v > 0)) - .build(); + .addParameter(paramKey, new Parameter.IntegerParameter(paramKey, 1, v -> v > 0)) + .build(); assertEquals(1, methodComponent.getParameters().size()); assertTrue(methodComponent.getParameters().containsKey(paramKey)); } @@ -51,12 +51,13 @@ public void testGetParameters() { public void testValidate() throws IOException { // Invalid parameter key String methodName = "test-method"; - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field(NAME, methodName) - .startObject(PARAMETERS) - .field("invalid", "invalid") - .endObject() - .endObject(); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, methodName) + .startObject(PARAMETERS) + .field("invalid", "invalid") + .endObject() + .endObject(); Map in = xContentBuilderToMap(xContentBuilder); MethodComponentContext componentContext1 = MethodComponentContext.parse(in); @@ -64,48 +65,48 @@ public void testValidate() throws IOException { assertNotNull(methodComponent1.validate(componentContext1)); // Invalid parameter type - xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field(NAME, methodName) - .startObject(PARAMETERS) - .field("valid", "invalid") - .endObject() - .endObject(); + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, methodName) + .startObject(PARAMETERS) + .field("valid", "invalid") + .endObject() + .endObject(); in = xContentBuilderToMap(xContentBuilder); MethodComponentContext componentContext2 = MethodComponentContext.parse(in); MethodComponent methodComponent2 = MethodComponent.Builder.builder(methodName) - .addParameter("valid", new Parameter.IntegerParameter("valid", 1, v -> v > 0)) - .build(); + .addParameter("valid", new Parameter.IntegerParameter("valid", 1, v -> v > 0)) + .build(); assertNotNull(methodComponent2.validate(componentContext2)); // valid configuration - xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field(NAME, methodName) - .startObject(PARAMETERS) - .field("valid1", 16) - .field("valid2", 128) - .endObject() - .endObject(); + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, methodName) + .startObject(PARAMETERS) + .field("valid1", 16) + .field("valid2", 128) + .endObject() + .endObject(); in = xContentBuilderToMap(xContentBuilder); MethodComponentContext componentContext3 = MethodComponentContext.parse(in); MethodComponent methodComponent3 = MethodComponent.Builder.builder(methodName) - .addParameter("valid1", new Parameter.IntegerParameter("valid1", 1, v -> v > 0)) - .addParameter("valid2", new Parameter.IntegerParameter("valid2",1, v -> v > 0)) - .build(); + .addParameter("valid1", new Parameter.IntegerParameter("valid1", 1, v -> v > 0)) + .addParameter("valid2", new Parameter.IntegerParameter("valid2", 1, v -> v > 0)) + .build(); assertNull(methodComponent3.validate(componentContext3)); // valid configuration - empty parameters - xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field(NAME, methodName) - .endObject(); + xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, methodName).endObject(); in = xContentBuilderToMap(xContentBuilder); MethodComponentContext componentContext4 = MethodComponentContext.parse(in); MethodComponent methodComponent4 = MethodComponent.Builder.builder(methodName) - .addParameter("valid1", new Parameter.IntegerParameter("valid1",1, v -> v > 0)) - .addParameter("valid2", new Parameter.IntegerParameter("valid2",1, v -> v > 0)) - .build(); + .addParameter("valid1", new Parameter.IntegerParameter("valid1", 1, v -> v > 0)) + .addParameter("valid2", new Parameter.IntegerParameter("valid2", 1, v -> v > 0)) + .build(); assertNull(methodComponent4.validate(componentContext4)); } @@ -118,25 +119,24 @@ public void testGetAsMap_withoutGenerator() throws IOException { int default2 = 5; MethodComponent methodComponent = MethodComponent.Builder.builder(methodName) - .addParameter(parameterName1, new Parameter.IntegerParameter(parameterName1, default1, v -> v > 0)) - .addParameter(parameterName2, new Parameter.IntegerParameter(parameterName2, default2, v -> v > 0)) - .build(); - - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field(NAME, methodName) - .startObject(PARAMETERS) - .field(parameterName1, 16) - .field(parameterName2, 128) - .endObject() - .endObject(); + .addParameter(parameterName1, new Parameter.IntegerParameter(parameterName1, default1, v -> v > 0)) + .addParameter(parameterName2, new Parameter.IntegerParameter(parameterName2, default2, v -> v > 0)) + .build(); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, methodName) + .startObject(PARAMETERS) + .field(parameterName1, 16) + .field(parameterName2, 128) + .endObject() + .endObject(); Map in = xContentBuilderToMap(xContentBuilder); MethodComponentContext methodComponentContext = MethodComponentContext.parse(in); assertEquals(in, methodComponent.getAsMap(methodComponentContext)); - xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field(NAME, methodName) - .endObject(); + xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, methodName).endObject(); in = xContentBuilderToMap(xContentBuilder); methodComponentContext = MethodComponentContext.parse(in); @@ -149,14 +149,12 @@ public void testGetAsMap_withGenerator() throws IOException { String methodName = "test-method"; Map generatedMap = ImmutableMap.of("test-key", "test-value"); MethodComponent methodComponent = MethodComponent.Builder.builder(methodName) - .addParameter("valid1", new Parameter.IntegerParameter("valid1",1, v -> v > 0)) - .addParameter("valid2", new Parameter.IntegerParameter("valid2",1, v -> v > 0)) - .setMapGenerator((methodComponent1, methodComponentContext) -> generatedMap) - .build(); - - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field(NAME, methodName) - .endObject(); + .addParameter("valid1", new Parameter.IntegerParameter("valid1", 1, v -> v > 0)) + .addParameter("valid2", new Parameter.IntegerParameter("valid2", 1, v -> v > 0)) + .setMapGenerator((methodComponent1, methodComponentContext) -> generatedMap) + .build(); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, methodName).endObject(); Map in = xContentBuilderToMap(xContentBuilder); MethodComponentContext methodComponentContext = MethodComponentContext.parse(in); @@ -171,7 +169,7 @@ public void testBuilder() { assertEquals(0, methodComponent.getParameters().size()); assertEquals(name, methodComponent.getName()); - builder.addParameter("test", new Parameter.IntegerParameter("test",1, v -> v > 0)); + builder.addParameter("test", new Parameter.IntegerParameter("test", 1, v -> v > 0)); methodComponent = builder.build(); assertEquals(1, methodComponent.getParameters().size()); diff --git a/src/test/java/org/opensearch/knn/index/NmslibIT.java b/src/test/java/org/opensearch/knn/index/NmslibIT.java index f0e98894a..1e1b1e3b5 100644 --- a/src/test/java/org/opensearch/knn/index/NmslibIT.java +++ b/src/test/java/org/opensearch/knn/index/NmslibIT.java @@ -9,7 +9,6 @@ * GitHub history for details. */ - package org.opensearch.knn.index; import com.google.common.collect.ImmutableList; @@ -39,7 +38,6 @@ import static org.hamcrest.Matchers.containsString; - public class NmslibIT extends KNNRestTestCase { static TestUtils.TestData testData; @@ -67,23 +65,23 @@ public void testEndToEnd() throws IOException, InterruptedException { // Create an index XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(fieldName) - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()) - .startObject(KNNConstants.PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) - .endObject() - .endObject() - .endObject() - .endObject() - .endObject(); + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); Map mappingMap = xContentBuilderToMap(builder); String mapping = Strings.toString(builder); @@ -93,8 +91,12 @@ public void testEndToEnd() throws IOException, InterruptedException { // Index the test data for (int i = 0; i < testData.indexData.docs.length; i++) { - addKnnDoc(indexName, Integer.toString(testData.indexData.docs[i]), fieldName, - Floats.asList(testData.indexData.vectors[i]).toArray()); + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + fieldName, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); } // Assert we have the right number of documents in the index @@ -111,15 +113,17 @@ public void testEndToEnd() throws IOException, InterruptedException { List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); - assertEquals(KNNEngine.NMSLIB.score(KNNScoringUtil.l1Norm(testData.queries[i], primitiveArray), - spaceType), actualScores.get(j), 0.0001); + assertEquals( + KNNEngine.NMSLIB.score(KNNScoringUtil.l1Norm(testData.queries[i], primitiveArray), spaceType), + actualScores.get(j), + 0.0001 + ); } } // Delete index deleteKNNIndex(indexName); - // Search every 5 seconds 14 times to confirm graph gets evicted int intervals = 14; for (int i = 0; i < intervals; i++) { @@ -127,7 +131,7 @@ public void testEndToEnd() throws IOException, InterruptedException { return; } - Thread.sleep(5*1000); + Thread.sleep(5 * 1000); } fail("Graphs are not getting evicted"); @@ -135,24 +139,23 @@ public void testEndToEnd() throws IOException, InterruptedException { public void testAddDoc() throws Exception { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] vector = {6.0f, 6.0f}; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); } - public void testUpdateDoc() throws Exception { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] vector = {6.0f, 6.0f}; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); // update - Float[] updatedVector = {8.0f, 8.0f}; + Float[] updatedVector = { 8.0f, 8.0f }; updateKnnDoc(INDEX_NAME, "1", FIELD_NAME, updatedVector); } public void testDeleteDoc() throws Exception { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] vector = {6.0f, 6.0f}; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); // delete knn doc @@ -162,12 +165,12 @@ public void testDeleteDoc() throws Exception { public void testCreateIndexWithValidAlgoParams_settings() { try { Settings settings = Settings.builder() - .put(getKNNDefaultIndexSettings()) - .put("index.knn.algo_param.m", 32) - .put("index.knn.algo_param.ef_construction", 400) - .build(); + .put(getKNNDefaultIndexSettings()) + .put("index.knn.algo_param.m", 32) + .put("index.knn.algo_param.ef_construction", 400) + .build(); createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] vector = {6.0f, 6.0f}; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); } catch (Exception ex) { fail("Exception not expected as valid index arguements passed: " + ex); @@ -177,15 +180,15 @@ public void testCreateIndexWithValidAlgoParams_settings() { @SuppressWarnings("unchecked") public void testCreateIndexWithValidAlgoParams_mapping() { try { - Settings settings = Settings.builder() - .put(getKNNDefaultIndexSettings()) - .build(); + Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).build(); String spaceType = SpaceType.L1.getValue(); int efConstruction = 14; int m = 13; - String mapping = Strings.toString(XContentFactory.jsonBuilder().startObject() + String mapping = Strings.toString( + XContentFactory.jsonBuilder() + .startObject() .startObject("properties") .startObject(FIELD_NAME) .field("type", "knn_vector") @@ -200,17 +203,18 @@ public void testCreateIndexWithValidAlgoParams_mapping() { .endObject() .endObject() .endObject() - .endObject()); + .endObject() + ); createKnnIndex(INDEX_NAME, settings, mapping); Map fullMapping = getAsMap(INDEX_NAME + "/_mapping"); Map indexMapping = (Map) fullMapping.get(INDEX_NAME); - Map mappingsMapping = (Map) indexMapping.get("mappings"); - Map propertiesMapping = (Map) mappingsMapping.get("properties"); - Map fieldMapping = (Map) propertiesMapping.get(FIELD_NAME); - Map methodMapping = (Map) fieldMapping.get(KNNConstants.KNN_METHOD); - Map parametersMapping = (Map) methodMapping.get(KNNConstants.PARAMETERS); + Map mappingsMapping = (Map) indexMapping.get("mappings"); + Map propertiesMapping = (Map) mappingsMapping.get("properties"); + Map fieldMapping = (Map) propertiesMapping.get(FIELD_NAME); + Map methodMapping = (Map) fieldMapping.get(KNNConstants.KNN_METHOD); + Map parametersMapping = (Map) methodMapping.get(KNNConstants.PARAMETERS); String spaceTypeMapping = (String) methodMapping.get(KNNConstants.METHOD_PARAMETER_SPACE_TYPE); Integer mMapping = (Integer) parametersMapping.get(KNNConstants.METHOD_PARAMETER_M); Integer efConstructionMapping = (Integer) parametersMapping.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION); @@ -219,7 +223,7 @@ public void testCreateIndexWithValidAlgoParams_mapping() { assertEquals(m, mMapping.intValue()); assertEquals(efConstruction, efConstructionMapping.intValue()); - Float[] vector = {6.0f, 6.0f}; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); } catch (Exception ex) { fail("Exception not expected as valid index arguments passed: " + ex); @@ -233,12 +237,14 @@ public void testCreateIndexWithValidAlgoParams_mappingAndSettings() { int m1 = 13; Settings settings = Settings.builder() - .put(getKNNDefaultIndexSettings()) - .put("index.knn.algo_param.m", m1) - .put("index.knn.algo_param.ef_construction", efConstruction1) - .build(); + .put(getKNNDefaultIndexSettings()) + .put("index.knn.algo_param.m", m1) + .put("index.knn.algo_param.ef_construction", efConstruction1) + .build(); - String mapping = Strings.toString(XContentFactory.jsonBuilder().startObject() + String mapping = Strings.toString( + XContentFactory.jsonBuilder() + .startObject() .startObject("properties") .startObject(FIELD_NAME) .field("type", "knn_vector") @@ -253,17 +259,20 @@ public void testCreateIndexWithValidAlgoParams_mappingAndSettings() { .endObject() .endObject() .endObject() - .endObject()); + .endObject() + ); createKnnIndex(INDEX_NAME + "1", settings, mapping); - Float[] vector = {6.0f, 6.0f}; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME + "1", "1", FIELD_NAME, vector); String spaceType2 = SpaceType.COSINESIMIL.getValue(); int efConstruction2 = 114; int m2 = 113; - mapping = Strings.toString(XContentFactory.jsonBuilder().startObject() + mapping = Strings.toString( + XContentFactory.jsonBuilder() + .startObject() .startObject("properties") .startObject(FIELD_NAME + "1") .field("type", "knn_vector") @@ -290,7 +299,8 @@ public void testCreateIndexWithValidAlgoParams_mappingAndSettings() { .endObject() .endObject() .endObject() - .endObject()); + .endObject() + ); createKnnIndex(INDEX_NAME + "2", settings, mapping); addKnnDoc(INDEX_NAME + "2", "1", FIELD_NAME, vector); @@ -300,35 +310,28 @@ public void testCreateIndexWithValidAlgoParams_mappingAndSettings() { } public void testQueryIndexWithValidQueryAlgoParams() throws IOException { - Settings settings = Settings.builder() - .put(getKNNDefaultIndexSettings()) - .put("index.knn.algo_param.ef_search", 300) - .build(); + Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).put("index.knn.algo_param.ef_search", 300).build(); createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] vector = {6.0f, 6.0f}; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); - float[] queryVector = {1.0f, 1.0f}; // vector to be queried - int k = 1; // nearest 1 neighbor + float[] queryVector = { 1.0f, 1.0f }; // vector to be queried + int k = 1; // nearest 1 neighbor KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); } public void testInvalidIndexHnswAlgoParams_settings() { - Settings settings = Settings.builder() - .put(getKNNDefaultIndexSettings()) - .put("index.knn.algo_param.m", "-1") - .build(); - expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, settings, - createKnnIndexMapping(FIELD_NAME, 2))); + Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).put("index.knn.algo_param.m", "-1").build(); + expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, 2))); } public void testInvalidIndexHnswAlgoParams_mapping() throws IOException { - Settings settings = Settings.builder() - .put(getKNNDefaultIndexSettings()) - .build(); + Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).build(); - String mapping = Strings.toString(XContentFactory.jsonBuilder().startObject() + String mapping = Strings.toString( + XContentFactory.jsonBuilder() + .startObject() .startObject("properties") .startObject(FIELD_NAME) .field("type", "knn_vector") @@ -341,19 +344,18 @@ public void testInvalidIndexHnswAlgoParams_mapping() throws IOException { .endObject() .endObject() .endObject() - .endObject()); + .endObject() + ); - expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, settings, - mapping)); + expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, settings, mapping)); } public void testInvalidIndexHnswAlgoParams_mappingAndSettings() throws IOException { - Settings settings = Settings.builder() - .put(getKNNDefaultIndexSettings()) - .put("index.knn.algo_param.m", "-1") - .build(); + Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).put("index.knn.algo_param.m", "-1").build(); - String mapping = Strings.toString(XContentFactory.jsonBuilder().startObject() + String mapping = Strings.toString( + XContentFactory.jsonBuilder() + .startObject() .startObject("properties") .startObject(FIELD_NAME) .field("type", "knn_vector") @@ -366,19 +368,18 @@ public void testInvalidIndexHnswAlgoParams_mappingAndSettings() throws IOExcepti .endObject() .endObject() .endObject() - .endObject()); + .endObject() + ); - expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, settings, - mapping)); + expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, settings, mapping)); } public void testInvalidQueryHnswAlgoParams() { - Settings settings = Settings.builder() - .put(getKNNDefaultIndexSettings()) - .put("index.knn.algo_param.ef_search", "-1") - .build(); - Exception ex = expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, settings, - createKnnIndexMapping(FIELD_NAME, 2))); + Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).put("index.knn.algo_param.ef_search", "-1").build(); + Exception ex = expectThrows( + ResponseException.class, + () -> createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, 2)) + ); assertThat(ex.getMessage(), containsString("Failed to parse value [-1] for setting [index.knn.algo_param.ef_search]")); } } diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index da2187c6a..6a06de14a 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -9,7 +9,6 @@ * GitHub history for details. */ - package org.opensearch.knn.index; import com.google.common.collect.ImmutableList; @@ -62,7 +61,7 @@ public void testEndToEnd() throws IOException, InterruptedException { String fieldName1 = "test-field-1"; String fieldName2 = "test-field-2"; - KNNMethod method1 =knnEngine1.getMethod(KNNConstants.METHOD_HNSW); + KNNMethod method1 = knnEngine1.getMethod(KNNConstants.METHOD_HNSW); KNNMethod method2 = knnEngine2.getMethod(KNNConstants.METHOD_HNSW); SpaceType spaceType1 = SpaceType.COSINESIMIL; SpaceType spaceType2 = SpaceType.L2; @@ -75,37 +74,37 @@ public void testEndToEnd() throws IOException, InterruptedException { // Create an index XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(fieldName1) - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, method1.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType1.getValue()) - .field(KNNConstants.KNN_ENGINE, knnEngine1.getName()) - .startObject(KNNConstants.PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) - .endObject() - .endObject() - .endObject() - .startObject(fieldName2) - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) - .field(KNNConstants.NAME, method2.getMethodComponent().getName()) - .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType2.getValue()) - .field(KNNConstants.KNN_ENGINE, knnEngine2.getName()) - .startObject(KNNConstants.PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) - .endObject() - .endObject() - .endObject() - .endObject() - .endObject(); + .startObject() + .startObject("properties") + .startObject(fieldName1) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, method1.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType1.getValue()) + .field(KNNConstants.KNN_ENGINE, knnEngine1.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .endObject() + .endObject() + .endObject() + .startObject(fieldName2) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, method2.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType2.getValue()) + .field(KNNConstants.KNN_ENGINE, knnEngine2.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); Map mappingMap = xContentBuilderToMap(builder); String mapping = Strings.toString(builder); @@ -114,9 +113,15 @@ public void testEndToEnd() throws IOException, InterruptedException { // Index the test data for (int i = 0; i < testData.indexData.docs.length; i++) { - addKnnDoc(indexName, Integer.toString(testData.indexData.docs[i]), ImmutableList.of(fieldName1, fieldName2), - ImmutableList.of(Floats.asList(testData.indexData.vectors[i]).toArray(), - Floats.asList(testData.indexData.vectors[i]).toArray())); + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + ImmutableList.of(fieldName1, fieldName2), + ImmutableList.of( + Floats.asList(testData.indexData.vectors[i]).toArray(), + Floats.asList(testData.indexData.vectors[i]).toArray() + ) + ); } // Assert we have the right number of documents in the index @@ -134,8 +139,11 @@ public void testEndToEnd() throws IOException, InterruptedException { List actualScores = parseSearchResponseScore(responseBody, fieldName1); for (int j = 0; j < k; j++) { float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); - assertEquals(knnEngine1.score(1 - KNNScoringUtil.cosinesimil(testData.queries[i], primitiveArray), - spaceType1), actualScores.get(j), 0.0001); + assertEquals( + knnEngine1.score(1 - KNNScoringUtil.cosinesimil(testData.queries[i], primitiveArray), spaceType1), + actualScores.get(j), + 0.0001 + ); } // Search the second field @@ -147,15 +155,17 @@ public void testEndToEnd() throws IOException, InterruptedException { actualScores = parseSearchResponseScore(responseBody, fieldName2); for (int j = 0; j < k; j++) { float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); - assertEquals(knnEngine2.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), - spaceType2), actualScores.get(j), 0.0001); + assertEquals( + knnEngine2.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType2), + actualScores.get(j), + 0.0001 + ); } } // Delete index deleteKNNIndex(indexName); - // Search every 5 seconds 14 times to confirm graph gets evicted int intervals = 14; for (int i = 0; i < intervals; i++) { @@ -163,7 +173,7 @@ public void testEndToEnd() throws IOException, InterruptedException { return; } - Thread.sleep(5*1000); + Thread.sleep(5 * 1000); } fail("Graphs are not getting evicted"); @@ -173,11 +183,10 @@ public void testAddDoc_blockedWhenCbTrips() throws Exception { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); updateClusterSettings("knn.circuit_breaker.triggered", "true"); - Float[] vector = {6.0f, 6.0f}; - ResponseException ex = expectThrows( - ResponseException.class, () -> addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector)); - String expMessage = "Indexing knn vector fields is rejected as circuit breaker triggered." + - " Check _opendistro/_knn/stats for detailed state"; + Float[] vector = { 6.0f, 6.0f }; + ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector)); + String expMessage = "Indexing knn vector fields is rejected as circuit breaker triggered." + + " Check _opendistro/_knn/stats for detailed state"; assertThat(EntityUtils.toString(ex.getResponse().getEntity()), containsString(expMessage)); // reset @@ -185,19 +194,17 @@ public void testAddDoc_blockedWhenCbTrips() throws Exception { addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); } - public void testUpdateDoc_blockedWhenCbTrips() throws Exception { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] vector = {6.0f, 6.0f}; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); // update updateClusterSettings("knn.circuit_breaker.triggered", "true"); - Float[] updatedVector = {8.0f, 8.0f}; - ResponseException ex = expectThrows( - ResponseException.class, () -> updateKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector)); - String expMessage = "Indexing knn vector fields is rejected as circuit breaker triggered." + - " Check _opendistro/_knn/stats for detailed state"; + Float[] updatedVector = { 8.0f, 8.0f }; + ResponseException ex = expectThrows(ResponseException.class, () -> updateKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector)); + String expMessage = "Indexing knn vector fields is rejected as circuit breaker triggered." + + " Check _opendistro/_knn/stats for detailed state"; assertThat(EntityUtils.toString(ex.getResponse().getEntity()), containsString(expMessage)); // reset @@ -207,13 +214,13 @@ public void testUpdateDoc_blockedWhenCbTrips() throws Exception { public void testAddAndSearchIndex_whenCBTrips() throws Exception { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - for (int i=1; i<=4; i++) { - Float[] vector = {(float)i, (float)(i+1)}; + for (int i = 1; i <= 4; i++) { + Float[] vector = { (float) i, (float) (i + 1) }; addKnnDoc(INDEX_NAME, Integer.toString(i), FIELD_NAME, vector); } - float[] queryVector = {1.0f, 1.0f}; // vector to be queried - int k = 10; // nearest 10 neighbor + float[] queryVector = { 1.0f, 1.0f }; // vector to be queried + int k = 10; // nearest 10 neighbor KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); @@ -221,9 +228,8 @@ public void testAddAndSearchIndex_whenCBTrips() throws Exception { updateClusterSettings("knn.circuit_breaker.triggered", "true"); // Try add another doc - Float[] vector = {1.0f, 2.0f}; - ResponseException ex = expectThrows( - ResponseException.class, () -> addKnnDoc(INDEX_NAME, "5", FIELD_NAME, vector)); + Float[] vector = { 1.0f, 2.0f }; + ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, "5", FIELD_NAME, vector)); // Still get 4 docs response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); @@ -238,103 +244,93 @@ public void testAddAndSearchIndex_whenCBTrips() throws Exception { } public void testIndexingVectorValidation_differentSizes() throws Exception { - Settings settings = Settings.builder() - .put(getKNNDefaultIndexSettings()) - .build(); + Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).build(); createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, 4)); // valid case with 4 dimension - Float[] vector = {6.0f, 7.0f, 8.0f, 9.0f}; + Float[] vector = { 6.0f, 7.0f, 8.0f, 9.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); - // invalid case with lesser dimension than original (3 < 4) - Float[] vector1 = {6.0f, 7.0f, 8.0f}; - ResponseException ex = expectThrows(ResponseException.class, () -> - addKnnDoc(INDEX_NAME, "2", FIELD_NAME, vector1)); - assertThat(EntityUtils.toString(ex.getResponse().getEntity()), - containsString("Vector dimension mismatch. Expected: 4, Given: 3")); + Float[] vector1 = { 6.0f, 7.0f, 8.0f }; + ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, "2", FIELD_NAME, vector1)); + assertThat(EntityUtils.toString(ex.getResponse().getEntity()), containsString("Vector dimension mismatch. Expected: 4, Given: 3")); // invalid case with more dimension than original (5 > 4) - Float[] vector2 = {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}; + Float[] vector2 = { 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }; ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, "3", FIELD_NAME, vector2)); - assertThat(EntityUtils.toString(ex.getResponse().getEntity()), - containsString("Vector dimension mismatch. Expected: 4, Given: 5")); + assertThat(EntityUtils.toString(ex.getResponse().getEntity()), containsString("Vector dimension mismatch. Expected: 4, Given: 5")); } public void testVectorMappingValidation_noDimension() throws Exception { - Settings settings = Settings.builder() - .put(getKNNDefaultIndexSettings()) - .build(); + Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).build(); - String mapping = Strings.toString(XContentFactory.jsonBuilder().startObject() + String mapping = Strings.toString( + XContentFactory.jsonBuilder() + .startObject() .startObject("properties") .startObject(FIELD_NAME) .field("type", "knn_vector") .endObject() .endObject() - .endObject()); + .endObject() + ); Exception ex = expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, settings, mapping)); assertThat(ex.getMessage(), containsString("Dimension value missing for vector: " + FIELD_NAME)); } public void testVectorMappingValidation_invalidDimension() { - Settings settings = Settings.builder() - .put(getKNNDefaultIndexSettings()) - .build(); - - Exception ex = expectThrows(ResponseException.class, () -> createKnnIndex(INDEX_NAME, settings, - createKnnIndexMapping(FIELD_NAME, KNNVectorFieldMapper.MAX_DIMENSION + 1))); - assertThat(ex.getMessage(), containsString("Dimension value cannot be greater than " + - KNNVectorFieldMapper.MAX_DIMENSION + " for vector: " + FIELD_NAME)); + Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).build(); + + Exception ex = expectThrows( + ResponseException.class, + () -> createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, KNNVectorFieldMapper.MAX_DIMENSION + 1)) + ); + assertThat( + ex.getMessage(), + containsString("Dimension value cannot be greater than " + KNNVectorFieldMapper.MAX_DIMENSION + " for vector: " + FIELD_NAME) + ); } public void testVectorMappingValidation_invalidVectorNaN() throws IOException { - Settings settings = Settings.builder() - .put(getKNNDefaultIndexSettings()) - .build(); + Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).build(); createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] vector = {Float.NaN, Float.NaN}; + Float[] vector = { Float.NaN, Float.NaN }; Exception ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, "3", FIELD_NAME, vector)); assertThat(ex.getMessage(), containsString("KNN vector values cannot be NaN")); } public void testVectorMappingValidation_invalidVectorInfinity() throws IOException { - Settings settings = Settings.builder() - .put(getKNNDefaultIndexSettings()) - .build(); + Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).build(); createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] vector = {Float.POSITIVE_INFINITY, Float.POSITIVE_INFINITY}; + Float[] vector = { Float.POSITIVE_INFINITY, Float.POSITIVE_INFINITY }; Exception ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, "3", FIELD_NAME, vector)); assertThat(ex.getMessage(), containsString("KNN vector values cannot be infinity")); } public void testVectorMappingValidation_updateDimension() throws Exception { - Settings settings = Settings.builder() - .put(getKNNDefaultIndexSettings()) - .build(); + Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).build(); createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, 4)); - Exception ex = expectThrows(ResponseException.class, () -> - putMappingRequest(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 5))); + Exception ex = expectThrows(ResponseException.class, () -> putMappingRequest(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 5))); assertThat(ex.getMessage(), containsString("Cannot update parameter [dimension] from [4] to [5]")); } public void testVectorMappingValidation_multiFieldsDifferentDimension() throws Exception { - Settings settings = Settings.builder() - .put(getKNNDefaultIndexSettings()) - .build(); + Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).build(); String f4 = FIELD_NAME + "-4"; String f5 = FIELD_NAME + "-5"; - String mapping = Strings.toString(XContentFactory.jsonBuilder().startObject() + String mapping = Strings.toString( + XContentFactory.jsonBuilder() + .startObject() .startObject("properties") .startObject(f4) .field("type", "knn_vector") @@ -345,17 +341,17 @@ public void testVectorMappingValidation_multiFieldsDifferentDimension() throws E .field("dimension", "5") .endObject() .endObject() - .endObject()); + .endObject() + ); createKnnIndex(INDEX_NAME, settings, mapping); - // valid case with 4 dimension - Float[] vector = {6.0f, 7.0f, 8.0f, 9.0f}; + Float[] vector = { 6.0f, 7.0f, 8.0f, 9.0f }; addKnnDoc(INDEX_NAME, "1", f4, vector); // valid case with 5 dimension - Float[] vector1 = {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}; + Float[] vector1 = { 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }; updateKnnDoc(INDEX_NAME, "1", f5, vector1); } @@ -364,7 +360,7 @@ public void testExistsQuery() throws IOException { String field2 = "field2"; createKnnIndex(INDEX_NAME, createKnnIndexMapping(Arrays.asList(field1, field2), Arrays.asList(2, 2))); - Float[] vector = {6.0f, 7.0f}; + Float[] vector = { 6.0f, 7.0f }; addKnnDoc(INDEX_NAME, "1", Arrays.asList(field1, field2), Arrays.asList(vector, vector)); addKnnDoc(INDEX_NAME, "2", field1, vector); @@ -374,20 +370,13 @@ public void testExistsQuery() throws IOException { addKnnDoc(INDEX_NAME, "5", field2, vector); addKnnDoc(INDEX_NAME, "6", field2, vector); - // Create document that does not have k-NN vector field - Request request = new Request( - "POST", - "/" + INDEX_NAME + "/_doc/7?refresh=true" - ); + Request request = new Request("POST", "/" + INDEX_NAME + "/_doc/7?refresh=true"); - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .field("non-knn-field", "test") - .endObject(); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("non-knn-field", "test").endObject(); request.setJsonEntity(Strings.toString(builder)); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); ExistsQueryBuilder existsQueryBuilder = new ExistsQueryBuilder(field1); response = searchExists(INDEX_NAME, existsQueryBuilder, 10); @@ -401,24 +390,22 @@ public void testExistsQuery() throws IOException { } public void testIndexingVectorValidation_updateVectorWithNull() throws Exception { - Settings settings = Settings.builder() - .put(getKNNDefaultIndexSettings()) - .build(); + Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).build(); createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, 4)); // valid case with 4 dimension - final Float[] vectorForDocumentOne = {6.0f, 7.0f, 8.0f, 9.0f}; + final Float[] vectorForDocumentOne = { 6.0f, 7.0f, 8.0f, 9.0f }; final String docOneId = "1"; addKnnDoc(INDEX_NAME, docOneId, FIELD_NAME, vectorForDocumentOne); - final Float[] vectorForDocumentTwo = {2.0f, 1.0f, 3.8f, 2.5f}; + final Float[] vectorForDocumentTwo = { 2.0f, 1.0f, 3.8f, 2.5f }; final String docTwoId = "2"; addKnnDoc(INDEX_NAME, docTwoId, FIELD_NAME, vectorForDocumentTwo); - //checking that both documents are retrievable based on knn search query + // checking that both documents are retrievable based on knn search query int k = 2; - float[] queryVector = {5.0f, 6.0f, 7.0f, 10.0f}; + float[] queryVector = { 5.0f, 6.0f, 7.0f, 10.0f }; final KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, k); final Response response = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); final List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); @@ -429,32 +416,34 @@ public void testIndexingVectorValidation_updateVectorWithNull() throws Exception // update vector value to null updateKnnDoc(INDEX_NAME, docOneId, FIELD_NAME, null); - //retrieving updated document by id, vector should be null + // retrieving updated document by id, vector should be null final Map knnDocMapUpdated = getKnnDoc(INDEX_NAME, docOneId); assertNull(knnDocMapUpdated.get(FIELD_NAME)); - //checking that first document one is no longer returned by knn search + // checking that first document one is no longer returned by knn search final Response updatedResponse = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); - final List updatedResults = - parseSearchResponse(EntityUtils.toString(updatedResponse.getEntity()), FIELD_NAME); + final List updatedResults = parseSearchResponse(EntityUtils.toString(updatedResponse.getEntity()), FIELD_NAME); assertEquals(1, updatedResults.size()); assertEquals(docTwoId, updatedResults.get(0).getDocId()); // update vector back to original value updateKnnDoc(INDEX_NAME, docOneId, FIELD_NAME, vectorForDocumentOne); final Response restoreInitialVectorValueResponse = searchKNNIndex(INDEX_NAME, knnQueryBuilder, k); - final List restoreInitialVectorValueResults = - parseSearchResponse(EntityUtils.toString(restoreInitialVectorValueResponse.getEntity()), FIELD_NAME); + final List restoreInitialVectorValueResults = parseSearchResponse( + EntityUtils.toString(restoreInitialVectorValueResponse.getEntity()), + FIELD_NAME + ); assertEquals(2, restoreInitialVectorValueResults.size()); assertEquals(docOneId, results.get(0).getDocId()); assertEquals(docTwoId, results.get(1).getDocId()); - //retrieving updated document by id, vector should be not null but has the original value + // retrieving updated document by id, vector should be not null but has the original value final Map knnDocMapRestoreInitialVectorValue = getKnnDoc(INDEX_NAME, docOneId); assertNotNull(knnDocMapRestoreInitialVectorValue.get(FIELD_NAME)); final Float[] vectorRestoreInitialValue = ((List) knnDocMapRestoreInitialVectorValue.get(FIELD_NAME)).stream() - .map(Double::floatValue).toArray(Float[]::new); + .map(Double::floatValue) + .toArray(Float[]::new); assertArrayEquals(vectorForDocumentOne, vectorRestoreInitialValue); } -} \ No newline at end of file +} diff --git a/src/test/java/org/opensearch/knn/index/ParameterTests.java b/src/test/java/org/opensearch/knn/index/ParameterTests.java index 1add2b409..4e7adfc8c 100644 --- a/src/test/java/org/opensearch/knn/index/ParameterTests.java +++ b/src/test/java/org/opensearch/knn/index/ParameterTests.java @@ -39,8 +39,7 @@ public ValidationException validate(Object value) { * Test integer parameter validate */ public void testIntegerParameter_validate() { - final IntegerParameter parameter = new IntegerParameter("test",1, - v -> v > 0); + final IntegerParameter parameter = new IntegerParameter("test", 1, v -> v > 0); // Invalid type assertNotNull(parameter.validate("String")); @@ -58,18 +57,20 @@ public void testMethodComponentContextParameter_validate() { Integer parameterValue1 = 12; Map defaultParameterMap = ImmutableMap.of(parameterKey1, parameterValue1); - MethodComponentContext methodComponentContext = new MethodComponentContext(methodComponentName1, - defaultParameterMap); + MethodComponentContext methodComponentContext = new MethodComponentContext(methodComponentName1, defaultParameterMap); Map methodComponentMap = ImmutableMap.of( - methodComponentName1, - MethodComponent.Builder.builder(parameterKey1) - .addParameter(parameterKey1, new IntegerParameter(parameterKey1,1, v -> v > 0)) - .build() + methodComponentName1, + MethodComponent.Builder.builder(parameterKey1) + .addParameter(parameterKey1, new IntegerParameter(parameterKey1, 1, v -> v > 0)) + .build() ); - final MethodComponentContextParameter parameter = new MethodComponentContextParameter("test", - methodComponentContext, methodComponentMap); + final MethodComponentContextParameter parameter = new MethodComponentContextParameter( + "test", + methodComponentContext, + methodComponentMap + ); // Invalid type assertNotNull(parameter.validate(17)); @@ -77,20 +78,17 @@ public void testMethodComponentContextParameter_validate() { // Invalid value String invalidMethodComponentName = "invalid-method"; - MethodComponentContext invalidMethodComponentContext1 = new MethodComponentContext(invalidMethodComponentName, - defaultParameterMap); + MethodComponentContext invalidMethodComponentContext1 = new MethodComponentContext(invalidMethodComponentName, defaultParameterMap); assertNotNull(parameter.validate(invalidMethodComponentContext1)); String invalidParameterKey = "invalid-parameter"; Map invalidParameterMap1 = ImmutableMap.of(invalidParameterKey, parameterValue1); - MethodComponentContext invalidMethodComponentContext2 = new MethodComponentContext(methodComponentName1, - invalidParameterMap1); + MethodComponentContext invalidMethodComponentContext2 = new MethodComponentContext(methodComponentName1, invalidParameterMap1); assertNotNull(parameter.validate(invalidMethodComponentContext2)); String invalidParameterValue = "invalid-value"; Map invalidParameterMap2 = ImmutableMap.of(parameterKey1, invalidParameterValue); - MethodComponentContext invalidMethodComponentContext3 = new MethodComponentContext(methodComponentName1, - invalidParameterMap2); + MethodComponentContext invalidMethodComponentContext3 = new MethodComponentContext(methodComponentName1, invalidParameterMap2); assertNotNull(parameter.validate(invalidMethodComponentContext3)); // valid value @@ -103,18 +101,20 @@ public void testMethodComponentContextParameter_getMethodComponent() { Integer parameterValue1 = 12; Map defaultParameterMap = ImmutableMap.of(parameterKey1, parameterValue1); - MethodComponentContext methodComponentContext = new MethodComponentContext(methodComponentName1, - defaultParameterMap); + MethodComponentContext methodComponentContext = new MethodComponentContext(methodComponentName1, defaultParameterMap); Map methodComponentMap = ImmutableMap.of( - methodComponentName1, - MethodComponent.Builder.builder(parameterKey1) - .addParameter(parameterKey1, new IntegerParameter(parameterKey1, 1, v -> v > 0)) - .build() + methodComponentName1, + MethodComponent.Builder.builder(parameterKey1) + .addParameter(parameterKey1, new IntegerParameter(parameterKey1, 1, v -> v > 0)) + .build() ); - final MethodComponentContextParameter parameter = new MethodComponentContextParameter("test", - methodComponentContext, methodComponentMap); + final MethodComponentContextParameter parameter = new MethodComponentContextParameter( + "test", + methodComponentContext, + methodComponentMap + ); // Test when method component is available assertEquals(methodComponentMap.get(methodComponentName1), parameter.getMethodComponent(methodComponentName1)); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 0688f554e..fa811c2e4 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -65,7 +65,7 @@ /** * Test used for testing Codecs */ -public class KNNCodecTestCase extends KNNTestCase { +public class KNNCodecTestCase extends KNNTestCase { private static FieldType sampleFieldType; static { @@ -86,13 +86,8 @@ protected void setUpMockClusterService() { } protected ResourceWatcherService createDisabledResourceWatcherService() { - final Settings settings = Settings.builder() - .put("resource.reload.enabled", false) - .build(); - return new ResourceWatcherService( - settings, - null - ); + final Settings settings = Settings.builder().put("resource.reload.enabled", false).build(); + return new ResourceWatcherService(settings, null); } public void testFooter(Codec codec) throws Exception { @@ -102,7 +97,7 @@ public void testFooter(Codec codec) throws Exception { iwc.setMergeScheduler(new SerialMergeScheduler()); iwc.setCodec(codec); - float[] array = {1.0f, 2.0f, 3.0f}; + float[] array = { 1.0f, 2.0f, 3.0f }; VectorField vectorField = new VectorField("test_vector", array, sampleFieldType); RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); Document doc = new Document(); @@ -115,11 +110,14 @@ public void testFooter(Codec codec) throws Exception { LeafReaderContext lrc = reader.getContext().leaves().iterator().next(); // leaf reader context SegmentReader segmentReader = (SegmentReader) FilterLeafReader.unwrap(lrc.reader()); String hnswFileExtension = segmentReader.getSegmentInfo().info.getUseCompoundFile() - ? KNNEngine.NMSLIB.getCompoundExtension() : KNNEngine.NMSLIB.getExtension(); + ? KNNEngine.NMSLIB.getCompoundExtension() + : KNNEngine.NMSLIB.getExtension(); String hnswSuffix = "test_vector" + hnswFileExtension; - List hnswFiles = segmentReader.getSegmentInfo().files().stream() - .filter(fileName -> fileName.endsWith(hnswSuffix)) - .collect(Collectors.toList()); + List hnswFiles = segmentReader.getSegmentInfo() + .files() + .stream() + .filter(fileName -> fileName.endsWith(hnswSuffix)) + .collect(Collectors.toList()); assertTrue(!hnswFiles.isEmpty()); ChecksumIndexInput indexInput = dir.openChecksumInput(hnswFiles.get(0), IOContext.DEFAULT); indexInput.seek(indexInput.length() - CodecUtil.footerLength()); @@ -127,7 +125,7 @@ public void testFooter(Codec codec) throws Exception { indexInput.close(); IndexSearcher searcher = new IndexSearcher(reader); - assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] {1.0f, 2.5f}, 1, "myindex"))); + assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] { 1.0f, 2.5f }, 1, "myindex"))); reader.close(); writer.close(); @@ -146,7 +144,7 @@ public void testMultiFieldsKnnIndex(Codec codec) throws Exception { /** * Add doc with field "test_vector" */ - float[] array = {1.0f, 3.0f, 4.0f}; + float[] array = { 1.0f, 3.0f, 4.0f }; VectorField vectorField = new VectorField("test_vector", array, sampleFieldType); RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); Document doc = new Document(); @@ -161,7 +159,7 @@ public void testMultiFieldsKnnIndex(Codec codec) throws Exception { iwc1.setMergeScheduler(new SerialMergeScheduler()); iwc1.setCodec(new KNN87Codec()); writer = new RandomIndexWriter(random(), dir, iwc1); - float[] array1 = {6.0f, 14.0f}; + float[] array1 = { 6.0f, 14.0f }; VectorField vectorField1 = new VectorField("my_vector", array1, sampleFieldType); Document doc1 = new Document(); doc1.add(vectorField1); @@ -179,14 +177,14 @@ public void testMultiFieldsKnnIndex(Codec codec) throws Exception { // query to verify distance for each of the field IndexSearcher searcher = new IndexSearcher(reader); - float score = searcher.search(new KNNQuery("test_vector", new float[] {1.0f, 0.0f, 0.0f}, 1, "dummy"), 10).scoreDocs[0].score; - float score1 = searcher.search(new KNNQuery("my_vector", new float[] {1.0f, 2.0f}, 1, "dummy"), 10).scoreDocs[0].score; - assertEquals(1.0f/(1 + 25), score, 0.01f); - assertEquals(1.0f/(1 + 169), score1, 0.01f); + float score = searcher.search(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy"), 10).scoreDocs[0].score; + float score1 = searcher.search(new KNNQuery("my_vector", new float[] { 1.0f, 2.0f }, 1, "dummy"), 10).scoreDocs[0].score; + assertEquals(1.0f / (1 + 25), score, 0.01f); + assertEquals(1.0f / (1 + 169), score1, 0.01f); // query to determine the hits - assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] {1.0f, 0.0f, 0.0f}, 1, "dummy"))); - assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] {1.0f, 1.0f}, 1, "dummy"))); + assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy"))); + assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] { 1.0f, 1.0f }, 1, "dummy"))); reader.close(); dir.close(); @@ -203,25 +201,33 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio // "Train" a faiss flat index - this really just creates an empty index that does brute force k-NN long vectorsPointer = JNIService.transferVectors(0, new float[0][0]); - byte [] modelBlob = JNIService.trainIndex(ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, "Flat", - SPACE_TYPE, spaceType.getValue()), dimension, vectorsPointer, - KNNEngine.FAISS.getName()); + byte[] modelBlob = JNIService.trainIndex( + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, "Flat", SPACE_TYPE, spaceType.getValue()), + dimension, + vectorsPointer, + KNNEngine.FAISS.getName() + ); // Setup model cache ModelDao modelDao = mock(ModelDao.class); // Set model state to created - ModelMetadata modelMetadata1 = new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata1 = new ModelMetadata( + knnEngine, + spaceType, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); Model mockModel = new Model(modelMetadata1, modelBlob, modelId); when(modelDao.get(modelId)).thenReturn(mockModel); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata1); Settings settings = settings(CURRENT).put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), "10%").build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); ClusterService clusterService = mock(ClusterService.class); when(clusterService.getSettings()).thenReturn(settings); @@ -242,12 +248,7 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio fieldType.freeze(); // Add the documents to the index - float[][] arrays = { - {1.0f, 3.0f, 4.0f}, - {2.0f, 5.0f, 8.0f}, - {3.0f, 6.0f, 9.0f}, - {4.0f, 7.0f, 10.0f} - }; + float[][] arrays = { { 1.0f, 3.0f, 4.0f }, { 2.0f, 5.0f, 8.0f }, { 3.0f, 6.0f, 9.0f }, { 4.0f, 7.0f, 10.0f } }; RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); String fieldName = "test_vector"; @@ -265,7 +266,7 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio KNNWeight.initialize(modelDao); ResourceWatcherService resourceWatcherService = createDisabledResourceWatcherService(); NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService); - float [] query = {10.0f, 10.0f, 10.0f}; + float[] query = { 10.0f, 10.0f, 10.0f }; IndexSearcher searcher = new IndexSearcher(reader); TopDocs topDocs = searcher.search(new KNNQuery(fieldName, query, 4, "dummy"), 10); @@ -280,4 +281,3 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); } } - diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNVectorSerializerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNNVectorSerializerTests.java index 4bc62bebf..00159677c 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNVectorSerializerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNVectorSerializerTests.java @@ -22,8 +22,8 @@ public class KNNVectorSerializerTests extends KNNTestCase { Random random = new Random(); public void testVectorSerializerFactory() throws Exception { - //check that default serializer can work with array of floats - //setup + // check that default serializer can work with array of floats + // setup final float[] vector = getArrayOfRandomFloats(20); final ByteArrayOutputStream bas = new ByteArrayOutputStream(); final DataOutputStream ds = new DataOutputStream(bas); @@ -40,19 +40,18 @@ public void testVectorSerializerFactory() throws Exception { assertNotNull(actualDeserializedVector); assertArrayEquals(vector, actualDeserializedVector, 0.1f); - final KNNVectorSerializer arraySerializer = - KNNVectorSerializerFactory.getSerializerBySerializationMode(SerializationMode.ARRAY); + final KNNVectorSerializer arraySerializer = KNNVectorSerializerFactory.getSerializerBySerializationMode(SerializationMode.ARRAY); assertNotNull(arraySerializer); - final KNNVectorSerializer collectionOfFloatsSerializer = - KNNVectorSerializerFactory.getSerializerBySerializationMode(SerializationMode.COLLECTION_OF_FLOATS); + final KNNVectorSerializer collectionOfFloatsSerializer = KNNVectorSerializerFactory.getSerializerBySerializationMode( + SerializationMode.COLLECTION_OF_FLOATS + ); assertNotNull(collectionOfFloatsSerializer); } - public void testVectorSerializerFactory_throwExceptionForStreamWithUnsupportedDataType() throws Exception { - //prepare array of chars that is not supported by serializer factory. expected behavior is to fail - final char[] arrayOfChars = new char[] {'a', 'b', 'c'}; + // prepare array of chars that is not supported by serializer factory. expected behavior is to fail + final char[] arrayOfChars = new char[] { 'a', 'b', 'c' }; final ByteArrayOutputStream bas = new ByteArrayOutputStream(); final DataOutputStream ds = new DataOutputStream(bas); for (char ch : arrayOfChars) @@ -75,14 +74,14 @@ public void testVectorAsArraySerializer() throws Exception { final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(bais); - //testing serialization + // testing serialization bais.reset(); final byte[] actualSerializedVector = vectorSerializer.floatToByteArray(vector); assertNotNull(actualSerializedVector); assertArrayEquals(serializedVector, actualSerializedVector); - //testing deserialization + // testing deserialization bais.reset(); final float[] actualDeserializedVector = vectorSerializer.byteToFloatArray(bais); @@ -91,7 +90,7 @@ public void testVectorAsArraySerializer() throws Exception { } public void testVectorAsCollectionOfFloatsSerializer() throws Exception { - //setup + // setup final float[] vector = getArrayOfRandomFloats(20); final ByteArrayOutputStream bas = new ByteArrayOutputStream(); @@ -103,14 +102,14 @@ public void testVectorAsCollectionOfFloatsSerializer() throws Exception { final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(bais); - //testing serialization + // testing serialization bais.reset(); final byte[] actualSerializedVector = vectorSerializer.floatToByteArray(vector); assertNotNull(actualSerializedVector); assertArrayEquals(vectorAsCollectionOfFloats, actualSerializedVector); - //testing deserialization + // testing deserialization bais.reset(); final float[] actualDeserializedVector = vectorSerializer.byteToFloatArray(bais); diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java index cfecd0413..fbd519163 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java @@ -63,25 +63,25 @@ public void testIndexAllocation_close() throws InterruptedException { ExecutorService executorService = Executors.newSingleThreadExecutor(); NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( - executorService, - memoryAddress, - IndexUtil.getFileSizeInKB(path), - knnEngine, - path, - "test", - watcherHandle + executorService, + memoryAddress, + IndexUtil.getFileSizeInKB(path), + knnEngine, + path, + "test", + watcherHandle ); indexAllocation.close(); - Thread.sleep(1000*2); + Thread.sleep(1000 * 2); indexAllocation.writeLock(); assertTrue(indexAllocation.isClosed()); indexAllocation.writeUnlock(); indexAllocation.close(); - Thread.sleep(1000*2); + Thread.sleep(1000 * 2); indexAllocation.writeLock(); assertTrue(indexAllocation.isClosed()); indexAllocation.writeUnlock(); @@ -92,13 +92,13 @@ public void testIndexAllocation_close() throws InterruptedException { public void testIndexAllocation_getMemoryAddress() { long memoryAddress = 12; NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( - null, - memoryAddress, - 0, - null, - "test", - "test", - null + null, + memoryAddress, + 0, + null, + "test", + "test", + null ); assertEquals(memoryAddress, indexAllocation.getMemoryAddress()); @@ -108,13 +108,13 @@ public void testIndexAllocation_readLock() throws InterruptedException { // To test the readLock, we grab the readLock in the main thread and then start a thread that grabs the write // lock and updates testLockValue1. We ensure that the value is not updated until after we release the readLock NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( - null, - 0, - 0, - null, - "test", - "test", - null + null, + 0, + 0, + null, + "test", + "test", + null ); int initialValue = 10; @@ -144,13 +144,13 @@ public void testIndexAllocation_writeLock() throws InterruptedException { // grabs the readLock and asserts testLockValue2 has been updated. Next in the main thread, we update the value // and release the writeLock. NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( - null, - 0, - 0, - null, - "test", - "test", - null + null, + 0, + 0, + null, + "test", + "test", + null ); int initialValue = 10; @@ -177,13 +177,13 @@ public void testIndexAllocation_writeLock() throws InterruptedException { public void testIndexAllocation_getSize() { int size = 12; NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( - null, - 0, - size, - null, - "test", - "test", - null + null, + 0, + size, + null, + "test", + "test", + null ); assertEquals(size, indexAllocation.getSizeInKB()); @@ -192,13 +192,13 @@ public void testIndexAllocation_getSize() { public void testIndexAllocation_getKnnEngine() { KNNEngine knnEngine = KNNEngine.DEFAULT; NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( - null, - 0, - 0, - knnEngine, - "test", - "test", - null + null, + 0, + 0, + knnEngine, + "test", + "test", + null ); assertEquals(knnEngine, indexAllocation.getKnnEngine()); @@ -207,13 +207,13 @@ public void testIndexAllocation_getKnnEngine() { public void testIndexAllocation_getIndexPath() { String indexPath = "test-path"; NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( - null, - 0, - 0, - null, - indexPath, - "test", - null + null, + 0, + 0, + null, + indexPath, + "test", + null ); assertEquals(indexPath, indexAllocation.getIndexPath()); @@ -222,13 +222,13 @@ public void testIndexAllocation_getIndexPath() { public void testIndexAllocation_getOsIndexName() { String osIndexName = "test-index"; NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( - null, - 0, - 0, - null, - "test", - osIndexName, - null + null, + 0, + 0, + null, + "test", + osIndexName, + null ); assertEquals(osIndexName, indexAllocation.getOpenSearchIndexName()); @@ -246,22 +246,21 @@ public void testTrainingDataAllocation_close() throws InterruptedException { ExecutorService executorService = Executors.newSingleThreadExecutor(); NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( - executorService, - memoryAddress, - 0 + executorService, + memoryAddress, + 0 ); - trainingDataAllocation.close(); - Thread.sleep(1000*2); + Thread.sleep(1000 * 2); trainingDataAllocation.writeLock(); assertTrue(trainingDataAllocation.isClosed()); trainingDataAllocation.writeUnlock(); trainingDataAllocation.close(); - Thread.sleep(1000*2); + Thread.sleep(1000 * 2); trainingDataAllocation.writeLock(); assertTrue(trainingDataAllocation.isClosed()); trainingDataAllocation.writeUnlock(); @@ -273,9 +272,9 @@ public void testTrainingDataAllocation_getMemoryAddress() { long memoryAddress = 12; NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( - null, - memoryAddress, - 0 + null, + memoryAddress, + 0 ); assertEquals(memoryAddress, trainingDataAllocation.getMemoryAddress()); @@ -286,9 +285,9 @@ public void testTrainingDataAllocation_readLock() throws InterruptedException { // updates testLockValue3. We then assert that while we hold the readLock, the value is not updated. After we // release the readLock, the value should be updated. NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( - null, - 0, - 0 + null, + 0, + 0 ); int initialValue = 10; @@ -319,9 +318,9 @@ public void testTrainingDataAllocation_writeLock() throws InterruptedException { // asserts that testLockValue4 is set to finalValue and then start another thread that updates testLockValue4 // and releases the writeLock. NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( - null, - 0, - 0 + null, + 0, + 0 ); int initialValue = 10; @@ -354,9 +353,9 @@ public void testTrainingDataAllocation_getSize() { int size = 12; NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( - null, - 0, - size + null, + 0, + size ); assertEquals(size, trainingDataAllocation.getSizeInKB()); @@ -366,9 +365,9 @@ public void testTrainingDataAllocation_setMemoryAddress() { long pointer = 12; NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( - null, - pointer, - 0 + null, + pointer, + 0 ); assertEquals(pointer, trainingDataAllocation.getMemoryAddress()); @@ -381,14 +380,11 @@ public void testTrainingDataAllocation_setMemoryAddress() { public void testAnonymousAllocation_close() throws InterruptedException { ExecutorService executorService = Executors.newSingleThreadExecutor(); - NativeMemoryAllocation.AnonymousAllocation anonymousAllocation = new NativeMemoryAllocation.AnonymousAllocation( - executorService, - 0 - ); + NativeMemoryAllocation.AnonymousAllocation anonymousAllocation = new NativeMemoryAllocation.AnonymousAllocation(executorService, 0); anonymousAllocation.close(); - Thread.sleep(1000*2); + Thread.sleep(1000 * 2); anonymousAllocation.writeLock(); assertTrue(anonymousAllocation.isClosed()); anonymousAllocation.writeUnlock(); @@ -398,10 +394,7 @@ public void testAnonymousAllocation_close() throws InterruptedException { public void testAnonymousAllocation_getSize() { int size = 12; - NativeMemoryAllocation.AnonymousAllocation anonymousAllocation = new NativeMemoryAllocation.AnonymousAllocation( - null, - size - ); + NativeMemoryAllocation.AnonymousAllocation anonymousAllocation = new NativeMemoryAllocation.AnonymousAllocation(null, size); assertEquals(size, anonymousAllocation.getSizeInKB()); } diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java index 079e41e56..718df0b1f 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java @@ -37,9 +37,7 @@ public class NativeMemoryCacheManagerTests extends OpenSearchSingleNodeTestCase public void tearDown() throws Exception { // Clear out persistent metadata ClusterUpdateSettingsRequest clusterUpdateSettingsRequest = new ClusterUpdateSettingsRequest(); - Settings circuitBreakerSettings = Settings.builder() - .putNull(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED) - .build(); + Settings circuitBreakerSettings = Settings.builder().putNull(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED).build(); clusterUpdateSettingsRequest.persistentSettings(circuitBreakerSettings); client().admin().cluster().updateSettings(clusterUpdateSettingsRequest).get(); super.tearDown(); @@ -92,13 +90,11 @@ public void testGetCacheSizeAsPercentage() throws ExecutionException { long maxWeight = nativeMemoryCacheManager.getMaxCacheSizeInKilobytes(); int entryWeight = (int) (maxWeight / 3); - TestNativeMemoryEntryContent testNativeMemoryEntryContent = new TestNativeMemoryEntryContent( - "test-1", entryWeight, 0); + TestNativeMemoryEntryContent testNativeMemoryEntryContent = new TestNativeMemoryEntryContent("test-1", entryWeight, 0); nativeMemoryCacheManager.get(testNativeMemoryEntryContent, true); - assertEquals(100 * (float) entryWeight / (float) maxWeight, nativeMemoryCacheManager.getCacheSizeAsPercentage(), - 0.001); + assertEquals(100 * (float) entryWeight / (float) maxWeight, nativeMemoryCacheManager.getCacheSizeAsPercentage(), 0.001); nativeMemoryCacheManager.close(); } @@ -108,8 +104,7 @@ public void testGetIndexSizeInKilobytes() throws ExecutionException, IOException int genericEntryWeight = 100; int indexEntryWeight = 20; - TestNativeMemoryEntryContent testNativeMemoryEntryContent = new TestNativeMemoryEntryContent( - "test-1", genericEntryWeight, 0); + TestNativeMemoryEntryContent testNativeMemoryEntryContent = new TestNativeMemoryEntryContent("test-1", genericEntryWeight, 0); nativeMemoryCacheManager.get(testNativeMemoryEntryContent, true); @@ -117,13 +112,13 @@ public void testGetIndexSizeInKilobytes() throws ExecutionException, IOException String key = "test-key"; NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( - null, - 0, - indexEntryWeight, - null, - key, - indexName, - null + null, + 0, + indexEntryWeight, + null, + key, + indexName, + null ); NativeMemoryEntryContext.IndexEntryContext indexEntryContext = mock(NativeMemoryEntryContext.IndexEntryContext.class); @@ -143,8 +138,7 @@ public void testGetIndexSizeAsPercentage() throws ExecutionException, IOExceptio int genericEntryWeight = (int) (maxWeight / 3); int indexEntryWeight = (int) (maxWeight / 3); - TestNativeMemoryEntryContent testNativeMemoryEntryContent = new TestNativeMemoryEntryContent( - "test-1", genericEntryWeight, 0); + TestNativeMemoryEntryContent testNativeMemoryEntryContent = new TestNativeMemoryEntryContent("test-1", genericEntryWeight, 0); nativeMemoryCacheManager.get(testNativeMemoryEntryContent, true); @@ -152,13 +146,13 @@ public void testGetIndexSizeAsPercentage() throws ExecutionException, IOExceptio String key = "test-key"; NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( - null, - 0, - indexEntryWeight, - null, - key, - indexName, - null + null, + 0, + indexEntryWeight, + null, + key, + indexName, + null ); NativeMemoryEntryContext.IndexEntryContext indexEntryContext = mock(NativeMemoryEntryContext.IndexEntryContext.class); @@ -167,8 +161,11 @@ public void testGetIndexSizeAsPercentage() throws ExecutionException, IOExceptio nativeMemoryCacheManager.get(indexEntryContext, true); - assertEquals(100 * (float) indexEntryWeight / (float) maxWeight, - nativeMemoryCacheManager.getIndexSizeAsPercentage(indexName), 0.001); + assertEquals( + 100 * (float) indexEntryWeight / (float) maxWeight, + nativeMemoryCacheManager.getIndexSizeAsPercentage(indexName), + 0.001 + ); nativeMemoryCacheManager.close(); } @@ -179,8 +176,7 @@ public void testGetTrainingSize() throws ExecutionException { int genericEntryWeight = (int) (maxWeight / 3); int allocationEntryWeight = (int) (maxWeight / 3); - TestNativeMemoryEntryContent testNativeMemoryEntryContent = new TestNativeMemoryEntryContent( - "test-1", genericEntryWeight, 0); + TestNativeMemoryEntryContent testNativeMemoryEntryContent = new TestNativeMemoryEntryContent("test-1", genericEntryWeight, 0); nativeMemoryCacheManager.get(testNativeMemoryEntryContent, true); @@ -188,21 +184,25 @@ public void testGetTrainingSize() throws ExecutionException { String key = "test-key"; NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( - null, - 0, - allocationEntryWeight + null, + 0, + allocationEntryWeight ); - NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = - mock(NativeMemoryEntryContext.TrainingDataEntryContext.class); + NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = mock( + NativeMemoryEntryContext.TrainingDataEntryContext.class + ); when(trainingDataEntryContext.load()).thenReturn(trainingDataAllocation); when(trainingDataEntryContext.getKey()).thenReturn(key); nativeMemoryCacheManager.get(trainingDataEntryContext, true); assertEquals((float) allocationEntryWeight, nativeMemoryCacheManager.getTrainingSizeInKilobytes(), 0.001); - assertEquals(100 * (float) allocationEntryWeight / (float) maxWeight, - nativeMemoryCacheManager.getTrainingSizeAsPercentage(), 0.001); + assertEquals( + 100 * (float) allocationEntryWeight / (float) maxWeight, + nativeMemoryCacheManager.getTrainingSizeAsPercentage(), + 0.001 + ); nativeMemoryCacheManager.close(); } @@ -213,8 +213,7 @@ public void testGetIndexGraphCount() throws ExecutionException, IOException { int genericEntryWeight = (int) (maxWeight / 3); int indexEntryWeight = (int) (maxWeight / 3); - TestNativeMemoryEntryContent testNativeMemoryEntryContent = new TestNativeMemoryEntryContent( - "test-1", genericEntryWeight, 0); + TestNativeMemoryEntryContent testNativeMemoryEntryContent = new TestNativeMemoryEntryContent("test-1", genericEntryWeight, 0); nativeMemoryCacheManager.get(testNativeMemoryEntryContent, true); @@ -225,13 +224,13 @@ public void testGetIndexGraphCount() throws ExecutionException, IOException { String key3 = "test-key-3"; NativeMemoryAllocation.IndexAllocation indexAllocation1 = new NativeMemoryAllocation.IndexAllocation( - null, - 0, - indexEntryWeight, - null, - key1, - indexName1, - null + null, + 0, + indexEntryWeight, + null, + key1, + indexName1, + null ); NativeMemoryEntryContext.IndexEntryContext indexEntryContext = mock(NativeMemoryEntryContext.IndexEntryContext.class); @@ -241,13 +240,13 @@ public void testGetIndexGraphCount() throws ExecutionException, IOException { nativeMemoryCacheManager.get(indexEntryContext, true); NativeMemoryAllocation.IndexAllocation indexAllocation2 = new NativeMemoryAllocation.IndexAllocation( - null, - 0, - indexEntryWeight, - null, - key2, - indexName1, - null + null, + 0, + indexEntryWeight, + null, + key2, + indexName1, + null ); indexEntryContext = mock(NativeMemoryEntryContext.IndexEntryContext.class); @@ -257,13 +256,13 @@ public void testGetIndexGraphCount() throws ExecutionException, IOException { nativeMemoryCacheManager.get(indexEntryContext, true); NativeMemoryAllocation.IndexAllocation indexAllocation3 = new NativeMemoryAllocation.IndexAllocation( - null, - 0, - indexEntryWeight, - null, - key3, - indexName2, - null + null, + 0, + indexEntryWeight, + null, + key3, + indexName2, + null ); indexEntryContext = mock(NativeMemoryEntryContext.IndexEntryContext.class); @@ -311,8 +310,7 @@ public void testGet_evictable() throws ExecutionException { TestNativeMemoryEntryContent testNativeMemoryEntryContent1 = new TestNativeMemoryEntryContent("test-1", size, pointer); - NativeMemoryAllocation testNativeMemoryAllocation = nativeMemoryCacheManager.get(testNativeMemoryEntryContent1, - true); + NativeMemoryAllocation testNativeMemoryAllocation = nativeMemoryCacheManager.get(testNativeMemoryEntryContent1, true); assertEquals(size, nativeMemoryCacheManager.getCacheSizeInKilobytes()); assertEquals(size, testNativeMemoryAllocation.getSizeInKB()); assertEquals(pointer, testNativeMemoryAllocation.getMemoryAddress()); @@ -325,12 +323,12 @@ public void testGet_unevictable() throws ExecutionException { NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager(); int maxWeight = (int) nativeMemoryCacheManager.getMaxCacheSizeInKilobytes(); - TestNativeMemoryEntryContent testNativeMemoryEntryContent1 = new TestNativeMemoryEntryContent("test-1", maxWeight/2); + TestNativeMemoryEntryContent testNativeMemoryEntryContent1 = new TestNativeMemoryEntryContent("test-1", maxWeight / 2); nativeMemoryCacheManager.get(testNativeMemoryEntryContent1, true); // Then, add another entry that would overflow the cache TestNativeMemoryEntryContent testNativeMemoryEntryContent2 = new TestNativeMemoryEntryContent("test-2", maxWeight); - expectThrows(OutOfNativeMemoryException.class, () ->nativeMemoryCacheManager.get(testNativeMemoryEntryContent2, false)); + expectThrows(OutOfNativeMemoryException.class, () -> nativeMemoryCacheManager.get(testNativeMemoryEntryContent2, false)); nativeMemoryCacheManager.close(); } @@ -403,43 +401,43 @@ public void testGetIndicesCacheStats() throws IOException, ExecutionException { int size2 = 5; NativeMemoryAllocation.IndexAllocation indexAllocation1 = new NativeMemoryAllocation.IndexAllocation( - null, - 0, - size1, - null, - testKey1, - indexName1, - null + null, + 0, + size1, + null, + testKey1, + indexName1, + null ); NativeMemoryAllocation.IndexAllocation indexAllocation2 = new NativeMemoryAllocation.IndexAllocation( - null, - 0, - size2, - null, - testKey2, - indexName1, - null + null, + 0, + size2, + null, + testKey2, + indexName1, + null ); NativeMemoryAllocation.IndexAllocation indexAllocation3 = new NativeMemoryAllocation.IndexAllocation( - null, - 0, - size1, - null, - testKey3, - indexName2, - null + null, + 0, + size1, + null, + testKey3, + indexName2, + null ); NativeMemoryAllocation.IndexAllocation indexAllocation4 = new NativeMemoryAllocation.IndexAllocation( - null, - 0, - size2, - null, - testKey4, - indexName2, - null + null, + 0, + size2, + null, + testKey4, + indexName2, + null ); NativeMemoryEntryContext.IndexEntryContext indexEntryContext1 = mock(NativeMemoryEntryContext.IndexEntryContext.class); @@ -468,7 +466,7 @@ public void testGetIndicesCacheStats() throws IOException, ExecutionException { assertEquals(2, indicesStats.get(indexName1).get(GRAPH_COUNT)); assertEquals(2, indicesStats.get(indexName2).get(GRAPH_COUNT)); assertEquals((long) (size1 + size2), indicesStats.get(indexName1).get(GRAPH_MEMORY_USAGE.getName())); - assertEquals((long)size1 + size2, indicesStats.get(indexName2).get(GRAPH_MEMORY_USAGE.getName())); + assertEquals((long) size1 + size2, indicesStats.get(indexName2).get(GRAPH_MEMORY_USAGE.getName())); nativeMemoryCacheManager.close(); } diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java index b8fd05d1e..495f20347 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java @@ -13,7 +13,6 @@ import com.google.common.collect.ImmutableMap; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.indices.IndicesService; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.util.KNNEngine; @@ -43,20 +42,20 @@ public void testAbstract_getKey() { public void testIndexEntryContext_load() throws IOException { NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy = mock(NativeMemoryLoadStrategy.IndexLoadStrategy.class); NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( - "test", - indexLoadStrategy, - null, - "test" + "test", + indexLoadStrategy, + null, + "test" ); NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( - null, - 0, - 10, - KNNEngine.DEFAULT, - "test-path", - "test-name", - null + null, + 0, + 10, + KNNEngine.DEFAULT, + "test-path", + "test-name", + null ); when(indexLoadStrategy.load(indexEntryContext)).thenReturn(indexAllocation); @@ -67,11 +66,10 @@ public void testIndexEntryContext_load() throws IOException { public void testIndexEntryContext_calculateSize() throws IOException { // Create a file and write random bytes to it Path tmpFile = createTempFile(); - byte[] data = new byte[1024*3]; + byte[] data = new byte[1024 * 3]; Arrays.fill(data, (byte) 'c'); - try (OutputStream out = new BufferedOutputStream( - Files.newOutputStream(tmpFile, CREATE, APPEND))) { + try (OutputStream out = new BufferedOutputStream(Files.newOutputStream(tmpFile, CREATE, APPEND))) { out.write(data, 0, data.length); } catch (IOException x) { fail("Failed to write to file"); @@ -82,10 +80,10 @@ public void testIndexEntryContext_calculateSize() throws IOException { // Check that the indexEntryContext will return the same thing NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( - tmpFile.toAbsolutePath().toString(), - null, - null, - "test" + tmpFile.toAbsolutePath().toString(), + null, + null, + "test" ); assertEquals(expectedSize, indexEntryContext.calculateSizeInKB().longValue()); @@ -94,10 +92,10 @@ public void testIndexEntryContext_calculateSize() throws IOException { public void testIndexEntryContext_getOpenSearchIndexName() { String openSearchIndexName = "test-index"; NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( - "test", - null, - null, - openSearchIndexName + "test", + null, + null, + openSearchIndexName ); assertEquals(openSearchIndexName, indexEntryContext.getOpenSearchIndexName()); @@ -106,10 +104,10 @@ public void testIndexEntryContext_getOpenSearchIndexName() { public void testIndexEntryContext_getParameters() { Map parameters = ImmutableMap.of("test-1", 10); NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( - "test", - null, - parameters, - "test" + "test", + null, + parameters, + "test" ); assertEquals(parameters, indexEntryContext.getParameters()); @@ -118,19 +116,19 @@ public void testIndexEntryContext_getParameters() { public void testTrainingDataEntryContext_load() { NativeMemoryLoadStrategy.TrainingLoadStrategy trainingLoadStrategy = mock(NativeMemoryLoadStrategy.TrainingLoadStrategy.class); NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = new NativeMemoryEntryContext.TrainingDataEntryContext( - 0, - "test", - "test", - trainingLoadStrategy, - null, - 0, - 0 + 0, + "test", + "test", + trainingLoadStrategy, + null, + 0, + 0 ); NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = new NativeMemoryAllocation.TrainingDataAllocation( - null, - 0, - 0 + null, + 0, + 0 ); when(trainingLoadStrategy.load(trainingDataEntryContext)).thenReturn(trainingDataAllocation); @@ -141,13 +139,13 @@ public void testTrainingDataEntryContext_load() { public void testTrainingDataEntryContext_getTrainIndexName() { String trainIndexName = "test-index"; NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = new NativeMemoryEntryContext.TrainingDataEntryContext( - 0, - trainIndexName, - "test", - null, - null, - 0, - 0 + 0, + trainIndexName, + "test", + null, + null, + 0, + 0 ); assertEquals(trainIndexName, trainingDataEntryContext.getTrainIndexName()); @@ -156,13 +154,13 @@ public void testTrainingDataEntryContext_getTrainIndexName() { public void testTrainingDataEntryContext_getTrainFieldName() { String trainFieldName = "test-field"; NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = new NativeMemoryEntryContext.TrainingDataEntryContext( - 0, - "test", - trainFieldName, - null, - null, - 0, - 0 + 0, + "test", + trainFieldName, + null, + null, + 0, + 0 ); assertEquals(trainFieldName, trainingDataEntryContext.getTrainFieldName()); @@ -171,13 +169,13 @@ public void testTrainingDataEntryContext_getTrainFieldName() { public void testTrainingDataEntryContext_getMaxVectorCount() { int maxVectorCount = 11; NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = new NativeMemoryEntryContext.TrainingDataEntryContext( - 0, - "test", - "test", - null, - null, - maxVectorCount, - 0 + 0, + "test", + "test", + null, + null, + maxVectorCount, + 0 ); assertEquals(maxVectorCount, trainingDataEntryContext.getMaxVectorCount()); @@ -186,13 +184,13 @@ public void testTrainingDataEntryContext_getMaxVectorCount() { public void testTrainingDataEntryContext_getSearchSize() { int searchSize = 11; NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = new NativeMemoryEntryContext.TrainingDataEntryContext( - 0, - "test", - "test", - null, - null, - 0, - searchSize + 0, + "test", + "test", + null, + null, + 0, + searchSize ); assertEquals(searchSize, trainingDataEntryContext.getSearchSize()); @@ -201,13 +199,13 @@ public void testTrainingDataEntryContext_getSearchSize() { public void testTrainingDataEntryContext_getIndicesService() { ClusterService clusterService = mock(ClusterService.class); NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = new NativeMemoryEntryContext.TrainingDataEntryContext( - 0, - "test", - "test", - null, - clusterService, - 0, - 0 + 0, + "test", + "test", + null, + clusterService, + 0, + 0 ); assertEquals(clusterService, trainingDataEntryContext.getClusterService()); diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index 0d623adff..f870333a4 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -61,14 +61,15 @@ public void testIndexLoadStrategy_load() throws IOException { NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService); NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( - path, - NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), - parameters, - "test" + path, + NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), + parameters, + "test" ); // Load - NativeMemoryAllocation.IndexAllocation indexAllocation = NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().load(indexEntryContext); + NativeMemoryAllocation.IndexAllocation indexAllocation = NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance() + .load(indexEntryContext); // Confirm that the file was loaded by querying float[] query = new float[dimension]; @@ -83,7 +84,7 @@ public void testTrainingLoadStrategy_load() { // listener onResponse to release the write lock VectorReader vectorReader = mock(VectorReader.class); ArrayList vectors = new ArrayList<>(); - vectors.add(new Float[]{1.0F, 2.0F}); + vectors.add(new Float[] { 1.0F, 2.0F }); logger.info("J0"); doAnswer(invocationOnMock -> { logger.info("J1"); @@ -107,17 +108,19 @@ public void testTrainingLoadStrategy_load() { NativeMemoryLoadStrategy.TrainingLoadStrategy.initialize(vectorReader); NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = new NativeMemoryEntryContext.TrainingDataEntryContext( - 0, - "test", - "test", - NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance(), - null, - 0, - 0); + 0, + "test", + "test", + NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance(), + null, + 0, + 0 + ); // Load the allocation. Initially, the memory address should be 0. However, after the readlock is obtained, // the memory address should not be 0. - NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance().load(trainingDataEntryContext); + NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance() + .load(trainingDataEntryContext); assertEquals(0, trainingDataAllocation.getMemoryAddress()); trainingDataAllocation.readLock(); assertNotEquals(0, trainingDataAllocation.getMemoryAddress()); diff --git a/src/test/java/org/opensearch/knn/index/util/KNNLibraryTests.java b/src/test/java/org/opensearch/knn/index/util/KNNLibraryTests.java index 3167e1971..4f99c6833 100644 --- a/src/test/java/org/opensearch/knn/index/util/KNNLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/util/KNNLibraryTests.java @@ -41,8 +41,13 @@ public class KNNLibraryTests extends KNNTestCase { */ public void testNativeLibrary_getLatestBuildVersion() { String latestBuildVersion = "test-build-version"; - TestNativeLibrary testNativeLibrary = new TestNativeLibrary(Collections.emptyMap(), Collections.emptyMap(), - latestBuildVersion, "", ""); + TestNativeLibrary testNativeLibrary = new TestNativeLibrary( + Collections.emptyMap(), + Collections.emptyMap(), + latestBuildVersion, + "", + "" + ); assertEquals(latestBuildVersion, testNativeLibrary.getLatestBuildVersion()); } @@ -51,8 +56,7 @@ public void testNativeLibrary_getLatestBuildVersion() { */ public void testNativeLibrary_getLatestLibVersion() { String latestVersion = "test-lib-version"; - TestNativeLibrary testNativeLibrary = new TestNativeLibrary(Collections.emptyMap(), Collections.emptyMap(), - "", latestVersion, ""); + TestNativeLibrary testNativeLibrary = new TestNativeLibrary(Collections.emptyMap(), Collections.emptyMap(), "", latestVersion, ""); assertEquals(latestVersion, testNativeLibrary.getLatestLibVersion()); } @@ -61,8 +65,7 @@ public void testNativeLibrary_getLatestLibVersion() { */ public void testNativeLibrary_getExtension() { String extension = ".extension"; - TestNativeLibrary testNativeLibrary = new TestNativeLibrary(Collections.emptyMap(), Collections.emptyMap(), - "", "", extension); + TestNativeLibrary testNativeLibrary = new TestNativeLibrary(Collections.emptyMap(), Collections.emptyMap(), "", "", extension); assertEquals(extension, testNativeLibrary.getExtension()); } @@ -71,8 +74,7 @@ public void testNativeLibrary_getExtension() { */ public void testNativeLibrary_getCompoundExtension() { String extension = ".extension"; - TestNativeLibrary testNativeLibrary = new TestNativeLibrary(Collections.emptyMap(), Collections.emptyMap(), - "", "", extension); + TestNativeLibrary testNativeLibrary = new TestNativeLibrary(Collections.emptyMap(), Collections.emptyMap(), "", "", extension); assertEquals(extension + "c", testNativeLibrary.getCompoundExtension()); } @@ -81,19 +83,14 @@ public void testNativeLibrary_getCompoundExtension() { */ public void testNativeLibrary_getMethod() { String methodName1 = "test-method-1"; - KNNMethod knnMethod1 = KNNMethod.Builder.builder(MethodComponent.Builder.builder(methodName1).build()) - .build(); + KNNMethod knnMethod1 = KNNMethod.Builder.builder(MethodComponent.Builder.builder(methodName1).build()).build(); String methodName2 = "test-method-2"; - KNNMethod knnMethod2 = KNNMethod.Builder.builder(MethodComponent.Builder.builder(methodName2).build()) - .build(); + KNNMethod knnMethod2 = KNNMethod.Builder.builder(MethodComponent.Builder.builder(methodName2).build()).build(); - Map knnMethodMap = ImmutableMap.of( - methodName1, knnMethod1, methodName2, knnMethod2 - ); + Map knnMethodMap = ImmutableMap.of(methodName1, knnMethod1, methodName2, knnMethod2); - TestNativeLibrary testNativeLibrary = new TestNativeLibrary(knnMethodMap, Collections.emptyMap(), - "", "", ""); + TestNativeLibrary testNativeLibrary = new TestNativeLibrary(knnMethodMap, Collections.emptyMap(), "", "", ""); assertEquals(knnMethod1, testNativeLibrary.getMethod(methodName1)); assertEquals(knnMethod2, testNativeLibrary.getMethod(methodName2)); expectThrows(IllegalArgumentException.class, () -> testNativeLibrary.getMethod("invalid")); @@ -103,9 +100,8 @@ public void testNativeLibrary_getMethod() { * Test native library scoring override */ public void testNativeLibrary_score() { - Map> translationMap = ImmutableMap.of(SpaceType.L2, s -> s*2); - TestNativeLibrary testNativeLibrary = new TestNativeLibrary(Collections.emptyMap(), translationMap, - "", "", ""); + Map> translationMap = ImmutableMap.of(SpaceType.L2, s -> s * 2); + TestNativeLibrary testNativeLibrary = new TestNativeLibrary(Collections.emptyMap(), translationMap, "", "", ""); // Test override assertEquals(2f, testNativeLibrary.score(1f, SpaceType.L2), 0.0001); @@ -119,24 +115,19 @@ public void testNativeLibrary_score() { public void testNativeLibrary_validateMethod() throws IOException { // Invalid - method not supported String methodName1 = "test-method-1"; - KNNMethod knnMethod1 = KNNMethod.Builder.builder(MethodComponent.Builder.builder(methodName1).build()) - .build(); + KNNMethod knnMethod1 = KNNMethod.Builder.builder(MethodComponent.Builder.builder(methodName1).build()).build(); Map methodMap = ImmutableMap.of(methodName1, knnMethod1); - TestNativeLibrary testNativeLibrary1 = new TestNativeLibrary(methodMap, Collections.emptyMap(), - "", "", ""); + TestNativeLibrary testNativeLibrary1 = new TestNativeLibrary(methodMap, Collections.emptyMap(), "", "", ""); - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field(NAME, "invalid") - .endObject(); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, "invalid").endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext1 = KNNMethodContext.parse(in); expectThrows(IllegalArgumentException.class, () -> testNativeLibrary1.validateMethod(knnMethodContext1)); // Invalid - method validation String methodName2 = "test-method-2"; - KNNMethod knnMethod2 = new KNNMethod(MethodComponent.Builder.builder(methodName2).build(), - Collections.emptySet()) { + KNNMethod knnMethod2 = new KNNMethod(MethodComponent.Builder.builder(methodName2).build(), Collections.emptySet()) { @Override public ValidationException validate(KNNMethodContext knnMethodContext) { return new ValidationException(); @@ -144,11 +135,8 @@ public ValidationException validate(KNNMethodContext knnMethodContext) { }; methodMap = ImmutableMap.of(methodName2, knnMethod2); - TestNativeLibrary testNativeLibrary2 = new TestNativeLibrary(methodMap, Collections.emptyMap(), - "", "", ""); - xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field(NAME, methodName2) - .endObject(); + TestNativeLibrary testNativeLibrary2 = new TestNativeLibrary(methodMap, Collections.emptyMap(), "", "", ""); + xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, methodName2).endObject(); in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext2 = KNNMethodContext.parse(in); assertNotNull(testNativeLibrary2.validateMethod(knnMethodContext2)); @@ -159,23 +147,34 @@ public void testNativeLibrary_getMethodAsMap() { SpaceType spaceType = SpaceType.DEFAULT; Map generatedMap = ImmutableMap.of("test-key", "test-param"); MethodComponent methodComponent = MethodComponent.Builder.builder(methodName) - .setMapGenerator(((methodComponent1, methodComponentContext) -> generatedMap)) - .build(); + .setMapGenerator(((methodComponent1, methodComponentContext) -> generatedMap)) + .build(); KNNMethod knnMethod = KNNMethod.Builder.builder(methodComponent).build(); - TestNativeLibrary testNativeLibrary = new TestNativeLibrary(ImmutableMap.of(methodName, knnMethod), - Collections.emptyMap(), "", "", ""); + TestNativeLibrary testNativeLibrary = new TestNativeLibrary( + ImmutableMap.of(methodName, knnMethod), + Collections.emptyMap(), + "", + "", + "" + ); // Check that map is expected Map expectedMap = new HashMap<>(generatedMap); expectedMap.put(KNNConstants.SPACE_TYPE, spaceType.getValue()); - KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.DEFAULT, spaceType, - new MethodComponentContext(methodName, Collections.emptyMap())); + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.DEFAULT, + spaceType, + new MethodComponentContext(methodName, Collections.emptyMap()) + ); assertEquals(expectedMap, testNativeLibrary.getMethodAsMap(knnMethodContext)); // Check when invalid method is passed in - KNNMethodContext invalidKnnMethodContext = new KNNMethodContext(KNNEngine.DEFAULT, spaceType, - new MethodComponentContext("invalid", Collections.emptyMap())); + KNNMethodContext invalidKnnMethodContext = new KNNMethodContext( + KNNEngine.DEFAULT, + spaceType, + new MethodComponentContext("invalid", Collections.emptyMap()) + ); expectThrows(IllegalArgumentException.class, () -> testNativeLibrary.getMethodAsMap(invalidKnnMethodContext)); } @@ -191,18 +190,19 @@ public void testFaiss_methodAsMapBuilder() throws IOException { String parameter3 = "test-parameter-3"; Integer defaultValue3 = 3; MethodComponent methodComponent = MethodComponent.Builder.builder(methodName) - .addParameter(parameter1, new Parameter.IntegerParameter(parameter1, defaultValue1, value -> value > 0)) - .addParameter(parameter2, new Parameter.IntegerParameter(parameter2, defaultValue2, value -> value > 0)) - .addParameter(parameter3, new Parameter.IntegerParameter(parameter3, defaultValue3, value -> value > 0)) - .build(); - - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field(NAME, methodName) - .startObject(PARAMETERS) - .field(parameter1, value1) - .field(parameter2, value2) - .endObject() - .endObject(); + .addParameter(parameter1, new Parameter.IntegerParameter(parameter1, defaultValue1, value -> value > 0)) + .addParameter(parameter2, new Parameter.IntegerParameter(parameter2, defaultValue2, value -> value > 0)) + .addParameter(parameter3, new Parameter.IntegerParameter(parameter3, defaultValue3, value -> value > 0)) + .build(); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, methodName) + .startObject(PARAMETERS) + .field(parameter1, value1) + .field(parameter2, value2) + .endObject() + .endObject(); Map in = xContentBuilderToMap(xContentBuilder); MethodComponentContext methodComponentContext = MethodComponentContext.parse(in); @@ -214,10 +214,9 @@ public void testFaiss_methodAsMapBuilder() throws IOException { expectedMap.put(NAME, methodName); expectedMap.put(INDEX_DESCRIPTION_PARAMETER, methodDescription + value1); - Map methodAsMap = MethodAsMapBuilder - .builder(methodDescription, methodComponent, methodComponentContext) - .addParameter(parameter1, "", "") - .build(); + Map methodAsMap = MethodAsMapBuilder.builder(methodDescription, methodComponent, methodComponentContext) + .addParameter(parameter1, "", "") + .build(); assertEquals(expectedMap, methodAsMap); } @@ -232,9 +231,13 @@ static class TestNativeLibrary extends KNNLibrary.NativeLibrary { * @param latestLibraryVersion String representation of latest version of the library * @param extension String representing the extension that library files should use */ - public TestNativeLibrary(Map methods, - Map> scoreTranslation, - String latestLibraryBuildVersion, String latestLibraryVersion, String extension) { + public TestNativeLibrary( + Map methods, + Map> scoreTranslation, + String latestLibraryBuildVersion, + String latestLibraryVersion, + String extension + ) { super(methods, scoreTranslation, latestLibraryBuildVersion, latestLibraryVersion, extension); } } diff --git a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java index d4648822b..fb810d969 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java @@ -34,16 +34,26 @@ public class ModelCacheTests extends KNNTestCase { public void testGet_normal() throws ExecutionException, InterruptedException { String modelId = "test-model-id"; int dimension = 2; - Model mockModel = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), "hello".getBytes(), modelId); + Model mockModel = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + "hello".getBytes(), + modelId + ); String cacheSize = "10%"; ModelDao modelDao = mock(ModelDao.class); when(modelDao.get(modelId)).thenReturn(mockModel); Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize).build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); @@ -59,16 +69,25 @@ public void testGet_modelDoesNotFitInCache() throws ExecutionException, Interrup int dimension = 2; String cacheSize = "1kb"; - Model mockModel = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), - new byte[BYTES_PER_KILOBYTES + 1], modelId); + Model mockModel = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + new byte[BYTES_PER_KILOBYTES + 1], + modelId + ); ModelDao modelDao = mock(ModelDao.class); when(modelDao.get(modelId)).thenReturn(mockModel); Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize).build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); @@ -88,8 +107,7 @@ public void testGet_modelDoesNotExist() throws ExecutionException, InterruptedEx when(modelDao.get(modelId)).thenThrow(new IllegalArgumentException()); Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize).build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); @@ -107,20 +125,40 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException String cacheSize = "10%"; int size1 = BYTES_PER_KILOBYTES; - Model mockModel1 = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), new byte[size1], modelId1); + Model mockModel1 = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + new byte[size1], + modelId1 + ); int size2 = BYTES_PER_KILOBYTES * 3; - Model mockModel2 = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), new byte[size2], modelId2); + Model mockModel2 = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + new byte[size2], + modelId2 + ); ModelDao modelDao = mock(ModelDao.class); when(modelDao.get(modelId1)).thenReturn(mockModel1); when(modelDao.get(modelId2)).thenReturn(mockModel2); - Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize).build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); @@ -143,19 +181,40 @@ public void testRemove_normal() throws ExecutionException, InterruptedException String cacheSize = "10%"; int size1 = BYTES_PER_KILOBYTES; - Model mockModel1 = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), new byte[size1], modelId1); + Model mockModel1 = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + new byte[size1], + modelId1 + ); int size2 = BYTES_PER_KILOBYTES * 3; - Model mockModel2 = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), new byte[size2], modelId2); + Model mockModel2 = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + new byte[size2], + modelId2 + ); ModelDao modelDao = mock(ModelDao.class); when(modelDao.get(modelId1)).thenReturn(mockModel1); when(modelDao.get(modelId2)).thenReturn(mockModel2); Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize).build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); @@ -172,26 +231,36 @@ public void testRemove_normal() throws ExecutionException, InterruptedException modelCache.remove(modelId1); - assertEquals( (size2 / BYTES_PER_KILOBYTES) + 1, modelCache.getTotalWeightInKB()); + assertEquals((size2 / BYTES_PER_KILOBYTES) + 1, modelCache.getTotalWeightInKB()); modelCache.remove(modelId2); - assertEquals( 0, modelCache.getTotalWeightInKB()); + assertEquals(0, modelCache.getTotalWeightInKB()); } public void testRebuild_normal() throws ExecutionException, InterruptedException { String modelId = "test-model-id"; int dimension = 2; String cacheSize = "10%"; - Model mockModel = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), "hello".getBytes(), modelId); + Model mockModel = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + "hello".getBytes(), + modelId + ); ModelDao modelDao = mock(ModelDao.class); when(modelDao.get(modelId)).thenReturn(mockModel); Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize).build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); @@ -217,8 +286,19 @@ public void testRebuild_afterSettingUpdate() throws ExecutionException, Interrup int dimension = 2; int modelSize = 2 * BYTES_PER_KILOBYTES; - Model mockModel = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), new byte[modelSize], modelId); + Model mockModel = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + new byte[modelSize], + modelId + ); String cacheSize1 = "1kb"; String cacheSize2 = "4kb"; @@ -227,8 +307,7 @@ public void testRebuild_afterSettingUpdate() throws ExecutionException, Interrup when(modelDao.get(modelId)).thenReturn(mockModel); Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize1).build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); @@ -257,8 +336,7 @@ public void testRemove_modelNotInCache() { ModelDao modelDao = mock(ModelDao.class); Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize).build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); @@ -266,17 +344,28 @@ public void testRemove_modelNotInCache() { ModelCache.initialize(modelDao, clusterService); ModelCache modelCache = new ModelCache(); - assertEquals( 0, modelCache.getTotalWeightInKB()); + assertEquals(0, modelCache.getTotalWeightInKB()); modelCache.remove(modelId1); - assertEquals( 0, modelCache.getTotalWeightInKB()); + assertEquals(0, modelCache.getTotalWeightInKB()); } public void testContains() throws ExecutionException, InterruptedException { String modelId1 = "test-model-id-1"; int dimension = 2; int modelSize1 = 100; - Model mockModel1 = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), new byte[modelSize1], modelId1); + Model mockModel1 = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + new byte[modelSize1], + modelId1 + ); String modelId2 = "test-model-id-2"; @@ -286,8 +375,7 @@ public void testContains() throws ExecutionException, InterruptedException { when(modelDao.get(modelId1)).thenReturn(mockModel1); Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize).build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); @@ -305,13 +393,35 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { int dimension = 2; String modelId1 = "test-model-id-1"; int modelSize1 = BYTES_PER_KILOBYTES; - Model mockModel1 = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), new byte[modelSize1], modelId1); + Model mockModel1 = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + new byte[modelSize1], + modelId1 + ); String modelId2 = "test-model-id-2"; - int modelSize2 = BYTES_PER_KILOBYTES*2; - Model mockModel2 = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), new byte[modelSize2], modelId2); + int modelSize2 = BYTES_PER_KILOBYTES * 2; + Model mockModel2 = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + new byte[modelSize2], + modelId2 + ); String cacheSize = "10%"; @@ -320,8 +430,7 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { when(modelDao.get(modelId2)).thenReturn(mockModel2); Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize).build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); @@ -332,9 +441,9 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { modelCache.get(modelId1); modelCache.get(modelId2); - assertEquals( ((modelSize1 + modelSize2) / BYTES_PER_KILOBYTES) + 2, modelCache.getTotalWeightInKB()); + assertEquals(((modelSize1 + modelSize2) / BYTES_PER_KILOBYTES) + 2, modelCache.getTotalWeightInKB()); modelCache.removeAll(); - assertEquals( 0, modelCache.getTotalWeightInKB()); + assertEquals(0, modelCache.getTotalWeightInKB()); } public void testModelCacheEvictionDueToSize() throws ExecutionException, InterruptedException { @@ -342,17 +451,27 @@ public void testModelCacheEvictionDueToSize() throws ExecutionException, Interru int dimension = 2; int maxDocuments = 10; ModelDao modelDao = mock(ModelDao.class); - for(int i =0; i < maxDocuments; i++){ - String modelId = String.format(modelIdPattern,i); - Model mockModel = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), new byte[BYTES_PER_KILOBYTES*2], modelId); + for (int i = 0; i < maxDocuments; i++) { + String modelId = String.format(modelIdPattern, i); + Model mockModel = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + new byte[BYTES_PER_KILOBYTES * 2], + modelId + ); when(modelDao.get(modelId)).thenReturn(mockModel); } String cacheSize = "10kb"; Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize).build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); @@ -360,9 +479,9 @@ public void testModelCacheEvictionDueToSize() throws ExecutionException, Interru ModelCache.initialize(modelDao, clusterService); ModelCache modelCache = new ModelCache(); assertNull(modelCache.getEvictedDueToSizeAt()); - for(int i =0; i < maxDocuments; i++){ - modelCache.get(String.format(modelIdPattern,i)); + for (int i = 0; i < maxDocuments; i++) { + modelCache.get(String.format(modelIdPattern, i)); } - assertNotNull(modelCache.getEvictedDueToSizeAt()); + assertNotNull(modelCache.getEvictedDueToSizeAt()); } } diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index 63f70d3b4..590df80c8 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -110,15 +110,37 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti byte[] modelBlob = "hello".getBytes(); int dimension = 2; - Model model = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), modelBlob, modelId); + Model model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + modelBlob, + modelId + ); addDoc(model); assertEquals(model, modelDao.get(modelId)); assertNotNull(modelDao.getHealthStatus()); modelId = "failed-2"; - model = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.FAILED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), modelBlob, modelId); + model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.FAILED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + modelBlob, + modelId + ); addDoc(model); assertEquals(model, modelDao.get(modelId)); assertNotNull(modelDao.getHealthStatus()); @@ -129,11 +151,22 @@ public void testPut_withId() throws InterruptedException, IOException { ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); String modelId = "efbsdhcvbsd"; // User provided model id - byte [] modelBlob = "hello".getBytes(); + byte[] modelBlob = "hello".getBytes(); int dimension = 2; - Model model = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), modelBlob, modelId); + Model model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + modelBlob, + modelId + ); // Listener to confirm that everything was updated as expected final CountDownLatch inProgressLatch1 = new CountDownLatch(1); @@ -159,13 +192,14 @@ public void testPut_withId() throws InterruptedException, IOException { // User provided model id that already exists final CountDownLatch inProgressLatch2 = new CountDownLatch(1); ActionListener docCreationListenerDuplicateId = ActionListener.wrap( - response -> fail("Model already exists, but creation was successful"), - exception -> { - if (!(ExceptionsHelper.unwrapCause(exception) instanceof VersionConflictEngineException)) { - fail("Unable to put the model: " + exception); - } - inProgressLatch2.countDown(); - }); + response -> fail("Model already exists, but creation was successful"), + exception -> { + if (!(ExceptionsHelper.unwrapCause(exception) instanceof VersionConflictEngineException)) { + fail("Unable to put the model: " + exception); + } + inProgressLatch2.countDown(); + } + ); modelDao.put(model, docCreationListenerDuplicateId); assertTrue(inProgressLatch2.await(100, TimeUnit.SECONDS)); @@ -176,11 +210,22 @@ public void testPut_withoutModel() throws InterruptedException, IOException { ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); String modelId = "efbsdhcvbsd"; // User provided model id - byte [] modelBlob = null; + byte[] modelBlob = null; int dimension = 2; - Model model = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.TRAINING, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), modelBlob, modelId); + Model model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.TRAINING, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + modelBlob, + modelId + ); // Listener to confirm that everything was updated as expected final CountDownLatch inProgressLatch1 = new CountDownLatch(1); @@ -207,13 +252,14 @@ public void testPut_withoutModel() throws InterruptedException, IOException { // User provided model id that already exists final CountDownLatch inProgressLatch2 = new CountDownLatch(1); ActionListener docCreationListenerDuplicateId = ActionListener.wrap( - response -> fail("Model already exists, but creation was successful"), - exception -> { - if (!(ExceptionsHelper.unwrapCause(exception) instanceof VersionConflictEngineException)) { - fail("Unable to put the model: " + exception); - } - inProgressLatch2.countDown(); - }); + response -> fail("Model already exists, but creation was successful"), + exception -> { + if (!(ExceptionsHelper.unwrapCause(exception) instanceof VersionConflictEngineException)) { + fail("Unable to put the model: " + exception); + } + inProgressLatch2.countDown(); + } + ); modelDao.put(model, docCreationListenerDuplicateId); assertTrue(inProgressLatch2.await(100, TimeUnit.SECONDS)); @@ -221,19 +267,37 @@ public void testPut_withoutModel() throws InterruptedException, IOException { public void testPut_invalid_badState() { ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); - byte [] modelBlob = null; + byte[] modelBlob = null; int dimension = 2; createIndex(MODEL_INDEX_NAME); // Model is in invalid state - Model model = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.TRAINING, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), modelBlob, "any-id"); + Model model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.TRAINING, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + modelBlob, + "any-id" + ); model.getModelMetadata().setState(ModelState.CREATED); - expectThrows(IllegalArgumentException.class, () -> modelDao.put(model, ActionListener.wrap( - acknowledgedResponse -> fail("Should not get called."), - exception -> fail("Should not get to this call.")))); + expectThrows( + IllegalArgumentException.class, + () -> modelDao.put( + model, + ActionListener.wrap( + acknowledgedResponse -> fail("Should not get called."), + exception -> fail("Should not get to this call.") + ) + ) + ); } public void testUpdate() throws IOException, InterruptedException { @@ -241,11 +305,22 @@ public void testUpdate() throws IOException, InterruptedException { ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); String modelId = "efbsdhcvbsd"; // User provided model id - byte [] modelBlob = "hello".getBytes(); + byte[] modelBlob = "hello".getBytes(); int dimension = 2; - Model model = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.TRAINING, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), null, modelId); + Model model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.TRAINING, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + null, + modelId + ); // Listener to confirm that everything was updated as expected final CountDownLatch inProgressLatch1 = new CountDownLatch(1); @@ -270,8 +345,19 @@ public void testUpdate() throws IOException, InterruptedException { assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); // User provided model id that already exists - should be able to update - Model updatedModel = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), modelBlob, modelId); + Model updatedModel = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + modelBlob, + modelId + ); final CountDownLatch inProgressLatch2 = new CountDownLatch(1); ActionListener updateListener = ActionListener.wrap(response -> { @@ -308,14 +394,36 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti expectThrows(Exception.class, () -> modelDao.get(modelId)); // model id exists - Model model = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), modelBlob, modelId); + Model model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + modelBlob, + modelId + ); addDoc(model); assertEquals(model, modelDao.get(modelId)); // Get model during training - model = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.TRAINING, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), null, modelId); + model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.TRAINING, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + null, + modelId + ); addDoc(model); assertEquals(model, modelDao.get(modelId)); } @@ -334,13 +442,20 @@ public void testGetMetadata() throws IOException, InterruptedException { assertNull(modelDao.getMetadata(modelId)); // Model exists - byte [] modelBlob = "hello".getBytes(); + byte[] modelBlob = "hello".getBytes(); KNNEngine knnEngine = KNNEngine.FAISS; SpaceType spaceType = SpaceType.INNER_PRODUCT; int dimension = 2; - ModelMetadata modelMetadata = new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata = new ModelMetadata( + knnEngine, + spaceType, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); Model model = new Model(modelMetadata, modelBlob, modelId); @@ -396,8 +511,19 @@ public void testDelete() throws IOException, InterruptedException { }, exception -> fail("Unable to delete model: " + exception)); // model id exists - Model model = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), modelBlob, modelId); + Model model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + modelBlob, + modelId + ); ActionListener docCreationListener = ActionListener.wrap(response -> { assertEquals(modelId, response.getId()); @@ -413,15 +539,16 @@ public void testDelete() throws IOException, InterruptedException { public void addDoc(Model model) throws IOException, ExecutionException, InterruptedException { ModelMetadata modelMetadata = model.getModelMetadata(); - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .field(MODEL_ID, model.getModelID()) - .field(KNN_ENGINE, modelMetadata.getKnnEngine().getName()) - .field(METHOD_PARAMETER_SPACE_TYPE, modelMetadata.getSpaceType().getValue()) - .field(DIMENSION, modelMetadata.getDimension()) - .field(MODEL_STATE, modelMetadata.getState().getName()) - .field(MODEL_TIMESTAMP, modelMetadata.getTimestamp().toString()) - .field(MODEL_DESCRIPTION, modelMetadata.getDescription()) - .field(MODEL_ERROR, modelMetadata.getError()); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(MODEL_ID, model.getModelID()) + .field(KNN_ENGINE, modelMetadata.getKnnEngine().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, modelMetadata.getSpaceType().getValue()) + .field(DIMENSION, modelMetadata.getDimension()) + .field(MODEL_STATE, modelMetadata.getState().getName()) + .field(MODEL_TIMESTAMP, modelMetadata.getTimestamp().toString()) + .field(MODEL_DESCRIPTION, modelMetadata.getDescription()) + .field(MODEL_ERROR, modelMetadata.getError()); if (model.getModelBlob() != null) { builder.field(MODEL_BLOB_PARAMETER, Base64.getEncoder().encodeToString(model.getModelBlob())); @@ -429,11 +556,10 @@ public void addDoc(Model model) throws IOException, ExecutionException, Interrup builder.endObject(); - IndexRequest indexRequest = new IndexRequest() - .index(MODEL_INDEX_NAME) - .id(model.getModelID()) - .source(builder) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + IndexRequest indexRequest = new IndexRequest().index(MODEL_INDEX_NAME) + .id(model.getModelID()) + .source(builder) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); IndexResponse response = client().index(indexRequest).get(); assertTrue(response.status() == RestStatus.CREATED || response.status() == RestStatus.OK); diff --git a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java index 2d2b6757f..a2e5c6bbe 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java @@ -31,8 +31,15 @@ public void testStreams() throws IOException { SpaceType spaceType = SpaceType.L2; int dimension = 128; - ModelMetadata modelMetadata = new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata = new ModelMetadata( + knnEngine, + spaceType, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); BytesStreamOutput streamOutput = new BytesStreamOutput(); modelMetadata.writeTo(streamOutput); @@ -44,64 +51,112 @@ public void testStreams() throws IOException { public void testGetKnnEngine() { KNNEngine knnEngine = KNNEngine.DEFAULT; - ModelMetadata modelMetadata = new ModelMetadata(knnEngine, SpaceType.L2, 128, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata = new ModelMetadata( + knnEngine, + SpaceType.L2, + 128, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); assertEquals(knnEngine, modelMetadata.getKnnEngine()); } public void testGetSpaceType() { SpaceType spaceType = SpaceType.L2; - ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, spaceType, 128, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata = new ModelMetadata( + KNNEngine.DEFAULT, + spaceType, + 128, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); assertEquals(spaceType, modelMetadata.getSpaceType()); } public void testGetDimension() { int dimension = 128; - ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata = new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); assertEquals(dimension, modelMetadata.getDimension()); } public void testGetState() { ModelState modelState = ModelState.FAILED; - ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, modelState, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata = new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + 12, + modelState, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); assertEquals(modelState, modelMetadata.getState()); } public void testGetTimestamp() { String timeValue = ZonedDateTime.now(ZoneOffset.UTC).toString(); - ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, ModelState.CREATED, - timeValue, "", ""); + ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, ModelState.CREATED, timeValue, "", ""); assertEquals(timeValue, modelMetadata.getTimestamp()); } public void testDescription() { String description = "test description"; - ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), description, ""); + ModelMetadata modelMetadata = new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + 12, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + description, + "" + ); assertEquals(description, modelMetadata.getDescription()); } public void testGetError() { String error = "test error"; - ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", error); + ModelMetadata modelMetadata = new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + 12, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + error + ); assertEquals(error, modelMetadata.getError()); } public void testSetState() { ModelState modelState = ModelState.FAILED; - ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, modelState, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata = new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + 12, + modelState, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); assertEquals(modelState, modelMetadata.getState()); @@ -112,8 +167,15 @@ public void testSetState() { public void testSetError() { String error = ""; - ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, ModelState.TRAINING, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", error); + ModelMetadata modelMetadata = new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + 12, + ModelState.TRAINING, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + error + ); assertEquals(error, modelMetadata.getError()); @@ -131,16 +193,21 @@ public void testToString() { String description = "test-description"; String error = "test-error"; - String expected = knnEngine.getName() + "," + - spaceType.getValue() + "," + - dimension + "," + - modelState.getName() + "," + - timestamp + "," + - description + "," + - error; - - ModelMetadata modelMetadata = new ModelMetadata(knnEngine, spaceType, dimension, modelState, - timestamp, description, error); + String expected = knnEngine.getName() + + "," + + spaceType.getValue() + + "," + + dimension + + "," + + modelState.getName() + + "," + + timestamp + + "," + + description + + "," + + error; + + ModelMetadata modelMetadata = new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error); assertEquals(expected, modelMetadata.toString()); } @@ -148,28 +215,26 @@ public void testToString() { public void testEquals() { String time1 = ZonedDateTime.now(ZoneOffset.UTC).toString(); - String time2 = ZonedDateTime.of(2021, 9, 30,12, 20, 45, 1, - ZoneId.systemDefault()).toString(); - - ModelMetadata modelMetadata1 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time1, "", ""); - ModelMetadata modelMetadata2 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time1, "", ""); - - ModelMetadata modelMetadata3 = new ModelMetadata(KNNEngine.NMSLIB, SpaceType.L2, 128, ModelState.CREATED, - time1, "", ""); - ModelMetadata modelMetadata4 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L1, 128, ModelState.CREATED, - time1, "", ""); - ModelMetadata modelMetadata5 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 129, ModelState.CREATED, - time1, "", ""); - ModelMetadata modelMetadata6 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.TRAINING, - time1, "", ""); - ModelMetadata modelMetadata7 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time2, "", ""); - ModelMetadata modelMetadata8 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time1, "diff descript", ""); - ModelMetadata modelMetadata9 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time1, "", "diff error"); + String time2 = ZonedDateTime.of(2021, 9, 30, 12, 20, 45, 1, ZoneId.systemDefault()).toString(); + + ModelMetadata modelMetadata1 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", ""); + ModelMetadata modelMetadata2 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", ""); + + ModelMetadata modelMetadata3 = new ModelMetadata(KNNEngine.NMSLIB, SpaceType.L2, 128, ModelState.CREATED, time1, "", ""); + ModelMetadata modelMetadata4 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L1, 128, ModelState.CREATED, time1, "", ""); + ModelMetadata modelMetadata5 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 129, ModelState.CREATED, time1, "", ""); + ModelMetadata modelMetadata6 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.TRAINING, time1, "", ""); + ModelMetadata modelMetadata7 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time2, "", ""); + ModelMetadata modelMetadata8 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 128, + ModelState.CREATED, + time1, + "diff descript", + "" + ); + ModelMetadata modelMetadata9 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", "diff error"); assertEquals(modelMetadata1, modelMetadata1); assertEquals(modelMetadata1, modelMetadata2); @@ -187,28 +252,26 @@ public void testEquals() { public void testHashCode() { String time1 = ZonedDateTime.now(ZoneOffset.UTC).toString(); - String time2 = ZonedDateTime.of(2021, 9, 30,12, 20, 45, 1, - ZoneId.systemDefault()).toString(); - - ModelMetadata modelMetadata1 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time1, "", ""); - ModelMetadata modelMetadata2 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time1, "", ""); - - ModelMetadata modelMetadata3 = new ModelMetadata(KNNEngine.NMSLIB, SpaceType.L2, 128, ModelState.CREATED, - time1, "", ""); - ModelMetadata modelMetadata4 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L1, 128, ModelState.CREATED, - time1, "", ""); - ModelMetadata modelMetadata5 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 129, ModelState.CREATED, - time1, "", ""); - ModelMetadata modelMetadata6 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.TRAINING, - time1, "", ""); - ModelMetadata modelMetadata7 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time2, "", ""); - ModelMetadata modelMetadata8 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time1, "diff descript", ""); - ModelMetadata modelMetadata9 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time1, "", "diff error"); + String time2 = ZonedDateTime.of(2021, 9, 30, 12, 20, 45, 1, ZoneId.systemDefault()).toString(); + + ModelMetadata modelMetadata1 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", ""); + ModelMetadata modelMetadata2 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", ""); + + ModelMetadata modelMetadata3 = new ModelMetadata(KNNEngine.NMSLIB, SpaceType.L2, 128, ModelState.CREATED, time1, "", ""); + ModelMetadata modelMetadata4 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L1, 128, ModelState.CREATED, time1, "", ""); + ModelMetadata modelMetadata5 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 129, ModelState.CREATED, time1, "", ""); + ModelMetadata modelMetadata6 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.TRAINING, time1, "", ""); + ModelMetadata modelMetadata7 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time2, "", ""); + ModelMetadata modelMetadata8 = new ModelMetadata( + KNNEngine.FAISS, + SpaceType.L2, + 128, + ModelState.CREATED, + time1, + "diff descript", + "" + ); + ModelMetadata modelMetadata9 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, time1, "", "diff error"); assertEquals(modelMetadata1.hashCode(), modelMetadata1.hashCode()); assertEquals(modelMetadata1.hashCode(), modelMetadata2.hashCode()); @@ -232,17 +295,21 @@ public void testFromString() { String description = "test-description"; String error = "test-error"; - String stringRep1 = knnEngine.getName() + "," + - spaceType.getValue() + "," + - dimension + "," + - modelState.getName() + "," + - timestamp + "," + - description + "," + - error; - - - ModelMetadata expected = new ModelMetadata(knnEngine, spaceType, dimension, modelState, - timestamp, description, error); + String stringRep1 = knnEngine.getName() + + "," + + spaceType.getValue() + + "," + + dimension + + "," + + modelState.getName() + + "," + + timestamp + + "," + + description + + "," + + error; + + ModelMetadata expected = new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error); ModelMetadata fromString1 = ModelMetadata.fromString(stringRep1); assertEquals(expected, fromString1); @@ -259,9 +326,8 @@ public void testFromResponseMap() { String description = "test-description"; String error = "test-error"; - ModelMetadata expected = new ModelMetadata(knnEngine, spaceType, dimension, modelState, - timestamp, description, error); - Map metadataAsMap = new HashMap<>(); + ModelMetadata expected = new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error); + Map metadataAsMap = new HashMap<>(); metadataAsMap.put(KNNConstants.KNN_ENGINE, knnEngine.getName()); metadataAsMap.put(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()); metadataAsMap.put(KNNConstants.DIMENSION, dimension); diff --git a/src/test/java/org/opensearch/knn/indices/ModelTests.java b/src/test/java/org/opensearch/knn/indices/ModelTests.java index 7699bb509..7a01fd528 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelTests.java @@ -30,53 +30,148 @@ public void testNullConstructor() { } public void testInvalidConstructor() { - expectThrows(IllegalArgumentException.class, () -> new Model(new ModelMetadata(KNNEngine.DEFAULT, - SpaceType.DEFAULT, -1, ModelState.FAILED, ZonedDateTime.now(ZoneOffset.UTC).toString(), - "", ""), null, "test-model")); + expectThrows( + IllegalArgumentException.class, + () -> new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + -1, + ModelState.FAILED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + null, + "test-model" + ) + ); } public void testInvalidDimension() { - expectThrows(IllegalArgumentException.class, () -> new Model(new ModelMetadata(KNNEngine.DEFAULT, - SpaceType.DEFAULT, -1, ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), - "", ""), new byte[16], "test-model")); - expectThrows(IllegalArgumentException.class, () -> new Model(new ModelMetadata(KNNEngine.DEFAULT, - SpaceType.DEFAULT, 0, ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), - "", ""), new byte[16], "test-model")); - expectThrows(IllegalArgumentException.class, () -> new Model(new ModelMetadata(KNNEngine.DEFAULT, - SpaceType.DEFAULT, MAX_DIMENSION + 1, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), new byte[16], "test-model")); + expectThrows( + IllegalArgumentException.class, + () -> new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + -1, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + new byte[16], + "test-model" + ) + ); + expectThrows( + IllegalArgumentException.class, + () -> new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + 0, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + new byte[16], + "test-model" + ) + ); + expectThrows( + IllegalArgumentException.class, + () -> new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + MAX_DIMENSION + 1, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + new byte[16], + "test-model" + ) + ); } public void testGetModelMetadata() { KNNEngine knnEngine = KNNEngine.DEFAULT; - ModelMetadata modelMetadata = new ModelMetadata(knnEngine, SpaceType.DEFAULT, 2, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata = new ModelMetadata( + knnEngine, + SpaceType.DEFAULT, + 2, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); Model model = new Model(modelMetadata, new byte[16], "test-model"); assertEquals(modelMetadata, model.getModelMetadata()); } public void testGetModelBlob() { byte[] modelBlob = "hello".getBytes(); - Model model = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 2, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), modelBlob, "test-model"); + Model model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + 2, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + modelBlob, + "test-model" + ); assertArrayEquals(modelBlob, model.getModelBlob()); } public void testGetLength() { int size = 129; - Model model = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 2, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), new byte[size], "test-model"); + Model model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + 2, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + new byte[size], + "test-model" + ); assertEquals(size, model.getLength()); - model = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 2, ModelState.TRAINING, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), null, "test-model"); + model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + 2, + ModelState.TRAINING, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + null, + "test-model" + ); assertEquals(0, model.getLength()); } public void testSetModelBlob() { byte[] blob1 = "Hello blob 1".getBytes(); - Model model = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), blob1, "test-model"); + Model model = new Model( + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""), + blob1, + "test-model" + ); assertEquals(blob1, model.getModelBlob()); byte[] blob2 = "Hello blob 2".getBytes(); @@ -88,12 +183,21 @@ public void testEquals() { String time = ZonedDateTime.now(ZoneOffset.UTC).toString(); - Model model1 = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, - time, "", ""), new byte[16], "test-model-1"); - Model model2 = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, - time, "", ""), new byte[16], "test-model-1"); - Model model3 = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 2, ModelState.CREATED, - time, "", ""), new byte[16], "test-model-2"); + Model model1 = new Model( + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", ""), + new byte[16], + "test-model-1" + ); + Model model2 = new Model( + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", ""), + new byte[16], + "test-model-1" + ); + Model model3 = new Model( + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 2, ModelState.CREATED, time, "", ""), + new byte[16], + "test-model-2" + ); assertEquals(model1, model1); assertEquals(model1, model2); @@ -104,12 +208,21 @@ public void testHashCode() { String time = ZonedDateTime.now(ZoneOffset.UTC).toString(); - Model model1 = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, - time, "", ""), new byte[16], "test-model-1"); - Model model2 = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, - time, "", ""), new byte[16], "test-model-1"); - Model model3 = new Model(new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, - time, "", ""), new byte[16], "test-model-2"); + Model model1 = new Model( + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", ""), + new byte[16], + "test-model-1" + ); + Model model2 = new Model( + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", ""), + new byte[16], + "test-model-1" + ); + Model model3 = new Model( + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", ""), + new byte[16], + "test-model-2" + ); assertEquals(model1.hashCode(), model1.hashCode()); assertEquals(model1.hashCode(), model2.hashCode()); @@ -126,9 +239,8 @@ public void testModelFromSourceMap() { String description = "test-description"; String error = "test-error"; - ModelMetadata metadata = new ModelMetadata(knnEngine, spaceType, dimension, modelState, - timestamp, description, error); - Map modelAsMap = new HashMap<>(); + ModelMetadata metadata = new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error); + Map modelAsMap = new HashMap<>(); modelAsMap.put(KNNConstants.MODEL_ID, modelID); modelAsMap.put(KNNConstants.KNN_ENGINE, knnEngine.getName()); modelAsMap.put(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()); @@ -137,7 +249,7 @@ public void testModelFromSourceMap() { modelAsMap.put(KNNConstants.MODEL_TIMESTAMP, timestamp); modelAsMap.put(KNNConstants.MODEL_DESCRIPTION, description); modelAsMap.put(KNNConstants.MODEL_ERROR, error); - modelAsMap.put(KNNConstants.MODEL_BLOB_PARAMETER,"aGVsbG8="); + modelAsMap.put(KNNConstants.MODEL_BLOB_PARAMETER, "aGVsbG8="); byte[] blob1 = "hello".getBytes(); Model expected = new Model(metadata, blob1, modelID); diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index a848418b3..d85c53e08 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -55,91 +55,191 @@ public static void setUpClass() throws IOException { } public void testCreateIndex_invalid_engineNotSupported() { - expectThrows(IllegalArgumentException.class, () -> JNIService.createIndex(new int[]{}, new float[][]{}, - "test", ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), "invalid-engine")); + expectThrows( + IllegalArgumentException.class, + () -> JNIService.createIndex( + new int[] {}, + new float[][] {}, + "test", + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + "invalid-engine" + ) + ); } public void testCreateIndex_invalid_engineNull() { - expectThrows(Exception.class, () -> JNIService.createIndex(new int[]{}, new float[][]{}, - "test", ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), null)); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + new int[] {}, + new float[][] {}, + "test", + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + null + ) + ); } public void testCreateIndex_nmslib_invalid_noSpaceType() { - expectThrows(Exception.class, () -> JNIService.createIndex(testData.indexData.docs, testData.indexData.vectors, - "something", Collections.emptyMap(), KNNEngine.NMSLIB.getName())); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + testData.indexData.docs, + testData.indexData.vectors, + "something", + Collections.emptyMap(), + KNNEngine.NMSLIB.getName() + ) + ); } public void testCreateIndex_nmslib_invalid_vectorDocIDMismatch() throws IOException { - int[] docIds = new int[]{1, 2, 3}; - float[][] vectors1 = new float[][] {{1, 2}, {3, 4}}; + int[] docIds = new int[] { 1, 2, 3 }; + float[][] vectors1 = new float[][] { { 1, 2 }, { 3, 4 } }; Path tmpFile1 = createTempFile(); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors1, - tmpFile1.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB.getName())); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors1, + tmpFile1.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB.getName() + ) + ); - float[][] vectors2 = new float[][] {{1, 2}, {3, 4}, {4, 5}, {6, 7}, {8, 9}}; + float[][] vectors2 = new float[][] { { 1, 2 }, { 3, 4 }, { 4, 5 }, { 6, 7 }, { 8, 9 } }; Path tmpFile2 = createTempFile(); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors2, - tmpFile2.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB.getName())); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors2, + tmpFile2.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB.getName() + ) + ); } public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { - int[] docIds = new int[]{}; - float[][] vectors = new float[][]{}; + int[] docIds = new int[] {}; + float[][] vectors = new float[][] {}; Path tmpFile = createTempFile(); - expectThrows(Exception.class, () -> JNIService.createIndex(null, vectors, tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB.getName())); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + null, + vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB.getName() + ) + ); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, null, tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB.getName())); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + null, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB.getName() + ) + ); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors, null, - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB.getName())); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors, + null, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB.getName() + ) + ); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), - null, KNNEngine.NMSLIB.getName())); + expectThrows( + Exception.class, + () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), null, KNNEngine.NMSLIB.getName()) + ); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), null)); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + null + ) + ); } public void testCreateIndex_nmslib_invalid_badSpace() throws IOException { - int[] docIds = new int[]{1}; - float[][] vectors = new float[][]{{2, 3}}; + int[] docIds = new int[] { 1 }; + float[][] vectors = new float[][] { { 2, 3 } }; Path tmpFile = createTempFile(); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, "invalid"), KNNEngine.NMSLIB.getName())); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, "invalid"), + KNNEngine.NMSLIB.getName() + ) + ); } public void testCreateIndex_nmslib_invalid_inconsistentDimensions() throws IOException { - int[] docIds = new int[]{1, 2}; - float[][] vectors = new float[][]{{2, 3}, {2, 3, 4}}; + int[] docIds = new int[] { 1, 2 }; + float[][] vectors = new float[][] { { 2, 3 }, { 2, 3, 4 } }; Path tmpFile = createTempFile(); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB.getName())); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB.getName() + ) + ); } public void testCreateIndex_nmslib_invalid_badParameterType() throws IOException { - int[] docIds = new int[]{}; - float[][] vectors = new float[][]{}; + int[] docIds = new int[] {}; + float[][] vectors = new float[][] {}; - Map parametersMap = ImmutableMap.of(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "14", KNNConstants.METHOD_PARAMETER_M, "12"); + Map parametersMap = ImmutableMap.of( + KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, + "14", + KNNConstants.METHOD_PARAMETER_M, + "12" + ); Path tmpFile = createTempFile(); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors, + tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue(), KNNConstants.PARAMETERS, parametersMap), - KNNEngine.NMSLIB.getName())); + KNNEngine.NMSLIB.getName() + ) + ); } public void testCreateIndex_nmslib_valid() throws IOException { @@ -147,130 +247,232 @@ public void testCreateIndex_nmslib_valid() throws IOException { for (SpaceType spaceType : KNNEngine.NMSLIB.getMethod(KNNConstants.METHOD_HNSW).getSpaces()) { Path tmpFile = createTempFile(); - JNIService.createIndex(testData.indexData.docs, testData.indexData.vectors, - tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), - KNNEngine.NMSLIB.getName()); + JNIService.createIndex( + testData.indexData.docs, + testData.indexData.vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.NMSLIB.getName() + ); assertTrue(tmpFile.toFile().length() > 0); tmpFile = createTempFile(); - JNIService.createIndex(testData.indexData.docs, testData.indexData.vectors, - tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue(), - KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, 14, KNNConstants.METHOD_PARAMETER_M, 12), - KNNEngine.NMSLIB.getName()); + JNIService.createIndex( + testData.indexData.docs, + testData.indexData.vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of( + KNNConstants.SPACE_TYPE, + spaceType.getValue(), + KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, + 14, + KNNConstants.METHOD_PARAMETER_M, + 12 + ), + KNNEngine.NMSLIB.getName() + ); assertTrue(tmpFile.toFile().length() > 0); } } public void testCreateIndex_faiss_invalid_noSpaceType() { - int[] docIds = new int[]{}; - float[][] vectors = new float[][]{}; - - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors, "something", - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod), FAISS_NAME)); + int[] docIds = new int[] {}; + float[][] vectors = new float[][] {}; + + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors, + "something", + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod), + FAISS_NAME + ) + ); } public void testCreateIndex_faiss_invalid_vectorDocIDMismatch() throws IOException { - int[] docIds = new int[]{1, 2, 3}; - float[][] vectors1 = new float[][] {{1, 2}, {3, 4}}; + int[] docIds = new int[] { 1, 2, 3 }; + float[][] vectors1 = new float[][] { { 1, 2 }, { 3, 4 } }; Path tmpFile1 = createTempFile(); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors1, - tmpFile1.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, - faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - FAISS_NAME)); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors1, + tmpFile1.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + FAISS_NAME + ) + ); - float[][] vectors2 = new float[][] {{1, 2}, {3, 4}, {4, 5}, {6, 7}, {8, 9}}; + float[][] vectors2 = new float[][] { { 1, 2 }, { 3, 4 }, { 4, 5 }, { 6, 7 }, { 8, 9 } }; Path tmpFile2 = createTempFile(); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors2, - tmpFile2.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, - faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - FAISS_NAME)); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors2, + tmpFile2.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + FAISS_NAME + ) + ); } public void testCreateIndex_faiss_invalid_null() throws IOException { - int[] docIds = new int[]{}; - float[][] vectors = new float[][]{}; + int[] docIds = new int[] {}; + float[][] vectors = new float[][] {}; Path tmpFile = createTempFile(); - expectThrows(Exception.class, () -> JNIService.createIndex(null, vectors, tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, - KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), FAISS_NAME)); - - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, null, tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, - KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), FAISS_NAME)); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + null, + vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + FAISS_NAME + ) + ); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors, null, - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, - KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), FAISS_NAME)); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + null, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + FAISS_NAME + ) + ); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), - null, FAISS_NAME)); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors, + null, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + FAISS_NAME + ) + ); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, - KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), null)); + expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), null, FAISS_NAME)); + + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + null + ) + ); } public void testCreateIndex_faiss_invalid_invalidSpace() throws IOException { - int[] docIds = new int[]{1}; - float[][] vectors = new float[][]{{2, 3}}; + int[] docIds = new int[] { 1 }; + float[][] vectors = new float[][] { { 2, 3 } }; Path tmpFile = createTempFile(); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, - KNNConstants.SPACE_TYPE, "invalid"), FAISS_NAME)); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, "invalid"), + FAISS_NAME + ) + ); } public void testCreateIndex_faiss_invalid_inconsistentDimensions() throws IOException { - int[] docIds = new int[]{1, 2}; - float[][] vectors = new float[][]{{2, 3}, {2, 3, 4}}; + int[] docIds = new int[] { 1, 2 }; + float[][] vectors = new float[][] { { 2, 3 }, { 2, 3, 4 } }; Path tmpFile = createTempFile(); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, - KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), FAISS_NAME)); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + FAISS_NAME + ) + ); } public void testCreateIndex_faiss_invalid_noIndexDescription() throws IOException { - int[] docIds = new int[]{1, 2}; - float[][] vectors = new float[][]{{2, 3}, {2, 3, 4}}; + int[] docIds = new int[] { 1, 2 }; + float[][] vectors = new float[][] { { 2, 3 }, { 2, 3, 4 } }; Path tmpFile = createTempFile(); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), FAISS_NAME)); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + FAISS_NAME + ) + ); } public void testCreateIndex_faiss_invalid_invalidIndexDescription() throws IOException { - int[] docIds = new int[]{1, 2}; - float[][] vectors = new float[][]{{2, 3}, {2, 3, 4}}; + int[] docIds = new int[] { 1, 2 }; + float[][] vectors = new float[][] { { 2, 3 }, { 2, 3, 4 } }; Path tmpFile = createTempFile(); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, "invalid", - KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), FAISS_NAME)); + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, "invalid", KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + FAISS_NAME + ) + ); } public void testCreateIndex_faiss_invalid_invalidParameterType() throws IOException { - int[] docIds = new int[]{}; - float[][] vectors = new float[][]{}; + int[] docIds = new int[] {}; + float[][] vectors = new float[][] {}; Path tmpFile = createTempFile(); - expectThrows(Exception.class, () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, "IVF13", - KNNConstants.SPACE_TYPE, SpaceType.L2.getValue(), KNNConstants.PARAMETERS, - ImmutableMap.of(KNNConstants.METHOD_PARAMETER_NPROBES, "14")), FAISS_NAME)); - + expectThrows( + Exception.class, + () -> JNIService.createIndex( + docIds, + vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of( + INDEX_DESCRIPTION_PARAMETER, + "IVF13", + KNNConstants.SPACE_TYPE, + SpaceType.L2.getValue(), + KNNConstants.PARAMETERS, + ImmutableMap.of(KNNConstants.METHOD_PARAMETER_NPROBES, "14") + ), + FAISS_NAME + ) + ); } @@ -278,118 +480,138 @@ public void testCreateIndex_faiss_valid() throws IOException { List methods = ImmutableList.of(faissMethod); List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); - for (String method: methods) { + for (String method : methods) { for (SpaceType spaceType : spaces) { Path tmpFile1 = createTempFile(); - JNIService.createIndex(testData.indexData.docs, testData.indexData.vectors, tmpFile1.toAbsolutePath().toString(), - ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, method, - KNNConstants.SPACE_TYPE, spaceType.getValue() - ), - FAISS_NAME); + JNIService.createIndex( + testData.indexData.docs, + testData.indexData.vectors, + tmpFile1.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), + FAISS_NAME + ); assertTrue(tmpFile1.toFile().length() > 0); } } } public void testLoadIndex_invalidEngine() { - expectThrows(IllegalArgumentException.class, () -> JNIService.loadIndex( - "test", Collections.emptyMap(), "invalid-engine")); + expectThrows(IllegalArgumentException.class, () -> JNIService.loadIndex("test", Collections.emptyMap(), "invalid-engine")); } public void testLoadIndex_nmslib_invalid_badSpaceType() { - expectThrows(Exception.class, () -> JNIService.loadIndex( - "test", ImmutableMap.of(KNNConstants.SPACE_TYPE, "invalid"), KNNEngine.NMSLIB.getName())); + expectThrows( + Exception.class, + () -> JNIService.loadIndex("test", ImmutableMap.of(KNNConstants.SPACE_TYPE, "invalid"), KNNEngine.NMSLIB.getName()) + ); } public void testLoadIndex_nmslib_invalid_noSpaceType() { - expectThrows(Exception.class, () -> JNIService.loadIndex( - "test", Collections.emptyMap(), KNNEngine.NMSLIB.getName())); + expectThrows(Exception.class, () -> JNIService.loadIndex("test", Collections.emptyMap(), KNNEngine.NMSLIB.getName())); } public void testLoadIndex_nmslib_invalid_fileDoesNotExist() { - expectThrows(Exception.class, () -> JNIService.loadIndex( - "invalid", ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB.getName())); + expectThrows( + Exception.class, + () -> JNIService.loadIndex( + "invalid", + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB.getName() + ) + ); } public void testLoadIndex_nmslib_invalid_badFile() throws IOException { Path tmpFile = createTempFile(); - expectThrows(Exception.class, () -> JNIService.loadIndex( - tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB.getName())); + expectThrows( + Exception.class, + () -> JNIService.loadIndex( + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB.getName() + ) + ); } public void testLoadIndex_nmslib_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex(testData.indexData.docs, testData.indexData.vectors, - tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB.getName()); + JNIService.createIndex( + testData.indexData.docs, + testData.indexData.vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB.getName() + ); assertTrue(tmpFile.toFile().length() > 0); - long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB.getName()); + long pointer = JNIService.loadIndex( + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB.getName() + ); assertNotEquals(0, pointer); } public void testLoadIndex_faiss_invalid_fileDoesNotExist() { - expectThrows(Exception.class, () -> JNIService.loadIndex( - "invalid", Collections.emptyMap(), FAISS_NAME)); + expectThrows(Exception.class, () -> JNIService.loadIndex("invalid", Collections.emptyMap(), FAISS_NAME)); } public void testLoadIndex_faiss_invalid_badFile() throws IOException { Path tmpFile = createTempFile(); - expectThrows(Exception.class, () -> JNIService.loadIndex( - tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), FAISS_NAME)); + expectThrows(Exception.class, () -> JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), FAISS_NAME)); } public void testLoadIndex_faiss_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex(testData.indexData.docs, testData.indexData.vectors, tmpFile.toAbsolutePath().toString(), - ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, faissMethod, - KNNConstants.SPACE_TYPE, SpaceType.L2.getValue() - ), - FAISS_NAME); + JNIService.createIndex( + testData.indexData.docs, + testData.indexData.vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + FAISS_NAME + ); assertTrue(tmpFile.toFile().length() > 0); - long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), - FAISS_NAME); + long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), FAISS_NAME); assertNotEquals(0, pointer); } public void testQueryIndex_invalidEngine() { - expectThrows(IllegalArgumentException.class, () -> JNIService.queryIndex(0L, - new float[]{}, 0, "invalid-engine")); + expectThrows(IllegalArgumentException.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, "invalid-engine")); } public void testQueryIndex_nmslib_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, - new float[]{}, 0, KNNEngine.NMSLIB.getName())); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.NMSLIB.getName())); } public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex(testData.indexData.docs, testData.indexData.vectors, - tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB.getName()); + JNIService.createIndex( + testData.indexData.docs, + testData.indexData.vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB.getName() + ); assertTrue(tmpFile.toFile().length() > 0); - long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB.getName()); + long pointer = JNIService.loadIndex( + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB.getName() + ); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, - KNNEngine.NMSLIB.getName())); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.NMSLIB.getName())); } public void testQueryIndex_nmslib_valid() throws IOException { @@ -398,13 +620,20 @@ public void testQueryIndex_nmslib_valid() throws IOException { for (SpaceType spaceType : KNNEngine.NMSLIB.getMethod(KNNConstants.METHOD_HNSW).getSpaces()) { Path tmpFile = createTempFile(); - JNIService.createIndex(testData.indexData.docs, testData.indexData.vectors, - tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), - KNNEngine.NMSLIB.getName()); + JNIService.createIndex( + testData.indexData.docs, + testData.indexData.vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.NMSLIB.getName() + ); assertTrue(tmpFile.toFile().length() > 0); - long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), KNNEngine.NMSLIB.getName()); + long pointer = JNIService.loadIndex( + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.NMSLIB.getName() + ); assertNotEquals(0, pointer); for (float[] query : testData.queries) { @@ -416,21 +645,23 @@ public void testQueryIndex_nmslib_valid() throws IOException { public void testQueryIndex_faiss_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[]{}, 0, FAISS_NAME)); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, FAISS_NAME)); } public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex(testData.indexData.docs, testData.indexData.vectors, - tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, - faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - FAISS_NAME); + JNIService.createIndex( + testData.indexData.docs, + testData.indexData.vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + FAISS_NAME + ); assertTrue(tmpFile.toFile().length() > 0); - long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), - Collections.emptyMap(), FAISS_NAME); + long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), FAISS_NAME); assertNotEquals(0, pointer); expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, FAISS_NAME)); @@ -442,20 +673,23 @@ public void testQueryIndex_faiss_valid() throws IOException { List methods = ImmutableList.of(faissMethod); List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); - for (String method: methods) { + for (String method : methods) { for (SpaceType spaceType : spaces) { Path tmpFile = createTempFile(); - JNIService.createIndex(testData.indexData.docs, testData.indexData.vectors, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, method, - KNNConstants.SPACE_TYPE, spaceType.getValue() - ), - FAISS_NAME); + JNIService.createIndex( + testData.indexData.docs, + testData.indexData.vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), + FAISS_NAME + ); assertTrue(tmpFile.toFile().length() > 0); - long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), FAISS_NAME); + long pointer = JNIService.loadIndex( + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), + FAISS_NAME + ); assertNotEquals(0, pointer); for (float[] query : testData.queries) { @@ -474,13 +708,20 @@ public void testFree_nmslib_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex(testData.indexData.docs, testData.indexData.vectors, - tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB.getName()); + JNIService.createIndex( + testData.indexData.docs, + testData.indexData.vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB.getName() + ); assertTrue(tmpFile.toFile().length() > 0); - long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB.getName()); + long pointer = JNIService.loadIndex( + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB.getName() + ); assertNotEquals(0, pointer); JNIService.free(pointer, KNNEngine.NMSLIB.getName()); @@ -490,16 +731,16 @@ public void testFree_faiss_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex(testData.indexData.docs, testData.indexData.vectors, tmpFile.toAbsolutePath().toString(), - ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, faissMethod, - KNNConstants.SPACE_TYPE, SpaceType.L2.getValue() - ), - FAISS_NAME); + JNIService.createIndex( + testData.indexData.docs, + testData.indexData.vectors, + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + FAISS_NAME + ); assertTrue(tmpFile.toFile().length() > 0); - long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), - FAISS_NAME); + long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), FAISS_NAME); assertNotEquals(0, pointer); JNIService.free(pointer, FAISS_NAME); @@ -510,7 +751,7 @@ public void testTransferVectors() { assertNotEquals(0, trainPointer1); long trainPointer2; - for (int i =0; i < 10; i++) { + for (int i = 0; i < 10; i++) { trainPointer2 = JNIService.transferVectors(trainPointer1, testData.indexData.vectors); assertEquals(trainPointer1, trainPointer2); } @@ -524,17 +765,18 @@ public void testTrain() { assertNotEquals(0, trainPointer1); long trainPointer2; - for (int i =0; i < 10; i++) { + for (int i = 0; i < 10; i++) { trainPointer2 = JNIService.transferVectors(trainPointer1, testData.indexData.vectors); assertEquals(trainPointer1, trainPointer2); } Map parameters = ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, "IVF16,PQ4", - KNNConstants.SPACE_TYPE, SpaceType.L2.getValue() + INDEX_DESCRIPTION_PARAMETER, + "IVF16,PQ4", + KNNConstants.SPACE_TYPE, + SpaceType.L2.getValue() ); - byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer1, FAISS_NAME); assertNotEquals(0, faissIndex.length); @@ -547,28 +789,34 @@ public void testCreateIndexFromTemplate() throws IOException { assertNotEquals(0, trainPointer1); long trainPointer2; - for (int i =0; i < 10; i++) { + for (int i = 0; i < 10; i++) { trainPointer2 = JNIService.transferVectors(trainPointer1, testData.indexData.vectors); assertEquals(trainPointer1, trainPointer2); } SpaceType spaceType = SpaceType.L2; - KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, spaceType, - new MethodComponentContext(METHOD_IVF, - ImmutableMap.of( - METHOD_PARAMETER_NLIST, 16, - METHOD_ENCODER_PARAMETER, new MethodComponentContext(ENCODER_PQ, - ImmutableMap.of( - ENCODER_PARAMETER_PQ_M, 16, - ENCODER_PARAMETER_PQ_CODE_SIZE, 8 - ))))); + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.FAISS, + spaceType, + new MethodComponentContext( + METHOD_IVF, + ImmutableMap.of( + METHOD_PARAMETER_NLIST, + 16, + METHOD_ENCODER_PARAMETER, + new MethodComponentContext(ENCODER_PQ, ImmutableMap.of(ENCODER_PARAMETER_PQ_M, 16, ENCODER_PARAMETER_PQ_CODE_SIZE, 8)) + ) + ) + ); String description = knnMethodContext.getEngine().getMethodAsMap(knnMethodContext).get(INDEX_DESCRIPTION_PARAMETER).toString(); assertEquals("IVF16,PQ16x8", description); Map parameters = ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, description, - KNNConstants.SPACE_TYPE, spaceType.getValue() + INDEX_DESCRIPTION_PARAMETER, + description, + KNNConstants.SPACE_TYPE, + spaceType.getValue() ); byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer1, FAISS_NAME); @@ -577,12 +825,17 @@ METHOD_ENCODER_PARAMETER, new MethodComponentContext(ENCODER_PQ, JNIService.freeVectors(trainPointer1); Path tmpFile1 = createTempFile(); - JNIService.createIndexFromTemplate(testData.indexData.docs, testData.indexData.vectors, - tmpFile1.toAbsolutePath().toString(), faissIndex, ImmutableMap.of(INDEX_THREAD_QTY, 1), FAISS_NAME); + JNIService.createIndexFromTemplate( + testData.indexData.docs, + testData.indexData.vectors, + tmpFile1.toAbsolutePath().toString(), + faissIndex, + ImmutableMap.of(INDEX_THREAD_QTY, 1), + FAISS_NAME + ); assertTrue(tmpFile1.toFile().length() > 0); - long pointer = JNIService.loadIndex(tmpFile1.toAbsolutePath().toString(), Collections.emptyMap(), - FAISS_NAME); + long pointer = JNIService.loadIndex(tmpFile1.toAbsolutePath().toString(), Collections.emptyMap(), FAISS_NAME); assertNotEquals(0, pointer); } } diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java index 30224ce26..edd8d2106 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java @@ -15,7 +15,6 @@ import org.opensearch.action.DocWriteResponse; import org.opensearch.client.Request; import org.opensearch.client.Response; -import org.opensearch.client.ResponseException; import org.opensearch.common.xcontent.XContentType; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.index.SpaceType; @@ -29,16 +28,9 @@ import java.io.IOException; import java.util.Map; -import static org.opensearch.knn.common.KNNConstants.DIMENSION; -import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.MODELS; -import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; -import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; -import static org.opensearch.knn.common.KNNConstants.MODEL_STATE; -import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP; /** * Integration tests to check the correctness of {@link org.opensearch.knn.plugin.rest.RestDeleteModelHandler} @@ -47,8 +39,7 @@ public class RestDeleteModelHandlerIT extends KNNRestTestCase { private ModelMetadata getModelMetadata() { - return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, - "2021-03-27", "test model", ""); + return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", ""); } public void testDeleteModelExists() throws IOException { @@ -64,8 +55,7 @@ public void testDeleteModelExists() throws IOException { Request request = new Request("DELETE", restURI); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); assertEquals(getDocCount(MODEL_INDEX_NAME), 0); } @@ -76,15 +66,11 @@ public void testDeleteModelFailsInvalid() throws IOException { Request request = new Request("DELETE", restURI); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); String responseBody = EntityUtils.toString(response.getEntity()); assertNotNull(responseBody); - Map responseMap = createParser( - XContentType.JSON.xContent(), - responseBody - ).map(); + Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); assertEquals("invalid-model-id", responseMap.get(MODEL_ID)); assertEquals(DocWriteResponse.Result.NOT_FOUND.getLowercase(), responseMap.get(DeleteModelResponse.RESULT)); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestGetModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestGetModelHandlerIT.java index daf9b86ef..b6853e8bb 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestGetModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestGetModelHandlerIT.java @@ -47,8 +47,7 @@ public class RestGetModelHandlerIT extends KNNRestTestCase { private ModelMetadata getModelMetadata() { - return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, - "2021-03-27", "test model", ""); + return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", ""); } public void testGetModelExists() throws IOException { @@ -68,10 +67,7 @@ public void testGetModelExists() throws IOException { String responseBody = EntityUtils.toString(response.getEntity()); assertNotNull(responseBody); - Map responseMap = createParser( - XContentType.JSON.xContent(), - responseBody - ).map(); + Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); assertEquals(testModelID, responseMap.get(MODEL_ID)); assertEquals(testModelMetadata.getDescription(), responseMap.get(MODEL_DESCRIPTION)); @@ -83,7 +79,6 @@ public void testGetModelExists() throws IOException { assertEquals(testModelMetadata.getTimestamp(), responseMap.get(MODEL_TIMESTAMP)); } - public void testGetModelExistsWithFilter() throws IOException { createModelSystemIndex(); String testModelID = "test-model-id"; @@ -104,10 +99,7 @@ public void testGetModelExistsWithFilter() throws IOException { String responseBody = EntityUtils.toString(response.getEntity()); assertNotNull(responseBody); - Map responseMap = createParser( - XContentType.JSON.xContent(), - responseBody - ).map(); + Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); assertTrue(responseMap.size() == filterdPath.size()); assertEquals(testModelID, responseMap.get(MODEL_ID)); @@ -125,8 +117,7 @@ public void testGetModelFailsInvalid() throws IOException { String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "invalid-model-id"); Request request = new Request("GET", restURI); - ResponseException ex = expectThrows(ResponseException.class, () -> - client().performRequest(request)); + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); assertTrue(ex.getMessage().contains("\"invalid-model-id\"")); } diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index 31fb8cc9b..1a879aa0c 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -82,11 +82,11 @@ public void testStatsValueCheck() throws IOException { createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); // Index test document - Float[] vector = {6.0f, 6.0f}; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); // First search: Ensure that misses=1 - float[] qvector = {6.0f, 6.0f}; + float[] qvector = { 6.0f, 6.0f }; searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, qvector, 1), 1); response = getKnnStats(Collections.emptyList(), Collections.emptyList()); @@ -137,8 +137,7 @@ public void testValidMetricsStats() throws IOException { * Test checks that handler correctly returns failure on an invalid metric */ public void testInvalidMetricsStats() { - expectThrows(ResponseException.class, () -> getKnnStats(Collections.emptyList(), - Collections.singletonList("invalid_metric"))); + expectThrows(ResponseException.class, () -> getKnnStats(Collections.emptyList(), Collections.singletonList("invalid_metric"))); } /** @@ -172,10 +171,13 @@ public void testScriptStats_singleShard() throws Exception { clearScriptCache(); // Get initial stats - Response response = getKnnStats(Collections.emptyList(), Arrays.asList( - StatNames.SCRIPT_COMPILATIONS.getName(), - StatNames.SCRIPT_QUERY_REQUESTS.getName(), - StatNames.SCRIPT_QUERY_ERRORS.getName()) + Response response = getKnnStats( + Collections.emptyList(), + Arrays.asList( + StatNames.SCRIPT_COMPILATIONS.getName(), + StatNames.SCRIPT_QUERY_REQUESTS.getName(), + StatNames.SCRIPT_QUERY_ERRORS.getName() + ) ); List> nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity())); int initialScriptCompilations = (int) (nodeStats.get(0).get(StatNames.SCRIPT_COMPILATIONS.getName())); @@ -184,29 +186,27 @@ public void testScriptStats_singleShard() throws Exception { // Create an index with a single vector createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] vector = {6.0f, 6.0f}; + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); // Check l2 query and script compilation stats QueryBuilder qb = new MatchAllQueryBuilder(); Map params = new HashMap<>(); - float[] queryVector = {1.0f, 1.0f}; + float[] queryVector = { 1.0f, 1.0f }; params.put("field", FIELD_NAME); params.put("query_value", queryVector); params.put("space_type", SpaceType.L2.getValue()); Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - response = getKnnStats(Collections.emptyList(), Arrays.asList( - StatNames.SCRIPT_COMPILATIONS.getName(), - StatNames.SCRIPT_QUERY_REQUESTS.getName()) + response = getKnnStats( + Collections.emptyList(), + Arrays.asList(StatNames.SCRIPT_COMPILATIONS.getName(), StatNames.SCRIPT_QUERY_REQUESTS.getName()) ); nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity())); assertEquals((int) (nodeStats.get(0).get(StatNames.SCRIPT_COMPILATIONS.getName())), initialScriptCompilations + 1); - assertEquals(initialScriptQueryRequests + 1, - (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_REQUESTS.getName()))); + assertEquals(initialScriptQueryRequests + 1, (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_REQUESTS.getName()))); // Check query error stats params = new HashMap<>(); @@ -217,12 +217,9 @@ public void testScriptStats_singleShard() throws Exception { Request finalRequest = request; expectThrows(ResponseException.class, () -> client().performRequest(finalRequest)); - response = getKnnStats(Collections.emptyList(), Collections.singletonList( - StatNames.SCRIPT_QUERY_ERRORS.getName()) - ); + response = getKnnStats(Collections.emptyList(), Collections.singletonList(StatNames.SCRIPT_QUERY_ERRORS.getName())); nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity())); - assertEquals(initialScriptQueryErrors + 1, - (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_ERRORS.getName()))); + assertEquals(initialScriptQueryErrors + 1, (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_ERRORS.getName()))); } /** @@ -232,10 +229,13 @@ public void testScriptStats_multipleShards() throws Exception { clearScriptCache(); // Get initial stats - Response response = getKnnStats(Collections.emptyList(), Arrays.asList( - StatNames.SCRIPT_COMPILATIONS.getName(), - StatNames.SCRIPT_QUERY_REQUESTS.getName(), - StatNames.SCRIPT_QUERY_ERRORS.getName()) + Response response = getKnnStats( + Collections.emptyList(), + Arrays.asList( + StatNames.SCRIPT_COMPILATIONS.getName(), + StatNames.SCRIPT_QUERY_REQUESTS.getName(), + StatNames.SCRIPT_QUERY_ERRORS.getName() + ) ); List> nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity())); int initialScriptCompilations = (int) (nodeStats.get(0).get(StatNames.SCRIPT_COMPILATIONS.getName())); @@ -243,14 +243,13 @@ public void testScriptStats_multipleShards() throws Exception { int initialScriptQueryErrors = (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_ERRORS.getName())); // Create an index with a single vector - createKnnIndex(INDEX_NAME, Settings.builder() - .put("number_of_shards", 2) - .put("number_of_replicas", 0) - .put("index.knn", true) - .build(), - createKnnIndexMapping(FIELD_NAME, 2)); - - Float[] vector = {6.0f, 6.0f}; + createKnnIndex( + INDEX_NAME, + Settings.builder().put("number_of_shards", 2).put("number_of_replicas", 0).put("index.knn", true).build(), + createKnnIndexMapping(FIELD_NAME, 2) + ); + + Float[] vector = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector); addKnnDoc(INDEX_NAME, "2", FIELD_NAME, vector); addKnnDoc(INDEX_NAME, "3", FIELD_NAME, vector); @@ -259,25 +258,23 @@ public void testScriptStats_multipleShards() throws Exception { // Check l2 query and script compilation stats QueryBuilder qb = new MatchAllQueryBuilder(); Map params = new HashMap<>(); - float[] queryVector = {1.0f, 1.0f}; + float[] queryVector = { 1.0f, 1.0f }; params.put("field", FIELD_NAME); params.put("query_value", queryVector); params.put("space_type", SpaceType.L2.getValue()); Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - response = getKnnStats(Collections.emptyList(), Arrays.asList( - StatNames.SCRIPT_COMPILATIONS.getName(), - StatNames.SCRIPT_QUERY_REQUESTS.getName()) + response = getKnnStats( + Collections.emptyList(), + Arrays.asList(StatNames.SCRIPT_COMPILATIONS.getName(), StatNames.SCRIPT_QUERY_REQUESTS.getName()) ); nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity())); assertEquals((int) (nodeStats.get(0).get(StatNames.SCRIPT_COMPILATIONS.getName())), initialScriptCompilations + 1); - //TODO fix the test case. For some reason request count is treated as 4. + // TODO fix the test case. For some reason request count is treated as 4. // https://github.com/opendistro-for-elasticsearch/k-NN/issues/272 - assertEquals(initialScriptQueryRequests + 4, - (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_REQUESTS.getName()))); + assertEquals(initialScriptQueryRequests + 4, (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_REQUESTS.getName()))); // Check query error stats params = new HashMap<>(); @@ -288,12 +285,9 @@ public void testScriptStats_multipleShards() throws Exception { Request finalRequest = request; expectThrows(ResponseException.class, () -> client().performRequest(finalRequest)); - response = getKnnStats(Collections.emptyList(), Collections.singletonList( - StatNames.SCRIPT_QUERY_ERRORS.getName()) - ); + response = getKnnStats(Collections.emptyList(), Collections.singletonList(StatNames.SCRIPT_QUERY_ERRORS.getName())); nodeStats = parseNodeStatsResponse(EntityUtils.toString(response.getEntity())); - assertEquals(initialScriptQueryErrors + 2, - (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_ERRORS.getName()))); + assertEquals(initialScriptQueryErrors + 2, (int) (nodeStats.get(0).get(StatNames.SCRIPT_QUERY_ERRORS.getName()))); } public void testModelIndexHealthMetricsStats() throws IOException { @@ -318,7 +312,7 @@ public void testModelIndexHealthMetricsStats() throws IOException { assertNotNull(statsMap.get(modelIndexStatusName)); // Check value is indeed part of ClusterHealthStatus - assertNotNull(ClusterHealthStatus.fromString((String)statsMap.get(modelIndexStatusName))); + assertNotNull(ClusterHealthStatus.fromString((String) statsMap.get(modelIndexStatusName))); } @@ -339,14 +333,11 @@ public void testModelIndexingDegradedMetricsStats() throws IOException { assertEquals(false, nodeStats.get(statName)); } - // Useful settings when debugging to prevent timeouts @Override protected Settings restClientSettings() { if (isDebuggingTest || isDebuggingRemoteCluster) { - return Settings.builder() - .put(CLIENT_SOCKET_TIMEOUT, TimeValue.timeValueMinutes(10)) - .build(); + return Settings.builder().put(CLIENT_SOCKET_TIMEOUT, TimeValue.timeValueMinutes(10)).build(); } else { return super.restClientSettings(); } diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNWarmupHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNWarmupHandlerIT.java index fbef632e8..a2078c291 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNWarmupHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNWarmupHandlerIT.java @@ -48,7 +48,7 @@ public void testEmptyIndex() throws IOException { public void testSingleIndex() throws IOException { int graphCountBefore = getTotalGraphsInCache(); createKnnIndex(testIndexName, getKNNDefaultIndexSettings(), createKnnIndexMapping(testFieldName, dimensions)); - addKnnDoc(testIndexName, "1", testFieldName, new Float[]{6.0f, 6.0f}); + addKnnDoc(testIndexName, "1", testFieldName, new Float[] { 6.0f, 6.0f }); knnWarmup(Collections.singletonList(testIndexName)); @@ -59,10 +59,10 @@ public void testMultipleIndices() throws IOException { int graphCountBefore = getTotalGraphsInCache(); createKnnIndex(testIndexName + "1", getKNNDefaultIndexSettings(), createKnnIndexMapping(testFieldName, dimensions)); - addKnnDoc(testIndexName + "1", "1", testFieldName, new Float[]{6.0f, 6.0f}); + addKnnDoc(testIndexName + "1", "1", testFieldName, new Float[] { 6.0f, 6.0f }); createKnnIndex(testIndexName + "2", getKNNDefaultIndexSettings(), createKnnIndexMapping(testFieldName, dimensions)); - addKnnDoc(testIndexName + "2", "1", testFieldName, new Float[]{6.0f, 6.0f}); + addKnnDoc(testIndexName + "2", "1", testFieldName, new Float[] { 6.0f, 6.0f }); knnWarmup(Arrays.asList(testIndexName + "1", testIndexName + "2")); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java index 025116cd8..8137f50a1 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java @@ -37,7 +37,6 @@ import static org.opensearch.knn.common.KNNConstants.MODELS; - /** * Integration tests to check the correctness of {@link org.opensearch.knn.plugin.rest.RestSearchModelHandler} */ @@ -45,30 +44,24 @@ public class RestSearchModelHandlerIT extends KNNRestTestCase { private ModelMetadata getModelMetadata() { - return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, - "2021-03-27", "test model", ""); + return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", ""); } public void testNotSupportedParams() throws IOException { createModelSystemIndex(); String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); - Map invalidParams = new HashMap<>(); + Map invalidParams = new HashMap<>(); invalidParams.put("index", "index-name"); Request request = new Request("GET", restURI); request.addParameters(invalidParams); expectThrows(ResponseException.class, () -> client().performRequest(request)); } - public void testNoModelExists() throws IOException { createModelSystemIndex(); String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); Request request = new Request("GET", restURI); - request.setJsonEntity("{\n" + - " \"query\": {\n" + - " \"match_all\": {}\n" + - " }\n" + - "}"); + request.setJsonEntity("{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}"); Response response = client().performRequest(request); assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -90,19 +83,15 @@ public void testSearchModelExists() throws IOException { List testModelID = Arrays.asList("test-modelid1", "test-modelid2"); byte[] testModelBlob = "hello".getBytes(); ModelMetadata testModelMetadata = getModelMetadata(); - for(String modelID: testModelID){ + for (String modelID : testModelID) { addModelToSystemIndex(modelID, testModelMetadata, testModelBlob); } String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); - for(String method: Arrays.asList("GET", "POST")){ + for (String method : Arrays.asList("GET", "POST")) { Request request = new Request(method, restURI); - request.setJsonEntity("{\n" + - " \"query\": {\n" + - " \"match_all\": {}\n" + - " }\n" + - "}"); + request.setJsonEntity("{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}"); Response response = client().performRequest(request); assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -113,13 +102,13 @@ public void testSearchModelExists() throws IOException { SearchResponse searchResponse = SearchResponse.fromXContent(parser); assertNotNull(searchResponse); - //returns only model from ModelIndex + // returns only model from ModelIndex assertEquals(searchResponse.getHits().getHits().length, testModelID.size()); - for(SearchHit hit: searchResponse.getHits().getHits()){ + for (SearchHit hit : searchResponse.getHits().getHits()) { assertTrue(testModelID.contains(hit.getId())); Model model = Model.getModelFromSourceMap(hit.getSourceAsMap()); - assertEquals(getModelMetadata(),model.getModelMetadata()); + assertEquals(getModelMetadata(), model.getModelMetadata()); assertArrayEquals(testModelBlob, model.getModelBlob()); } } @@ -132,20 +121,17 @@ public void testSearchModelWithoutSource() throws IOException { List testModelID = Arrays.asList("test-modelid1", "test-modelid2"); byte[] testModelBlob = "hello".getBytes(); ModelMetadata testModelMetadata = getModelMetadata(); - for(String modelID: testModelID){ + for (String modelID : testModelID) { addModelToSystemIndex(modelID, testModelMetadata, testModelBlob); } String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); - for(String method: Arrays.asList("GET", "POST")){ + for (String method : Arrays.asList("GET", "POST")) { Request request = new Request(method, restURI); - request.setJsonEntity("{\n" + - " \"_source\" : false,\n" + - " \"query\": {\n" + - " \"match_all\": {}\n" + - " }\n" + - "}"); + request.setJsonEntity( + "{\n" + " \"_source\" : false,\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}" + ); Response response = client().performRequest(request); assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -156,10 +142,10 @@ public void testSearchModelWithoutSource() throws IOException { SearchResponse searchResponse = SearchResponse.fromXContent(parser); assertNotNull(searchResponse); - //returns only model from ModelIndex + // returns only model from ModelIndex assertEquals(searchResponse.getHits().getHits().length, testModelID.size()); - for(SearchHit hit: searchResponse.getHits().getHits()){ + for (SearchHit hit : searchResponse.getHits().getHits()) { assertTrue(testModelID.contains(hit.getId())); assertNull(hit.getSourceAsMap()); } @@ -173,22 +159,24 @@ public void testSearchModelWithSourceFilteringIncludes() throws IOException { List testModelID = Arrays.asList("test-modelid1", "test-modelid2"); byte[] testModelBlob = "hello".getBytes(); ModelMetadata testModelMetadata = getModelMetadata(); - for(String modelID: testModelID){ + for (String modelID : testModelID) { addModelToSystemIndex(modelID, testModelMetadata, testModelBlob); } String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); - for(String method: Arrays.asList("GET", "POST")){ + for (String method : Arrays.asList("GET", "POST")) { Request request = new Request(method, restURI); - request.setJsonEntity("{\n" + - " \"_source\": {\n" + - " \"includes\": [ \"state\", \"description\" ]\n"+ - " }, " + - " \"query\": {\n" + - " \"match_all\": {}\n" + - " }\n" + - "}"); + request.setJsonEntity( + "{\n" + + " \"_source\": {\n" + + " \"includes\": [ \"state\", \"description\" ]\n" + + " }, " + + " \"query\": {\n" + + " \"match_all\": {}\n" + + " }\n" + + "}" + ); Response response = client().performRequest(request); assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -199,10 +187,10 @@ public void testSearchModelWithSourceFilteringIncludes() throws IOException { SearchResponse searchResponse = SearchResponse.fromXContent(parser); assertNotNull(searchResponse); - //returns only model from ModelIndex + // returns only model from ModelIndex assertEquals(searchResponse.getHits().getHits().length, testModelID.size()); - for(SearchHit hit: searchResponse.getHits().getHits()){ + for (SearchHit hit : searchResponse.getHits().getHits()) { assertTrue(testModelID.contains(hit.getId())); Map sourceAsMap = hit.getSourceAsMap(); assertFalse(sourceAsMap.containsKey("model_blob")); @@ -220,22 +208,24 @@ public void testSearchModelWithSourceFilteringExcludes() throws IOException { List testModelID = Arrays.asList("test-modelid1", "test-modelid2"); byte[] testModelBlob = "hello".getBytes(); ModelMetadata testModelMetadata = getModelMetadata(); - for(String modelID: testModelID){ + for (String modelID : testModelID) { addModelToSystemIndex(modelID, testModelMetadata, testModelBlob); } String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); - for(String method: Arrays.asList("GET", "POST")){ + for (String method : Arrays.asList("GET", "POST")) { Request request = new Request(method, restURI); - request.setJsonEntity("{\n" + - " \"_source\": {\n" + - " \"excludes\": [\"model_blob\" ]\n"+ - " }, " + - " \"query\": {\n" + - " \"match_all\": {}\n" + - " }\n" + - "}"); + request.setJsonEntity( + "{\n" + + " \"_source\": {\n" + + " \"excludes\": [\"model_blob\" ]\n" + + " }, " + + " \"query\": {\n" + + " \"match_all\": {}\n" + + " }\n" + + "}" + ); Response response = client().performRequest(request); assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -246,10 +236,10 @@ public void testSearchModelWithSourceFilteringExcludes() throws IOException { SearchResponse searchResponse = SearchResponse.fromXContent(parser); assertNotNull(searchResponse); - //returns only model from ModelIndex + // returns only model from ModelIndex assertEquals(searchResponse.getHits().getHits().length, testModelID.size()); - for(SearchHit hit: searchResponse.getHits().getHits()){ + for (SearchHit hit : searchResponse.getHits().getHits()) { assertTrue(testModelID.contains(hit.getId())); Map sourceAsMap = hit.getSourceAsMap(); assertFalse(sourceAsMap.containsKey("model_blob")); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java index bed626a7b..4b72e22fa 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java @@ -71,25 +71,25 @@ public void testTrainModel_fail_notEnoughData() throws IOException, InterruptedE } } */ - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .field(NAME, "ivf") - .field(KNN_ENGINE, "faiss") - .field(METHOD_PARAMETER_SPACE_TYPE, "innerproduct") - .startObject(PARAMETERS) - .field(METHOD_PARAMETER_NLIST, 128) - .startObject(METHOD_ENCODER_PARAMETER) - .field(NAME, "pq") - .startObject(PARAMETERS) - .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2) - .field(ENCODER_PARAMETER_PQ_M, 2) - .endObject() - .endObject() - .endObject() - .endObject(); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, "ivf") + .field(KNN_ENGINE, "faiss") + .field(METHOD_PARAMETER_SPACE_TYPE, "innerproduct") + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 128) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, "pq") + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2) + .field(ENCODER_PARAMETER_PQ_M, 2) + .endObject() + .endObject() + .endObject() + .endObject(); Map method = xContentBuilderToMap(builder); - Response trainResponse = trainModel(null, trainingIndexName, trainingFieldName, dimension, method, - "dummy description"); + Response trainResponse = trainModel(null, trainingIndexName, trainingFieldName, dimension, method, "dummy description"); assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode())); @@ -97,10 +97,7 @@ public void testTrainModel_fail_notEnoughData() throws IOException, InterruptedE String trainResponseBody = EntityUtils.toString(trainResponse.getEntity()); assertNotNull(trainResponseBody); - Map trainResponseMap = createParser( - XContentType.JSON.xContent(), - trainResponseBody - ).map(); + Map trainResponseMap = createParser(XContentType.JSON.xContent(), trainResponseBody).map(); String modelId = (String) trainResponseMap.get(MODEL_ID); assertNotNull(modelId); @@ -109,10 +106,7 @@ public void testTrainModel_fail_notEnoughData() throws IOException, InterruptedE String responseBody = EntityUtils.toString(getResponse.getEntity()); assertNotNull(responseBody); - Map responseMap = createParser( - XContentType.JSON.xContent(), - responseBody - ).map(); + Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); assertEquals(modelId, responseMap.get(MODEL_ID)); @@ -157,25 +151,25 @@ public void testTrainModel_fail_tooMuchData() throws Exception { } } */ - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .field(NAME, "ivf") - .field(KNN_ENGINE, "faiss") - .field(METHOD_PARAMETER_SPACE_TYPE, "innerproduct") - .startObject(PARAMETERS) - .field(METHOD_PARAMETER_NLIST, 128) - .startObject(METHOD_ENCODER_PARAMETER) - .field(NAME, "pq") - .startObject(PARAMETERS) - .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2) - .field(ENCODER_PARAMETER_PQ_M, 2) - .endObject() - .endObject() - .endObject() - .endObject(); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, "ivf") + .field(KNN_ENGINE, "faiss") + .field(METHOD_PARAMETER_SPACE_TYPE, "innerproduct") + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 128) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, "pq") + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2) + .field(ENCODER_PARAMETER_PQ_M, 2) + .endObject() + .endObject() + .endObject() + .endObject(); Map method = xContentBuilderToMap(builder); - Response trainResponse = trainModel(null, trainingIndexName, trainingFieldName, dimension, method, - "dummy description"); + Response trainResponse = trainModel(null, trainingIndexName, trainingFieldName, dimension, method, "dummy description"); assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode())); @@ -183,10 +177,7 @@ public void testTrainModel_fail_tooMuchData() throws Exception { String trainResponseBody = EntityUtils.toString(trainResponse.getEntity()); assertNotNull(trainResponseBody); - Map trainResponseMap = createParser( - XContentType.JSON.xContent(), - trainResponseBody - ).map(); + Map trainResponseMap = createParser(XContentType.JSON.xContent(), trainResponseBody).map(); String modelId = (String) trainResponseMap.get(MODEL_ID); assertNotNull(modelId); @@ -195,10 +186,7 @@ public void testTrainModel_fail_tooMuchData() throws Exception { String responseBody = EntityUtils.toString(getResponse.getEntity()); assertNotNull(responseBody); - Map responseMap = createParser( - XContentType.JSON.xContent(), - responseBody - ).map(); + Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); assertEquals(modelId, responseMap.get(MODEL_ID)); @@ -242,25 +230,25 @@ public void testTrainModel_success_withId() throws IOException, InterruptedExcep } } */ - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .field(NAME, "ivf") - .field(KNN_ENGINE, "faiss") - .field(METHOD_PARAMETER_SPACE_TYPE, "l2") - .startObject(PARAMETERS) - .field(METHOD_PARAMETER_NLIST, 1) - .startObject(METHOD_ENCODER_PARAMETER) - .field(NAME, "pq") - .startObject(PARAMETERS) - .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2) - .field(ENCODER_PARAMETER_PQ_M, 2) - .endObject() - .endObject() - .endObject() - .endObject(); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, "ivf") + .field(KNN_ENGINE, "faiss") + .field(METHOD_PARAMETER_SPACE_TYPE, "l2") + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 1) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, "pq") + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2) + .field(ENCODER_PARAMETER_PQ_M, 2) + .endObject() + .endObject() + .endObject() + .endObject(); Map method = xContentBuilderToMap(builder); - Response trainResponse = trainModel(modelId, trainingIndexName, trainingFieldName, dimension, method, - "dummy description"); + Response trainResponse = trainModel(modelId, trainingIndexName, trainingFieldName, dimension, method, "dummy description"); assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode())); @@ -269,10 +257,7 @@ public void testTrainModel_success_withId() throws IOException, InterruptedExcep String responseBody = EntityUtils.toString(getResponse.getEntity()); assertNotNull(responseBody); - Map responseMap = createParser( - XContentType.JSON.xContent(), - responseBody - ).map(); + Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); assertEquals(modelId, responseMap.get(MODEL_ID)); @@ -316,25 +301,25 @@ public void testTrainModel_success_noId() throws IOException, InterruptedExcepti } } */ - XContentBuilder builder = XContentFactory.jsonBuilder().startObject() - .field(NAME, "ivf") - .field(KNN_ENGINE, "faiss") - .field(METHOD_PARAMETER_SPACE_TYPE, "innerproduct") - .startObject(PARAMETERS) - .field(METHOD_PARAMETER_NLIST, 2) - .startObject(METHOD_ENCODER_PARAMETER) - .field(NAME, "pq") - .startObject(PARAMETERS) - .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2) - .field(ENCODER_PARAMETER_PQ_M, 2) - .endObject() - .endObject() - .endObject() - .endObject(); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, "ivf") + .field(KNN_ENGINE, "faiss") + .field(METHOD_PARAMETER_SPACE_TYPE, "innerproduct") + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 2) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, "pq") + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2) + .field(ENCODER_PARAMETER_PQ_M, 2) + .endObject() + .endObject() + .endObject() + .endObject(); Map method = xContentBuilderToMap(builder); - Response trainResponse = trainModel(null, trainingIndexName, trainingFieldName, dimension, method, - "dummy description"); + Response trainResponse = trainModel(null, trainingIndexName, trainingFieldName, dimension, method, "dummy description"); assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode())); @@ -342,10 +327,7 @@ public void testTrainModel_success_noId() throws IOException, InterruptedExcepti String trainResponseBody = EntityUtils.toString(trainResponse.getEntity()); assertNotNull(trainResponseBody); - Map trainResponseMap = createParser( - XContentType.JSON.xContent(), - trainResponseBody - ).map(); + Map trainResponseMap = createParser(XContentType.JSON.xContent(), trainResponseBody).map(); String modelId = (String) trainResponseMap.get(MODEL_ID); assertNotNull(modelId); @@ -354,10 +336,7 @@ public void testTrainModel_success_noId() throws IOException, InterruptedExcepti String responseBody = EntityUtils.toString(getResponse.getEntity()); assertNotNull(responseBody); - Map responseMap = createParser( - XContentType.JSON.xContent(), - responseBody - ).map(); + Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); assertEquals(modelId, responseMap.get(MODEL_ID)); diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java index 01908c645..70dbee248 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java @@ -18,25 +18,41 @@ public class KNNScoringSpaceFactoryTests extends KNNTestCase { public void testValidSpaces() { - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = - mock(KNNVectorFieldMapper.KNNVectorFieldType.class); - NumberFieldMapper.NumberFieldType numberFieldType = new NumberFieldMapper.NumberFieldType("field", - NumberFieldMapper.NumberType.LONG); + KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + NumberFieldMapper.NumberFieldType numberFieldType = new NumberFieldMapper.NumberFieldType( + "field", + NumberFieldMapper.NumberType.LONG + ); List floatQueryObject = new ArrayList<>(); Long longQueryObject = 0L; - assertTrue(KNNScoringSpaceFactory.create(SpaceType.L2.getValue(), floatQueryObject, knnVectorFieldType) - instanceof KNNScoringSpace.L2); - assertTrue(KNNScoringSpaceFactory.create(SpaceType.COSINESIMIL.getValue(), floatQueryObject, knnVectorFieldType) - instanceof KNNScoringSpace.CosineSimilarity); - assertTrue(KNNScoringSpaceFactory.create(SpaceType.INNER_PRODUCT.getValue(), floatQueryObject, knnVectorFieldType) - instanceof KNNScoringSpace.InnerProd); - assertTrue(KNNScoringSpaceFactory.create(SpaceType.HAMMING_BIT.getValue(), longQueryObject, numberFieldType) - instanceof KNNScoringSpace.HammingBit); + assertTrue( + KNNScoringSpaceFactory.create(SpaceType.L2.getValue(), floatQueryObject, knnVectorFieldType) instanceof KNNScoringSpace.L2 + ); + assertTrue( + KNNScoringSpaceFactory.create( + SpaceType.COSINESIMIL.getValue(), + floatQueryObject, + knnVectorFieldType + ) instanceof KNNScoringSpace.CosineSimilarity + ); + assertTrue( + KNNScoringSpaceFactory.create( + SpaceType.INNER_PRODUCT.getValue(), + floatQueryObject, + knnVectorFieldType + ) instanceof KNNScoringSpace.InnerProd + ); + assertTrue( + KNNScoringSpaceFactory.create( + SpaceType.HAMMING_BIT.getValue(), + longQueryObject, + numberFieldType + ) instanceof KNNScoringSpace.HammingBit + ); } public void testInvalidSpace() { - expectThrows(IllegalArgumentException.class, () -> KNNScoringSpaceFactory.create(SpaceType.L2.getValue(), - null, null)); + expectThrows(IllegalArgumentException.class, () -> KNNScoringSpaceFactory.create(SpaceType.L2.getValue(), null, null)); } } diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java index c9cdba7b8..090581970 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java @@ -23,86 +23,80 @@ public class KNNScoringSpaceTests extends KNNTestCase { public void testL2() { - float[] arrayFloat = new float[]{1.0f, 2.0f, 3.0f}; + float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); - KNNVectorFieldMapper.KNNVectorFieldType fieldType = new KNNVectorFieldMapper.KNNVectorFieldType("test", - Collections.emptyMap(), 3); + KNNVectorFieldMapper.KNNVectorFieldType fieldType = new KNNVectorFieldMapper.KNNVectorFieldType("test", Collections.emptyMap(), 3); KNNScoringSpace.L2 l2 = new KNNScoringSpace.L2(arrayListQueryObject, fieldType); assertEquals(1F, l2.scoringMethod.apply(arrayFloat, arrayFloat), 0.1F); - NumberFieldMapper.NumberFieldType invalidFieldType = new NumberFieldMapper.NumberFieldType("field", - NumberFieldMapper.NumberType.INTEGER); - expectThrows(IllegalArgumentException.class, () -> - new KNNScoringSpace.L2(arrayListQueryObject, invalidFieldType)); + NumberFieldMapper.NumberFieldType invalidFieldType = new NumberFieldMapper.NumberFieldType( + "field", + NumberFieldMapper.NumberType.INTEGER + ); + expectThrows(IllegalArgumentException.class, () -> new KNNScoringSpace.L2(arrayListQueryObject, invalidFieldType)); } public void testCosineSimilarity() { - float[] arrayFloat = new float[]{1.0f, 2.0f, 3.0f}; + float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); - float[] arrayFloat2 = new float[]{2.0f, 4.0f, 6.0f}; + float[] arrayFloat2 = new float[] { 2.0f, 4.0f, 6.0f }; - KNNVectorFieldMapper.KNNVectorFieldType fieldType = new KNNVectorFieldMapper.KNNVectorFieldType("test", - Collections.emptyMap(), 3); - KNNScoringSpace.CosineSimilarity cosineSimilarity = - new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, fieldType); + KNNVectorFieldMapper.KNNVectorFieldType fieldType = new KNNVectorFieldMapper.KNNVectorFieldType("test", Collections.emptyMap(), 3); + KNNScoringSpace.CosineSimilarity cosineSimilarity = new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, fieldType); assertEquals(3F, cosineSimilarity.scoringMethod.apply(arrayFloat2, arrayFloat), 0.1F); - NumberFieldMapper.NumberFieldType invalidFieldType = new NumberFieldMapper.NumberFieldType("field", - NumberFieldMapper.NumberType.INTEGER); - expectThrows(IllegalArgumentException.class, () -> - new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, invalidFieldType)); + NumberFieldMapper.NumberFieldType invalidFieldType = new NumberFieldMapper.NumberFieldType( + "field", + NumberFieldMapper.NumberType.INTEGER + ); + expectThrows(IllegalArgumentException.class, () -> new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, invalidFieldType)); } public void testInnerProdSimilarity() { - float[] arrayFloat_case1 = new float[]{1.0f, 2.0f, 3.0f}; + float[] arrayFloat_case1 = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject_case1 = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); - float[] arrayFloat2_case1 = new float[]{1.0f, 1.0f, 1.0f}; + float[] arrayFloat2_case1 = new float[] { 1.0f, 1.0f, 1.0f }; - KNNVectorFieldMapper.KNNVectorFieldType fieldType = new KNNVectorFieldMapper.KNNVectorFieldType("test", - Collections.emptyMap(), 3); - KNNScoringSpace.InnerProd innerProd = - new KNNScoringSpace.InnerProd(arrayListQueryObject_case1, fieldType); + KNNVectorFieldMapper.KNNVectorFieldType fieldType = new KNNVectorFieldMapper.KNNVectorFieldType("test", Collections.emptyMap(), 3); + KNNScoringSpace.InnerProd innerProd = new KNNScoringSpace.InnerProd(arrayListQueryObject_case1, fieldType); assertEquals(7.0F, innerProd.scoringMethod.apply(arrayFloat_case1, arrayFloat2_case1), 0.001F); - float[] arrayFloat_case2 = new float[]{100_000.0f, 200_000.0f, 300_000.0f}; + float[] arrayFloat_case2 = new float[] { 100_000.0f, 200_000.0f, 300_000.0f }; List arrayListQueryObject_case2 = new ArrayList<>(Arrays.asList(100_000.0, 200_000.0, 300_000.0)); - float[] arrayFloat2_case2 = new float[]{-100_000.0f, -200_000.0f, -300_000.0f}; + float[] arrayFloat2_case2 = new float[] { -100_000.0f, -200_000.0f, -300_000.0f }; innerProd = new KNNScoringSpace.InnerProd(arrayListQueryObject_case2, fieldType); - assertEquals(7.142857143E-12F, innerProd.scoringMethod.apply(arrayFloat_case2, arrayFloat2_case2), - 1.0E-11F); + assertEquals(7.142857143E-12F, innerProd.scoringMethod.apply(arrayFloat_case2, arrayFloat2_case2), 1.0E-11F); - float[] arrayFloat_case3 = new float[]{100_000.0f, 200_000.0f, 300_000.0f}; + float[] arrayFloat_case3 = new float[] { 100_000.0f, 200_000.0f, 300_000.0f }; List arrayListQueryObject_case3 = new ArrayList<>(Arrays.asList(100_000.0, 200_000.0, 300_000.0)); - float[] arrayFloat2_case3 = new float[]{100_000.0f, 200_000.0f, 300_000.0f}; + float[] arrayFloat2_case3 = new float[] { 100_000.0f, 200_000.0f, 300_000.0f }; innerProd = new KNNScoringSpace.InnerProd(arrayListQueryObject_case3, fieldType); assertEquals(140_000_000_001F, innerProd.scoringMethod.apply(arrayFloat_case3, arrayFloat2_case3), 0.01F); - NumberFieldMapper.NumberFieldType invalidFieldType = new NumberFieldMapper.NumberFieldType("field", - NumberFieldMapper.NumberType.INTEGER); - expectThrows(IllegalArgumentException.class, () -> - new KNNScoringSpace.InnerProd(arrayListQueryObject_case2, invalidFieldType)); + NumberFieldMapper.NumberFieldType invalidFieldType = new NumberFieldMapper.NumberFieldType( + "field", + NumberFieldMapper.NumberType.INTEGER + ); + expectThrows(IllegalArgumentException.class, () -> new KNNScoringSpace.InnerProd(arrayListQueryObject_case2, invalidFieldType)); } @SuppressWarnings("unchecked") public void testHammingBit_Long() { - NumberFieldMapper.NumberFieldType fieldType = new NumberFieldMapper.NumberFieldType("field", - NumberFieldMapper.NumberType.LONG); + NumberFieldMapper.NumberFieldType fieldType = new NumberFieldMapper.NumberFieldType("field", NumberFieldMapper.NumberType.LONG); Long longObject1 = 1234L; // ..._0000_0100_1101_0010 Long longObject2 = 2468L; // ..._0000_1001_1010_0100 KNNScoringSpace.HammingBit hammingBit = new KNNScoringSpace.HammingBit(longObject1, fieldType); - assertEquals(0.1111F, - ((BiFunction)hammingBit.scoringMethod).apply(longObject1, longObject2), 0.1F); + assertEquals(0.1111F, ((BiFunction) hammingBit.scoringMethod).apply(longObject1, longObject2), 0.1F); KNNVectorFieldMapper.KNNVectorFieldType invalidFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); - expectThrows(IllegalArgumentException.class, () -> - new KNNScoringSpace.HammingBit(longObject1, invalidFieldType)); + expectThrows(IllegalArgumentException.class, () -> new KNNScoringSpace.HammingBit(longObject1, invalidFieldType)); } @SuppressWarnings("unchecked") @@ -120,14 +114,16 @@ public void testHammingBit_Base64() { float expectedResult = 1F / (1 + 16); KNNScoringSpace.HammingBit hammingBit = new KNNScoringSpace.HammingBit(base64Object1, fieldType); - assertEquals(expectedResult, - ((BiFunction)hammingBit.scoringMethod).apply( - new BigInteger(Base64.getDecoder().decode(base64Object1)), - new BigInteger(Base64.getDecoder().decode(base64Object2)) - ), 0.1F); + assertEquals( + expectedResult, + ((BiFunction) hammingBit.scoringMethod).apply( + new BigInteger(Base64.getDecoder().decode(base64Object1)), + new BigInteger(Base64.getDecoder().decode(base64Object2)) + ), + 0.1F + ); KNNVectorFieldMapper.KNNVectorFieldType invalidFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); - expectThrows(IllegalArgumentException.class, () -> - new KNNScoringSpace.HammingBit(base64Object1, invalidFieldType)); + expectThrows(IllegalArgumentException.class, () -> new KNNScoringSpace.HammingBit(base64Object1, invalidFieldType)); } } diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java index 7323f958b..789432ec8 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java @@ -20,15 +20,16 @@ public class KNNScoringSpaceUtilTests extends KNNTestCase { public void testFieldTypeCheck() { - assertTrue(KNNScoringSpaceUtil.isLongFieldType(new NumberFieldMapper.NumberFieldType("field", - NumberFieldMapper.NumberType.LONG))); - assertFalse(KNNScoringSpaceUtil.isLongFieldType(new NumberFieldMapper.NumberFieldType("field", - NumberFieldMapper.NumberType.INTEGER))); + assertTrue(KNNScoringSpaceUtil.isLongFieldType(new NumberFieldMapper.NumberFieldType("field", NumberFieldMapper.NumberType.LONG))); + assertFalse( + KNNScoringSpaceUtil.isLongFieldType(new NumberFieldMapper.NumberFieldType("field", NumberFieldMapper.NumberType.INTEGER)) + ); assertFalse(KNNScoringSpaceUtil.isLongFieldType(new BinaryFieldMapper.BinaryFieldType("test"))); assertTrue(KNNScoringSpaceUtil.isBinaryFieldType(new BinaryFieldMapper.BinaryFieldType("test"))); - assertFalse(KNNScoringSpaceUtil.isBinaryFieldType(new NumberFieldMapper.NumberFieldType("field", - NumberFieldMapper.NumberType.INTEGER))); + assertFalse( + KNNScoringSpaceUtil.isBinaryFieldType(new NumberFieldMapper.NumberFieldType("field", NumberFieldMapper.NumberType.INTEGER)) + ); assertTrue(KNNScoringSpaceUtil.isKNNVectorFieldType(mock(KNNVectorFieldMapper.KNNVectorFieldType.class))); assertFalse(KNNScoringSpaceUtil.isKNNVectorFieldType(new BinaryFieldMapper.BinaryFieldType("test"))); @@ -57,7 +58,7 @@ public void testParseBinaryQuery() { } public void testParseKNNVectorQuery() { - float[] arrayFloat = new float[]{1.0f, 2.0f, 3.0f}; + float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); KNNVectorFieldMapper.KNNVectorFieldType fieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 08cede77c..291924219 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -35,36 +35,35 @@ private List getTestQueryVector() { } public void testL2SquaredScoringFunction() { - float[] queryVector = {1.0f, 1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; Float distance = KNNScoringUtil.l2Squared(queryVector, inputVector); assertTrue(distance == 27.0f); } public void testWrongDimensionL2SquaredScoringFunction() { - float[] queryVector = {1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; expectThrows(IllegalArgumentException.class, () -> KNNScoringUtil.l2Squared(queryVector, inputVector)); } public void testCosineSimilScoringFunction() { - float[] queryVector = {1.0f, 1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; float queryVectorMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(queryVector); float inputVectorMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(inputVector); float dotProduct = 12.0f; float expectedScore = (float) (dotProduct / (Math.sqrt(queryVectorMagnitude * inputVectorMagnitude))); - Float actualScore = KNNScoringUtil.cosinesimil(queryVector, inputVector); assertEquals(expectedScore, actualScore, 0.0001); } public void testCosineSimilOptimizedScoringFunction() { - float[] queryVector = {1.0f, 1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; float queryVectorMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(queryVector); float inputVectorMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(inputVector); float dotProduct = 12.0f; @@ -86,26 +85,26 @@ public void testConvertInvalidVectorToPrimitive() { } public void testCosineSimilQueryVectorZeroMagnitude() { - float[] queryVector = {0, 0}; - float[] inputVector = {4.0f, 4.0f}; + float[] queryVector = { 0, 0 }; + float[] inputVector = { 4.0f, 4.0f }; assertEquals(0, KNNScoringUtil.cosinesimil(queryVector, inputVector), 0.00001); } public void testCosineSimilOptimizedQueryVectorZeroMagnitude() { - float[] inputVector = {4.0f, 4.0f}; - float[] queryVector = {0, 0}; + float[] inputVector = { 4.0f, 4.0f }; + float[] queryVector = { 0, 0 }; assertTrue(0 == KNNScoringUtil.cosinesimilOptimized(queryVector, inputVector, 0.0f)); } public void testWrongDimensionCosineSimilScoringFunction() { - float[] queryVector = {1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; expectThrows(IllegalArgumentException.class, () -> KNNScoringUtil.cosinesimil(queryVector, inputVector)); } public void testWrongDimensionCosineSimilOPtimizedScoringFunction() { - float[] queryVector = {1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; expectThrows(IllegalArgumentException.class, () -> KNNScoringUtil.cosinesimilOptimized(queryVector, inputVector, 1.0f)); } @@ -173,7 +172,7 @@ public void testBitHammingDistance_Long() { public void testL2SquaredWhitelistedScoringFunction() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); scriptDocValues.setNextDocId(0); Float distance = KNNScoringUtil.l2Squared(queryVector, scriptDocValues); @@ -184,7 +183,7 @@ public void testL2SquaredWhitelistedScoringFunction() throws IOException { public void testScriptDocValuesFailsL2() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); expectThrows(IllegalStateException.class, () -> KNNScoringUtil.l2Squared(queryVector, scriptDocValues)); dataset.close(); @@ -193,7 +192,7 @@ public void testScriptDocValuesFailsL2() throws IOException { public void testCosineSimilarityScoringFunction() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); scriptDocValues.setNextDocId(0); @@ -205,7 +204,7 @@ public void testCosineSimilarityScoringFunction() throws IOException { public void testScriptDocValuesFailsCosineSimilarity() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); expectThrows(IllegalStateException.class, () -> KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues)); dataset.close(); @@ -214,7 +213,7 @@ public void testScriptDocValuesFailsCosineSimilarity() throws IOException { public void testCosineSimilarityOptimizedScoringFunction() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); scriptDocValues.setNextDocId(0); Float actualScore = KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues, 3.0f); @@ -225,7 +224,7 @@ public void testCosineSimilarityOptimizedScoringFunction() throws IOException { public void testScriptDocValuesFailsCosineSimilarityOptimized() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); expectThrows(IllegalStateException.class, () -> KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues, 3.0f)); dataset.close(); @@ -244,16 +243,14 @@ public KNNVectorScriptDocValues getScriptDocValues(String fieldName) throws IOEx if (scriptDocValues == null) { reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = new KNNVectorScriptDocValues(leafReaderContext.reader().getBinaryDocValues(fieldName),fieldName ); + scriptDocValues = new KNNVectorScriptDocValues(leafReaderContext.reader().getBinaryDocValues(fieldName), fieldName); } return scriptDocValues; } public void close() throws IOException { - if (reader != null) - reader.close(); - if (directory != null) - directory.close(); + if (reader != null) reader.close(); + if (directory != null) directory.close(); } public void createKNNVectorDocument(final float[] content, final String fieldName) throws IOException { @@ -261,10 +258,7 @@ public void createKNNVectorDocument(final float[] content, final String fieldNam IndexWriter writer = new IndexWriter(directory, conf); conf.setMergePolicy(NoMergePolicy.INSTANCE); // prevent merges for this test Document knnDocument = new Document(); - knnDocument.add( - new BinaryDocValuesField( - fieldName, - new VectorField(fieldName, content, new FieldType()).binaryValue())); + knnDocument.add(new BinaryDocValuesField(fieldName, new VectorField(fieldName, content, new FieldType()).binaryValue())); writer.addDocument(knnDocument); writer.commit(); writer.close(); diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java index e888fcb37..16df35921 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java @@ -40,19 +40,18 @@ public void testKNNL2ScriptScore() throws Exception { * Create knn index and populate data */ createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = {6.0f, 6.0f}; + Float[] f1 = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - Float[] f2 = {2.0f, 2.0f}; + Float[] f2 = { 2.0f, 2.0f }; addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - Float[] f3 = {4.0f, 4.0f}; + Float[] f3 = { 4.0f, 4.0f }; addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - Float[] f4 = {3.0f, 3.0f}; + Float[] f4 = { 3.0f, 3.0f }; addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); - /** * Construct Search Request */ @@ -64,20 +63,19 @@ public void testKNNL2ScriptScore() throws Exception { * "vector": [2.0, 2.0] * } */ - float[] queryVector = {1.0f, 1.0f}; + float[] queryVector = { 1.0f, 1.0f }; params.put("field", FIELD_NAME); params.put("query_value", queryVector); params.put("space_type", SpaceType.L2.getValue()); Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); List expectedDocids = Arrays.asList("2", "4", "3", "1"); List actualDocids = new ArrayList<>(); - for(KNNResult result : results) { + for (KNNResult result : results) { actualDocids.add(result.getDocId()); } @@ -95,19 +93,18 @@ public void testKNNL1ScriptScore() throws Exception { * Create knn index and populate data */ createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = {6.0f, 6.0f}; + Float[] f1 = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - Float[] f2 = {4.0f, 1.0f}; + Float[] f2 = { 4.0f, 1.0f }; addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - Float[] f3 = {3.0f, 3.0f}; + Float[] f3 = { 3.0f, 3.0f }; addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - Float[] f4 = {5.0f, 5.0f}; + Float[] f4 = { 5.0f, 5.0f }; addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); - /** * Construct Search Request */ @@ -119,20 +116,19 @@ public void testKNNL1ScriptScore() throws Exception { * "vector": [1.0, 1.0] * } */ - float[] queryVector = {1.0f, 1.0f}; + float[] queryVector = { 1.0f, 1.0f }; params.put("field", FIELD_NAME); params.put("query_value", queryVector); params.put("space_type", SpaceType.L1); Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); List expectedDocids = Arrays.asList("2", "4", "3", "1"); List actualDocids = new ArrayList<>(); - for(KNNResult result : results) { + for (KNNResult result : results) { actualDocids.add(result.getDocId()); } @@ -150,19 +146,18 @@ public void testKNNLInfScriptScore() throws Exception { * Create knn index and populate data */ createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = {6.0f, 6.0f}; + Float[] f1 = { 6.0f, 6.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - Float[] f2 = {4.0f, 1.0f}; + Float[] f2 = { 4.0f, 1.0f }; addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - Float[] f3 = {3.0f, 3.0f}; + Float[] f3 = { 3.0f, 3.0f }; addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - Float[] f4 = {5.0f, 5.0f}; + Float[] f4 = { 5.0f, 5.0f }; addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); - /** * Construct Search Request */ @@ -174,20 +169,19 @@ public void testKNNLInfScriptScore() throws Exception { * "vector": [1.0, 1.0] * } */ - float[] queryVector = {1.0f, 1.0f}; + float[] queryVector = { 1.0f, 1.0f }; params.put("field", FIELD_NAME); params.put("query_value", queryVector); params.put("space_type", SpaceType.LINF.getValue()); Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); List expectedDocids = Arrays.asList("3", "2", "4", "1"); List actualDocids = new ArrayList<>(); - for(KNNResult result : results) { + for (KNNResult result : results) { actualDocids.add(result.getDocId()); } @@ -205,13 +199,13 @@ public void testKNNCosineScriptScore() throws Exception { * Create knn index and populate data */ createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = {1.0f, -1.0f}; + Float[] f1 = { 1.0f, -1.0f }; addKnnDoc(INDEX_NAME, "0", FIELD_NAME, f1); - Float[] f2 = {1.0f, 0.0f}; + Float[] f2 = { 1.0f, 0.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f2); - Float[] f3 = {1.0f, 1.0f}; + Float[] f3 = { 1.0f, 1.0f }; addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f3); /** @@ -228,20 +222,19 @@ public void testKNNCosineScriptScore() throws Exception { * * */ - float[] queryVector = {2.0f, -2.0f}; + float[] queryVector = { 2.0f, -2.0f }; params.put("field", FIELD_NAME); params.put("query_value", queryVector); params.put("space_type", SpaceType.COSINESIMIL.getValue()); Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); List expectedDocids = Arrays.asList("0", "1", "2"); List actualDocids = new ArrayList<>(); - for(KNNResult result : results) { + for (KNNResult result : results) { actualDocids.add(result.getDocId()); } @@ -271,7 +264,7 @@ public void testKNNInvalidSourceScript() throws Exception { * "space_type": "cosinesimil" * } */ - float[] queryVector = {2.0f, -2.0f}; + float[] queryVector = { 2.0f, -2.0f }; params.put("field", FIELD_NAME); params.put("query_value", queryVector); params.put("space_type", SpaceType.COSINESIMIL.getValue()); @@ -287,15 +280,11 @@ public void testKNNInvalidSourceScript() throws Exception { builder.endObject(); builder.endObject(); builder.endObject(); - Request request = new Request( - "POST", - "/" + INDEX_NAME + "/_search" - ); + Request request = new Request("POST", "/" + INDEX_NAME + "/_search"); request.setJsonEntity(Strings.toString(builder)); - ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); - assertThat(EntityUtils.toString(ex.getResponse().getEntity()), - containsString("Unknown script name Dummy_source")); + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); + assertThat(EntityUtils.toString(ex.getResponse().getEntity()), containsString("Unknown script name Dummy_source")); } public void testInvalidSpace() throws Exception { @@ -310,14 +299,16 @@ public void testInvalidSpace() throws Exception { */ QueryBuilder qb = new MatchAllQueryBuilder(); Map params = new HashMap<>(); - float[] queryVector = {2.0f, -2.0f}; + float[] queryVector = { 2.0f, -2.0f }; params.put("field", FIELD_NAME); params.put("query_value", queryVector); params.put("space_type", INVALID_SPACE); Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); - assertThat(EntityUtils.toString(ex.getResponse().getEntity()), - containsString("Invalid space type. Please refer to the available space types")); + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); + assertThat( + EntityUtils.toString(ex.getResponse().getEntity()), + containsString("Invalid space type. Please refer to the available space types") + ); } public void testMissingParamsInScript() throws Exception { @@ -331,29 +322,26 @@ public void testMissingParamsInScript() throws Exception { */ QueryBuilder qb = new MatchAllQueryBuilder(); Map params = new HashMap<>(); - float[] queryVector = {2.0f, -2.0f}; + float[] queryVector = { 2.0f, -2.0f }; params.put("query_value", queryVector); params.put("space_type", SpaceType.COSINESIMIL.getValue()); Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); - assertThat(EntityUtils.toString(ex.getResponse().getEntity()), - containsString("Missing parameter [field]")); + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); + assertThat(EntityUtils.toString(ex.getResponse().getEntity()), containsString("Missing parameter [field]")); // Remove query vector parameter params.put("field", FIELD_NAME); params.remove("query_value"); Request vector_request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - ex = expectThrows(ResponseException.class, () -> client().performRequest(vector_request)); - assertThat(EntityUtils.toString(ex.getResponse().getEntity()), - containsString("Missing parameter [query_value]")); + ex = expectThrows(ResponseException.class, () -> client().performRequest(vector_request)); + assertThat(EntityUtils.toString(ex.getResponse().getEntity()), containsString("Missing parameter [query_value]")); // Remove space parameter params.put("query_value", queryVector); params.remove("space_type"); Request space_request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - ex = expectThrows(ResponseException.class, () -> client().performRequest(space_request)); - assertThat(EntityUtils.toString(ex.getResponse().getEntity()), - containsString("Missing parameter [space_type]")); + ex = expectThrows(ResponseException.class, () -> client().performRequest(space_request)); + assertThat(EntityUtils.toString(ex.getResponse().getEntity()), containsString("Missing parameter [space_type]")); } public void testUnequalDimensions() throws Exception { @@ -361,7 +349,7 @@ public void testUnequalDimensions() throws Exception { * Create knn index and populate data */ createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = {1.0f, -1.0f}; + Float[] f1 = { 1.0f, -1.0f }; addKnnDoc(INDEX_NAME, "0", FIELD_NAME, f1); /** @@ -369,12 +357,12 @@ public void testUnequalDimensions() throws Exception { */ QueryBuilder qb = new MatchAllQueryBuilder(); Map params = new HashMap<>(); - float[] queryVector = {2.0f, -2.0f, -2.0f}; // query dimension and field dimension mismatch + float[] queryVector = { 2.0f, -2.0f, -2.0f }; // query dimension and field dimension mismatch params.put("field", FIELD_NAME); params.put("query_value", queryVector); params.put("space_type", SpaceType.COSINESIMIL.getValue()); Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); assertThat(EntityUtils.toString(ex.getResponse().getEntity()), containsString("does not match")); } @@ -384,7 +372,7 @@ public void testKNNScoreforNonVectorDocument() throws Exception { * Create knn index and populate data */ createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = {1.0f, 1.0f}; + Float[] f1 = { 1.0f, 1.0f }; addDocWithNumericField(INDEX_NAME, "0", "price", 10); addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); forceMergeKnnIndex(INDEX_NAME); @@ -393,32 +381,31 @@ public void testKNNScoreforNonVectorDocument() throws Exception { */ QueryBuilder qb = new MatchAllQueryBuilder(); Map params = new HashMap<>(); - float[] queryVector = {2.0f, 2.0f}; // query dimension and field dimension mismatch + float[] queryVector = { 2.0f, 2.0f }; // query dimension and field dimension mismatch params.put("field", FIELD_NAME); params.put("query_value", queryVector); params.put("space_type", SpaceType.L2.getValue()); Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); String responseBody = EntityUtils.toString(response.getEntity()); - List hits = (List) ((Map)createParser(XContentType.JSON.xContent(), - responseBody).map().get("hits")).get("hits"); + List hits = (List) ((Map) createParser(XContentType.JSON.xContent(), responseBody).map() + .get("hits")).get("hits"); - List docIds = hits.stream().map(hit -> { - String id = ((String)((Map)hit).get("_id")); + List docIds = hits.stream().map(hit -> { + String id = ((String) ((Map) hit).get("_id")); return id; }).collect(Collectors.toList()); - //assert document order + // assert document order assertEquals("1", docIds.get(0)); assertEquals("0", docIds.get(1)); - List scores = hits.stream().map(hit -> { - Double score = ((Double)((Map)hit).get("_score")); + List scores = hits.stream().map(hit -> { + Double score = ((Double) ((Map) hit).get("_score")); return score; }).collect(Collectors.toList()); - //assert scores + // assert scores assertEquals(0.33333, scores.get(0), 0.001); assertEquals(Float.MIN_VALUE, scores.get(1), 0.001); } @@ -426,13 +413,16 @@ public void testKNNScoreforNonVectorDocument() throws Exception { @SuppressWarnings("unchecked") public void testHammingScriptScore_Long() throws Exception { createIndex(INDEX_NAME, Settings.EMPTY); - String longMapping = Strings.toString(XContentFactory.jsonBuilder().startObject() + String longMapping = Strings.toString( + XContentFactory.jsonBuilder() + .startObject() .startObject("properties") .startObject(FIELD_NAME) .field("type", "long") .endObject() .endObject() - .endObject()); + .endObject() + ); putMappingRequest(INDEX_NAME, longMapping); addDocWithNumericField(INDEX_NAME, "0", FIELD_NAME, 8L); @@ -463,18 +453,17 @@ public void testHammingScriptScore_Long() throws Exception { params1.put("space_type", SpaceType.HAMMING_BIT.getValue()); Request request1 = constructKNNScriptQueryRequest(INDEX_NAME, qb1, params1, 4); Response response1 = client().performRequest(request1); - assertEquals(request1.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response1.getStatusLine().getStatusCode())); + assertEquals(request1.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response1.getStatusLine().getStatusCode())); String responseBody1 = EntityUtils.toString(response1.getEntity()); - List hits1 = (List) ((Map)createParser(XContentType.JSON.xContent(), - responseBody1).map().get("hits")).get("hits"); + List hits1 = (List) ((Map) createParser(XContentType.JSON.xContent(), responseBody1).map() + .get("hits")).get("hits"); - List docIds1 = hits1.stream().map(hit -> - ((String)((Map)hit).get("_id"))).collect(Collectors.toList()); + List docIds1 = hits1.stream().map(hit -> ((String) ((Map) hit).get("_id"))).collect(Collectors.toList()); - List docScores1 = hits1.stream().map(hit -> - ((Double)((Map)hit).get("_score"))).collect(Collectors.toList()); + List docScores1 = hits1.stream() + .map(hit -> ((Double) ((Map) hit).get("_score"))) + .collect(Collectors.toList()); double[] scores1 = new double[docScores1.size()]; for (int i = 0; i < docScores1.size(); i++) { @@ -482,7 +471,7 @@ public void testHammingScriptScore_Long() throws Exception { } List correctIds1 = Arrays.asList("2", "0", "1", "3"); - double[] correctScores1 = new double[] {1.0/(1 + 3), 1.0/(1 + 9), 1.0/(1 + 9), 1.0/(1 + 30)}; + double[] correctScores1 = new double[] { 1.0 / (1 + 3), 1.0 / (1 + 9), 1.0 / (1 + 9), 1.0 / (1 + 30) }; assertEquals(4, correctIds1.size()); assertArrayEquals(correctIds1.toArray(), docIds1.toArray()); @@ -502,18 +491,17 @@ public void testHammingScriptScore_Long() throws Exception { params2.put("space_type", SpaceType.HAMMING_BIT.getValue()); Request request2 = constructKNNScriptQueryRequest(INDEX_NAME, qb2, params2, 4); Response response2 = client().performRequest(request2); - assertEquals(request2.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response2.getStatusLine().getStatusCode())); + assertEquals(request2.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response2.getStatusLine().getStatusCode())); String responseBody2 = EntityUtils.toString(response2.getEntity()); - List hits2 = (List) ((Map)createParser(XContentType.JSON.xContent(), - responseBody2).map().get("hits")).get("hits"); + List hits2 = (List) ((Map) createParser(XContentType.JSON.xContent(), responseBody2).map() + .get("hits")).get("hits"); - List docIds2 = hits2.stream().map(hit -> - ((String)((Map)hit).get("_id"))).collect(Collectors.toList()); + List docIds2 = hits2.stream().map(hit -> ((String) ((Map) hit).get("_id"))).collect(Collectors.toList()); - List docScores2 = hits2.stream().map(hit -> - ((Double)((Map)hit).get("_score"))).collect(Collectors.toList()); + List docScores2 = hits2.stream() + .map(hit -> ((Double) ((Map) hit).get("_score"))) + .collect(Collectors.toList()); double[] scores2 = new double[docScores2.size()]; for (int i = 0; i < docScores2.size(); i++) { @@ -521,7 +509,7 @@ public void testHammingScriptScore_Long() throws Exception { } List correctIds2 = Arrays.asList("0", "1", "2", "3"); - double[] correctScores2 = new double[] {1.0/(1 + 1), 1.0/(1 + 3), 1.0/(1 + 11), 1.0/(1 + 22)}; + double[] correctScores2 = new double[] { 1.0 / (1 + 1), 1.0 / (1 + 3), 1.0 / (1 + 11), 1.0 / (1 + 22) }; assertEquals(4, correctIds2.size()); assertArrayEquals(correctIds2.toArray(), docIds2.toArray()); @@ -529,16 +517,19 @@ public void testHammingScriptScore_Long() throws Exception { } @SuppressWarnings("unchecked") - public void testHammingScriptScore_Base64() throws Exception { + public void testHammingScriptScore_Base64() throws Exception { createIndex(INDEX_NAME, Settings.EMPTY); - String longMapping = Strings.toString(XContentFactory.jsonBuilder().startObject() + String longMapping = Strings.toString( + XContentFactory.jsonBuilder() + .startObject() .startObject("properties") .startObject(FIELD_NAME) .field("type", "binary") .field("doc_values", true) .endObject() .endObject() - .endObject()); + .endObject() + ); putMappingRequest(INDEX_NAME, longMapping); addDocWithBinaryField(INDEX_NAME, "0", FIELD_NAME, "AAAAAAAAAAk="); @@ -569,18 +560,17 @@ public void testHammingScriptScore_Base64() throws Exception { params1.put("space_type", SpaceType.HAMMING_BIT.getValue()); Request request1 = constructKNNScriptQueryRequest(INDEX_NAME, qb1, params1, 4); Response response1 = client().performRequest(request1); - assertEquals(request1.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response1.getStatusLine().getStatusCode())); + assertEquals(request1.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response1.getStatusLine().getStatusCode())); String responseBody1 = EntityUtils.toString(response1.getEntity()); - List hits1 = (List) ((Map)createParser(XContentType.JSON.xContent(), - responseBody1).map().get("hits")).get("hits"); + List hits1 = (List) ((Map) createParser(XContentType.JSON.xContent(), responseBody1).map() + .get("hits")).get("hits"); - List docIds1 = hits1.stream().map(hit -> - ((String)((Map)hit).get("_id"))).collect(Collectors.toList()); + List docIds1 = hits1.stream().map(hit -> ((String) ((Map) hit).get("_id"))).collect(Collectors.toList()); - List docScores1 = hits1.stream().map(hit -> - ((Double)((Map)hit).get("_score"))).collect(Collectors.toList()); + List docScores1 = hits1.stream() + .map(hit -> ((Double) ((Map) hit).get("_score"))) + .collect(Collectors.toList()); double[] scores1 = new double[docScores1.size()]; for (int i = 0; i < docScores1.size(); i++) { @@ -588,7 +578,7 @@ public void testHammingScriptScore_Base64() throws Exception { } List correctIds1 = Arrays.asList("2", "0", "1", "3"); - double[] correctScores1 = new double[] {1.0/(1 + 3), 1.0/(1 + 8), 1.0/(1 + 9), 1.0/(1 + 30)}; + double[] correctScores1 = new double[] { 1.0 / (1 + 3), 1.0 / (1 + 8), 1.0 / (1 + 9), 1.0 / (1 + 30) }; assertEquals(correctIds1.size(), docIds1.size()); assertArrayEquals(correctIds1.toArray(), docIds1.toArray()); @@ -608,18 +598,17 @@ public void testHammingScriptScore_Base64() throws Exception { params2.put("space_type", SpaceType.HAMMING_BIT.getValue()); Request request2 = constructKNNScriptQueryRequest(INDEX_NAME, qb2, params2, 4); Response response2 = client().performRequest(request2); - assertEquals(request2.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response2.getStatusLine().getStatusCode())); + assertEquals(request2.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response2.getStatusLine().getStatusCode())); String responseBody2 = EntityUtils.toString(response2.getEntity()); - List hits2 = (List) ((Map)createParser(XContentType.JSON.xContent(), - responseBody2).map().get("hits")).get("hits"); + List hits2 = (List) ((Map) createParser(XContentType.JSON.xContent(), responseBody2).map() + .get("hits")).get("hits"); - List docIds2 = hits2.stream().map(hit -> - ((String)((Map)hit).get("_id"))).collect(Collectors.toList()); + List docIds2 = hits2.stream().map(hit -> ((String) ((Map) hit).get("_id"))).collect(Collectors.toList()); - List docScores2 = hits2.stream().map(hit -> - ((Double)((Map)hit).get("_score"))).collect(Collectors.toList()); + List docScores2 = hits2.stream() + .map(hit -> ((Double) ((Map) hit).get("_score"))) + .collect(Collectors.toList()); double[] scores2 = new double[docScores2.size()]; for (int i = 0; i < docScores2.size(); i++) { @@ -627,7 +616,7 @@ public void testHammingScriptScore_Base64() throws Exception { } List correctIds2 = Arrays.asList("2", "0", "1", "3"); - double[] correctScores2 = new double[] {1.0/(1 + 4), 1.0/(1 + 7), 1.0/(1 + 8), 1.0/(1 + 29)}; + double[] correctScores2 = new double[] { 1.0 / (1 + 4), 1.0 / (1 + 7), 1.0 / (1 + 8), 1.0 / (1 + 29) }; assertEquals(correctIds2.size(), docIds2.size()); assertArrayEquals(correctIds2.toArray(), docIds2.toArray()); @@ -639,19 +628,18 @@ public void testKNNInnerProdScriptScore() throws Exception { * Create knn index and populate data */ createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = {-2.0f, -2.0f}; + Float[] f1 = { -2.0f, -2.0f }; addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - Float[] f2 = {1.0f, 1.0f}; + Float[] f2 = { 1.0f, 1.0f }; addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - Float[] f3 = {2.0f, 2.0f}; + Float[] f3 = { 2.0f, 2.0f }; addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - Float[] f4 = {2.0f, -2.0f}; + Float[] f4 = { 2.0f, -2.0f }; addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); - /** * Construct Search Request */ @@ -664,20 +652,19 @@ public void testKNNInnerProdScriptScore() throws Exception { * "space_type": "innerproduct", * } */ - float[] queryVector = {1.0f, 1.0f}; + float[] queryVector = { 1.0f, 1.0f }; params.put("field", FIELD_NAME); params.put("query_value", queryVector); params.put("space_type", SpaceType.INNER_PRODUCT.getValue()); Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); List expectedDocids = Arrays.asList("3", "2", "4", "1"); List actualDocids = new ArrayList<>(); - for(KNNResult result : results) { + for (KNNResult result : results) { actualDocids.add(result.getDocId()); } diff --git a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptScoringIT.java b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptScoringIT.java index d640eb97b..2c5a70895 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptScoringIT.java @@ -39,8 +39,7 @@ protected String createMapping(List properties) throws IOExcept Objects.requireNonNull(properties); XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("properties"); for (MappingProperty property : properties) { - XContentBuilder builder = xContentBuilder.startObject(property.getName()) - .field("type", property.getType()); + XContentBuilder builder = xContentBuilder.startObject(property.getName()).field("type", property.getType()); if (property.getDimension() != null) { builder.field("dimension", property.getDimension()); } @@ -64,54 +63,54 @@ private void buildTestIndex(Map knnDocuments) throws Exception private Map getKnnVectorTestData() { Map data = new HashMap<>(); - data.put("1", new Float[]{100.0f, 1.0f}); - data.put("2", new Float[]{99.0f, 2.0f}); - data.put("3", new Float[]{97.0f, 3.0f}); - data.put("4", new Float[]{98.0f, 4.0f}); + data.put("1", new Float[] { 100.0f, 1.0f }); + data.put("2", new Float[] { 99.0f, 2.0f }); + data.put("3", new Float[] { 97.0f, 3.0f }); + data.put("4", new Float[] { 98.0f, 4.0f }); return data; } private Map getL2TestData() { Map data = new HashMap<>(); - data.put("1", new Float[]{6.0f, 6.0f}); - data.put("2", new Float[]{2.0f, 2.0f}); - data.put("3", new Float[]{4.0f, 4.0f}); - data.put("4", new Float[]{3.0f, 3.0f}); + data.put("1", new Float[] { 6.0f, 6.0f }); + data.put("2", new Float[] { 2.0f, 2.0f }); + data.put("3", new Float[] { 4.0f, 4.0f }); + data.put("4", new Float[] { 3.0f, 3.0f }); return data; } private Map getL1TestData() { Map data = new HashMap<>(); - data.put("1", new Float[]{6.0f, 6.0f}); - data.put("2", new Float[]{4.0f, 1.0f}); - data.put("3", new Float[]{3.0f, 3.0f}); - data.put("4", new Float[]{5.0f, 5.0f}); + data.put("1", new Float[] { 6.0f, 6.0f }); + data.put("2", new Float[] { 4.0f, 1.0f }); + data.put("3", new Float[] { 3.0f, 3.0f }); + data.put("4", new Float[] { 5.0f, 5.0f }); return data; } private Map getLInfTestData() { Map data = new HashMap<>(); - data.put("1", new Float[]{6.0f, 6.0f}); - data.put("2", new Float[]{4.0f, 1.0f}); - data.put("3", new Float[]{3.0f, 3.0f}); - data.put("4", new Float[]{5.0f, 5.0f}); + data.put("1", new Float[] { 6.0f, 6.0f }); + data.put("2", new Float[] { 4.0f, 1.0f }); + data.put("3", new Float[] { 3.0f, 3.0f }); + data.put("4", new Float[] { 5.0f, 5.0f }); return data; } private Map getInnerProdTestData() { Map data = new HashMap<>(); - data.put("1", new Float[]{-2.0f, -2.0f}); - data.put("2", new Float[]{1.0f, 1.0f}); - data.put("3", new Float[]{2.0f, 2.0f}); - data.put("4", new Float[]{2.0f, -2.0f}); + data.put("1", new Float[] { -2.0f, -2.0f }); + data.put("2", new Float[] { 1.0f, 1.0f }); + data.put("3", new Float[] { 2.0f, 2.0f }); + data.put("4", new Float[] { 2.0f, -2.0f }); return data; } private Map getCosineTestData() { Map data = new HashMap<>(); - data.put("0", new Float[]{1.0f, -1.0f}); - data.put("2", new Float[]{1.0f, 1.0f}); - data.put("1", new Float[]{1.0f, 0.0f}); + data.put("0", new Float[] { 1.0f, -1.0f }); + data.put("2", new Float[] { 1.0f, 1.0f }); + data.put("1", new Float[] { 1.0f, 0.0f }); return data; } @@ -133,12 +132,10 @@ public void testL2ScriptScoreFails() throws Exception { deleteKNNIndex(INDEX_NAME); } - private Request buildPainlessScriptRequest( - String source, int size, Map documents) throws Exception { + private Request buildPainlessScriptRequest(String source, int size, Map documents) throws Exception { buildTestIndex(documents); QueryBuilder qb = new MatchAllQueryBuilder(); - return constructScriptQueryRequest( - INDEX_NAME, qb, Collections.emptyMap(), Script.DEFAULT_SCRIPT_LANG, source, size); + return constructScriptQueryRequest(INDEX_NAME, qb, Collections.emptyMap(), Script.DEFAULT_SCRIPT_LANG, source, size); } public void testL2ScriptScore() throws Exception { @@ -147,14 +144,12 @@ public void testL2ScriptScore() throws Exception { Request request = buildPainlessScriptRequest(source, 3, getL2TestData()); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(3, results.size()); - - String[] expectedDocIDs = {"2", "4", "3", "1"}; + String[] expectedDocIDs = { "2", "4", "3", "1" }; for (int i = 0; i < results.size(); i++) { assertEquals(expectedDocIDs[i], results.get(i).getDocId()); } @@ -168,14 +163,12 @@ public void testGetValueReturnsDocValues() throws Exception { Request request = buildPainlessScriptRequest(source, testData.size(), testData); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - assertEquals(testData.size(),results.size()); - + assertEquals(testData.size(), results.size()); - String[] expectedDocIDs = {"1", "2", "4", "3"}; + String[] expectedDocIDs = { "1", "2", "4", "3" }; for (int i = 0; i < results.size(); i++) { assertEquals(expectedDocIDs[i], results.get(i).getDocId()); } @@ -198,23 +191,19 @@ public void testGetValueScriptFailsWithOutOfBoundException() throws Exception { deleteKNNIndex(INDEX_NAME); } - public void testGetValueScriptScoreWithNumericField() throws Exception { - String source = String.format( - "doc['%s'].size() == 0 ? 0 : doc['%s'].value[0]", FIELD_NAME, FIELD_NAME); + String source = String.format("doc['%s'].size() == 0 ? 0 : doc['%s'].value[0]", FIELD_NAME, FIELD_NAME); Map testData = getKnnVectorTestData(); Request request = buildPainlessScriptRequest(source, testData.size(), testData); addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(testData.size(), results.size()); - - String[] expectedDocIDs = {"1", "2", "4", "3"}; + String[] expectedDocIDs = { "1", "2", "4", "3" }; for (int i = 0; i < results.size(); i++) { assertEquals(expectedDocIDs[i], results.get(i).getDocId()); } @@ -223,19 +212,16 @@ public void testGetValueScriptScoreWithNumericField() throws Exception { public void testL2ScriptScoreWithNumericField() throws Exception { - String source = String.format( - "doc['%s'].size() == 0 ? 0 : 1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME, FIELD_NAME); + String source = String.format("doc['%s'].size() == 0 ? 0 : 1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME, FIELD_NAME); Request request = buildPainlessScriptRequest(source, 3, getL2TestData()); addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(3, results.size()); - - String[] expectedDocIDs = {"2", "4", "3", "1"}; + String[] expectedDocIDs = { "2", "4", "3", "1" }; for (int i = 0; i < results.size(); i++) { assertEquals(expectedDocIDs[i], results.get(i).getDocId()); } @@ -254,13 +240,12 @@ public void testCosineSimilarityScriptScore() throws Exception { String source = String.format("1 + cosineSimilarity([2.0f, -2.0f], doc['%s'])", FIELD_NAME); Request request = buildPainlessScriptRequest(source, 3, getCosineTestData()); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(3, results.size()); - String[] expectedDocIDs = {"0", "1", "2"}; + String[] expectedDocIDs = { "0", "1", "2" }; for (int i = 0; i < results.size(); i++) { assertEquals(expectedDocIDs[i], results.get(i).getDocId()); } @@ -268,18 +253,16 @@ public void testCosineSimilarityScriptScore() throws Exception { } public void testCosineSimilarityScriptScoreWithNumericField() throws Exception { - String source = String.format( - "doc['%s'].size() == 0 ? 0 : 1 + cosineSimilarity([2.0f, -2.0f], doc['%s'])", FIELD_NAME, FIELD_NAME); + String source = String.format("doc['%s'].size() == 0 ? 0 : 1 + cosineSimilarity([2.0f, -2.0f], doc['%s'])", FIELD_NAME, FIELD_NAME); Request request = buildPainlessScriptRequest(source, 3, getCosineTestData()); addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(3, results.size()); - String[] expectedDocIDs = {"0", "1", "2"}; + String[] expectedDocIDs = { "0", "1", "2" }; for (int i = 0; i < results.size(); i++) { assertEquals(expectedDocIDs[i], results.get(i).getDocId()); } @@ -299,13 +282,12 @@ public void testCosineSimilarityNormalizedScriptScore() throws Exception { String source = String.format("1 + cosineSimilarity([2.0f, -2.0f], doc['%s'], 3.0f)", FIELD_NAME); Request request = buildPainlessScriptRequest(source, 3, getCosineTestData()); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(3, results.size()); - String[] expectedDocIDs = {"0", "1", "2"}; + String[] expectedDocIDs = { "0", "1", "2" }; for (int i = 0; i < results.size(); i++) { assertEquals(expectedDocIDs[i], results.get(i).getDocId()); } @@ -314,18 +296,19 @@ public void testCosineSimilarityNormalizedScriptScore() throws Exception { public void testCosineSimilarityNormalizedScriptScoreWithNumericField() throws Exception { String source = String.format( - "doc['%s'].size() == 0 ? 0 : 1 + cosineSimilarity([2.0f, -2.0f], doc['%s'], 3.0f)", - FIELD_NAME, FIELD_NAME); + "doc['%s'].size() == 0 ? 0 : 1 + cosineSimilarity([2.0f, -2.0f], doc['%s'], 3.0f)", + FIELD_NAME, + FIELD_NAME + ); Request request = buildPainlessScriptRequest(source, 3, getCosineTestData()); addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(3, results.size()); - String[] expectedDocIDs = {"0", "1", "2"}; + String[] expectedDocIDs = { "0", "1", "2" }; for (int i = 0; i < results.size(); i++) { assertEquals(expectedDocIDs[i], results.get(i).getDocId()); } @@ -340,20 +323,19 @@ public void testL1ScriptScoreFails() throws Exception { expectThrows(ResponseException.class, () -> client().performRequest(request)); deleteKNNIndex(INDEX_NAME); } + public void testL1ScriptScore() throws Exception { String source = String.format("1/(1 + l1Norm([1.0f, 1.0f], doc['%s']))", FIELD_NAME); Request request = buildPainlessScriptRequest(source, 3, getL1TestData()); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(3, results.size()); - - String[] expectedDocIDs = {"2", "3", "4", "1"}; + String[] expectedDocIDs = { "2", "3", "4", "1" }; for (int i = 0; i < results.size(); i++) { assertEquals(expectedDocIDs[i], results.get(i).getDocId()); } @@ -362,47 +344,43 @@ public void testL1ScriptScore() throws Exception { public void testL1ScriptScoreWithNumericField() throws Exception { - String source = String.format( - "doc['%s'].size() == 0 ? 0 : 1/(1 + l1Norm([1.0f, 1.0f], doc['%s']))", FIELD_NAME, FIELD_NAME); + String source = String.format("doc['%s'].size() == 0 ? 0 : 1/(1 + l1Norm([1.0f, 1.0f], doc['%s']))", FIELD_NAME, FIELD_NAME); Request request = buildPainlessScriptRequest(source, 3, getL1TestData()); addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(3, results.size()); - - String[] expectedDocIDs = {"2", "3", "4", "1"}; + String[] expectedDocIDs = { "2", "3", "4", "1" }; for (int i = 0; i < results.size(); i++) { assertEquals(expectedDocIDs[i], results.get(i).getDocId()); } deleteKNNIndex(INDEX_NAME); } - // L-inf tests - public void testLInfScriptScoreFails() throws Exception { + // L-inf tests + public void testLInfScriptScoreFails() throws Exception { String source = String.format("1/(1 + lInfNorm([1.0f, 1.0f], doc['%s']))", FIELD_NAME); Request request = buildPainlessScriptRequest(source, 3, getLInfTestData()); addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000); expectThrows(ResponseException.class, () -> client().performRequest(request)); deleteKNNIndex(INDEX_NAME); } + public void testLInfScriptScore() throws Exception { String source = String.format("1/(1 + lInfNorm([1.0f, 1.0f], doc['%s']))", FIELD_NAME); Request request = buildPainlessScriptRequest(source, 3, getLInfTestData()); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(3, results.size()); - - String[] expectedDocIDs = {"3", "2", "4", "1"}; + String[] expectedDocIDs = { "3", "2", "4", "1" }; for (int i = 0; i < results.size(); i++) { assertEquals(expectedDocIDs[i], results.get(i).getDocId()); } @@ -411,19 +389,16 @@ public void testLInfScriptScore() throws Exception { public void testLInfScriptScoreWithNumericField() throws Exception { - String source = String.format( - "doc['%s'].size() == 0 ? 0 : 1/(1 + lInfNorm([1.0f, 1.0f], doc['%s']))", FIELD_NAME, FIELD_NAME); + String source = String.format("doc['%s'].size() == 0 ? 0 : 1/(1 + lInfNorm([1.0f, 1.0f], doc['%s']))", FIELD_NAME, FIELD_NAME); Request request = buildPainlessScriptRequest(source, 3, getLInfTestData()); addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(3, results.size()); - - String[] expectedDocIDs = {"3", "2", "4", "1"}; + String[] expectedDocIDs = { "3", "2", "4", "1" }; for (int i = 0; i < results.size(); i++) { assertEquals(expectedDocIDs[i], results.get(i).getDocId()); } @@ -431,8 +406,7 @@ public void testLInfScriptScoreWithNumericField() throws Exception { } public void testInnerProdScriptScoreFails() throws Exception { - String source = String.format( - "float x = innerProduct([1.0f, 1.0f], doc['%s']); return x >= 0? 2-1/(x+1):1/(1-x);", FIELD_NAME); + String source = String.format("float x = innerProduct([1.0f, 1.0f], doc['%s']); return x >= 0? 2-1/(x+1):1/(1-x);", FIELD_NAME); Request request = buildPainlessScriptRequest(source, 3, getInnerProdTestData()); addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000); expectThrows(ResponseException.class, () -> client().performRequest(request)); @@ -441,19 +415,16 @@ public void testInnerProdScriptScoreFails() throws Exception { public void testInnerProdScriptScore() throws Exception { - String source = String.format( - "float x = innerProduct([1.0f, 1.0f], doc['%s']); return x >= 0? 2-1/(x+1):1/(1-x);", FIELD_NAME); + String source = String.format("float x = innerProduct([1.0f, 1.0f], doc['%s']); return x >= 0? 2-1/(x+1):1/(1-x);", FIELD_NAME); Request request = buildPainlessScriptRequest(source, 3, getInnerProdTestData()); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(3, results.size()); - - String[] expectedDocIDs = {"3", "2", "4", "1"}; + String[] expectedDocIDs = { "3", "2", "4", "1" }; for (int i = 0; i < results.size(); i++) { assertEquals(expectedDocIDs[i], results.get(i).getDocId()); } @@ -463,29 +434,28 @@ public void testInnerProdScriptScore() throws Exception { public void testInnerProdScriptScoreWithNumericField() throws Exception { String source = String.format( - "if (doc['%s'].size() == 0) " + - "{ return 0; } " + - "else " + - "{ float x = innerProduct([1.0f, 1.0f], doc['%s']); return x >= 0? 2-1/(x+1):1/(1-x); }", - FIELD_NAME, FIELD_NAME); + "if (doc['%s'].size() == 0) " + + "{ return 0; } " + + "else " + + "{ float x = innerProduct([1.0f, 1.0f], doc['%s']); return x >= 0? 2-1/(x+1):1/(1-x); }", + FIELD_NAME, + FIELD_NAME + ); Request request = buildPainlessScriptRequest(source, 3, getInnerProdTestData()); addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, - RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(3, results.size()); - - String[] expectedDocIDs = {"3", "2", "4", "1"}; + String[] expectedDocIDs = { "3", "2", "4", "1" }; for (int i = 0; i < results.size(); i++) { assertEquals(expectedDocIDs[i], results.get(i).getDocId()); } deleteKNNIndex(INDEX_NAME); } - class MappingProperty { private String name; @@ -515,4 +485,3 @@ String getType() { } } } - diff --git a/src/test/java/org/opensearch/knn/plugin/stats/KNNCounterTests.java b/src/test/java/org/opensearch/knn/plugin/stats/KNNCounterTests.java index ccdf97d25..ab7fba764 100644 --- a/src/test/java/org/opensearch/knn/plugin/stats/KNNCounterTests.java +++ b/src/test/java/org/opensearch/knn/plugin/stats/KNNCounterTests.java @@ -17,7 +17,7 @@ public void testCount() { for (long i = 0; i < 100; i++) { KNNCounter.GRAPH_QUERY_ERRORS.increment(); - assertEquals((Long) (i+1), KNNCounter.GRAPH_QUERY_ERRORS.getCount()); + assertEquals((Long) (i + 1), KNNCounter.GRAPH_QUERY_ERRORS.getCount()); } } -} \ No newline at end of file +} diff --git a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/EventOccurredWithinThresholdSupplierTests.java b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/EventOccurredWithinThresholdSupplierTests.java index af4f995c0..fe013b0da 100644 --- a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/EventOccurredWithinThresholdSupplierTests.java +++ b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/EventOccurredWithinThresholdSupplierTests.java @@ -21,18 +21,14 @@ public class EventOccurredWithinThresholdSupplierTests extends KNNTestCase { public void testOutsideThreshold() throws InterruptedException { Instant now = Instant.now(); long threshold = 2; - EventOccurredWithinThresholdSupplier supplier = new EventOccurredWithinThresholdSupplier( - ()->now, threshold, ChronoUnit.SECONDS - ); + EventOccurredWithinThresholdSupplier supplier = new EventOccurredWithinThresholdSupplier(() -> now, threshold, ChronoUnit.SECONDS); TimeUnit.SECONDS.sleep(threshold + 1); assertFalse(supplier.get()); } public void testEventNeverHappened() throws InterruptedException { long threshold = 2; - EventOccurredWithinThresholdSupplier supplier = new EventOccurredWithinThresholdSupplier( - () -> null, threshold, ChronoUnit.SECONDS - ); + EventOccurredWithinThresholdSupplier supplier = new EventOccurredWithinThresholdSupplier(() -> null, threshold, ChronoUnit.SECONDS); TimeUnit.SECONDS.sleep(threshold + 1); assertFalse(supplier.get()); } @@ -40,9 +36,7 @@ public void testEventNeverHappened() throws InterruptedException { public void testInsideThreshold() throws InterruptedException { Instant now = Instant.now(); long threshold = 2; - EventOccurredWithinThresholdSupplier supplier = new EventOccurredWithinThresholdSupplier( - ()->now, threshold, ChronoUnit.MINUTES - ); + EventOccurredWithinThresholdSupplier supplier = new EventOccurredWithinThresholdSupplier(() -> now, threshold, ChronoUnit.MINUTES); TimeUnit.SECONDS.sleep(threshold + 1); assertTrue(supplier.get()); } diff --git a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/KNNCounterSupplierTests.java b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/KNNCounterSupplierTests.java index 1e0fa5c2b..1803b7a3d 100644 --- a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/KNNCounterSupplierTests.java +++ b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/KNNCounterSupplierTests.java @@ -15,4 +15,4 @@ public void testNormal() { KNNCounter.GRAPH_QUERY_REQUESTS.increment(); assertEquals((Long) 1L, knnCounterSupplier.get()); } -} \ No newline at end of file +} diff --git a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java index 1d942ee44..759f4dd5b 100644 --- a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java +++ b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java @@ -33,6 +33,7 @@ public void testEngineInitialized() { private class TestLibrary implements KNNLibrary { private Boolean initialized; + TestLibrary() { this.initialized = false; } diff --git a/src/test/java/org/opensearch/knn/plugin/transport/DeleteModelResponseTests.java b/src/test/java/org/opensearch/knn/plugin/transport/DeleteModelResponseTests.java index 3b984008a..d25c0929a 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/DeleteModelResponseTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/DeleteModelResponseTests.java @@ -17,11 +17,6 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.util.KNNEngine; -import org.opensearch.knn.indices.Model; -import org.opensearch.knn.indices.ModelMetadata; -import org.opensearch.knn.indices.ModelState; import java.io.IOException; @@ -38,7 +33,6 @@ public void testStreams() throws IOException { assertEquals(deleteModelResponse.getErrorMessage(), deleteModelResponseCopy.getErrorMessage()); } - public void testXContentWithError() throws IOException { String modelId = "test-model"; DeleteModelResponse deleteModelResponse = new DeleteModelResponse(modelId, "not_found", "model id not found"); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java index dd797ccd5..04d3b419e 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java @@ -28,8 +28,7 @@ public class GetModelResponseTests extends KNNTestCase { private ModelMetadata getModelMetadata(ModelState state) { - return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, state, - "2021-03-27 10:15:30 AM +05:30", "test model", ""); + return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, state, "2021-03-27 10:15:30 AM +05:30", "test model", ""); } public void testStreams() throws IOException { @@ -43,13 +42,13 @@ public void testStreams() throws IOException { assertEquals(getModelResponse.getModel(), getModelResponseCopy.getModel()); } - public void testXContent() throws IOException { String modelId = "test-model"; byte[] testModelBlob = "hello".getBytes(); - Model model = new Model(getModelMetadata(ModelState.CREATED), testModelBlob,modelId); + Model model = new Model(getModelMetadata(ModelState.CREATED), testModelBlob, modelId); GetModelResponse getModelResponse = new GetModelResponse(model); - String expectedResponseString = "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\"}"; + String expectedResponseString = + "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\"}"; XContentBuilder xContentBuilder = XContentFactory.contentBuilder(XContentType.JSON); getModelResponse.toXContent(xContentBuilder, null); assertEquals(expectedResponseString, Strings.toString(xContentBuilder)); @@ -59,7 +58,8 @@ public void testXContentWithNoModelBlob() throws IOException { String modelId = "test-model"; Model model = new Model(getModelMetadata(ModelState.FAILED), null, modelId); GetModelResponse getModelResponse = new GetModelResponse(model); - String expectedResponseString = "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\"}"; + String expectedResponseString = + "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\"}"; XContentBuilder xContentBuilder = XContentFactory.contentBuilder(XContentType.JSON); getModelResponse.toXContent(xContentBuilder, null); assertEquals(expectedResponseString, Strings.toString(xContentBuilder)); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/KNNWarmupTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/KNNWarmupTransportActionTests.java index 6f3e51409..da4c8d834 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/KNNWarmupTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/KNNWarmupTransportActionTests.java @@ -44,7 +44,7 @@ public void testShardOperation() throws IOException, ExecutionException, Interru knnWarmupTransportAction.shardOperation(knnWarmupRequest, shardRouting); assertEquals(0, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().size()); - addKnnDoc(testIndexName, "1", testFieldName, new Long[] {0L, 1L}); + addKnnDoc(testIndexName, "1", testFieldName, new Long[] { 0L, 1L }); knnWarmupTransportAction.shardOperation(knnWarmupRequest, shardRouting); assertEquals(1, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().size()); @@ -57,18 +57,27 @@ public void testShards() throws InterruptedException, ExecutionException, IOExce createKNNIndex(testIndexName); createKnnIndexMapping(testIndexName, testFieldName, dimensions); - addKnnDoc(testIndexName, "1", testFieldName, new Long[] {0L, 1L}); + addKnnDoc(testIndexName, "1", testFieldName, new Long[] { 0L, 1L }); - ShardsIterator shardsIterator = knnWarmupTransportAction.shards(clusterService.state(), knnWarmupRequest, - new String[] {testIndexName}); + ShardsIterator shardsIterator = knnWarmupTransportAction.shards( + clusterService.state(), + knnWarmupRequest, + new String[] { testIndexName } + ); assertEquals(1, shardsIterator.size()); } public void testCheckGlobalBlock() { ClusterService clusterService = mock(ClusterService.class); - ClusterBlock metaReadClusterBlock = new ClusterBlock(randomInt(), "test-meta-data-block", - false, false, false, RestStatus.FORBIDDEN, - EnumSet.of(ClusterBlockLevel.METADATA_READ)); + ClusterBlock metaReadClusterBlock = new ClusterBlock( + randomInt(), + "test-meta-data-block", + false, + false, + false, + RestStatus.FORBIDDEN, + EnumSet.of(ClusterBlockLevel.METADATA_READ) + ); ClusterBlocks clusterBlocks = ClusterBlocks.builder().addGlobalBlock(metaReadClusterBlock).build(); ClusterState state = ClusterState.builder(ClusterName.DEFAULT).blocks(clusterBlocks).build(); when(clusterService.state()).thenReturn(state); @@ -80,16 +89,21 @@ public void testCheckGlobalBlock() { public void testCheckRequestBlock() { ClusterService clusterService = mock(ClusterService.class); - ClusterBlock metaReadClusterBlock = new ClusterBlock(randomInt(), "test-meta-data-block", - false, false, false, RestStatus.FORBIDDEN, - EnumSet.of(ClusterBlockLevel.METADATA_READ)); + ClusterBlock metaReadClusterBlock = new ClusterBlock( + randomInt(), + "test-meta-data-block", + false, + false, + false, + RestStatus.FORBIDDEN, + EnumSet.of(ClusterBlockLevel.METADATA_READ) + ); ClusterBlocks clusterBlocks = ClusterBlocks.builder().addGlobalBlock(metaReadClusterBlock).build(); ClusterState state = ClusterState.builder(ClusterName.DEFAULT).blocks(clusterBlocks).build(); when(clusterService.state()).thenReturn(state); KNNWarmupTransportAction knnWarmupTransportAction = node().injector().getInstance(KNNWarmupTransportAction.class); KNNWarmupRequest knnWarmupRequest = new KNNWarmupRequest(testIndexName); - assertNotNull(knnWarmupTransportAction.checkRequestBlock(clusterService.state(), knnWarmupRequest, - new String[] {testIndexName})); + assertNotNull(knnWarmupTransportAction.checkRequestBlock(clusterService.state(), knnWarmupRequest, new String[] { testIndexName })); } } diff --git a/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java index 59c6caaf9..ae89d83e1 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java @@ -37,8 +37,7 @@ public class RemoveModelFromCacheTransportActionTests extends KNNSingleNodeTestC public void testNodeOperation_modelNotInCache() { ClusterService clusterService = mock(ClusterService.class); Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), "10%").build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); @@ -50,8 +49,7 @@ public void testNodeOperation_modelNotInCache() { assertEquals(0, modelCache.getTotalWeightInKB()); // Remove the model from the cache - RemoveModelFromCacheTransportAction action = node().injector() - .getInstance(RemoveModelFromCacheTransportAction.class); + RemoveModelFromCacheTransportAction action = node().injector().getInstance(RemoveModelFromCacheTransportAction.class); RemoveModelFromCacheNodeRequest request = new RemoveModelFromCacheNodeRequest("invalid-model"); action.nodeOperation(request); @@ -63,17 +61,16 @@ public void testNodeOperation_modelNotInCache() { public void testNodeOperation_modelInCache() throws ExecutionException, InterruptedException { ClusterService clusterService = mock(ClusterService.class); Settings settings = Settings.builder().put(MODEL_CACHE_SIZE_LIMIT_SETTING.getKey(), "10%").build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, - ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); + ClusterSettings clusterSettings = new ClusterSettings(settings, ImmutableSet.of(MODEL_CACHE_SIZE_LIMIT_SETTING)); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); ModelDao modelDao = mock(ModelDao.class); String modelId = "test-model-id"; Model model = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 16, ModelState.CREATED, - "timestamp", "description", ""), - new byte[128], modelId + new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 16, ModelState.CREATED, "timestamp", "description", ""), + new byte[128], + modelId ); when(modelDao.get(modelId)).thenReturn(model); @@ -85,8 +82,7 @@ public void testNodeOperation_modelInCache() throws ExecutionException, Interrup modelCache.get(modelId); // Remove the model from the cache - RemoveModelFromCacheTransportAction action = node().injector() - .getInstance(RemoveModelFromCacheTransportAction.class); + RemoveModelFromCacheTransportAction action = node().injector().getInstance(RemoveModelFromCacheTransportAction.class); RemoveModelFromCacheNodeRequest request = new RemoveModelFromCacheNodeRequest(modelId); action.nodeOperation(request); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoNodeResponseTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoNodeResponseTests.java index 1818e234f..8085b0170 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoNodeResponseTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoNodeResponseTests.java @@ -36,16 +36,13 @@ public void testStreams() throws IOException { int trainingJobCount = 13; InetAddress inetAddress = InetAddresses.fromInteger(randomInt()); - DiscoveryNode discoveryNode = new DiscoveryNode("id", new TransportAddress(inetAddress, 9200), - Version.CURRENT); + DiscoveryNode discoveryNode = new DiscoveryNode("id", new TransportAddress(inetAddress, 9200), Version.CURRENT); - TrainingJobRouteDecisionInfoNodeResponse original = - new TrainingJobRouteDecisionInfoNodeResponse(discoveryNode, trainingJobCount); + TrainingJobRouteDecisionInfoNodeResponse original = new TrainingJobRouteDecisionInfoNodeResponse(discoveryNode, trainingJobCount); original.writeTo(streamOutput); - TrainingJobRouteDecisionInfoNodeResponse copy = - new TrainingJobRouteDecisionInfoNodeResponse(streamOutput.bytes().streamInput()); + TrainingJobRouteDecisionInfoNodeResponse copy = new TrainingJobRouteDecisionInfoNodeResponse(streamOutput.bytes().streamInput()); assertEquals(original.getTrainingJobCount(), copy.getTrainingJobCount()); } @@ -55,8 +52,7 @@ public void testGetTrainingJobCount() { DiscoveryNode discoveryNode = mock(DiscoveryNode.class); - TrainingJobRouteDecisionInfoNodeResponse response = - new TrainingJobRouteDecisionInfoNodeResponse(discoveryNode, trainingJobCount); + TrainingJobRouteDecisionInfoNodeResponse response = new TrainingJobRouteDecisionInfoNodeResponse(discoveryNode, trainingJobCount); assertEquals(trainingJobCount, response.getTrainingJobCount().intValue()); } @@ -66,17 +62,17 @@ public void testToXContent() throws IOException { // We expect this: // { - // "training_job_count": 13 + // "training_job_count": 13 // } - XContentBuilder expectedXContentBuilder = XContentFactory.jsonBuilder().startObject() - .field(TRAINING_JOB_COUNT_FIELD_NAME, trainingJobCount) - .endObject(); + XContentBuilder expectedXContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TRAINING_JOB_COUNT_FIELD_NAME, trainingJobCount) + .endObject(); Map expected = xContentBuilderToMap(expectedXContentBuilder); DiscoveryNode discoveryNode = mock(DiscoveryNode.class); - TrainingJobRouteDecisionInfoNodeResponse response = - new TrainingJobRouteDecisionInfoNodeResponse(discoveryNode, trainingJobCount); + TrainingJobRouteDecisionInfoNodeResponse response = new TrainingJobRouteDecisionInfoNodeResponse(discoveryNode, trainingJobCount); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); builder = response.toXContent(builder, ToXContent.EMPTY_PARAMS).endObject(); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoResponseTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoResponseTests.java index 93d56a3cf..87d0b7ee5 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoResponseTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoResponseTests.java @@ -43,54 +43,54 @@ public void testStreams() throws IOException { // Initialize nodes and data InetAddress inetAddress1 = InetAddresses.fromInteger(randomInt()); String node1Id = "node-1"; - DiscoveryNode discoveryNode1 = new DiscoveryNode(node1Id, new TransportAddress(inetAddress1, 9200), - Version.CURRENT); + DiscoveryNode discoveryNode1 = new DiscoveryNode(node1Id, new TransportAddress(inetAddress1, 9200), Version.CURRENT); Integer trainingJobCount1 = 1; - TrainingJobRouteDecisionInfoNodeResponse nodeResponse1 = - new TrainingJobRouteDecisionInfoNodeResponse(discoveryNode1, trainingJobCount1); + TrainingJobRouteDecisionInfoNodeResponse nodeResponse1 = new TrainingJobRouteDecisionInfoNodeResponse( + discoveryNode1, + trainingJobCount1 + ); InetAddress inetAddress2 = InetAddresses.fromInteger(randomInt()); String node2Id = "node-2"; - DiscoveryNode discoveryNode2 = new DiscoveryNode(node2Id, new TransportAddress(inetAddress2, 9200), - Version.CURRENT); + DiscoveryNode discoveryNode2 = new DiscoveryNode(node2Id, new TransportAddress(inetAddress2, 9200), Version.CURRENT); Integer trainingJobCount2 = 2; - TrainingJobRouteDecisionInfoNodeResponse nodeResponse2 = - new TrainingJobRouteDecisionInfoNodeResponse(discoveryNode2, trainingJobCount2); + TrainingJobRouteDecisionInfoNodeResponse nodeResponse2 = new TrainingJobRouteDecisionInfoNodeResponse( + discoveryNode2, + trainingJobCount2 + ); InetAddress inetAddress3 = InetAddresses.fromInteger(randomInt()); String node3Id = "node-3"; - DiscoveryNode discoveryNode3 = new DiscoveryNode(node3Id, new TransportAddress(inetAddress3, 9200), - Version.CURRENT); + DiscoveryNode discoveryNode3 = new DiscoveryNode(node3Id, new TransportAddress(inetAddress3, 9200), Version.CURRENT); Integer trainingJobCount3 = 3; - TrainingJobRouteDecisionInfoNodeResponse nodeResponse3 = - new TrainingJobRouteDecisionInfoNodeResponse(discoveryNode3, trainingJobCount3); - - List nodeResponses = ImmutableList.of( - nodeResponse1, - nodeResponse2, - nodeResponse3 + TrainingJobRouteDecisionInfoNodeResponse nodeResponse3 = new TrainingJobRouteDecisionInfoNodeResponse( + discoveryNode3, + trainingJobCount3 ); + List nodeResponses = ImmutableList.of(nodeResponse1, nodeResponse2, nodeResponse3); + List failedNodeExceptions = Collections.emptyList(); // Setup output BytesStreamOutput streamOutput = new BytesStreamOutput(); - TrainingJobRouteDecisionInfoResponse original = - new TrainingJobRouteDecisionInfoResponse(ClusterName.DEFAULT, nodeResponses, failedNodeExceptions); + TrainingJobRouteDecisionInfoResponse original = new TrainingJobRouteDecisionInfoResponse( + ClusterName.DEFAULT, + nodeResponses, + failedNodeExceptions + ); original.writeTo(streamOutput); // Read back streamed out into streamed in - TrainingJobRouteDecisionInfoResponse copy = - new TrainingJobRouteDecisionInfoResponse(streamOutput.bytes().streamInput()); + TrainingJobRouteDecisionInfoResponse copy = new TrainingJobRouteDecisionInfoResponse(streamOutput.bytes().streamInput()); Map originalNodeResponseMap = original.getNodesMap(); Map copyNodeResponseMap = copy.getNodesMap(); assertEquals(originalNodeResponseMap.keySet(), copyNodeResponseMap.keySet()); assertTrue(originalNodeResponseMap.containsKey(node2Id)); - assertEquals(originalNodeResponseMap.get(node2Id).getTrainingJobCount(), - copyNodeResponseMap.get(node2Id).getTrainingJobCount()); + assertEquals(originalNodeResponseMap.get(node2Id).getTrainingJobCount(), copyNodeResponseMap.get(node2Id).getTrainingJobCount()); } public void testToXContent() throws IOException { @@ -100,63 +100,69 @@ public void testToXContent() throws IOException { DiscoveryNode discoveryNode1 = mock(DiscoveryNode.class); when(discoveryNode1.getId()).thenReturn(id1); Integer trainingJobCount1 = 1; - TrainingJobRouteDecisionInfoNodeResponse nodeResponse1 = - new TrainingJobRouteDecisionInfoNodeResponse(discoveryNode1, trainingJobCount1); + TrainingJobRouteDecisionInfoNodeResponse nodeResponse1 = new TrainingJobRouteDecisionInfoNodeResponse( + discoveryNode1, + trainingJobCount1 + ); String id2 = "id_2"; DiscoveryNode discoveryNode2 = mock(DiscoveryNode.class); when(discoveryNode2.getId()).thenReturn(id2); Integer trainingJobCount2 = 2; - TrainingJobRouteDecisionInfoNodeResponse nodeResponse2 = - new TrainingJobRouteDecisionInfoNodeResponse(discoveryNode2, trainingJobCount2); + TrainingJobRouteDecisionInfoNodeResponse nodeResponse2 = new TrainingJobRouteDecisionInfoNodeResponse( + discoveryNode2, + trainingJobCount2 + ); String id3 = "id_3"; DiscoveryNode discoveryNode3 = mock(DiscoveryNode.class); when(discoveryNode3.getId()).thenReturn(id3); Integer trainingJobCount3 = 3; - TrainingJobRouteDecisionInfoNodeResponse nodeResponse3 = - new TrainingJobRouteDecisionInfoNodeResponse(discoveryNode3, trainingJobCount3); + TrainingJobRouteDecisionInfoNodeResponse nodeResponse3 = new TrainingJobRouteDecisionInfoNodeResponse( + discoveryNode3, + trainingJobCount3 + ); // We expect this: // { - // "nodes": { - // "id_1": { - // "training_job_count": 1 - // }, - // "id_2": { - // "training_job_count": 2 - // }, - // "id_3": { - // "training_job_count": 3 - // }, - // } + // "nodes": { + // "id_1": { + // "training_job_count": 1 + // }, + // "id_2": { + // "training_job_count": 2 + // }, + // "id_3": { + // "training_job_count": 3 + // }, + // } // } - XContentBuilder expectedXContentBuilder = XContentFactory.jsonBuilder().startObject() - .startObject(NODES_KEY) - .startObject(id1) - .field(TRAINING_JOB_COUNT_FIELD_NAME, trainingJobCount1) - .endObject() - .startObject(id2) - .field(TRAINING_JOB_COUNT_FIELD_NAME, trainingJobCount2) - .endObject() - .startObject(id3) - .field(TRAINING_JOB_COUNT_FIELD_NAME, trainingJobCount3) - .endObject() - .endObject() - .endObject(); + XContentBuilder expectedXContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(NODES_KEY) + .startObject(id1) + .field(TRAINING_JOB_COUNT_FIELD_NAME, trainingJobCount1) + .endObject() + .startObject(id2) + .field(TRAINING_JOB_COUNT_FIELD_NAME, trainingJobCount2) + .endObject() + .startObject(id3) + .field(TRAINING_JOB_COUNT_FIELD_NAME, trainingJobCount3) + .endObject() + .endObject() + .endObject(); Map expected = xContentBuilderToMap(expectedXContentBuilder); // Configure response - List nodeResponses = ImmutableList.of( - nodeResponse1, - nodeResponse2, - nodeResponse3 - ); + List nodeResponses = ImmutableList.of(nodeResponse1, nodeResponse2, nodeResponse3); List failedNodeExceptions = Collections.emptyList(); - TrainingJobRouteDecisionInfoResponse response = - new TrainingJobRouteDecisionInfoResponse(ClusterName.DEFAULT, nodeResponses, failedNodeExceptions); + TrainingJobRouteDecisionInfoResponse response = new TrainingJobRouteDecisionInfoResponse( + ClusterName.DEFAULT, + nodeResponses, + failedNodeExceptions + ); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); builder = response.toXContent(builder, ToXContent.EMPTY_PARAMS).endObject(); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportActionTests.java index 625c7e913..be58f2220 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouteDecisionInfoTransportActionTests.java @@ -30,7 +30,6 @@ import java.util.concurrent.TimeUnit; import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -55,10 +54,9 @@ public void teardown() { public void testNodeOperation() throws IOException, InterruptedException { // Ensure initial value of train job count is 0 TrainingJobRouteDecisionInfoTransportAction action = node().injector() - .getInstance(TrainingJobRouteDecisionInfoTransportAction.class); + .getInstance(TrainingJobRouteDecisionInfoTransportAction.class); - TrainingJobRouteDecisionInfoNodeRequest request = - new TrainingJobRouteDecisionInfoNodeRequest(); + TrainingJobRouteDecisionInfoNodeRequest request = new TrainingJobRouteDecisionInfoNodeRequest(); TrainingJobRouteDecisionInfoNodeResponse response1 = action.nodeOperation(request); assertEquals(0, response1.getTrainingJobCount().intValue()); @@ -78,24 +76,17 @@ public void testNodeOperation() throws IOException, InterruptedException { TrainingJobRouteDecisionInfoNodeResponse response2 = action.nodeOperation(request); assertEquals(1, response2.getTrainingJobCount().intValue()); - IndexResponse indexResponse = new IndexResponse( - new ShardId(MODEL_INDEX_NAME, "uuid", 0), - "any-type", - modelId, - 0, - 0, - 0, - true - ); - ((ActionListener)invocationOnMock.getArguments()[1]).onResponse(indexResponse); + IndexResponse indexResponse = new IndexResponse(new ShardId(MODEL_INDEX_NAME, "uuid", 0), "any-type", modelId, 0, 0, 0, true); + ((ActionListener) invocationOnMock.getArguments()[1]).onResponse(indexResponse); return null; }).when(modelDao).put(any(Model.class), any(ActionListener.class)); // Set up the rest of the training logic final CountDownLatch inProgressLatch = new CountDownLatch(1); - ActionListener responseListener = ActionListener.wrap(indexResponse -> { - inProgressLatch.countDown(); - }, e -> fail("Failure should not have occurred")); + ActionListener responseListener = ActionListener.wrap( + indexResponse -> { inProgressLatch.countDown(); }, + e -> fail("Failure should not have occurred") + ); doAnswer(invocationOnMock -> { responseListener.onResponse(mock(IndexResponse.class)); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java index 0c3206961..7b89488a0 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java @@ -43,24 +43,26 @@ public class TrainingJobRouterTransportActionTests extends KNNTestCase { public void testSingleNode_withCapacity() { // Mock datanodes in the cluster through mocking the cluster service - List nodeIds = ImmutableList.of( - "node-1" - ); + List nodeIds = ImmutableList.of("node-1"); ImmutableOpenMap discoveryNodesMap = generateDiscoveryNodes(nodeIds); ClusterService clusterService = generateMockedClusterService(discoveryNodesMap); // Create a response to be returned with job route decision info List responseList = new ArrayList<>(); - nodeIds.forEach(id -> responseList.add(new TrainingJobRouteDecisionInfoNodeResponse( - discoveryNodesMap.get(id), - 0 // node has capacity - ))); + nodeIds.forEach( + id -> responseList.add( + new TrainingJobRouteDecisionInfoNodeResponse( + discoveryNodesMap.get(id), + 0 // node has capacity + ) + ) + ); TrainingJobRouteDecisionInfoResponse infoResponse = new TrainingJobRouteDecisionInfoResponse( - ClusterName.DEFAULT, - responseList, - Collections.emptyList() + ClusterName.DEFAULT, + responseList, + Collections.emptyList() ); TransportService transportService = mock(TransportService.class); @@ -68,7 +70,11 @@ public void testSingleNode_withCapacity() { // Setup the action TrainingJobRouterTransportAction transportAction = new TrainingJobRouterTransportAction( - transportService, new ActionFilters(Collections.emptySet()), clusterService, client); + transportService, + new ActionFilters(Collections.emptySet()), + clusterService, + client + ); // Select the node DiscoveryNode selectedNode = transportAction.selectNode(null, infoResponse); @@ -77,24 +83,26 @@ public void testSingleNode_withCapacity() { public void testSingleNode_withoutCapacity() { // Mock datanodes in the cluster through mocking the cluster service - List nodeIds = ImmutableList.of( - "node-1" - ); + List nodeIds = ImmutableList.of("node-1"); ImmutableOpenMap discoveryNodesMap = generateDiscoveryNodes(nodeIds); ClusterService clusterService = generateMockedClusterService(discoveryNodesMap); // Create a response to be returned with job route decision info List responseList = new ArrayList<>(); - nodeIds.forEach(id -> responseList.add(new TrainingJobRouteDecisionInfoNodeResponse( - discoveryNodesMap.get(id), - 1 // node has no capacity - ))); + nodeIds.forEach( + id -> responseList.add( + new TrainingJobRouteDecisionInfoNodeResponse( + discoveryNodesMap.get(id), + 1 // node has no capacity + ) + ) + ); TrainingJobRouteDecisionInfoResponse infoResponse = new TrainingJobRouteDecisionInfoResponse( - ClusterName.DEFAULT, - responseList, - Collections.emptyList() + ClusterName.DEFAULT, + responseList, + Collections.emptyList() ); TransportService transportService = mock(TransportService.class); @@ -102,7 +110,11 @@ public void testSingleNode_withoutCapacity() { // Setup the action TrainingJobRouterTransportAction transportAction = new TrainingJobRouterTransportAction( - transportService, new ActionFilters(Collections.emptySet()), clusterService, client); + transportService, + new ActionFilters(Collections.emptySet()), + clusterService, + client + ); // Select the node DiscoveryNode selectedNode = transportAction.selectNode(null, infoResponse); @@ -111,11 +123,7 @@ public void testSingleNode_withoutCapacity() { public void testMultiNode_withCapacity() { // Mock datanodes in the cluster through mocking the cluster service - List nodeIds = ImmutableList.of( - "node-1", - "node-2", - "node-3" - ); + List nodeIds = ImmutableList.of("node-1", "node-2", "node-3"); ImmutableOpenMap discoveryNodesMap = generateDiscoveryNodes(nodeIds); ClusterService clusterService = generateMockedClusterService(discoveryNodesMap); @@ -133,9 +141,9 @@ public void testMultiNode_withCapacity() { responseList.add(new TrainingJobRouteDecisionInfoNodeResponse(discoveryNodesMap.get(nodeIds.get(1)), 1)); TrainingJobRouteDecisionInfoResponse infoResponse = new TrainingJobRouteDecisionInfoResponse( - ClusterName.DEFAULT, - responseList, - Collections.emptyList() + ClusterName.DEFAULT, + responseList, + Collections.emptyList() ); TransportService transportService = mock(TransportService.class); @@ -143,7 +151,11 @@ public void testMultiNode_withCapacity() { // Setup the action TrainingJobRouterTransportAction transportAction = new TrainingJobRouterTransportAction( - transportService, new ActionFilters(Collections.emptySet()), clusterService, client); + transportService, + new ActionFilters(Collections.emptySet()), + clusterService, + client + ); // Select the node DiscoveryNode selectedNode = transportAction.selectNode(null, infoResponse); @@ -152,11 +164,7 @@ public void testMultiNode_withCapacity() { public void testMultiNode_withCapacity_withPreferredAvailable() { // Mock datanodes in the cluster through mocking the cluster service - List nodeIds = ImmutableList.of( - "node-1", - "node-2", - "node-3" - ); + List nodeIds = ImmutableList.of("node-1", "node-2", "node-3"); String preferredNode = nodeIds.get(2); @@ -176,9 +184,9 @@ public void testMultiNode_withCapacity_withPreferredAvailable() { responseList.add(new TrainingJobRouteDecisionInfoNodeResponse(discoveryNodesMap.get(nodeIds.get(2)), 0)); TrainingJobRouteDecisionInfoResponse infoResponse = new TrainingJobRouteDecisionInfoResponse( - ClusterName.DEFAULT, - responseList, - Collections.emptyList() + ClusterName.DEFAULT, + responseList, + Collections.emptyList() ); TransportService transportService = mock(TransportService.class); @@ -186,7 +194,11 @@ public void testMultiNode_withCapacity_withPreferredAvailable() { // Setup the action TrainingJobRouterTransportAction transportAction = new TrainingJobRouterTransportAction( - transportService, new ActionFilters(Collections.emptySet()), clusterService, client); + transportService, + new ActionFilters(Collections.emptySet()), + clusterService, + client + ); // Select the node DiscoveryNode selectedNode = transportAction.selectNode(preferredNode, infoResponse); @@ -195,11 +207,7 @@ public void testMultiNode_withCapacity_withPreferredAvailable() { public void testMultiNode_withCapacity_withoutPreferredAvailable() { // Mock datanodes in the cluster through mocking the cluster service - List nodeIds = ImmutableList.of( - "node-1", - "node-2", - "node-3" - ); + List nodeIds = ImmutableList.of("node-1", "node-2", "node-3"); String preferredNode = nodeIds.get(2); @@ -219,9 +227,9 @@ public void testMultiNode_withCapacity_withoutPreferredAvailable() { responseList.add(new TrainingJobRouteDecisionInfoNodeResponse(discoveryNodesMap.get(nodeIds.get(1)), 1)); TrainingJobRouteDecisionInfoResponse infoResponse = new TrainingJobRouteDecisionInfoResponse( - ClusterName.DEFAULT, - responseList, - Collections.emptyList() + ClusterName.DEFAULT, + responseList, + Collections.emptyList() ); TransportService transportService = mock(TransportService.class); @@ -229,7 +237,11 @@ public void testMultiNode_withCapacity_withoutPreferredAvailable() { // Setup the action TrainingJobRouterTransportAction transportAction = new TrainingJobRouterTransportAction( - transportService, new ActionFilters(Collections.emptySet()), clusterService, client); + transportService, + new ActionFilters(Collections.emptySet()), + clusterService, + client + ); // Select the node DiscoveryNode selectedNode = transportAction.selectNode(preferredNode, infoResponse); @@ -239,11 +251,7 @@ public void testMultiNode_withCapacity_withoutPreferredAvailable() { public void testMultiNode_withoutCapacity() { // Mock datanodes in the cluster through mocking the cluster service - List nodeIds = ImmutableList.of( - "node-1", - "node-2", - "node-3" - ); + List nodeIds = ImmutableList.of("node-1", "node-2", "node-3"); ImmutableOpenMap discoveryNodesMap = generateDiscoveryNodes(nodeIds); ClusterService clusterService = generateMockedClusterService(discoveryNodesMap); @@ -261,9 +269,9 @@ public void testMultiNode_withoutCapacity() { responseList.add(new TrainingJobRouteDecisionInfoNodeResponse(discoveryNodesMap.get(nodeIds.get(1)), 1)); TrainingJobRouteDecisionInfoResponse infoResponse = new TrainingJobRouteDecisionInfoResponse( - ClusterName.DEFAULT, - responseList, - Collections.emptyList() + ClusterName.DEFAULT, + responseList, + Collections.emptyList() ); TransportService transportService = mock(TransportService.class); @@ -271,7 +279,11 @@ public void testMultiNode_withoutCapacity() { // Setup the action TrainingJobRouterTransportAction transportAction = new TrainingJobRouterTransportAction( - transportService, new ActionFilters(Collections.emptySet()), clusterService, client); + transportService, + new ActionFilters(Collections.emptySet()), + clusterService, + client + ); // Select the node DiscoveryNode selectedNode = transportAction.selectNode(null, infoResponse); @@ -284,17 +296,17 @@ public void testTrainingIndexSize() { String trainingIndexName = "training-index"; int dimension = 133; int vectorCount = 1000000; - int expectedSize = dimension * vectorCount * Float.BYTES / BYTES_PER_KILOBYTES + 1; // 519,531.25 KB ~= 520 MB + int expectedSize = dimension * vectorCount * Float.BYTES / BYTES_PER_KILOBYTES + 1; // 519,531.25 KB ~= 520 MB // Setup the request TrainingModelRequest trainingModelRequest = new TrainingModelRequest( - null, - KNNMethodContext.getDefault(), - dimension, - trainingIndexName, - "training-field", - null, - "description" + null, + KNNMethodContext.getDefault(), + dimension, + trainingIndexName, + "training-field", + null, + "description" ); // Mock client to return the right number of docs @@ -312,11 +324,15 @@ public void testTrainingIndexSize() { ClusterService clusterService = mock(ClusterService.class); TransportService transportService = mock(TransportService.class); TrainingJobRouterTransportAction transportAction = new TrainingJobRouterTransportAction( - transportService, new ActionFilters(Collections.emptySet()), clusterService, client); + transportService, + new ActionFilters(Collections.emptySet()), + clusterService, + client + ); ActionListener listener = ActionListener.wrap( - size -> assertEquals(expectedSize, size.intValue()), - e -> fail(e.getMessage()) + size -> assertEquals(expectedSize, size.intValue()), + e -> fail(e.getMessage()) ); transportAction.getTrainingIndexSizeInKB(trainingModelRequest, listener); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index 105b2a189..06b6f3d01 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -40,7 +40,6 @@ import java.util.List; import java.util.Map; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -56,13 +55,13 @@ public void testStreams() throws IOException { String description = "some test description"; TrainingModelRequest original1 = new TrainingModelRequest( - modelId, - knnMethodContext, - dimension, - trainingIndex, - trainingField, - preferredNode, - description + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + preferredNode, + description ); BytesStreamOutput streamOutput = new BytesStreamOutput(); @@ -78,13 +77,13 @@ public void testStreams() throws IOException { // Also, check when preferred node and model id and description are null TrainingModelRequest original2 = new TrainingModelRequest( - null, - knnMethodContext, - dimension, - trainingIndex, - trainingField, - null, - null + null, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + null, + null ); streamOutput = new BytesStreamOutput(); @@ -112,13 +111,13 @@ public void testGetters() { int trainingSetSizeInKB = 102; TrainingModelRequest trainingModelRequest = new TrainingModelRequest( - modelId, - knnMethodContext, - dimension, - trainingIndex, - trainingField, - preferredNode, - description + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + preferredNode, + description ); trainingModelRequest.setMaximumVectorCount(maxVectorCount); @@ -151,19 +150,26 @@ public void testValidation_invalid_modelIdAlreadyExists() { String trainingField = "test-training-field"; TrainingModelRequest trainingModelRequest = new TrainingModelRequest( - modelId, - knnMethodContext, - dimension, - trainingIndex, - trainingField, - null, - null + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + null, + null ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate ModelDao modelDao = mock(ModelDao.class); - ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 128, - ModelState.CREATED, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata = new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + 128, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); // This cluster service will result in no validation exceptions @@ -199,13 +205,13 @@ public void testValidation_invalid_invalidMethodContext() { String trainingField = "test-training-field"; TrainingModelRequest trainingModelRequest = new TrainingModelRequest( - modelId, - knnMethodContext, - dimension, - trainingIndex, - trainingField, - null, - null + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + null, + null ); // Mock the model dao to return null so that no exception is produced @@ -242,13 +248,13 @@ public void testValidation_invalid_trainingIndexDoesNotExist() { String trainingField = "test-training-field"; TrainingModelRequest trainingModelRequest = new TrainingModelRequest( - modelId, - knnMethodContext, - dimension, - trainingIndex, - trainingField, - null, - null + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + null, + null ); // Mock the model dao to return null so that no exception is produced @@ -288,13 +294,13 @@ public void testValidation_invalid_trainingFieldDoesNotExist() { String trainingField = "test-training-field"; TrainingModelRequest trainingModelRequest = new TrainingModelRequest( - modelId, - knnMethodContext, - dimension, - trainingIndex, - trainingField, - null, - null + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + null, + null ); // Mock the model dao to return null so that no exception is produced @@ -339,27 +345,23 @@ public void testValidation_invalid_trainingFieldNotKnnVector() { String trainingField = "test-training-field"; TrainingModelRequest trainingModelRequest = new TrainingModelRequest( - modelId, - knnMethodContext, - dimension, - trainingIndex, - trainingField, - null, - null + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + null, + null ); // Mock the model dao to return null so that no exception is produced ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(null); - // Return mapping with different type + // Return mapping with different type Map mappingMap = ImmutableMap.of( - "properties", ImmutableMap.of( - trainingField, ImmutableMap.of( - "type", "int", - KNNConstants.DIMENSION, dimension - ) - ) + "properties", + ImmutableMap.of(trainingField, ImmutableMap.of("type", "int", KNNConstants.DIMENSION, dimension)) ); MappingMetadata mappingMetadata = mock(MappingMetadata.class); when(mappingMetadata.getSourceAsMap()).thenReturn(mappingMap); @@ -398,27 +400,26 @@ public void testValidation_invalid_dimensionDoesNotMatch() { String trainingField = "test-training-field"; TrainingModelRequest trainingModelRequest = new TrainingModelRequest( - modelId, - knnMethodContext, - dimension, - trainingIndex, - trainingField, - null, - null + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + null, + null ); // Mock the model dao to return null so that no exception is produced ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(null); - // Return mapping with different dimension + // Return mapping with different dimension Map mappingMap = ImmutableMap.of( - "properties", ImmutableMap.of( - trainingField, ImmutableMap.of( - "type", KNNVectorFieldMapper.CONTENT_TYPE, - KNNConstants.DIMENSION, dimension + 1 - ) - ) + "properties", + ImmutableMap.of( + trainingField, + ImmutableMap.of("type", KNNVectorFieldMapper.CONTENT_TYPE, KNNConstants.DIMENSION, dimension + 1) + ) ); MappingMetadata mappingMetadata = mock(MappingMetadata.class); when(mappingMetadata.getSourceAsMap()).thenReturn(mappingMap); @@ -456,13 +457,13 @@ public void testValidation_invalid_preferredNodeDoesNotExist() { String preferredNode = "preferred-node"; TrainingModelRequest trainingModelRequest = new TrainingModelRequest( - modelId, - knnMethodContext, - dimension, - trainingIndex, - trainingField, - preferredNode, - null + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + preferredNode, + null ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -471,12 +472,8 @@ public void testValidation_invalid_preferredNodeDoesNotExist() { // This cluster service mocking should not produce exception Map mappingMap = ImmutableMap.of( - "properties", ImmutableMap.of( - trainingField, ImmutableMap.of( - "type", KNNVectorFieldMapper.CONTENT_TYPE, - KNNConstants.DIMENSION, dimension - ) - ) + "properties", + ImmutableMap.of(trainingField, ImmutableMap.of("type", KNNVectorFieldMapper.CONTENT_TYPE, KNNConstants.DIMENSION, dimension)) ); MappingMetadata mappingMetadata = mock(MappingMetadata.class); @@ -525,13 +522,13 @@ public void testValidation_invalid_descriptionToLong() { String description = new String(chars); TrainingModelRequest trainingModelRequest = new TrainingModelRequest( - modelId, - knnMethodContext, - dimension, - trainingIndex, - trainingField, - null, - description + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + null, + description ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -569,13 +566,13 @@ public void testValidation_valid_trainingIndexBuiltFromMethod() { String trainingField = "test-training-field"; TrainingModelRequest trainingModelRequest = new TrainingModelRequest( - modelId, - knnMethodContext, - dimension, - trainingIndex, - trainingField, - null, - null + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + null, + null ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -606,13 +603,13 @@ public void testValidation_valid_trainingIndexBuiltFromModel() { String trainingFieldModeId = "training-field-model-id"; TrainingModelRequest trainingModelRequest = new TrainingModelRequest( - modelId, - knnMethodContext, - dimension, - trainingIndex, - trainingField, - null, - null + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + null, + null ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -622,15 +619,14 @@ public void testValidation_valid_trainingIndexBuiltFromModel() { ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(null); when(modelDao.getMetadata(trainingFieldModeId)).thenReturn(trainingFieldModelMetadata); - + // Return model id instead of dimension directly Map mappingMap = ImmutableMap.of( - "properties", ImmutableMap.of( - trainingField, ImmutableMap.of( - "type", KNNVectorFieldMapper.CONTENT_TYPE, - KNNConstants.MODEL_ID, trainingFieldModeId - ) - ) + "properties", + ImmutableMap.of( + trainingField, + ImmutableMap.of("type", KNNVectorFieldMapper.CONTENT_TYPE, KNNConstants.MODEL_ID, trainingFieldModeId) + ) ); MappingMetadata mappingMetadata = mock(MappingMetadata.class); when(mappingMetadata.getSourceAsMap()).thenReturn(mappingMap); @@ -656,7 +652,6 @@ public void testValidation_valid_trainingIndexBuiltFromModel() { assertNull(exception); } - /** * This method produces a cluster service that will mock so that there are no validation exceptions. * @@ -667,12 +662,8 @@ public void testValidation_valid_trainingIndexBuiltFromModel() { */ private ClusterService getClusterServiceForValidReturns(String trainingIndex, String trainingField, int dimension) { Map mappingMap = ImmutableMap.of( - "properties", ImmutableMap.of( - trainingField, ImmutableMap.of( - "type", KNNVectorFieldMapper.CONTENT_TYPE, - KNNConstants.DIMENSION, dimension - ) - ) + "properties", + ImmutableMap.of(trainingField, ImmutableMap.of("type", KNNVectorFieldMapper.CONTENT_TYPE, KNNConstants.DIMENSION, dimension)) ); MappingMetadata mappingMetadata = mock(MappingMetadata.class); when(mappingMetadata.getSourceAsMap()).thenReturn(mappingMap); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelResponseTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelResponseTests.java index e82ac70bf..2fef2635a 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelResponseTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelResponseTests.java @@ -56,11 +56,12 @@ public void testToXContent() throws IOException { // We expect this: // { - // "model_id": "test-model-id" + // "model_id": "test-model-id" // } - XContentBuilder expectedXContentBuilder = XContentFactory.jsonBuilder().startObject() - .field(KNNConstants.MODEL_ID, modelId) - .endObject(); + XContentBuilder expectedXContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(KNNConstants.MODEL_ID, modelId) + .endObject(); Map expected = xContentBuilderToMap(expectedXContentBuilder); // Check responses are equal diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java index b03dc1873..90f26aa59 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java @@ -48,39 +48,41 @@ public void testDoExecute() throws InterruptedException, ExecutionException, IOE for (int i = 0; i < trainingDataCount; i++) { Float[] vector = new Float[dimension]; Arrays.fill(vector, Float.intBitsToFloat(i)); - addKnnDoc(trainingIndexName, Integer.toString(i+1), trainingFieldName, vector); + addKnnDoc(trainingIndexName, Integer.toString(i + 1), trainingFieldName, vector); } // Create train model request String modelId = "test-model-id"; - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject() - .field(NAME, METHOD_IVF) - .field(KNN_ENGINE, KNNEngine.FAISS.getName()) - .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue()) - .startObject(PARAMETERS) - .field(METHOD_PARAMETER_NLIST, 4) - .endObject() - .endObject(); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 4) + .endObject() + .endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); TrainingModelRequest trainingModelRequest = new TrainingModelRequest( - modelId, - knnMethodContext, - dimension, - trainingIndexName, - trainingFieldName, - null, - "test-detector" + modelId, + knnMethodContext, + dimension, + trainingIndexName, + trainingFieldName, + null, + "test-detector" ); trainingModelRequest.setTrainingDataSizeInKB(estimateVectorSetSizeInKB(trainingDataCount, dimension)); // Create listener that ensures that the initial model put succeeds - ActionListener listener = ActionListener.wrap(response -> - assertEquals(modelId, response.getModelId()), e -> fail("Failure: " + e.getMessage())); + ActionListener listener = ActionListener.wrap( + response -> assertEquals(modelId, response.getModelId()), + e -> fail("Failure: " + e.getMessage()) + ); - TrainingModelTransportAction trainingModelTransportAction = node().injector() - .getInstance(TrainingModelTransportAction.class); + TrainingModelTransportAction trainingModelTransportAction = node().injector().getInstance(TrainingModelTransportAction.class); trainingModelTransportAction.doExecute(null, trainingModelRequest, listener); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java index 92f32ab70..f38273f15 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java @@ -32,8 +32,15 @@ public void testStreams() throws IOException { String modelId = "test-model"; boolean isRemoveRequest = false; - ModelMetadata modelMetadata = new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata = new ModelMetadata( + knnEngine, + spaceType, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest(modelId, isRemoveRequest, modelMetadata); BytesStreamOutput streamOutput = new BytesStreamOutput(); @@ -48,8 +55,15 @@ public void testStreams() throws IOException { public void testValidate() { - ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 128, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata = new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + 128, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); UpdateModelMetadataRequest updateModelMetadataRequest1 = new UpdateModelMetadataRequest("test", true, null); assertNull(updateModelMetadataRequest1.validate()); @@ -79,8 +93,15 @@ public void testIsRemoveRequest() { } public void testGetModelMetadata() { - ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 128, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata = new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + 128, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest("test", true, modelMetadata); assertEquals(modelMetadata, updateModelMetadataRequest.getModelMetadata()); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java index 932a3265a..c98745695 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java @@ -37,18 +37,17 @@ public class UpdateModelMetadataTransportActionTests extends KNNSingleNodeTestCa public void testExecutor() { UpdateModelMetadataTransportAction updateModelMetadataTransportAction = node().injector() - .getInstance(UpdateModelMetadataTransportAction.class); + .getInstance(UpdateModelMetadataTransportAction.class); assertEquals(ThreadPool.Names.SAME, updateModelMetadataTransportAction.executor()); } public void testRead() throws IOException { UpdateModelMetadataTransportAction updateModelMetadataTransportAction = node().injector() - .getInstance(UpdateModelMetadataTransportAction.class); + .getInstance(UpdateModelMetadataTransportAction.class); AcknowledgedResponse acknowledgedResponse = new AcknowledgedResponse(true); BytesStreamOutput streamOutput = new BytesStreamOutput(); acknowledgedResponse.writeTo(streamOutput); - AcknowledgedResponse acknowledgedResponse1 = updateModelMetadataTransportAction.read(streamOutput.bytes() - .streamInput()); + AcknowledgedResponse acknowledgedResponse1 = updateModelMetadataTransportAction.read(streamOutput.bytes().streamInput()); assertEquals(acknowledgedResponse, acknowledgedResponse1); } @@ -59,77 +58,82 @@ public void testMasterOperation() throws InterruptedException { // Setup the model String modelId = "test-model"; - ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 128, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); - - // Get update transport action + ModelMetadata modelMetadata = new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + 128, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); + + // Get update transport action UpdateModelMetadataTransportAction updateModelMetadataTransportAction = node().injector() - .getInstance(UpdateModelMetadataTransportAction.class); + .getInstance(UpdateModelMetadataTransportAction.class); // Generate update request - UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest(modelId, false, - modelMetadata); + UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest(modelId, false, modelMetadata); // Get cluster state, update metadata, check cluster state - all asynchronously final CountDownLatch inProgressLatch1 = new CountDownLatch(1); client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse1 -> { ClusterState clusterState1 = stateResponse1.getState(); updateModelMetadataTransportAction.masterOperation( - updateModelMetadataRequest, - clusterState1, - ActionListener.wrap(acknowledgedResponse -> { - assertTrue(acknowledgedResponse.isAcknowledged()); + updateModelMetadataRequest, + clusterState1, + ActionListener.wrap(acknowledgedResponse -> { + assertTrue(acknowledgedResponse.isAcknowledged()); - client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse2 -> { - ClusterState updatedClusterState = stateResponse2.getState(); - IndexMetadata indexMetadata = updatedClusterState.metadata().index(MODEL_INDEX_NAME); - assertNotNull(indexMetadata); + client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse2 -> { + ClusterState updatedClusterState = stateResponse2.getState(); + IndexMetadata indexMetadata = updatedClusterState.metadata().index(MODEL_INDEX_NAME); + assertNotNull(indexMetadata); - Map modelMetadataMap = indexMetadata.getCustomData(MODEL_METADATA_FIELD); - assertNotNull(modelMetadataMap); + Map modelMetadataMap = indexMetadata.getCustomData(MODEL_METADATA_FIELD); + assertNotNull(modelMetadataMap); - String modelAsString = modelMetadataMap.get(modelId); - assertNotNull(modelAsString); + String modelAsString = modelMetadataMap.get(modelId); + assertNotNull(modelAsString); - ModelMetadata modelMetadataCopy = ModelMetadata.fromString(modelAsString); - assertEquals(modelMetadata, modelMetadataCopy); + ModelMetadata modelMetadataCopy = ModelMetadata.fromString(modelAsString); + assertEquals(modelMetadata, modelMetadataCopy); - inProgressLatch1.countDown(); + inProgressLatch1.countDown(); - }, e -> fail("Update failed:" + e))); - }, e -> fail("Update failed: " + e)) + }, e -> fail("Update failed:" + e))); + }, e -> fail("Update failed: " + e)) ); - }, e -> fail("Update failed: " + e))); + }, e -> fail("Update failed: " + e))); assertTrue(inProgressLatch1.await(60, TimeUnit.SECONDS)); // Generate remove request - UpdateModelMetadataRequest removeModelMetadataRequest = new UpdateModelMetadataRequest(modelId, true, - modelMetadata); + UpdateModelMetadataRequest removeModelMetadataRequest = new UpdateModelMetadataRequest(modelId, true, modelMetadata); final CountDownLatch inProgressLatch2 = new CountDownLatch(1); client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse1 -> { ClusterState clusterState1 = stateResponse1.getState(); updateModelMetadataTransportAction.masterOperation( - removeModelMetadataRequest, - clusterState1, - ActionListener.wrap(acknowledgedResponse -> { - assertTrue(acknowledgedResponse.isAcknowledged()); + removeModelMetadataRequest, + clusterState1, + ActionListener.wrap(acknowledgedResponse -> { + assertTrue(acknowledgedResponse.isAcknowledged()); - client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse2 -> { - ClusterState updatedClusterState = stateResponse2.getState(); - IndexMetadata indexMetadata = updatedClusterState.metadata().index(MODEL_INDEX_NAME); - assertNotNull(indexMetadata); + client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse2 -> { + ClusterState updatedClusterState = stateResponse2.getState(); + IndexMetadata indexMetadata = updatedClusterState.metadata().index(MODEL_INDEX_NAME); + assertNotNull(indexMetadata); - Map modelMetadataMap = indexMetadata.getCustomData(MODEL_METADATA_FIELD); - assertNotNull(modelMetadataMap); + Map modelMetadataMap = indexMetadata.getCustomData(MODEL_METADATA_FIELD); + assertNotNull(modelMetadataMap); - String modelAsString = modelMetadataMap.get(modelId); - assertNull(modelAsString); + String modelAsString = modelMetadataMap.get(modelId); + assertNull(modelAsString); - inProgressLatch2.countDown(); - }, e -> fail("Update failed"))); - }, e -> fail("Update failed")) + inProgressLatch2.countDown(); + }, e -> fail("Update failed"))); + }, e -> fail("Update failed")) ); }, e -> fail("Update failed"))); @@ -138,7 +142,7 @@ public void testMasterOperation() throws InterruptedException { public void testCheckBlock() { UpdateModelMetadataTransportAction updateModelMetadataTransportAction = node().injector() - .getInstance(UpdateModelMetadataTransportAction.class); + .getInstance(UpdateModelMetadataTransportAction.class); assertNull(updateModelMetadataTransportAction.checkBlock(null, null)); } } diff --git a/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java b/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java index 678a9d2dc..889f3916f 100644 --- a/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java +++ b/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java @@ -11,15 +11,9 @@ package org.opensearch.knn.recall; -import org.apache.http.util.EntityUtils; -import org.opensearch.client.Response; import org.opensearch.knn.KNNRestTestCase; -import org.opensearch.knn.KNNResult; import org.opensearch.knn.TestUtils; -import org.opensearch.knn.index.KNNQueryBuilder; import org.opensearch.knn.index.SpaceType; -import java.io.IOException; -import java.util.ArrayList; import java.util.List; import java.util.Set; import static org.opensearch.knn.index.KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY; diff --git a/src/test/java/org/opensearch/knn/training/TrainingDataConsumerTests.java b/src/test/java/org/opensearch/knn/training/TrainingDataConsumerTests.java index 78f061969..d5a66c5b6 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingDataConsumerTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingDataConsumerTests.java @@ -29,7 +29,10 @@ public void testAccept() { // Mock the training data allocation int dimension = 128; - NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = mock(NativeMemoryAllocation.TrainingDataAllocation.class); // new NativeMemoryAllocation.TrainingDataAllocation(0, numVectors*dimension* Float.BYTES); + NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation = mock(NativeMemoryAllocation.TrainingDataAllocation.class); // new + // NativeMemoryAllocation.TrainingDataAllocation(0, + // numVectors*dimension* + // Float.BYTES); when(trainingDataAllocation.getMemoryAddress()).thenReturn(0L); // Capture argument passed to set pointer diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobRunnerTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobRunnerTests.java index 41d1ba03d..f1fd306a9 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobRunnerTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobRunnerTests.java @@ -26,7 +26,6 @@ import java.util.concurrent.TimeUnit; import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -56,30 +55,23 @@ public void testExecute_success() throws IOException, InterruptedException { // This gets called right after the initial put, before training begins. Just check that the model id is // equal - ActionListener responseListener = ActionListener.wrap(indexResponse -> - assertEquals(modelId, indexResponse.getId()), e -> fail("Failure should not have occurred")); + ActionListener responseListener = ActionListener.wrap( + indexResponse -> assertEquals(modelId, indexResponse.getId()), + e -> fail("Failure should not have occurred") + ); // After put finishes, it should call the onResponse function that will call responseListener and then kickoff // training. ModelDao modelDao = mock(ModelDao.class); doAnswer(invocationOnMock -> { assertEquals(1, trainingJobRunner.getJobCount()); // Make sure job count is correct - IndexResponse indexResponse = new IndexResponse( - new ShardId(MODEL_INDEX_NAME, "uuid", 0), - "any-type", - modelId, - 0, - 0, - 0, - true - ); - ((ActionListener)invocationOnMock.getArguments()[1]).onResponse(indexResponse); + IndexResponse indexResponse = new IndexResponse(new ShardId(MODEL_INDEX_NAME, "uuid", 0), "any-type", modelId, 0, 0, 0, true); + ((ActionListener) invocationOnMock.getArguments()[1]).onResponse(indexResponse); return null; }).when(modelDao).put(any(Model.class), any(ActionListener.class)); // Function finishes when update is called - doAnswer(invocationOnMock -> null) - .when(modelDao).update(any(Model.class), any(ActionListener.class)); + doAnswer(invocationOnMock -> null).when(modelDao).update(any(Model.class), any(ActionListener.class)); // Finally, initialize the singleton runner, execute the job. TrainingJobRunner.initialize(threadPool, modelDao); @@ -116,33 +108,28 @@ public void testExecute_failure_rejected() throws IOException, InterruptedExcept // This gets called right after the initial put, before training begins. Just check that the model id is // equal ActionListener responseListener = ActionListener.wrap( - indexResponse -> assertEquals(modelId, indexResponse.getId()), - e -> fail("Should not reach this state") + indexResponse -> assertEquals(modelId, indexResponse.getId()), + e -> fail("Should not reach this state") ); // After put finishes, it should call the onResponse function that will call responseListener and then kickoff // training. ModelDao modelDao = mock(ModelDao.class); doAnswer(invocationOnMock -> { - IndexResponse indexResponse = new IndexResponse( - new ShardId(MODEL_INDEX_NAME, "uuid", 0), - "any-type", - modelId, - 0, - 0, - 0, - true - ); - ((ActionListener)invocationOnMock.getArguments()[1]).onResponse(indexResponse); + IndexResponse indexResponse = new IndexResponse(new ShardId(MODEL_INDEX_NAME, "uuid", 0), "any-type", modelId, 0, 0, 0, true); + ((ActionListener) invocationOnMock.getArguments()[1]).onResponse(indexResponse); return null; }).when(modelDao).put(any(Model.class), any(ActionListener.class)); // Once update is called, try to start another training job. This should fail because the calling thread // is running training TrainingJobRunner trainingJobRunner = TrainingJobRunner.getInstance(); - doAnswer(invocationOnMock -> expectThrows(RejectedExecutionException.class, - () -> trainingJobRunner.execute(trainingJob, responseListener))).when(modelDao) - .update(model, responseListener); + doAnswer( + invocationOnMock -> expectThrows( + RejectedExecutionException.class, + () -> trainingJobRunner.execute(trainingJob, responseListener) + ) + ).when(modelDao).update(model, responseListener); // Finally, initialize the singleton runner, execute the job. TrainingJobRunner.initialize(threadPool, modelDao); diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java index f20e51cf0..a4a9bda98 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java @@ -46,13 +46,13 @@ public void testGetModelId() { when(knnMethodContext.getSpaceType()).thenReturn(SpaceType.DEFAULT); TrainingJob trainingJob = new TrainingJob( - modelId, - knnMethodContext, - mock(NativeMemoryCacheManager.class), - mock(NativeMemoryEntryContext.TrainingDataEntryContext.class), - mock(NativeMemoryEntryContext.AnonymousEntryContext.class), - 10, - "" + modelId, + knnMethodContext, + mock(NativeMemoryCacheManager.class), + mock(NativeMemoryEntryContext.TrainingDataEntryContext.class), + mock(NativeMemoryEntryContext.AnonymousEntryContext.class), + 10, + "" ); assertEquals(modelId, trainingJob.getModelId()); @@ -71,27 +71,27 @@ public void testGetModel() { String modelID = "test-model-id"; TrainingJob trainingJob = new TrainingJob( - modelID, - knnMethodContext, - mock(NativeMemoryCacheManager.class), - mock(NativeMemoryEntryContext.TrainingDataEntryContext.class), - mock(NativeMemoryEntryContext.AnonymousEntryContext.class), - dimension, - desciption + modelID, + knnMethodContext, + mock(NativeMemoryCacheManager.class), + mock(NativeMemoryEntryContext.TrainingDataEntryContext.class), + mock(NativeMemoryEntryContext.AnonymousEntryContext.class), + dimension, + desciption ); Model model = new Model( - new ModelMetadata( - knnEngine, - spaceType, - dimension, - ModelState.TRAINING, - trainingJob.getModel().getModelMetadata().getTimestamp(), - desciption, - error - ), - null, - modelID + new ModelMetadata( + knnEngine, + spaceType, + dimension, + ModelState.TRAINING, + trainingJob.getModel().getModelMetadata().getTimestamp(), + desciption, + error + ), + null, + modelID ); assertEquals(model, trainingJob.getModel()); @@ -105,8 +105,11 @@ public void testRun_success() throws IOException, ExecutionException { int nlists = 5; int dimension = 16; KNNEngine knnEngine = KNNEngine.FAISS; - KNNMethodContext knnMethodContext = new KNNMethodContext(knnEngine, SpaceType.INNER_PRODUCT, - new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists))); + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + SpaceType.INNER_PRODUCT, + new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists)) + ); // Set up training data int tdataPoints = 100; @@ -138,8 +141,9 @@ public void testRun_success() throws IOException, ExecutionException { when(nativeMemoryAllocation.getMemoryAddress()).thenReturn(memoryAddress); String tdataKey = "t-data-key"; - NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = - mock(NativeMemoryEntryContext.TrainingDataEntryContext.class); + NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = mock( + NativeMemoryEntryContext.TrainingDataEntryContext.class + ); when(trainingDataEntryContext.getKey()).thenReturn(tdataKey); when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation); @@ -149,13 +153,13 @@ public void testRun_success() throws IOException, ExecutionException { }).when(nativeMemoryCacheManager).invalidate(tdataKey); TrainingJob trainingJob = new TrainingJob( - modelId, - knnMethodContext, - nativeMemoryCacheManager, - trainingDataEntryContext, - modelContext, - dimension, - "" + modelId, + knnMethodContext, + nativeMemoryCacheManager, + trainingDataEntryContext, + modelContext, + dimension, + "" ); trainingJob.run(); @@ -166,12 +170,19 @@ public void testRun_success() throws IOException, ExecutionException { assertEquals(ModelState.CREATED, model.getModelMetadata().getState()); // Simple test that creates the index from template and doesnt fail - int[] ids = { 1, 2, 3, 4}; + int[] ids = { 1, 2, 3, 4 }; float[][] vectors = new float[ids.length][dimension]; fillFloatArrayRandomly(vectors); Path indexPath = createTempFile(); - JNIService.createIndexFromTemplate(ids, vectors, indexPath.toString(), model.getModelBlob(), ImmutableMap.of(INDEX_THREAD_QTY, 1), knnEngine.getName()); + JNIService.createIndexFromTemplate( + ids, + vectors, + indexPath.toString(), + model.getModelBlob(), + ImmutableMap.of(INDEX_THREAD_QTY, 1), + knnEngine.getName() + ); assertNotEquals(0, new File(indexPath.toString()).length()); } @@ -184,8 +195,11 @@ public void testRun_failure_onGetTrainingDataAllocation() throws ExecutionExcept int nlists = 5; int dimension = 16; KNNEngine knnEngine = KNNEngine.FAISS; - KNNMethodContext knnMethodContext = new KNNMethodContext(knnEngine, SpaceType.INNER_PRODUCT, - new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists))); + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + SpaceType.INNER_PRODUCT, + new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists)) + ); // Setup model manager NativeMemoryCacheManager nativeMemoryCacheManager = mock(NativeMemoryCacheManager.class); @@ -205,23 +219,23 @@ public void testRun_failure_onGetTrainingDataAllocation() throws ExecutionExcept // Setup mock allocation for training data String tdataKey = "t-data-key"; - NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = - mock(NativeMemoryEntryContext.TrainingDataEntryContext.class); + NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = mock( + NativeMemoryEntryContext.TrainingDataEntryContext.class + ); when(trainingDataEntryContext.getKey()).thenReturn(tdataKey); // Throw error on getting data String testException = "test exception"; - when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)) - .thenThrow(new RuntimeException(testException)); + when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenThrow(new RuntimeException(testException)); TrainingJob trainingJob = new TrainingJob( - modelId, - knnMethodContext, - nativeMemoryCacheManager, - trainingDataEntryContext, - modelContext, - dimension, - "" + modelId, + knnMethodContext, + nativeMemoryCacheManager, + trainingDataEntryContext, + modelContext, + dimension, + "" ); trainingJob.run(); @@ -241,8 +255,11 @@ public void testRun_failure_onGetModelAnonymousAllocation() throws ExecutionExce int nlists = 5; int dimension = 16; KNNEngine knnEngine = KNNEngine.FAISS; - KNNMethodContext knnMethodContext = new KNNMethodContext(knnEngine, SpaceType.INNER_PRODUCT, - new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists))); + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + SpaceType.INNER_PRODUCT, + new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists)) + ); // Setup model manager NativeMemoryCacheManager nativeMemoryCacheManager = mock(NativeMemoryCacheManager.class); @@ -255,8 +272,9 @@ public void testRun_failure_onGetModelAnonymousAllocation() throws ExecutionExce when(nativeMemoryAllocation.getMemoryAddress()).thenReturn((long) 0); String tdataKey = "t-data-key"; - NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = - mock(NativeMemoryEntryContext.TrainingDataEntryContext.class); + NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = mock( + NativeMemoryEntryContext.TrainingDataEntryContext.class + ); when(trainingDataEntryContext.getKey()).thenReturn(tdataKey); when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation); @@ -274,17 +292,16 @@ public void testRun_failure_onGetModelAnonymousAllocation() throws ExecutionExce // Throw error on getting model alloc String testException = "test exception"; - when(nativeMemoryCacheManager.get(modelContext, false)) - .thenThrow(new RuntimeException(testException)); + when(nativeMemoryCacheManager.get(modelContext, false)).thenThrow(new RuntimeException(testException)); TrainingJob trainingJob = new TrainingJob( - modelId, - knnMethodContext, - nativeMemoryCacheManager, - trainingDataEntryContext, - modelContext, - dimension, - "" + modelId, + knnMethodContext, + nativeMemoryCacheManager, + trainingDataEntryContext, + modelContext, + dimension, + "" ); trainingJob.run(); @@ -304,12 +321,16 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep int nlists = 5; int dimension = 16; KNNEngine knnEngine = KNNEngine.FAISS; - KNNMethodContext knnMethodContext = new KNNMethodContext(knnEngine, SpaceType.INNER_PRODUCT, - new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists))); + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + SpaceType.INNER_PRODUCT, + new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists)) + ); String tdataKey = "t-data-key"; - NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = - mock(NativeMemoryEntryContext.TrainingDataEntryContext.class); + NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = mock( + NativeMemoryEntryContext.TrainingDataEntryContext.class + ); when(trainingDataEntryContext.getKey()).thenReturn(tdataKey); // Setup model manager @@ -339,13 +360,13 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation); TrainingJob trainingJob = new TrainingJob( - modelId, - knnMethodContext, - nativeMemoryCacheManager, - trainingDataEntryContext, - mock(NativeMemoryEntryContext.AnonymousEntryContext.class), - dimension, - "" + modelId, + knnMethodContext, + nativeMemoryCacheManager, + trainingDataEntryContext, + mock(NativeMemoryEntryContext.AnonymousEntryContext.class), + dimension, + "" ); trainingJob.run(); @@ -363,8 +384,11 @@ public void testRun_failure_notEnoughTrainingData() throws ExecutionException { int nlists = 1024; // setting this to 1024 will cause training to fail when there is only 2 data points int dimension = 16; KNNEngine knnEngine = KNNEngine.FAISS; - KNNMethodContext knnMethodContext = new KNNMethodContext(knnEngine, SpaceType.INNER_PRODUCT, - new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists))); + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + SpaceType.INNER_PRODUCT, + new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists)) + ); // Set up training data int tdataPoints = 2; @@ -396,8 +420,9 @@ public void testRun_failure_notEnoughTrainingData() throws ExecutionException { when(nativeMemoryAllocation.getMemoryAddress()).thenReturn(memoryAddress); String tdataKey = "t-data-key"; - NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = - mock(NativeMemoryEntryContext.TrainingDataEntryContext.class); + NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = mock( + NativeMemoryEntryContext.TrainingDataEntryContext.class + ); when(trainingDataEntryContext.getKey()).thenReturn(tdataKey); when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation); @@ -407,13 +432,13 @@ public void testRun_failure_notEnoughTrainingData() throws ExecutionException { }).when(nativeMemoryCacheManager).invalidate(tdataKey); TrainingJob trainingJob = new TrainingJob( - modelId, - knnMethodContext, - nativeMemoryCacheManager, - trainingDataEntryContext, - modelContext, - dimension, - "" + modelId, + knnMethodContext, + nativeMemoryCacheManager, + trainingDataEntryContext, + modelContext, + dimension, + "" ); trainingJob.run(); @@ -424,7 +449,7 @@ public void testRun_failure_notEnoughTrainingData() throws ExecutionException { assertFalse(model.getModelMetadata().getError().isEmpty()); } - private void fillFloatArrayRandomly(float [][] vectors) { + private void fillFloatArrayRandomly(float[][] vectors) { for (int i = 0; i < vectors.length; i++) { for (int j = 0; j < vectors[i].length; j++) { vectors[i][j] = randomFloat();