Skip to content

Commit

Permalink
Add a version filter to enable bwc in 2.12 (opensearch-project#1944)
Browse files Browse the repository at this point in the history
Signed-off-by: Sicheng Song <[email protected]>
  • Loading branch information
b4sjoo committed Feb 3, 2024
1 parent 34a9f8e commit 79b71da
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.FieldDefaults;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.annotation.InputDataSet;
Expand All @@ -28,6 +29,8 @@ public class TextDocsInputDataSet extends MLInputDataset{

private List<String> docs;

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MULTI_MODAL = Version.V_2_11_0;

@Builder(toBuilder = true)
public TextDocsInputDataSet(List<String> docs, ModelResultFilter resultFilter) {
super(MLInputDataType.TEXT_DOCS);
Expand All @@ -41,10 +44,15 @@ public TextDocsInputDataSet(List<String> docs, ModelResultFilter resultFilter) {

public TextDocsInputDataSet(StreamInput streamInput) throws IOException {
super(MLInputDataType.TEXT_DOCS);
docs = new ArrayList<>();
int size = streamInput.readInt();
for (int i=0; i<size; i++) {
docs.add(streamInput.readOptionalString());
Version version = streamInput.getVersion();
if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MULTI_MODAL)) {
docs = new ArrayList<>();
int size = streamInput.readInt();
for (int i=0; i<size; i++) {
docs.add(streamInput.readOptionalString());
}
} else {
docs = streamInput.readStringList();
}
if (streamInput.readBoolean()) {
resultFilter = new ModelResultFilter(streamInput);
Expand All @@ -56,9 +64,14 @@ public TextDocsInputDataSet(StreamInput streamInput) throws IOException {
@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
super.writeTo(streamOutput);
streamOutput.writeInt(docs.size());
for (String doc : docs) {
streamOutput.writeOptionalString(doc);
Version version = streamOutput.getVersion();
if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MULTI_MODAL)) {
streamOutput.writeInt(docs.size());
for (String doc : docs) {
streamOutput.writeOptionalString(doc);
}
} else {
streamOutput.writeStringCollection(docs);
}
if (resultFilter != null) {
streamOutput.writeBoolean(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Builder;
import lombok.Data;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand Down Expand Up @@ -57,6 +58,10 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
public static final String BACKEND_ROLES_FIELD = "backend_roles";
public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles";
public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group";

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_AGENT_FRAMEWORK = Version.V_2_12_0;

private FunctionName functionName;
private String modelName;
private String modelGroupId;
Expand Down Expand Up @@ -141,15 +146,12 @@ public MLRegisterModelInput(FunctionName functionName,
}

public MLRegisterModelInput(StreamInput in) throws IOException {
Version streamInputVersion = in.getVersion();
this.functionName = in.readEnum(FunctionName.class);
this.modelName = in.readString();
this.modelGroupId = in.readOptionalString();
this.version = in.readOptionalString();
this.description = in.readOptionalString();
this.isEnabled = in.readOptionalBoolean();
if (in.readBoolean()) {
this.rateLimiter = new MLRateLimiter(in);
}
this.url = in.readOptionalString();
this.hashValue = in.readOptionalString();
if (in.readBoolean()) {
Expand All @@ -175,24 +177,26 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.accessMode = in.readEnum(AccessMode.class);
}
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
this.isHidden = in.readOptionalBoolean();
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP)) {
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
}
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_AGENT_FRAMEWORK)) {
this.isEnabled = in.readOptionalBoolean();
if (in.readBoolean()) {
this.rateLimiter = new MLRateLimiter(in);
}
this.isHidden = in.readOptionalBoolean();
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeEnum(functionName);
out.writeString(modelName);
out.writeOptionalString(modelGroupId);
out.writeOptionalString(version);
out.writeOptionalString(description);
out.writeOptionalBoolean(isEnabled);
if (rateLimiter != null) {
out.writeBoolean(true);
rateLimiter.writeTo(out);
} else {
out.writeBoolean(false);
}
out.writeOptionalString(url);
out.writeOptionalString(hashValue);
if (modelFormat != null) {
Expand Down Expand Up @@ -229,8 +233,19 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(doesVersionCreateModelGroup);
out.writeOptionalBoolean(isHidden);
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP)) {
out.writeOptionalBoolean(doesVersionCreateModelGroup);
}
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_AGENT_FRAMEWORK)) {
out.writeOptionalBoolean(isEnabled);
if (rateLimiter != null) {
out.writeBoolean(true);
rateLimiter.writeTo(out);
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isHidden);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Builder;
import lombok.Data;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand Down Expand Up @@ -51,6 +52,9 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable {
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; // optional
public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group";

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_AGENT_FRAMEWORK = Version.V_2_12_0;

private FunctionName functionName;
private String name;

Expand Down Expand Up @@ -125,15 +129,12 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
}

public MLRegisterModelMetaInput(StreamInput in) throws IOException {
Version streamInputVersion = in.getVersion();
this.name = in.readString();
this.functionName = in.readEnum(FunctionName.class);
this.modelGroupId = in.readOptionalString();
this.version = in.readOptionalString();
this.description = in.readOptionalString();
this.isEnabled = in.readOptionalBoolean();
if (in.readBoolean()) {
rateLimiter = new MLRateLimiter(in);
}
if (in.readBoolean()) {
modelFormat = in.readEnum(MLModelFormat.class);
}
Expand All @@ -151,24 +152,26 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException {
accessMode = in.readEnum(AccessMode.class);
}
this.isAddAllBackendRoles = in.readOptionalBoolean();
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
this.isHidden = in.readOptionalBoolean();
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP)) {
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
}
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_AGENT_FRAMEWORK)) {
this.isEnabled = in.readOptionalBoolean();
if (in.readBoolean()) {
this.rateLimiter = new MLRateLimiter(in);
}
this.isHidden = in.readOptionalBoolean();
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeString(name);
out.writeEnum(functionName);
out.writeOptionalString(modelGroupId);
out.writeOptionalString(version);
out.writeOptionalString(description);
out.writeOptionalBoolean(isEnabled);
if (rateLimiter != null) {
out.writeBoolean(true);
rateLimiter.writeTo(out);
} else {
out.writeBoolean(false);
}
if (modelFormat != null) {
out.writeBoolean(true);
out.writeEnum(modelFormat);
Expand Down Expand Up @@ -203,8 +206,19 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isAddAllBackendRoles);
out.writeOptionalBoolean(doesVersionCreateModelGroup);
out.writeOptionalBoolean(isHidden);
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP)) {
out.writeOptionalBoolean(doesVersionCreateModelGroup);
}
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_AGENT_FRAMEWORK)) {
out.writeOptionalBoolean(isEnabled);
if (rateLimiter != null) {
out.writeBoolean(true);
rateLimiter.writeTo(out);
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isHidden);
}
}

