Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validation for pq m parameter before training starts #1713

Merged
merged 15 commits into from
May 30, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Add KnnCircuitBreakerException and modify exception message [#1688](https://github.com/opensearch-project/k-NN/pull/1688)
* Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684)
* Support script score when doc value is disabled and fix misusing DISI [#1696](https://github.com/opensearch-project/k-NN/pull/1696)
* Add validation for pq m parameter before training starts [#1713](https://github.com/opensearch-project/k-NN/pull/1713)
### Bug Fixes
* Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692)
### Infrastructure
Expand Down
30 changes: 30 additions & 0 deletions src/main/java/org/opensearch/knn/index/KNNMethod.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import lombok.Getter;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.training.TrainingDataSpec;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -77,6 +78,35 @@ public ValidationException validate(KNNMethodContext knnMethodContext) {
return validationException;
}

public ValidationException validateWithData(KNNMethodContext knnMethodContext, TrainingDataSpec trainingDataSpec) {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
List<String> errorMessages = new ArrayList<>();
if (!containsSpace(knnMethodContext.getSpaceType())) {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
errorMessages.add(
String.format(
"\"%s\" configuration does not support space type: " + "\"%s\".",
this.methodComponent.getName(),
knnMethodContext.getSpaceType().getValue()
)
);
}

ValidationException methodValidation = methodComponent.validateWithData(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have to validate further even if we don't support space type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though validation will ultimately fail if space type is not supported, more error messages can be added to the errors based on any potential problems with the method component context

knnMethodContext.getMethodComponentContext(),
trainingDataSpec
);
if (methodValidation != null) {
errorMessages.addAll(methodValidation.validationErrors());
}

if (errorMessages.isEmpty()) {
return null;
}

ValidationException validationException = new ValidationException();
validationException.addValidationErrors(errorMessages);
return validationException;
}

/**
* returns whether training is required or not
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.stream.Collectors;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.opensearch.knn.training.TrainingDataSpec;

import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
Expand Down Expand Up @@ -86,6 +87,10 @@ public ValidationException validate() {
return knnEngine.validateMethod(this);
}

public ValidationException validateWithData(TrainingDataSpec trainingDataSpec) {
return knnEngine.validateMethodWithData(this, trainingDataSpec);
}

/**
* This method returns whether training is requires or not from knnEngine
*
Expand Down
31 changes: 31 additions & 0 deletions src/main/java/org/opensearch/knn/index/MethodComponent.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.util.IndexHyperParametersUtil;
import org.opensearch.knn.training.TrainingDataSpec;

import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -102,6 +103,36 @@ public ValidationException validate(MethodComponentContext methodComponentContex
return validationException;
}

public ValidationException validateWithData(MethodComponentContext methodComponentContext, TrainingDataSpec trainingDataSpec) {
Map<String, Object> providedParameters = methodComponentContext.getParameters();
List<String> errorMessages = new ArrayList<>();

if (providedParameters == null) {
return null;
}

ValidationException parameterValidation;
for (Map.Entry<String, Object> parameter : providedParameters.entrySet()) {
if (!parameters.containsKey(parameter.getKey())) {
errorMessages.add(String.format("Invalid parameter for method \"%s\".", getName()));
continue;
}

parameterValidation = parameters.get(parameter.getKey()).validateWithData(parameter.getValue(), trainingDataSpec);
if (parameterValidation != null) {
errorMessages.addAll(parameterValidation.validationErrors());
}
}

if (errorMessages.isEmpty()) {
return null;
}

ValidationException validationException = new ValidationException();
validationException.addValidationErrors(errorMessages);
return validationException;
}

/**
* gets requiresTraining value
*
Expand Down
134 changes: 134 additions & 0 deletions src/main/java/org/opensearch/knn/index/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
package org.opensearch.knn.index;

import org.opensearch.common.ValidationException;
import org.opensearch.knn.training.TrainingDataSpec;

import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Predicate;

/**
Expand All @@ -26,6 +28,7 @@ public abstract class Parameter<T> {
private String name;
private T defaultValue;
protected Predicate<T> validator;
protected BiFunction<T, TrainingDataSpec, Boolean> validatorWithData;

/**
* Constructor
Expand All @@ -38,6 +41,14 @@ public Parameter(String name, T defaultValue, Predicate<T> validator) {
this.name = name;
this.defaultValue = defaultValue;
this.validator = validator;
this.validatorWithData = null;
}

public Parameter(String name, T defaultValue, Predicate<T> validator, BiFunction<T, TrainingDataSpec, Boolean> validatorWithData) {
this.name = name;
this.defaultValue = defaultValue;
this.validator = validator;
this.validatorWithData = validatorWithData;
}

/**
Expand Down Expand Up @@ -66,6 +77,8 @@ public T getDefaultValue() {
*/
public abstract ValidationException validate(Object value);

public abstract ValidationException validateWithData(Object value, TrainingDataSpec trainingDataSpec);

/**
* Boolean method parameter
*/
Expand All @@ -74,6 +87,15 @@ public BooleanParameter(String name, Boolean defaultValue, Predicate<Boolean> va
super(name, defaultValue, validator);
}

public BooleanParameter(
String name,
Boolean defaultValue,
Predicate<Boolean> validator,
BiFunction<Boolean, TrainingDataSpec, Boolean> validatorWithData
) {
super(name, defaultValue, validator, validatorWithData);
}

@Override
public ValidationException validate(Object value) {
ValidationException validationException = null;
Expand All @@ -89,6 +111,27 @@ public ValidationException validate(Object value) {
}
return validationException;
}

@Override
public ValidationException validateWithData(Object value, TrainingDataSpec trainingDataSpec) {
ValidationException validationException = null;
if (!(value instanceof Boolean)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("value not of type Boolean for Boolean parameter [%s].", getName()));
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
return validationException;
}

if (validatorWithData == null) {
return validationException;
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
}

if (!validatorWithData.apply((Boolean) value, trainingDataSpec)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("parameter validation failed for Boolean parameter [%s].", getName()));
}

