Skip to content

Commit

Permalink
Add a version filter to enable bwc in 2.11 (#1939)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4sjoo authored Jan 31, 2024
1 parent a946231 commit a32d35a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.FieldDefaults;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.annotation.InputDataSet;
Expand All @@ -28,6 +29,8 @@ public class TextDocsInputDataSet extends MLInputDataset{

private List<String> docs;

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MULTI_MODAL = Version.V_2_11_0;

@Builder(toBuilder = true)
public TextDocsInputDataSet(List<String> docs, ModelResultFilter resultFilter) {
super(MLInputDataType.TEXT_DOCS);
Expand All @@ -41,10 +44,15 @@ public TextDocsInputDataSet(List<String> docs, ModelResultFilter resultFilter) {

public TextDocsInputDataSet(StreamInput streamInput) throws IOException {
super(MLInputDataType.TEXT_DOCS);
docs = new ArrayList<>();
int size = streamInput.readInt();
for (int i=0; i<size; i++) {
docs.add(streamInput.readOptionalString());
Version version = streamInput.getVersion();
if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MULTI_MODAL)) {
docs = new ArrayList<>();
int size = streamInput.readInt();
for (int i=0; i<size; i++) {
docs.add(streamInput.readOptionalString());
}
} else {
docs = streamInput.readStringList();
}
if (streamInput.readBoolean()) {
resultFilter = new ModelResultFilter(streamInput);
Expand All @@ -56,9 +64,14 @@ public TextDocsInputDataSet(StreamInput streamInput) throws IOException {
@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
super.writeTo(streamOutput);
streamOutput.writeInt(docs.size());
for (String doc : docs) {
streamOutput.writeOptionalString(doc);
Version version = streamOutput.getVersion();
if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MULTI_MODAL)) {
streamOutput.writeInt(docs.size());
for (String doc : docs) {
streamOutput.writeOptionalString(doc);
}
} else {
streamOutput.writeStringCollection(docs);
}
if (resultFilter != null) {
streamOutput.writeBoolean(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Builder;
import lombok.Data;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand Down Expand Up @@ -54,6 +55,9 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
public static final String BACKEND_ROLES_FIELD = "backend_roles";
public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles";
public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group";

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP = Version.V_2_11_0;

private FunctionName functionName;
private String modelName;
private String modelGroupId;
Expand Down Expand Up @@ -130,6 +134,7 @@ public MLRegisterModelInput(FunctionName functionName,


public MLRegisterModelInput(StreamInput in) throws IOException {
Version streamInputVersion = in.getVersion();
this.functionName = in.readEnum(FunctionName.class);
this.modelName = in.readString();
this.modelGroupId = in.readOptionalString();
Expand Down Expand Up @@ -160,11 +165,14 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.accessMode = in.readEnum(AccessMode.class);
}
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP)) {
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeEnum(functionName);
out.writeString(modelName);
out.writeOptionalString(modelGroupId);
Expand Down Expand Up @@ -206,7 +214,9 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(doesVersionCreateModelGroup);
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP)) {
out.writeOptionalBoolean(doesVersionCreateModelGroup);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Builder;
import lombok.Data;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand Down Expand Up @@ -48,6 +49,7 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional
public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group";

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP = Version.V_2_11_0;

private FunctionName functionName;
private String name;
Expand Down Expand Up @@ -111,6 +113,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
}

public MLRegisterModelMetaInput(StreamInput in) throws IOException{
Version streamInputVersion = in.getVersion();
this.name = in.readString();
this.functionName = in.readEnum(FunctionName.class);
this.modelGroupId = in.readOptionalString();
Expand All @@ -133,11 +136,14 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException{
accessMode = in.readEnum(AccessMode.class);
}
this.isAddAllBackendRoles = in.readOptionalBoolean();
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP)) {
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeString(name);
out.writeEnum(functionName);
out.writeOptionalString(modelGroupId);
Expand Down Expand Up @@ -177,7 +183,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isAddAllBackendRoles);
out.writeOptionalBoolean(doesVersionCreateModelGroup);
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP)) {
out.writeOptionalBoolean(doesVersionCreateModelGroup);
}
}

@Override
Expand Down

0 comments on commit a32d35a

Please sign in to comment.