Skip to content

Commit

Permalink
Fix field validation in VectorReader (opensearch-project#207)
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 authored and martin-gaievski committed Mar 7, 2022
1 parent 653e19a commit a8a02c2
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 153 deletions.
122 changes: 122 additions & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@

package org.opensearch.knn.index;

import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;

import java.io.File;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;

Expand All @@ -34,4 +42,118 @@ public static int getFileSizeInKB(String filePath) {

return Math.toIntExact((file.length() / BYTES_PER_KILOBYTES) + 1L); // Add one so that integer division rounds up
}

/**
* Validate that a field is a k-NN vector field and has the expected dimension
*
* @param indexMetadata metadata for index to validate
* @param field field name to validate
* @param expectedDimension expected dimension of the field. If this value is negative, dimension will not be
* checked
* @param modelDao used to look up dimension if field uses a model for initialization. Can be null if
* expectedDimension is negative
* @return ValidationException exception produced by field validation
*/
@SuppressWarnings("unchecked")
public static ValidationException validateKnnField(IndexMetadata indexMetadata, String field, int expectedDimension,
ModelDao modelDao) {
// Index metadata should not be null
if (indexMetadata == null) {
throw new IllegalArgumentException("IndexMetadata should not be null");
}

ValidationException exception = new ValidationException();

// Check the mapping
MappingMetadata mappingMetadata = indexMetadata.mapping();
if (mappingMetadata == null) {
exception.addValidationError("Invalid index. Index does not contain a mapping");
return exception;
}

// The mapping output *should* look like this:
// "{properties={field={type=knn_vector, dimension=8}}}"
Map<String, Object> properties = (Map<String, Object>)mappingMetadata.getSourceAsMap().get("properties");

if (properties == null) {
exception.addValidationError("Properties in map does not exists. This is unexpected");
return exception;
}

Object fieldMapping = properties.get(field);

// Check field existence
if (fieldMapping == null) {
exception.addValidationError(String.format("Field \"%s\" does not exist.", field));
return exception;
}

// Check if field is a map. If not, that is a problem
if (!(fieldMapping instanceof Map)) {
exception.addValidationError(String.format("Field info for \"%s\" is not a map.", field));
return exception;
}

Map<String, Object> fieldMap = (Map<String, Object>) fieldMapping;

// Check fields type is knn_vector
Object type = fieldMap.get("type");

if (!(type instanceof String) || !KNNVectorFieldMapper.CONTENT_TYPE.equals(type)) {
exception.addValidationError(String.format("Field \"%s\" is not of type %s.", field,
KNNVectorFieldMapper.CONTENT_TYPE));
return exception;
}

// Return if dimension does not need to be checked
if (expectedDimension < 0) {
return null;
}

// Check that the dimension of the method passed in matches that of the model
Object dimension = fieldMap.get(KNNConstants.DIMENSION);

// If dimension is null, the training index/field could use a model. In this case, we need to get the model id
// for the index and then fetch its dimension from the models metadata
if (dimension == null) {

String modelId = (String) fieldMap.get(KNNConstants.MODEL_ID);

if (modelId == null) {
exception.addValidationError(String.format("Field \"%s\" does not have a dimension set.", field));
return exception;
}

if (modelDao == null) {
throw new IllegalArgumentException(String.format("Field \"%s\" uses model. modelDao cannot be null.",
field));
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
exception.addValidationError(String.format("Model \"%s\" for field \"%s\" does not exist.", modelId,
field));
return exception;
}

dimension = modelMetadata.getDimension();
if ((Integer) dimension != expectedDimension) {
exception.addValidationError(String.format("Field \"%s\" has dimension %d, which is different from " +
"dimension specified in the training request: %d", field, dimension,
expectedDimension));
return exception;
}

return null;
}

// If the dimension was found in training fields mapping, check that it equals the models proposed dimension.
if ((Integer) dimension != expectedDimension) {
exception.addValidationError(String.format("Field \"%s\" has dimension %d, which is different from " +
"dimension specified in the training request: %d", field, dimension, expectedDimension));
return exception;
}

return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

package org.opensearch.knn.index.memory;

import org.opensearch.indices.IndicesService;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.knn.index.IndexUtil;

import java.io.IOException;
Expand Down Expand Up @@ -129,7 +129,7 @@ public static class TrainingDataEntryContext extends NativeMemoryEntryContext<Na

private final int size;
private final NativeMemoryLoadStrategy.TrainingLoadStrategy trainingLoadStrategy;
private final IndicesService indicesService;
private final ClusterService clusterService;
private final String trainIndexName;
private final String trainFieldName;
private final int maxVectorCount;
Expand All @@ -142,23 +142,23 @@ public static class TrainingDataEntryContext extends NativeMemoryEntryContext<Na
* @param trainIndexName name of index used to pull training data from
* @param trainFieldName name of field used to pull training data from
* @param trainingLoadStrategy strategy to load training data into memory
* @param indicesService service used to extract information about indices
* @param clusterService service used to extract information about indices
* @param maxVectorCount maximum number of vectors there can be
* @param searchSize size each search request should return during loading
*/
public TrainingDataEntryContext(int size,
String trainIndexName,
String trainFieldName,
NativeMemoryLoadStrategy.TrainingLoadStrategy trainingLoadStrategy,
IndicesService indicesService,
ClusterService clusterService,
int maxVectorCount,
int searchSize) {
super(generateKey(trainIndexName, trainFieldName));
this.size = size;
this.trainingLoadStrategy = trainingLoadStrategy;
this.trainIndexName = trainIndexName;
this.trainFieldName = trainFieldName;
this.indicesService = indicesService;
this.clusterService = clusterService;
this.maxVectorCount = maxVectorCount;
this.searchSize = searchSize;
}
Expand Down Expand Up @@ -210,12 +210,12 @@ public int getSearchSize() {
}

/**
* Getter for indices service.
* Getter for cluster service.
*
* @return indices service
* @return cluster service
*/
public IndicesService getIndicesService() {
return indicesService;
public ClusterService getClusterService() {
return clusterService;
}

private static String generateKey(String trainIndexName, String trainFieldName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ public NativeMemoryAllocation.TrainingDataAllocation load(NativeMemoryEntryConte
trainingDataAllocation.writeLock();

vectorReader.read(
nativeMemoryEntryContext.getIndicesService(),
nativeMemoryEntryContext.getClusterService(),
nativeMemoryEntryContext.getTrainIndexName(),
nativeMemoryEntryContext.getTrainFieldName(),
nativeMemoryEntryContext.getMaxVectorCount(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,11 @@
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.KNNVectorFieldMapper;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;

import java.io.IOException;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;

/**
* Request to train and serialize a model
Expand Down Expand Up @@ -266,15 +262,6 @@ public ActionRequestValidationException validate() {
exception.addValidationError("Method does not require training.");
}

// Validate training data
IndexMetadata indexMetadata = clusterService.state().metadata().index(trainingIndex);
if (indexMetadata == null) {
exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationError("Index \"" + this.trainingIndex + "\" does not exist.");
} else {
exception = validateTrainingField(indexMetadata, exception);
}

// Check if preferred node is real
if (preferredNodeId != null && !clusterService.state().nodes().getDataNodes().containsKey(preferredNodeId)) {
exception = exception == null ? new ActionRequestValidationException() : exception;
Expand All @@ -288,6 +275,22 @@ public ActionRequestValidationException validate() {
" characters");
}

// Validate training index exists
IndexMetadata indexMetadata = clusterService.state().metadata().index(trainingIndex);
if (indexMetadata == null) {
exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationError("Index \"" + this.trainingIndex + "\" does not exist.");
return exception;
}

// Validate the training field
ValidationException fieldValidation = IndexUtil.validateKnnField(indexMetadata, this.trainingField,
this.dimension, modelDao);
if (fieldValidation != null) {
exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationErrors(fieldValidation.validationErrors());
}

return exception;
}

Expand All @@ -305,97 +308,4 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeInt(this.searchSize);
out.writeInt(this.trainingDataSizeInKB);
}

@SuppressWarnings("unchecked")
private ActionRequestValidationException validateTrainingField(IndexMetadata indexMetadata,
ActionRequestValidationException exception) {
// Index metadata should not be null
if (indexMetadata == null) {
throw new IllegalArgumentException("IndexMetadata should not be null");
}

// The mapping output *should* look like this:
// "{properties={train_field={type=knn_vector, dimension=8}}}"
Map<String, Object> properties = (Map<String, Object>)indexMetadata.mapping().getSourceAsMap()
.get("properties");

if (properties == null) {
exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationError("Properties in map does not exists. This is unexpected");
return exception;
}

Object trainingFieldMapping = properties.get(trainingField);

// Check field existence
if (trainingFieldMapping == null) {
exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationError(String.format("Field \"%s\" does not exist.", this.trainingField));
return exception;
}

// Check if field is a map. If not, that is a problem
if (!(trainingFieldMapping instanceof Map)) {
exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationError(String.format("Field info for \"%s\" is not a map.", this.trainingField));
return exception;
}

Map<String, Object> trainingFieldMap = (Map<String, Object>) trainingFieldMapping;

// Check fields type is knn_vector
Object type = trainingFieldMap.get("type");

if (!(type instanceof String) || !KNNVectorFieldMapper.CONTENT_TYPE.equals(type)) {
exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationError(String.format("Field \"%s\" is not of type %s.", this.trainingField,
KNNVectorFieldMapper.CONTENT_TYPE));
return exception;
}

// Check that the dimension of the method passed in matches that of the model
Object dimension = trainingFieldMap.get(KNNConstants.DIMENSION);

// If dimension is null, the training index/field could use a model. In this case, we need to get the model id
// for the index and then fetch its dimension from the models metadata
if (dimension == null) {
String modelId = (String) trainingFieldMap.get(KNNConstants.MODEL_ID);

if (modelId == null) {
exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationError(String.format("Field \"%s\" does not have a dimension set.",
this.trainingField));
return exception;
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationError(String.format("Model \"%s\" for field \"%s\" does not exist.", modelId,
this.trainingField));
return exception;
}

dimension = modelMetadata.getDimension();
if ((Integer) dimension != this.dimension) {
exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationError(String.format("Field \"%s\" has dimension %d, which is different from " +
"dimension specified in the training request: %d", this.trainingField, dimension,
this.dimension));
return exception;
}

return exception;
}

// If the dimension was found in training fields mapping, check that it equals the models proposed dimension.
if ((Integer) dimension != this.dimension) {
exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationError(String.format("Field \"%s\" has dimension %d, which is different from " +
"dimension specified in the training request: %d", this.trainingField, dimension, this.dimension));
return exception;
}

return exception;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import org.opensearch.action.ActionListener;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.indices.IndicesService;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
Expand All @@ -32,14 +32,14 @@
*/
public class TrainingModelTransportAction extends HandledTransportAction<TrainingModelRequest, TrainingModelResponse> {

private final IndicesService indicesService;
private final ClusterService clusterService;

@Inject
public TrainingModelTransportAction(TransportService transportService,
ActionFilters actionFilters,
IndicesService indicesService) {
ClusterService clusterService) {
super(TrainingModelAction.NAME, transportService, actionFilters, TrainingModelRequest::new);
this.indicesService = indicesService;
this.clusterService = clusterService;
}

@Override
Expand All @@ -52,7 +52,7 @@ protected void doExecute(Task task, TrainingModelRequest request,
request.getTrainingIndex(),
request.getTrainingField(),
NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance(),
indicesService,
clusterService,
request.getMaximumVectorCount(),
request.getSearchSize()
);
Expand Down
Loading

0 comments on commit a8a02c2

Please sign in to comment.