diff --git a/src/main/java/org/opensearch/knn/indices/Model.java b/src/main/java/org/opensearch/knn/indices/Model.java index 2cff1911dd..a04bdf7b1a 100644 --- a/src/main/java/org/opensearch/knn/indices/Model.java +++ b/src/main/java/org/opensearch/knn/indices/Model.java @@ -19,13 +19,9 @@ import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.Writeable; -import org.opensearch.common.xcontent.ToXContent; import org.opensearch.common.xcontent.ToXContentObject; import org.opensearch.common.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentParser; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; import java.util.Base64; @@ -33,7 +29,6 @@ import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; @@ -164,28 +159,17 @@ public int hashCode() { return new HashCodeBuilder().append(getModelMetadata()).append(getModelBlob()).toHashCode(); } - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - XContentBuilder xContentBuilder = builder.startObject(); - if(Strings.hasText(modelID)){ - builder.field(MODEL_ID, modelID); - } - builder.field(MODEL_STATE, getModelMetadata().getState().getName()); - builder.field(MODEL_TIMESTAMP, getModelMetadata().getTimestamp()); - builder.field(MODEL_DESCRIPTION, getModelMetadata().getDescription()); - builder.field(MODEL_ERROR, getModelMetadata().getError()); - - String base64Model = ""; - if(getModelBlob() != null){ - base64Model = Base64.getEncoder().encodeToString(getModelBlob()); - } - builder.field(MODEL_BLOB_PARAMETER, base64Model); - - builder.field(METHOD_PARAMETER_SPACE_TYPE, getModelMetadata().getSpaceType().getValue()); - builder.field(DIMENSION, getModelMetadata().getDimension()); - builder.field(KNN_ENGINE, getModelMetadata().getKnnEngine().getName()); - - return xContentBuilder.endObject(); + /** + * Parse source map content into {@link Model} instance. + * + * @param sourceMap source contents + * @param modelID model's identifier + * @return model instance + */ + public static Model getModelFromSourceMap(Map sourceMap, @NonNull String modelID) { + ModelMetadata modelMetadata = ModelMetadata.getMetadataFromSourceMap(sourceMap); + byte[] blob = getModelBlobFromResponse(sourceMap); + return new Model(modelMetadata, blob, modelID); } private void writeOptionalModelBlob(StreamOutput output) throws IOException { @@ -219,16 +203,24 @@ private static byte[] getModelBlobFromResponse(Map responseMap){ return Base64.getDecoder().decode((String) blob); } - /** - * Parse source map content into {@link Model} instance. - * - * @param sourceMap source contents - * @param modelID model's identifier - * @return model instance - */ - public static Model getModelFromSourceMap(Map sourceMap, @NonNull String modelID) { - ModelMetadata modelMetadata = ModelMetadata.getMetadataFromSourceMap(sourceMap); - byte[] blob = getModelBlobFromResponse(sourceMap); - return new Model(modelMetadata, blob, modelID); + private static void createFieldIfNotNull(XContentBuilder builder, String fieldName, Object value) throws IOException { + if (value == null) + return; + builder.field(fieldName, value); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + if (Strings.hasText(modelID)) { + builder.field(MODEL_ID, modelID); + } + String base64Model = ""; + if (getModelBlob() != null) { + base64Model = Base64.getEncoder().encodeToString(getModelBlob()); + } + builder.field(MODEL_BLOB_PARAMETER, base64Model); + getModelMetadata().toXContent(builder, params); + return xContentBuilder.endObject(); } } diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 11fc583719..d4aa5490c9 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -393,15 +393,7 @@ public void get(String modelId, ActionListener actionListener) @Override public void search(SearchRequest request, ActionListener actionListener) { request.indices(MODEL_INDEX_NAME); - client.search(request,ActionListener.wrap(response -> { - for (SearchHit hit : response.getHits()) { - ToXContentObject xContentObject = Model.getModelFromSourceMap(hit.getSourceAsMap(), hit.getId()); - XContentBuilder builder = xContentObject.toXContent(jsonBuilder(), EMPTY_PARAMS); - hit.sourceRef(BytesReference.bytes(builder)); - } - actionListener.onResponse(response); - - }, actionListener::onFailure)); + client.search(request,actionListener); } @Override diff --git a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java index 7330d0f365..b2c81aef8b 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java +++ b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java @@ -16,18 +16,29 @@ import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.xcontent.ToXContentObject; +import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; +import java.util.Base64; import java.util.Map; import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.MODEL_BLOB_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; +import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR; +import static org.opensearch.knn.common.KNNConstants.MODEL_STATE; +import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP; import static org.opensearch.knn.index.KNNVectorFieldMapper.MAX_DIMENSION; -public class ModelMetadata implements Writeable { +public class ModelMetadata implements Writeable, ToXContentObject { private static final String DELIMITER = ","; @@ -229,6 +240,12 @@ private static String objectToString(Object value) { return (String)value; } + private static Integer objectToInteger(Object value) { + if(value == null) + return null; + return (Integer)value; + } + /** * Returns ModelMetadata from Map representation * @@ -245,7 +262,7 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m Object error = modelSourceMap.get(KNNConstants.MODEL_ERROR); ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.getEngine(objectToString(engine)), - SpaceType.getSpace(objectToString( space)), (Integer) dimension, ModelState.getModelState(objectToString(state)), + SpaceType.getSpace(objectToString( space)), objectToInteger(dimension), ModelState.getModelState(objectToString(state)), objectToString(timestamp), objectToString(description), objectToString( error)); return modelMetadata; } @@ -260,4 +277,17 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(getDescription()); out.writeString(getError()); } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_STATE, getState().getName()); + builder.field(MODEL_TIMESTAMP, getTimestamp()); + builder.field(MODEL_DESCRIPTION, getDescription()); + builder.field(MODEL_ERROR, getError()); + + builder.field(METHOD_PARAMETER_SPACE_TYPE, getSpaceType().getValue()); + builder.field(DIMENSION, getDimension()); + builder.field(KNN_ENGINE, getKnnEngine().getName()); + return builder; + } } diff --git a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java index 2d2b6757f2..b53fcc374c 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java @@ -32,7 +32,7 @@ public void testStreams() throws IOException { int dimension = 128; ModelMetadata modelMetadata = new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); BytesStreamOutput streamOutput = new BytesStreamOutput(); modelMetadata.writeTo(streamOutput); @@ -45,7 +45,7 @@ public void testStreams() throws IOException { public void testGetKnnEngine() { KNNEngine knnEngine = KNNEngine.DEFAULT; ModelMetadata modelMetadata = new ModelMetadata(knnEngine, SpaceType.L2, 128, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); assertEquals(knnEngine, modelMetadata.getKnnEngine()); } @@ -53,7 +53,7 @@ public void testGetKnnEngine() { public void testGetSpaceType() { SpaceType spaceType = SpaceType.L2; ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, spaceType, 128, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); assertEquals(spaceType, modelMetadata.getSpaceType()); } @@ -61,7 +61,7 @@ public void testGetSpaceType() { public void testGetDimension() { int dimension = 128; ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); assertEquals(dimension, modelMetadata.getDimension()); } @@ -69,7 +69,7 @@ public void testGetDimension() { public void testGetState() { ModelState modelState = ModelState.FAILED; ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, modelState, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); assertEquals(modelState, modelMetadata.getState()); } @@ -77,7 +77,7 @@ public void testGetState() { public void testGetTimestamp() { String timeValue = ZonedDateTime.now(ZoneOffset.UTC).toString(); ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, ModelState.CREATED, - timeValue, "", ""); + timeValue, "", ""); assertEquals(timeValue, modelMetadata.getTimestamp()); } @@ -85,7 +85,7 @@ public void testGetTimestamp() { public void testDescription() { String description = "test description"; ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), description, ""); + ZonedDateTime.now(ZoneOffset.UTC).toString(), description, ""); assertEquals(description, modelMetadata.getDescription()); } @@ -93,7 +93,7 @@ public void testDescription() { public void testGetError() { String error = "test error"; ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", error); + ZonedDateTime.now(ZoneOffset.UTC).toString(), "", error); assertEquals(error, modelMetadata.getError()); } @@ -101,7 +101,7 @@ public void testGetError() { public void testSetState() { ModelState modelState = ModelState.FAILED; ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, modelState, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); assertEquals(modelState, modelMetadata.getState()); @@ -113,7 +113,7 @@ public void testSetState() { public void testSetError() { String error = ""; ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, ModelState.TRAINING, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", error); + ZonedDateTime.now(ZoneOffset.UTC).toString(), "", error); assertEquals(error, modelMetadata.getError()); @@ -132,15 +132,15 @@ public void testToString() { String error = "test-error"; String expected = knnEngine.getName() + "," + - spaceType.getValue() + "," + - dimension + "," + - modelState.getName() + "," + - timestamp + "," + - description + "," + - error; + spaceType.getValue() + "," + + dimension + "," + + modelState.getName() + "," + + timestamp + "," + + description + "," + + error; ModelMetadata modelMetadata = new ModelMetadata(knnEngine, spaceType, dimension, modelState, - timestamp, description, error); + timestamp, description, error); assertEquals(expected, modelMetadata.toString()); } @@ -149,27 +149,27 @@ public void testEquals() { String time1 = ZonedDateTime.now(ZoneOffset.UTC).toString(); String time2 = ZonedDateTime.of(2021, 9, 30,12, 20, 45, 1, - ZoneId.systemDefault()).toString(); + ZoneId.systemDefault()).toString(); ModelMetadata modelMetadata1 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time1, "", ""); + time1, "", ""); ModelMetadata modelMetadata2 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time1, "", ""); + time1, "", ""); ModelMetadata modelMetadata3 = new ModelMetadata(KNNEngine.NMSLIB, SpaceType.L2, 128, ModelState.CREATED, - time1, "", ""); + time1, "", ""); ModelMetadata modelMetadata4 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L1, 128, ModelState.CREATED, - time1, "", ""); + time1, "", ""); ModelMetadata modelMetadata5 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 129, ModelState.CREATED, - time1, "", ""); + time1, "", ""); ModelMetadata modelMetadata6 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.TRAINING, - time1, "", ""); + time1, "", ""); ModelMetadata modelMetadata7 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time2, "", ""); + time2, "", ""); ModelMetadata modelMetadata8 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time1, "diff descript", ""); + time1, "diff descript", ""); ModelMetadata modelMetadata9 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time1, "", "diff error"); + time1, "", "diff error"); assertEquals(modelMetadata1, modelMetadata1); assertEquals(modelMetadata1, modelMetadata2); @@ -188,27 +188,27 @@ public void testHashCode() { String time1 = ZonedDateTime.now(ZoneOffset.UTC).toString(); String time2 = ZonedDateTime.of(2021, 9, 30,12, 20, 45, 1, - ZoneId.systemDefault()).toString(); + ZoneId.systemDefault()).toString(); ModelMetadata modelMetadata1 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time1, "", ""); + time1, "", ""); ModelMetadata modelMetadata2 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time1, "", ""); + time1, "", ""); ModelMetadata modelMetadata3 = new ModelMetadata(KNNEngine.NMSLIB, SpaceType.L2, 128, ModelState.CREATED, - time1, "", ""); + time1, "", ""); ModelMetadata modelMetadata4 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L1, 128, ModelState.CREATED, - time1, "", ""); + time1, "", ""); ModelMetadata modelMetadata5 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 129, ModelState.CREATED, - time1, "", ""); + time1, "", ""); ModelMetadata modelMetadata6 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.TRAINING, - time1, "", ""); + time1, "", ""); ModelMetadata modelMetadata7 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time2, "", ""); + time2, "", ""); ModelMetadata modelMetadata8 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time1, "diff descript", ""); + time1, "diff descript", ""); ModelMetadata modelMetadata9 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED, - time1, "", "diff error"); + time1, "", "diff error"); assertEquals(modelMetadata1.hashCode(), modelMetadata1.hashCode()); assertEquals(modelMetadata1.hashCode(), modelMetadata2.hashCode()); @@ -233,16 +233,16 @@ public void testFromString() { String error = "test-error"; String stringRep1 = knnEngine.getName() + "," + - spaceType.getValue() + "," + - dimension + "," + - modelState.getName() + "," + - timestamp + "," + - description + "," + - error; + spaceType.getValue() + "," + + dimension + "," + + modelState.getName() + "," + + timestamp + "," + + description + "," + + error; ModelMetadata expected = new ModelMetadata(knnEngine, spaceType, dimension, modelState, - timestamp, description, error); + timestamp, description, error); ModelMetadata fromString1 = ModelMetadata.fromString(stringRep1); assertEquals(expected, fromString1); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java index a86d8fb769..806d435302 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java @@ -22,6 +22,7 @@ import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; import org.opensearch.knn.plugin.KNNPlugin; @@ -117,6 +118,144 @@ public void testSearchModelExists() throws IOException { for(SearchHit hit: searchResponse.getHits().getHits()){ assertTrue(testModelID.contains(hit.getId())); + Model model = Model.getModelFromSourceMap(hit.getSourceAsMap(), hit.getId()); + assertEquals(getModelMetadata(),model.getModelMetadata()); + assertArrayEquals(testModelBlob, model.getModelBlob()); + } + } + } + + public void testSearchModelWithoutSource() throws IOException { + createModelSystemIndex(); + createIndex("irrelevant-index", Settings.EMPTY); + addDocWithBinaryField("irrelevant-index", "id1", "field-name", "value"); + List testModelID = Arrays.asList("test-modelid1", "test-modelid2"); + byte[] testModelBlob = "hello".getBytes(); + ModelMetadata testModelMetadata = getModelMetadata(); + for(String modelID: testModelID){ + addModelToSystemIndex(modelID, testModelMetadata, testModelBlob); + } + + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); + + for(String method: Arrays.asList("GET", "POST")){ + Request request = new Request(method, restURI); + request.setJsonEntity("{\n" + + " \"_source\" : false,\n" + + " \"query\": {\n" + + " \"match_all\": {}\n" + + " }\n" + + "}"); + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + String responseBody = EntityUtils.toString(response.getEntity()); + assertNotNull(responseBody); + + XContentParser parser = createParser(XContentType.JSON.xContent(), responseBody); + SearchResponse searchResponse = SearchResponse.fromXContent(parser); + assertNotNull(searchResponse); + + //returns only model from ModelIndex + assertEquals(searchResponse.getHits().getHits().length, testModelID.size()); + + for(SearchHit hit: searchResponse.getHits().getHits()){ + assertTrue(testModelID.contains(hit.getId())); + assertNull(hit.getSourceAsMap()); + } + } + } + + public void testSearchModelWithSourceFilteringIncludes() throws IOException { + createModelSystemIndex(); + createIndex("irrelevant-index", Settings.EMPTY); + addDocWithBinaryField("irrelevant-index", "id1", "field-name", "value"); + List testModelID = Arrays.asList("test-modelid1", "test-modelid2"); + byte[] testModelBlob = "hello".getBytes(); + ModelMetadata testModelMetadata = getModelMetadata(); + for(String modelID: testModelID){ + addModelToSystemIndex(modelID, testModelMetadata, testModelBlob); + } + + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); + + for(String method: Arrays.asList("GET", "POST")){ + Request request = new Request(method, restURI); + request.setJsonEntity("{\n" + + " \"_source\": {\n" + + " \"includes\": [ \"state\", \"description\" ]\n"+ + " }, " + + " \"query\": {\n" + + " \"match_all\": {}\n" + + " }\n" + + "}"); + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + String responseBody = EntityUtils.toString(response.getEntity()); + assertNotNull(responseBody); + + XContentParser parser = createParser(XContentType.JSON.xContent(), responseBody); + SearchResponse searchResponse = SearchResponse.fromXContent(parser); + assertNotNull(searchResponse); + + //returns only model from ModelIndex + assertEquals(searchResponse.getHits().getHits().length, testModelID.size()); + + for(SearchHit hit: searchResponse.getHits().getHits()){ + assertTrue(testModelID.contains(hit.getId())); + Map sourceAsMap = hit.getSourceAsMap(); + assertFalse(sourceAsMap.containsKey("model_blob")); + assertTrue(sourceAsMap.containsKey("state")); + assertFalse(sourceAsMap.containsKey("timestamp")); + assertTrue(sourceAsMap.containsKey("description")); + } + } + } + + public void testSearchModelWithSourceFilteringExcludes() throws IOException { + createModelSystemIndex(); + createIndex("irrelevant-index", Settings.EMPTY); + addDocWithBinaryField("irrelevant-index", "id1", "field-name", "value"); + List testModelID = Arrays.asList("test-modelid1", "test-modelid2"); + byte[] testModelBlob = "hello".getBytes(); + ModelMetadata testModelMetadata = getModelMetadata(); + for(String modelID: testModelID){ + addModelToSystemIndex(modelID, testModelMetadata, testModelBlob); + } + + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); + + for(String method: Arrays.asList("GET", "POST")){ + Request request = new Request(method, restURI); + request.setJsonEntity("{\n" + + " \"_source\": {\n" + + " \"excludes\": [\"model_blob\" ]\n"+ + " }, " + + " \"query\": {\n" + + " \"match_all\": {}\n" + + " }\n" + + "}"); + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + String responseBody = EntityUtils.toString(response.getEntity()); + assertNotNull(responseBody); + + XContentParser parser = createParser(XContentType.JSON.xContent(), responseBody); + SearchResponse searchResponse = SearchResponse.fromXContent(parser); + assertNotNull(searchResponse); + + //returns only model from ModelIndex + assertEquals(searchResponse.getHits().getHits().length, testModelID.size()); + + for(SearchHit hit: searchResponse.getHits().getHits()){ + assertTrue(testModelID.contains(hit.getId())); + Map sourceAsMap = hit.getSourceAsMap(); + assertFalse(sourceAsMap.containsKey("model_blob")); + assertTrue(sourceAsMap.containsKey("state")); + assertTrue(sourceAsMap.containsKey("timestamp")); + assertTrue(sourceAsMap.containsKey("description")); } } } diff --git a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java index 5aad37ff28..dd797ccd52 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java @@ -49,7 +49,7 @@ public void testXContent() throws IOException { byte[] testModelBlob = "hello".getBytes(); Model model = new Model(getModelMetadata(ModelState.CREATED), testModelBlob,modelId); GetModelResponse getModelResponse = new GetModelResponse(model); - String expectedResponseString = "{\"model_id\":\"test-model\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"model_blob\":\"aGVsbG8=\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\"}"; + String expectedResponseString = "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\"}"; XContentBuilder xContentBuilder = XContentFactory.contentBuilder(XContentType.JSON); getModelResponse.toXContent(xContentBuilder, null); assertEquals(expectedResponseString, Strings.toString(xContentBuilder)); @@ -59,7 +59,7 @@ public void testXContentWithNoModelBlob() throws IOException { String modelId = "test-model"; Model model = new Model(getModelMetadata(ModelState.FAILED), null, modelId); GetModelResponse getModelResponse = new GetModelResponse(model); - String expectedResponseString = "{\"model_id\":\"test-model\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"model_blob\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\"}"; + String expectedResponseString = "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\"}"; XContentBuilder xContentBuilder = XContentFactory.contentBuilder(XContentType.JSON); getModelResponse.toXContent(xContentBuilder, null); assertEquals(expectedResponseString, Strings.toString(xContentBuilder));