Skip to content

Commit

Permalink
fine tune predict API: read model from index directly (opensearch-pro…
Browse files Browse the repository at this point in the history
…ject#1557)

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Oct 27, 2023
1 parent 8c3e453 commit 0920ba7
Showing 1 changed file with 5 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@
import java.util.Optional;

import org.opensearch.client.node.NodeClient;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.model.MLModelManager;
Expand Down Expand Up @@ -91,9 +89,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
}

return channel -> {
MLModelGetRequest getModelRequest = new MLModelGetRequest(modelId, false);
ActionListener<MLModelGetResponse> listener = ActionListener.wrap(r -> {
MLModel mlModel = r.getMlModel();
ActionListener<MLModel> listener = ActionListener.wrap(mlModel -> {
String algoName = mlModel.getAlgorithm().name();
client
.execute(
Expand All @@ -109,8 +105,9 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
log.error("Failed to send error response", ex);
}
});
client.execute(MLModelGetAction.INSTANCE, getModelRequest, listener);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
modelManager.getModel(modelId, ActionListener.runBefore(listener, () -> context.restore()));
}
};
}

Expand Down

0 comments on commit 0920ba7

Please sign in to comment.