Skip to content

Commit

Permalink
fix 2.12 backward compatibility issue
Browse files Browse the repository at this point in the history
Signed-off-by: Sicheng Song <[email protected]>
  • Loading branch information
b4sjoo committed Jan 29, 2024
1 parent ec1d7de commit b6f4bb0
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,12 @@ public TextDocsInputDataSet(StreamInput streamInput) throws IOException {
super(MLInputDataType.TEXT_DOCS);
Version version = streamInput.getVersion();
if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MULTI_MODAL)) {
System.out.println("seasonsg debug: read stream input shows bwc not working" );
docs = new ArrayList<>();
int size = streamInput.readInt();
for (int i=0; i<size; i++) {
docs.add(streamInput.readOptionalString());
}
} else {
System.out.println("seasonsg debug: read stream input shows bwc" );
docs = streamInput.readStringList();
}
if (streamInput.readBoolean()) {
Expand All @@ -70,13 +68,11 @@ public void writeTo(StreamOutput streamOutput) throws IOException {
super.writeTo(streamOutput);
Version version = streamOutput.getVersion();
if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MULTI_MODAL)) {
System.out.println("seasonsg debug: read stream output shows bwc not working" );
streamOutput.writeInt(docs.size());
for (String doc : docs) {
streamOutput.writeOptionalString(doc);
}
} else {
System.out.println("seasonsg debug: write stream output shows bwc" );
streamOutput.writeStringCollection(docs);
}
if (resultFilter != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group";

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_AGENT_FRAMEWORK = Version.CURRENT;

private FunctionName functionName;
private String modelName;
Expand Down Expand Up @@ -152,10 +153,6 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
this.modelGroupId = in.readOptionalString();
this.version = in.readOptionalString();
this.description = in.readOptionalString();
this.isEnabled = in.readOptionalBoolean();
if (in.readBoolean()) {
this.rateLimiter = new MLRateLimiter(in);
}
this.url = in.readOptionalString();
this.hashValue = in.readOptionalString();
if (in.readBoolean()) {
Expand All @@ -181,10 +178,16 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.accessMode = in.readEnum(AccessMode.class);
}
this.isHidden = in.readOptionalBoolean();
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP)) {
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
}
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_AGENT_FRAMEWORK)) {
this.isEnabled = in.readOptionalBoolean();
if (in.readBoolean()) {
this.rateLimiter = new MLRateLimiter(in);
}
this.isHidden = in.readOptionalBoolean();
}
}

@Override
Expand All @@ -195,13 +198,6 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(modelGroupId);
out.writeOptionalString(version);
out.writeOptionalString(description);
out.writeOptionalBoolean(isEnabled);
if (rateLimiter != null) {
out.writeBoolean(true);
rateLimiter.writeTo(out);
} else {
out.writeBoolean(false);
}
out.writeOptionalString(url);
out.writeOptionalString(hashValue);
if (modelFormat != null) {
Expand Down Expand Up @@ -238,10 +234,19 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isHidden);
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP)) {
out.writeOptionalBoolean(doesVersionCreateModelGroup);
}
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_AGENT_FRAMEWORK)) {
out.writeOptionalBoolean(isEnabled);
if (rateLimiter != null) {
out.writeBoolean(true);
rateLimiter.writeTo(out);
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isHidden);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable {
public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group";

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_AGENT_FRAMEWORK = Version.CURRENT;

private FunctionName functionName;
private String name;
Expand Down Expand Up @@ -134,10 +135,6 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException {
this.modelGroupId = in.readOptionalString();
this.version = in.readOptionalString();
this.description = in.readOptionalString();
this.isEnabled = in.readOptionalBoolean();
if (in.readBoolean()) {
rateLimiter = new MLRateLimiter(in);
}
if (in.readBoolean()) {
modelFormat = in.readEnum(MLModelFormat.class);
}
Expand All @@ -155,10 +152,16 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException {
accessMode = in.readEnum(AccessMode.class);
}
this.isAddAllBackendRoles = in.readOptionalBoolean();
this.isHidden = in.readOptionalBoolean();
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP)) {
this.doesVersionCreateModelGroup = in.readOptionalBoolean();
}
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_AGENT_FRAMEWORK)) {
this.isEnabled = in.readOptionalBoolean();
if (in.readBoolean()) {
this.rateLimiter = new MLRateLimiter(in);
}
this.isHidden = in.readOptionalBoolean();
}
}

