diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 9d228c1a0..af4161b35 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -40,7 +40,7 @@ def __init__(self, os_client: OpenSearch): self._model_execute = ModelExecute(os_client) self.model_access_control = ModelAccessControl(os_client) self.connector = Connector(os_client) - self.profile = ModelProfile(os_client) + self.model_profile = ModelProfile(os_client) def execute(self, algorithm_name: str, input_json: dict) -> dict: """ diff --git a/opensearch_py_ml/ml_commons/model_profile.py b/opensearch_py_ml/ml_commons/model_profile.py index ae0fc770d..18eee2989 100644 --- a/opensearch_py_ml/ml_commons/model_profile.py +++ b/opensearch_py_ml/ml_commons/model_profile.py @@ -6,6 +6,7 @@ # GitHub history for details. from opensearchpy import OpenSearch +from typing import Optional from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI @@ -16,23 +17,35 @@ class ModelProfile: def __init__(self, os_client: OpenSearch): self.client = os_client - def get_profile(self, payload: dict): - if not isinstance(payload, dict): - raise ValueError("payload needs to be a dictionary") + def _validate_input(self, path_parameter, payload): + if path_parameter is not None and not isinstance(path_parameter, str): + raise ValueError("payload needs to be a dictionary or None") + + if payload is not None and not isinstance(payload, dict): + raise ValueError("path_parameter needs to be a string or None") + + def get_profile(self, payload: Optional[dict] = None): + if payload is not None and not isinstance(payload, dict): + raise ValueError("payload needs to be a dictionary or None") return self.client.transport.perform_request( method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload ) - def get_models_profile(self, payload: dict): - if not isinstance(payload, dict): - raise ValueError("payload needs to be a dictionary") + def get_models_profile(self, path_parameter: Optional[str]='', payload: Optional[dict] = None): + + self._validate_input(path_parameter, payload) + + url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{path_parameter if path_parameter else ''}" return self.client.transport.perform_request( - method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/models", body=payload + method="GET", url=url, body=payload ) - - def get_tasks_profile(self, payload: dict): - if not isinstance(payload, dict): - raise ValueError("payload needs to be a dictionary") + + + def get_tasks_profile(self, path_parameter: Optional[str]='', payload: Optional[dict] = None): + + self._validate_input(path_parameter, payload) + + url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{path_parameter if path_parameter else ''}" return self.client.transport.perform_request( - method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks", body=payload + method="GET", url=url, body=payload ) \ No newline at end of file diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 27cd79dc9..4995ef276 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -18,6 +18,7 @@ from opensearch_py_ml.ml_commons import MLCommonClient from opensearch_py_ml.ml_commons.model_uploader import ModelUploader +from opensearch_py_ml.ml_commons.model_profile import ModelProfile from opensearch_py_ml.ml_models.sentencetransformermodel import SentenceTransformerModel from tests import OPENSEARCH_TEST_CLIENT @@ -573,3 +574,47 @@ def test_search(): except: # noqa: E722 raised = True assert raised == False, "Raised Exception in searching model" + +# Model Profile Tests. These will need some model train/predict data. Hence, need to be +# at the end after the training/prediction tests are done. + +def profile_client(): + client = ModelProfile(OPENSEARCH_TEST_CLIENT) + return client + +def test_get_profile(profile_client): + + with pytest.raises(ValueError): + profile_client.get_profile("") + + result = profile_client.get_profile() + assert isinstance(result, dict) + if len(result) > 0: + assert "nodes" in result + + +def test_get_models_profile(profile_client): + + with pytest.raises(ValueError): + profile_client.get_models_profile("") + + result = profile_client.get_models_profile() + assert isinstance(result, dict) + if len(result) > 0: + assert "nodes" in result + for node_id, node_val in result['nodes']: + assert "models" in node_val + + + +def test_get_tasks_profile(profile_client): + + with pytest.raises(ValueError): + profile_client.get_tasks_profile("") + + result = profile_client.get_tasks_profile() + if len(result) > 0: + assert "nodes" in result + for node_id, node_val in result['nodes']: + assert "tasks" in node_val + \ No newline at end of file