return validationException;
}
}

/**
Expand All @@ -99,6 +142,15 @@ public IntegerParameter(String name, Integer defaultValue, Predicate<Integer> va
super(name, defaultValue, validator);
}

public IntegerParameter(
String name,
Integer defaultValue,
Predicate<Integer> validator,
BiFunction<Integer, TrainingDataSpec, Boolean> validatorWithData
) {
super(name, defaultValue, validator, validatorWithData);
}

@Override
public ValidationException validate(Object value) {
ValidationException validationException = null;
Expand All @@ -118,6 +170,27 @@ public ValidationException validate(Object value) {
}
return validationException;
}

@Override
public ValidationException validateWithData(Object value, TrainingDataSpec trainingDataSpec) {
ValidationException validationException = null;
if (!(value instanceof Integer)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("value not of type Integer for Integer parameter [%s].", getName()));
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
return validationException;
}

if (validatorWithData == null) {
return validationException;
}

if (!validatorWithData.apply((Integer) value, trainingDataSpec)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("parameter validation failed for Integer parameter [%s].", getName()));
}

return validationException;
}
}

/**
Expand All @@ -136,6 +209,15 @@ public StringParameter(String name, String defaultValue, Predicate<String> valid
super(name, defaultValue, validator);
}

public StringParameter(
String name,
String defaultValue,
Predicate<String> validator,
BiFunction<String, TrainingDataSpec, Boolean> validatorWithData
) {
super(name, defaultValue, validator, validatorWithData);
}

/**
* Check if the value passed in is valid
*
Expand All @@ -161,6 +243,27 @@ public ValidationException validate(Object value) {
}
return validationException;
}

@Override
public ValidationException validateWithData(Object value, TrainingDataSpec trainingDataSpec) {
ValidationException validationException = null;
if (!(value instanceof String)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("value not of type String for String parameter [%s].", getName()));
return validationException;
}

if (validatorWithData == null) {
return validationException;
}

if (!validatorWithData.apply((String) value, trainingDataSpec)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("parameter validation failed for String parameter [%s].", getName()));
}

return validationException;
}
}

/**
Expand Down Expand Up @@ -190,6 +293,12 @@ public MethodComponentContextParameter(
}

return methodComponents.get(methodComponentContext.getName()).validate(methodComponentContext) == null;
}, (methodComponentContext, trainingDataSpec) -> {
if (!methodComponents.containsKey(methodComponentContext.getName())) {
return false;
}
return methodComponents.get(methodComponentContext.getName())
.validateWithData(methodComponentContext, trainingDataSpec) == null;
});
this.methodComponents = methodComponents;
}
Expand All @@ -216,6 +325,31 @@ public ValidationException validate(Object value) {
return validationException;
}

@Override
public ValidationException validateWithData(Object value, TrainingDataSpec trainingDataSpec) {
ValidationException validationException = null;
if (!(value instanceof MethodComponentContext)) {
validationException = new ValidationException();
validationException.addValidationError(
String.format("value not of type MethodComponentContext for MethodComponentContext parameter [%s].", getName())
);
return validationException;
}

if (validatorWithData == null) {
return validationException;
}

if (!validatorWithData.apply((MethodComponentContext) value, trainingDataSpec)) {
validationException = new ValidationException();
validationException.addValidationError(
String.format("parameter validation failed for MethodComponentContext parameter [%s].", getName())
);
}

return validationException;
}

/**
* Get method component by name
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.common.ValidationException;
import org.opensearch.knn.index.KNNMethod;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.training.TrainingDataSpec;

import java.util.Map;

Expand Down Expand Up @@ -39,6 +40,12 @@ public ValidationException validateMethod(KNNMethodContext knnMethodContext) {
return getMethod(methodName).validate(knnMethodContext);
}

@Override
public ValidationException validateMethodWithData(KNNMethodContext knnMethodContext, TrainingDataSpec trainingDataSpec) {
String methodName = knnMethodContext.getMethodComponentContext().getName();
return getMethod(methodName).validateWithData(knnMethodContext, trainingDataSpec);
}

@Override
public boolean isTrainingRequired(KNNMethodContext knnMethodContext) {
String methodName = knnMethodContext.getMethodComponentContext().getName();
Expand Down
9 changes: 4 additions & 5 deletions src/main/java/org/opensearch/knn/index/util/Faiss.java
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@ class Faiss extends NativeLibrary {
.build()
);

// TODO: To think about in future: for PQ, if dimension is not divisible by code count, PQ will fail. Right now,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice - finally getting rid of it

// we do not have a way to base validation off of dimension. Failure will happen during training in JNI.
// Define methods supported by faiss. See issue here: https://github.com/opensearch-project/k-NN/issues/1075
private final static Map<String, MethodComponent> HNSW_ENCODERS = ImmutableMap.<String, MethodComponent>builder()
.putAll(
ImmutableMap.of(
Expand All @@ -122,7 +119,8 @@ class Faiss extends NativeLibrary {
new Parameter.IntegerParameter(
ENCODER_PARAMETER_PQ_M,
ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT,
v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT
v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT,
(v, trainingDataSpec) -> trainingDataSpec.getDimension() % v == 0
)
)
.addParameter(
Expand Down Expand Up @@ -161,7 +159,8 @@ class Faiss extends NativeLibrary {
new Parameter.IntegerParameter(
ENCODER_PARAMETER_PQ_M,
ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT,
v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT
v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT,
(v, trainingDataSpec) -> trainingDataSpec.getDimension() % v == 0
)
)
.addParameter(
Expand Down
Loading
Loading