@Override
Expand Down
10 changes: 5 additions & 5 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput
.build();
IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX);

if (mlRegisterModelMetaInput.getIsHidden()) {
if (mlRegisterModelMetaInput.getIsHidden() != null && mlRegisterModelMetaInput.getIsHidden()) {
indexRequest.id(modelName);
}
indexRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS));
Expand Down Expand Up @@ -530,7 +530,7 @@ private void indexRemoteModel(
.build();

IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX);
if (registerModelInput.getIsHidden()) {
if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) {
indexModelMetaRequest.id(modelName);
}
indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS));
Expand Down Expand Up @@ -593,7 +593,7 @@ void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, St
.isHidden(registerModelInput.getIsHidden())
.build();
IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX);
if (registerModelInput.getIsHidden()) {
if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) {
indexModelMetaRequest.id(modelName);
}
indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS));
Expand Down Expand Up @@ -660,7 +660,7 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas
if (functionName == FunctionName.METRICS_CORRELATION) {
indexModelMetaRequest.id(functionName.name());
}
if (registerModelInput.getIsHidden()) {
if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) {
indexModelMetaRequest.id(modelName);
}
indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS));
Expand Down Expand Up @@ -740,7 +740,7 @@ private void registerModel(
.isHidden(registerModelInput.getIsHidden())
.build();
IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX);
if (registerModelInput.getIsHidden()) {
if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) {
indexRequest.id(modelName);
}
String chunkId = getModelChunkId(modelId, chunkNum);
Expand Down

0 comments on commit 79b71da

Please sign in to comment.