@Override
Expand All @@ -169,13 +172,6 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(modelGroupId);
out.writeOptionalString(version);
out.writeOptionalString(description);
out.writeOptionalBoolean(isEnabled);
if (rateLimiter != null) {
out.writeBoolean(true);
rateLimiter.writeTo(out);
} else {
out.writeBoolean(false);
}
if (modelFormat != null) {
out.writeBoolean(true);
out.writeEnum(modelFormat);
Expand Down Expand Up @@ -210,10 +206,19 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isAddAllBackendRoles);
out.writeOptionalBoolean(isHidden);
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP)) {
out.writeOptionalBoolean(doesVersionCreateModelGroup);
}
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_AGENT_FRAMEWORK)) {
out.writeOptionalBoolean(isEnabled);
if (rateLimiter != null) {
out.writeBoolean(true);
rateLimiter.writeTo(out);
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isHidden);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
public abstract class TextEmbeddingModel extends DLModel {
@Override
public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
System.out.println("seasonsg debug: TextEmbeddingModel predict method called" );
MLInputDataset inputDataSet = mlInput.getInputDataset();
List<ModelTensors> tensorOutputs = new ArrayList<>();
Output output;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
final User userInfo = user;

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
System.out.println("seasonsg debug: predict request send to transport layer." );
ActionListener<MLTaskResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
MLModel cachedMlModel = modelCacheHelper.getModelInfo(modelId);
ActionListener<MLModel> modelActionListener = new ActionListener<>() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ private void indexRemoteModel(
.build();

IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX);
if (registerModelInput.getIsHidden()) {
if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) {
indexModelMetaRequest.id(modelName);
}
indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS));
Expand Down Expand Up @@ -593,7 +593,7 @@ void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, St
.isHidden(registerModelInput.getIsHidden())
.build();
IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX);
if (registerModelInput.getIsHidden()) {
if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) {
indexModelMetaRequest.id(modelName);
}
indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS));
Expand Down Expand Up @@ -660,7 +660,7 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas
if (functionName == FunctionName.METRICS_CORRELATION) {
indexModelMetaRequest.id(functionName.name());
}
if (registerModelInput.getIsHidden()) {
if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) {
indexModelMetaRequest.id(modelName);
}
indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS));
Expand Down Expand Up @@ -740,7 +740,7 @@ private void registerModel(
.isHidden(registerModelInput.getIsHidden())
.build();
IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX);
if (registerModelInput.getIsHidden()) {
if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) {
indexRequest.id(modelName);
}
String chunkId = getModelChunkId(modelId, chunkNum);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,10 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
Expand Down Expand Up @@ -200,7 +197,6 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTas
}

private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> listener) {
System.out.println("seasonsg debug: MLPredictTaskRunner predict method called" );
ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId());
// track ML task count and add ML task into cache
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();
Expand All @@ -225,18 +221,12 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
}
MLOutput output = mlModelManager.trackPredictDuration(modelId, () -> predictor.predict(mlInput));
if (output instanceof MLPredictionOutput) {
System.out.println("seasonsg debug: MLPredictTaskRunner predict method complete" );
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
}

// Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
handleAsyncMLTaskComplete(mlTask);
System.out.println("seasonsg debug: MLTaskResponse building start." );
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
System.out.println("seasonsg debug: MLTaskResponse building complete." );
XContentBuilder builder = XContentFactory.jsonBuilder();
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
System.out.println("seasonsg debug: " + builder.toString());
internalListener.onResponse(response);
return;
} catch (Exception e) {
Expand Down

0 comments on commit b6f4bb0

Please sign in to comment.