Skip to content

Commit

Permalink
Support source filtering for model search (opensearch-project#162)
Browse files Browse the repository at this point in the history
Signed-off-by: Vijayan Balasubramanian <[email protected]>
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
VijayanB authored and martin-gaievski committed Mar 7, 2022
1 parent 8ac277f commit be820df
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 86 deletions.
68 changes: 30 additions & 38 deletions src/main/java/org/opensearch/knn/indices/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,16 @@
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.ToXContentObject;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.Base64;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
Expand Down Expand Up @@ -164,28 +159,17 @@ public int hashCode() {
return new HashCodeBuilder().append(getModelMetadata()).append(getModelBlob()).toHashCode();
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
XContentBuilder xContentBuilder = builder.startObject();
if(Strings.hasText(modelID)){
builder.field(MODEL_ID, modelID);
}
builder.field(MODEL_STATE, getModelMetadata().getState().getName());
builder.field(MODEL_TIMESTAMP, getModelMetadata().getTimestamp());
builder.field(MODEL_DESCRIPTION, getModelMetadata().getDescription());
builder.field(MODEL_ERROR, getModelMetadata().getError());

String base64Model = "";
if(getModelBlob() != null){
base64Model = Base64.getEncoder().encodeToString(getModelBlob());
}
builder.field(MODEL_BLOB_PARAMETER, base64Model);

builder.field(METHOD_PARAMETER_SPACE_TYPE, getModelMetadata().getSpaceType().getValue());
builder.field(DIMENSION, getModelMetadata().getDimension());
builder.field(KNN_ENGINE, getModelMetadata().getKnnEngine().getName());

return xContentBuilder.endObject();
/**
* Parse source map content into {@link Model} instance.
*
* @param sourceMap source contents
* @param modelID model's identifier
* @return model instance
*/
public static Model getModelFromSourceMap(Map<String, Object> sourceMap, @NonNull String modelID) {
ModelMetadata modelMetadata = ModelMetadata.getMetadataFromSourceMap(sourceMap);
byte[] blob = getModelBlobFromResponse(sourceMap);
return new Model(modelMetadata, blob, modelID);
}

private void writeOptionalModelBlob(StreamOutput output) throws IOException {
Expand Down Expand Up @@ -219,16 +203,24 @@ private static byte[] getModelBlobFromResponse(Map<String, Object> responseMap){
return Base64.getDecoder().decode((String) blob);
}

/**
* Parse source map content into {@link Model} instance.
*
* @param sourceMap source contents
* @param modelID model's identifier
* @return model instance
*/
public static Model getModelFromSourceMap(Map<String,Object> sourceMap, @NonNull String modelID) {
ModelMetadata modelMetadata = ModelMetadata.getMetadataFromSourceMap(sourceMap);
byte[] blob = getModelBlobFromResponse(sourceMap);
return new Model(modelMetadata, blob, modelID);
private static void createFieldIfNotNull(XContentBuilder builder, String fieldName, Object value) throws IOException {
if (value == null)
return;
builder.field(fieldName, value);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
XContentBuilder xContentBuilder = builder.startObject();
if (Strings.hasText(modelID)) {
builder.field(MODEL_ID, modelID);
}
String base64Model = "";
if (getModelBlob() != null) {
base64Model = Base64.getEncoder().encodeToString(getModelBlob());
}
builder.field(MODEL_BLOB_PARAMETER, base64Model);
getModelMetadata().toXContent(builder, params);
return xContentBuilder.endObject();
}
}
21 changes: 2 additions & 19 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Nullable;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.ToXContentObject;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.plugin.transport.DeleteModelResponse;
Expand All @@ -51,10 +48,6 @@
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheResponse;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestResponse;
import org.opensearch.rest.RestStatus;
import org.opensearch.search.SearchHit;

import java.io.IOException;
import java.net.URL;
Expand All @@ -63,13 +56,11 @@
import java.util.Map;
import java.util.concurrent.ExecutionException;

import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS;
import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.MODEL_METADATA_FIELD;
import static org.opensearch.knn.index.KNNSettings.MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING;
import static org.opensearch.knn.index.KNNSettings.MODEL_INDEX_NUMBER_OF_SHARDS_SETTING;
import static org.opensearch.knn.common.KNNConstants.MODEL_METADATA_FIELD;

/**
* ModelDao is used to interface with the model persistence layer
Expand Down Expand Up @@ -393,15 +384,7 @@ public void get(String modelId, ActionListener<GetModelResponse> actionListener)
@Override
public void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
request.indices(MODEL_INDEX_NAME);
client.search(request,ActionListener.wrap(response -> {
for (SearchHit hit : response.getHits()) {
ToXContentObject xContentObject = Model.getModelFromSourceMap(hit.getSourceAsMap(), hit.getId());
XContentBuilder builder = xContentObject.toXContent(jsonBuilder(), EMPTY_PARAMS);
hit.sourceRef(BytesReference.bytes(builder));
}
actionListener.onResponse(response);

}, actionListener::onFailure));
client.search(request, actionListener);
}

@Override
Expand Down
34 changes: 32 additions & 2 deletions src/main/java/org/opensearch/knn/indices/ModelMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,29 @@
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;
import org.opensearch.common.xcontent.ToXContentObject;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.Base64;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_BLOB_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR;
import static org.opensearch.knn.common.KNNConstants.MODEL_STATE;
import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP;
import static org.opensearch.knn.index.KNNVectorFieldMapper.MAX_DIMENSION;

public class ModelMetadata implements Writeable {
public class ModelMetadata implements Writeable, ToXContentObject {

private static final String DELIMITER = ",";

Expand Down Expand Up @@ -229,6 +240,12 @@ private static String objectToString(Object value) {
return (String)value;
}

private static Integer objectToInteger(Object value) {
if(value == null)
return null;
return (Integer)value;
}

/**
* Returns ModelMetadata from Map representation
*
Expand All @@ -245,7 +262,7 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
Object error = modelSourceMap.get(KNNConstants.MODEL_ERROR);

ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.getEngine(objectToString(engine)),
SpaceType.getSpace(objectToString( space)), (Integer) dimension, ModelState.getModelState(objectToString(state)),
SpaceType.getSpace(objectToString( space)), objectToInteger(dimension), ModelState.getModelState(objectToString(state)),
objectToString(timestamp), objectToString(description), objectToString( error));
return modelMetadata;
}
Expand All @@ -260,4 +277,17 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(getDescription());
out.writeString(getError());
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(MODEL_STATE, getState().getName());
builder.field(MODEL_TIMESTAMP, getTimestamp());
builder.field(MODEL_DESCRIPTION, getDescription());
builder.field(MODEL_ERROR, getError());

builder.field(METHOD_PARAMETER_SPACE_TYPE, getSpaceType().getValue());
builder.field(DIMENSION, getDimension());
builder.field(KNN_ENGINE, getKnnEngine().getName());
return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,56 +11,34 @@

package org.opensearch.knn.plugin.action;

import joptsimple.internal.Strings;
import org.apache.http.util.EntityUtils;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionFuture;
import org.opensearch.action.ActionListener;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNSingleNodeTestCase;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.knn.plugin.transport.GetModelAction;
import org.opensearch.knn.plugin.transport.GetModelRequest;
import org.opensearch.knn.plugin.transport.GetModelResponse;
import org.opensearch.rest.RestStatus;

import java.io.IOException;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODELS;
import static org.opensearch.knn.common.KNNConstants.MODEL_BLOB_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.MODEL_STATE;
import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;

/**
* Integration tests to check the correctness of {@link org.opensearch.knn.plugin.rest.RestGetModelHandler}
Expand Down Expand Up @@ -105,6 +83,43 @@ public void testGetModelExists() throws IOException {
assertEquals(testModelMetadata.getTimestamp(), responseMap.get(MODEL_TIMESTAMP));
}


public void testGetModelExistsWithFilter() throws IOException {
createModelSystemIndex();
String testModelID = "test-model-id";
byte[] testModelBlob = "hello".getBytes();
ModelMetadata testModelMetadata = getModelMetadata();

addModelToSystemIndex(testModelID, testModelMetadata, testModelBlob);

String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID);
Request request = new Request("GET", restURI);

List<String> filterdPath = Arrays.asList(MODEL_ID, MODEL_DESCRIPTION, MODEL_TIMESTAMP, KNN_ENGINE);
request.addParameter("filter_path", Strings.join(filterdPath, ","));

Response response = client().performRequest(request);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

String responseBody = EntityUtils.toString(response.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(
XContentType.JSON.xContent(),
responseBody
).map();

assertTrue(responseMap.size() == filterdPath.size());
assertEquals(testModelID, responseMap.get(MODEL_ID));
assertEquals(testModelMetadata.getDescription(), responseMap.get(MODEL_DESCRIPTION));
assertEquals(testModelMetadata.getTimestamp(), responseMap.get(MODEL_TIMESTAMP));
assertEquals(testModelMetadata.getKnnEngine().getName(), responseMap.get(KNN_ENGINE));
assertFalse(responseMap.containsKey(DIMENSION));
assertFalse(responseMap.containsKey(MODEL_ERROR));
assertFalse(responseMap.containsKey(METHOD_PARAMETER_SPACE_TYPE));
assertFalse(responseMap.containsKey(MODEL_STATE));
}

public void testGetModelFailsInvalid() throws IOException {
createModelSystemIndex();
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "invalid-model-id");
Expand Down
Loading

0 comments on commit be820df

Please sign in to comment.