diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index af4161b3..34529d05 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -24,8 +24,8 @@ from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl from opensearch_py_ml.ml_commons.model_connector import Connector from opensearch_py_ml.ml_commons.model_execute import ModelExecute -from opensearch_py_ml.ml_commons.model_uploader import ModelUploader from opensearch_py_ml.ml_commons.model_profile import ModelProfile +from opensearch_py_ml.ml_commons.model_uploader import ModelUploader class MLCommonClient: diff --git a/opensearch_py_ml/ml_commons/model_profile.py b/opensearch_py_ml/ml_commons/model_profile.py index 18eee298..fecdddab 100644 --- a/opensearch_py_ml/ml_commons/model_profile.py +++ b/opensearch_py_ml/ml_commons/model_profile.py @@ -5,47 +5,49 @@ # Any modifications Copyright OpenSearch Contributors. See # GitHub history for details. -from opensearchpy import OpenSearch from typing import Optional +from opensearchpy import OpenSearch + from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI class ModelProfile: API_ENDPOINT = "profile" - + def __init__(self, os_client: OpenSearch): self.client = os_client - + def _validate_input(self, path_parameter, payload): if path_parameter is not None and not isinstance(path_parameter, str): raise ValueError("payload needs to be a dictionary or None") if payload is not None and not isinstance(payload, dict): raise ValueError("path_parameter needs to be a string or None") - + def get_profile(self, payload: Optional[dict] = None): if payload is not None and not isinstance(payload, dict): raise ValueError("payload needs to be a dictionary or None") return self.client.transport.perform_request( method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload ) - - def get_models_profile(self, path_parameter: Optional[str]='', payload: Optional[dict] = None): - + + def get_models_profile( + self, path_parameter: Optional[str] = "", payload: Optional[dict] = None + ): self._validate_input(path_parameter, payload) - + url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{path_parameter if path_parameter else ''}" return self.client.transport.perform_request( method="GET", url=url, body=payload ) - - def get_tasks_profile(self, path_parameter: Optional[str]='', payload: Optional[dict] = None): - + def get_tasks_profile( + self, path_parameter: Optional[str] = "", payload: Optional[dict] = None + ): self._validate_input(path_parameter, payload) - + url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{path_parameter if path_parameter else ''}" return self.client.transport.perform_request( method="GET", url=url, body=payload - ) \ No newline at end of file + ) diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 170a3ea6..20193a03 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -17,8 +17,8 @@ from sklearn.datasets import load_iris from opensearch_py_ml.ml_commons import MLCommonClient -from opensearch_py_ml.ml_commons.model_uploader import ModelUploader from opensearch_py_ml.ml_commons.model_profile import ModelProfile +from opensearch_py_ml.ml_commons.model_uploader import ModelUploader from opensearch_py_ml.ml_models.sentencetransformermodel import SentenceTransformerModel from tests import OPENSEARCH_TEST_CLIENT @@ -575,19 +575,21 @@ def test_search(): raised = True assert raised == False, "Raised Exception in searching model" + # Model Profile Tests. These will need some model train/predict data. Hence, need to be # at the end after the training/prediction tests are done. + @pytest.fixture def profile_client(): client = ModelProfile(OPENSEARCH_TEST_CLIENT) return client + def test_get_profile(profile_client): - with pytest.raises(ValueError): profile_client.get_profile("") - + result = profile_client.get_profile() assert isinstance(result, dict) if len(result) > 0: @@ -595,33 +597,29 @@ def test_get_profile(profile_client): def test_get_models_profile(profile_client): - with pytest.raises(ValueError): profile_client.get_models_profile(10) - + with pytest.raises(ValueError): profile_client.get_models_profile("", 10) - + result = profile_client.get_models_profile() assert isinstance(result, dict) if len(result) > 0: assert "nodes" in result - for _, node_val in result['nodes'].items(): + for _, node_val in result["nodes"].items(): assert "models" in node_val - def test_get_tasks_profile(profile_client): - with pytest.raises(ValueError): profile_client.get_tasks_profile(10) - + with pytest.raises(ValueError): profile_client.get_tasks_profile("", 10) - + result = profile_client.get_tasks_profile() if len(result) > 0: assert "nodes" in result - for _, node_val in result['nodes'].items(): + for _, node_val in result["nodes"].items(): assert "tasks" in node_val - \ No newline at end of file