Skip to content

Commit

Permalink
Update Model Search Response
Browse files Browse the repository at this point in the history
Return model index search response as output for
model search api.

Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Oct 29, 2021
1 parent c9a6651 commit 9714763
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 95 deletions.
68 changes: 30 additions & 38 deletions src/main/java/org/opensearch/knn/indices/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,16 @@
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.ToXContentObject;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.Base64;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
Expand Down Expand Up @@ -164,28 +159,17 @@ public int hashCode() {
return new HashCodeBuilder().append(getModelMetadata()).append(getModelBlob()).toHashCode();
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
XContentBuilder xContentBuilder = builder.startObject();
if(Strings.hasText(modelID)){
builder.field(MODEL_ID, modelID);
}
builder.field(MODEL_STATE, getModelMetadata().getState().getName());
builder.field(MODEL_TIMESTAMP, getModelMetadata().getTimestamp());
builder.field(MODEL_DESCRIPTION, getModelMetadata().getDescription());
builder.field(MODEL_ERROR, getModelMetadata().getError());

String base64Model = "";
if(getModelBlob() != null){
base64Model = Base64.getEncoder().encodeToString(getModelBlob());
}
builder.field(MODEL_BLOB_PARAMETER, base64Model);

builder.field(METHOD_PARAMETER_SPACE_TYPE, getModelMetadata().getSpaceType().getValue());
builder.field(DIMENSION, getModelMetadata().getDimension());
builder.field(KNN_ENGINE, getModelMetadata().getKnnEngine().getName());

return xContentBuilder.endObject();
/**
* Parse source map content into {@link Model} instance.
*
* @param sourceMap source contents
* @param modelID model's identifier
* @return model instance
*/
public static Model getModelFromSourceMap(Map<String, Object> sourceMap, @NonNull String modelID) {
ModelMetadata modelMetadata = ModelMetadata.getMetadataFromSourceMap(sourceMap);
byte[] blob = getModelBlobFromResponse(sourceMap);
return new Model(modelMetadata, blob, modelID);
}

private void writeOptionalModelBlob(StreamOutput output) throws IOException {
Expand Down Expand Up @@ -219,16 +203,24 @@ private static byte[] getModelBlobFromResponse(Map<String, Object> responseMap){
return Base64.getDecoder().decode((String) blob);
}

/**
* Parse source map content into {@link Model} instance.
*
* @param sourceMap source contents
* @param modelID model's identifier
* @return model instance
*/
public static Model getModelFromSourceMap(Map<String,Object> sourceMap, @NonNull String modelID) {
ModelMetadata modelMetadata = ModelMetadata.getMetadataFromSourceMap(sourceMap);
byte[] blob = getModelBlobFromResponse(sourceMap);
return new Model(modelMetadata, blob, modelID);
private static void createFieldIfNotNull(XContentBuilder builder, String fieldName, Object value) throws IOException {
if (value == null)
return;
builder.field(fieldName, value);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
XContentBuilder xContentBuilder = builder.startObject();
if (Strings.hasText(modelID)) {
builder.field(MODEL_ID, modelID);
}
String base64Model = "";
if (getModelBlob() != null) {
base64Model = Base64.getEncoder().encodeToString(getModelBlob());
}
builder.field(MODEL_BLOB_PARAMETER, base64Model);
getModelMetadata().toXContent(builder, params);
return xContentBuilder.endObject();
}
}
10 changes: 1 addition & 9 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -393,15 +393,7 @@ public void get(String modelId, ActionListener<GetModelResponse> actionListener)
@Override
public void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
request.indices(MODEL_INDEX_NAME);
client.search(request,ActionListener.wrap(response -> {
for (SearchHit hit : response.getHits()) {
ToXContentObject xContentObject = Model.getModelFromSourceMap(hit.getSourceAsMap(), hit.getId());
XContentBuilder builder = xContentObject.toXContent(jsonBuilder(), EMPTY_PARAMS);
hit.sourceRef(BytesReference.bytes(builder));
}
actionListener.onResponse(response);

}, actionListener::onFailure));
client.search(request,actionListener);
}

@Override
Expand Down
34 changes: 32 additions & 2 deletions src/main/java/org/opensearch/knn/indices/ModelMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,29 @@
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;
import org.opensearch.common.xcontent.ToXContentObject;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.Base64;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_BLOB_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR;
import static org.opensearch.knn.common.KNNConstants.MODEL_STATE;
import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP;
import static org.opensearch.knn.index.KNNVectorFieldMapper.MAX_DIMENSION;

public class ModelMetadata implements Writeable {
public class ModelMetadata implements Writeable, ToXContentObject {

private static final String DELIMITER = ",";

Expand Down Expand Up @@ -229,6 +240,12 @@ private static String objectToString(Object value) {
return (String)value;
}

private static Integer objectToInteger(Object value) {
if(value == null)
return null;
return (Integer)value;
}

/**
* Returns ModelMetadata from Map representation
*
Expand All @@ -245,7 +262,7 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
Object error = modelSourceMap.get(KNNConstants.MODEL_ERROR);

ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.getEngine(objectToString(engine)),
SpaceType.getSpace(objectToString( space)), (Integer) dimension, ModelState.getModelState(objectToString(state)),
SpaceType.getSpace(objectToString( space)), objectToInteger(dimension), ModelState.getModelState(objectToString(state)),
objectToString(timestamp), objectToString(description), objectToString( error));
return modelMetadata;
}
Expand All @@ -260,4 +277,17 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(getDescription());
out.writeString(getError());
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(MODEL_STATE, getState().getName());
builder.field(MODEL_TIMESTAMP, getTimestamp());
builder.field(MODEL_DESCRIPTION, getDescription());
builder.field(MODEL_ERROR, getError());

builder.field(METHOD_PARAMETER_SPACE_TYPE, getSpaceType().getValue());
builder.field(DIMENSION, getDimension());
builder.field(KNN_ENGINE, getKnnEngine().getName());
return builder;
}
}
88 changes: 44 additions & 44 deletions src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public void testStreams() throws IOException {
int dimension = 128;

ModelMetadata modelMetadata = new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED,
ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "");
ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "");

BytesStreamOutput streamOutput = new BytesStreamOutput();
modelMetadata.writeTo(streamOutput);
Expand All @@ -45,63 +45,63 @@ public void testStreams() throws IOException {
public void testGetKnnEngine() {
KNNEngine knnEngine = KNNEngine.DEFAULT;
ModelMetadata modelMetadata = new ModelMetadata(knnEngine, SpaceType.L2, 128, ModelState.CREATED,
ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "");
ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "");

assertEquals(knnEngine, modelMetadata.getKnnEngine());
}

public void testGetSpaceType() {
SpaceType spaceType = SpaceType.L2;
ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, spaceType, 128, ModelState.CREATED,
ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "");
ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "");

assertEquals(spaceType, modelMetadata.getSpaceType());
}

public void testGetDimension() {
int dimension = 128;
ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, dimension, ModelState.CREATED,
ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "");
ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "");

assertEquals(dimension, modelMetadata.getDimension());
}

public void testGetState() {
ModelState modelState = ModelState.FAILED;
ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, modelState,
ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "");
ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "");

assertEquals(modelState, modelMetadata.getState());
}

public void testGetTimestamp() {
String timeValue = ZonedDateTime.now(ZoneOffset.UTC).toString();
ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, ModelState.CREATED,
timeValue, "", "");
timeValue, "", "");

assertEquals(timeValue, modelMetadata.getTimestamp());
}

public void testDescription() {
String description = "test description";
ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, ModelState.CREATED,
ZonedDateTime.now(ZoneOffset.UTC).toString(), description, "");
ZonedDateTime.now(ZoneOffset.UTC).toString(), description, "");

assertEquals(description, modelMetadata.getDescription());
}

public void testGetError() {
String error = "test error";
ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, ModelState.CREATED,
ZonedDateTime.now(ZoneOffset.UTC).toString(), "", error);
ZonedDateTime.now(ZoneOffset.UTC).toString(), "", error);

assertEquals(error, modelMetadata.getError());
}

public void testSetState() {
ModelState modelState = ModelState.FAILED;
ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, modelState,
ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "");
ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "");

assertEquals(modelState, modelMetadata.getState());

Expand All @@ -113,7 +113,7 @@ public void testSetState() {
public void testSetError() {
String error = "";
ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 12, ModelState.TRAINING,
ZonedDateTime.now(ZoneOffset.UTC).toString(), "", error);
ZonedDateTime.now(ZoneOffset.UTC).toString(), "", error);

assertEquals(error, modelMetadata.getError());

Expand All @@ -132,15 +132,15 @@ public void testToString() {
String error = "test-error";

String expected = knnEngine.getName() + "," +
spaceType.getValue() + "," +
dimension + "," +
modelState.getName() + "," +
timestamp + "," +
description + "," +
error;
spaceType.getValue() + "," +
dimension + "," +
modelState.getName() + "," +
timestamp + "," +
description + "," +
error;

ModelMetadata modelMetadata = new ModelMetadata(knnEngine, spaceType, dimension, modelState,
timestamp, description, error);
timestamp, description, error);

assertEquals(expected, modelMetadata.toString());
}
Expand All @@ -149,27 +149,27 @@ public void testEquals() {

String time1 = ZonedDateTime.now(ZoneOffset.UTC).toString();
String time2 = ZonedDateTime.of(2021, 9, 30,12, 20, 45, 1,
ZoneId.systemDefault()).toString();
ZoneId.systemDefault()).toString();

ModelMetadata modelMetadata1 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED,
time1, "", "");
time1, "", "");
ModelMetadata modelMetadata2 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED,
time1, "", "");
time1, "", "");

ModelMetadata modelMetadata3 = new ModelMetadata(KNNEngine.NMSLIB, SpaceType.L2, 128, ModelState.CREATED,
time1, "", "");
time1, "", "");
ModelMetadata modelMetadata4 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L1, 128, ModelState.CREATED,
time1, "", "");
time1, "", "");
ModelMetadata modelMetadata5 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 129, ModelState.CREATED,
time1, "", "");
time1, "", "");
ModelMetadata modelMetadata6 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.TRAINING,
time1, "", "");
time1, "", "");
ModelMetadata modelMetadata7 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED,
time2, "", "");
time2, "", "");
ModelMetadata modelMetadata8 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED,
time1, "diff descript", "");
time1, "diff descript", "");
ModelMetadata modelMetadata9 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED,
time1, "", "diff error");
time1, "", "diff error");

assertEquals(modelMetadata1, modelMetadata1);
assertEquals(modelMetadata1, modelMetadata2);
Expand All @@ -188,27 +188,27 @@ public void testHashCode() {

String time1 = ZonedDateTime.now(ZoneOffset.UTC).toString();
String time2 = ZonedDateTime.of(2021, 9, 30,12, 20, 45, 1,
ZoneId.systemDefault()).toString();
ZoneId.systemDefault()).toString();

ModelMetadata modelMetadata1 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED,
time1, "", "");
time1, "", "");
ModelMetadata modelMetadata2 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED,
time1, "", "");
time1, "", "");

ModelMetadata modelMetadata3 = new ModelMetadata(KNNEngine.NMSLIB, SpaceType.L2, 128, ModelState.CREATED,
time1, "", "");
time1, "", "");
ModelMetadata modelMetadata4 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L1, 128, ModelState.CREATED,
time1, "", "");
time1, "", "");
ModelMetadata modelMetadata5 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 129, ModelState.CREATED,
time1, "", "");
time1, "", "");
ModelMetadata modelMetadata6 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.TRAINING,
time1, "", "");
time1, "", "");
ModelMetadata modelMetadata7 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED,
time2, "", "");
time2, "", "");
ModelMetadata modelMetadata8 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED,
time1, "diff descript", "");
time1, "diff descript", "");
ModelMetadata modelMetadata9 = new ModelMetadata(KNNEngine.FAISS, SpaceType.L2, 128, ModelState.CREATED,
time1, "", "diff error");
time1, "", "diff error");

assertEquals(modelMetadata1.hashCode(), modelMetadata1.hashCode());
assertEquals(modelMetadata1.hashCode(), modelMetadata2.hashCode());
Expand All @@ -233,16 +233,16 @@ public void testFromString() {
String error = "test-error";

String stringRep1 = knnEngine.getName() + "," +
spaceType.getValue() + "," +
dimension + "," +
modelState.getName() + "," +
timestamp + "," +
description + "," +
error;
spaceType.getValue() + "," +
dimension + "," +
modelState.getName() + "," +
timestamp + "," +
description + "," +
error;


ModelMetadata expected = new ModelMetadata(knnEngine, spaceType, dimension, modelState,
timestamp, description, error);
timestamp, description, error);
ModelMetadata fromString1 = ModelMetadata.fromString(stringRep1);

assertEquals(expected, fromString1);
Expand Down
Loading

0 comments on commit 9714763

Please sign in to comment.