Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a version filter to enable bwc in 2.11 #1939

Merged
merged 4 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.FieldDefaults;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.annotation.InputDataSet;
Expand All @@ -28,6 +29,8 @@ public class TextDocsInputDataSet extends MLInputDataset{

private List<String> docs;

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MULTI_MODAL = Version.V_2_11_0;

@Builder(toBuilder = true)
public TextDocsInputDataSet(List<String> docs, ModelResultFilter resultFilter) {
super(MLInputDataType.TEXT_DOCS);
Expand All @@ -41,10 +44,15 @@ public TextDocsInputDataSet(List<String> docs, ModelResultFilter resultFilter) {

public TextDocsInputDataSet(StreamInput streamInput) throws IOException {
super(MLInputDataType.TEXT_DOCS);
docs = new ArrayList<>();
int size = streamInput.readInt();
for (int i=0; i<size; i++) {
docs.add(streamInput.readOptionalString());
Version version = streamInput.getVersion();
if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MULTI_MODAL)) {
docs = new ArrayList<>();
int size = streamInput.readInt();
for (int i=0; i<size; i++) {
docs.add(streamInput.readOptionalString());
}
} else {
docs = streamInput.readStringList();
}
if (streamInput.readBoolean()) {
resultFilter = new ModelResultFilter(streamInput);
Expand All @@ -56,9 +64,14 @@ public TextDocsInputDataSet(StreamInput streamInput) throws IOException {
@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
super.writeTo(streamOutput);
streamOutput.writeInt(docs.size());
for (String doc : docs) {
streamOutput.writeOptionalString(doc);
Version version = streamOutput.getVersion();
if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MULTI_MODAL)) {
streamOutput.writeInt(docs.size());
for (String doc : docs) {
streamOutput.writeOptionalString(doc);
}
} else {
streamOutput.writeStringCollection(docs);
}
if (resultFilter != null) {
streamOutput.writeBoolean(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Builder;
import lombok.Data;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand Down Expand Up @@ -54,6 +55,9 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
public static final String BACKEND_ROLES_FIELD = "backend_roles";
public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles";
public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group";

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP = Version.V_2_11_0;

private FunctionName functionName;
private String modelName;
private String modelGroupId;
Expand Down Expand Up @@ -130,6 +134,7 @@ public MLRegisterModelInput(FunctionName functionName,


public MLRegisterModelInput(StreamInput in) throws IOException {
Version streamInputVersion = in.getVersion();
this.functionName = in.readEnum(FunctionName.class);
this.modelName = in.readString();
this.modelGroupId = in.readOptionalString();
Expand Down Expand Up @@ -160,11 +165,14 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.accessMode = in.readEnum(AccessMode.class);
}
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP)) {
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeEnum(functionName);
out.writeString(modelName);
out.writeOptionalString(modelGroupId);
Expand Down Expand Up @@ -206,7 +214,9 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(doesVersionCreateModelGroup);
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP)) {
out.writeOptionalBoolean(doesVersionCreateModelGroup);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Builder;
import lombok.Data;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand Down Expand Up @@ -48,6 +49,7 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional
public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group";

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP = Version.V_2_11_0;

private FunctionName functionName;
private String name;
Expand Down Expand Up @@ -111,6 +113,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
}

public MLRegisterModelMetaInput(StreamInput in) throws IOException{
Version streamInputVersion = in.getVersion();
this.name = in.readString();
this.functionName = in.readEnum(FunctionName.class);
this.modelGroupId = in.readOptionalString();
Expand All @@ -133,11 +136,14 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException{
accessMode = in.readEnum(AccessMode.class);
}
this.isAddAllBackendRoles = in.readOptionalBoolean();
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP)) {
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeString(name);
out.writeEnum(functionName);
out.writeOptionalString(modelGroupId);
Expand Down Expand Up @@ -177,7 +183,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isAddAllBackendRoles);
out.writeOptionalBoolean(doesVersionCreateModelGroup);
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP)) {
out.writeOptionalBoolean(doesVersionCreateModelGroup);
}
}

@Override
Expand Down
Loading