diff --git a/CHANGELOG.md b/CHANGELOG.md index 4eb78bf78..86f270110 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Add support for train api functionality by @rawwar in ([#310](https://github.com/opensearch-project/opensearch-py-ml/pull/310)) - Add support for Model Access Control - Register, Update, Search and Delete by @rawwar in ([#332](https://github.com/opensearch-project/opensearch-py-ml/pull/332)) - Add support for model connectors by @rawwar in ([#345](https://github.com/opensearch-project/opensearch-py-ml/pull/345)) +- Add support for model profiles by @rawwar in ([#358](https://github.com/opensearch-project/opensearch-py-ml/pull/358)) ### Changed - Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211)) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 2fde29294..1f81c967c 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -25,6 +25,7 @@ from opensearch_py_ml.ml_commons.model_connector import Connector from opensearch_py_ml.ml_commons.model_execute import ModelExecute from opensearch_py_ml.ml_commons.model_uploader import ModelUploader +from opensearch_py_ml.ml_commons.validators.profile import validate_profile_input class MLCommonClient: @@ -606,3 +607,111 @@ def delete_task(self, task_id: str) -> object: method="DELETE", url=API_URL, ) + + def _get_profile(self, payload: Optional[dict] = None): + """ + Get the profile using the given payload. + + :param payload: The payload to be used for getting the profile. Defaults to None. + :type payload: Optional[dict] + :return: The response from the server after performing the request. + :rtype: Any + """ + validate_profile_input(None, payload) + return self._client.transport.perform_request( + method="GET", url=f"{ML_BASE_URI}/profile", body=payload + ) + + def _get_models_profile( + self, model_id: Optional[str] = "", payload: Optional[dict] = None + ): + """ + Get the profile of a model. + + Args: + model_id (str, optional): The ID of the model. Defaults to "". + payload (dict, optional): Additional payload for the request. Defaults to None. + + Returns: + dict: The response from the API. + """ + validate_profile_input(model_id, payload) + + url = f"{ML_BASE_URI}/profile/models/{model_id if model_id else ''}" + return self._client.transport.perform_request( + method="GET", url=url, body=payload + ) + + def _get_tasks_profile( + self, task_id: Optional[str] = "", payload: Optional[dict] = None + ): + """ + Retrieves the profile of a task from the API. + + Parameters: + task_id (str, optional): The ID of the task to retrieve the profile for. Defaults to an empty string. + payload (dict, optional): Additional payload for the request. Defaults to None. + + Returns: + dict: The profile of the task. + + Raises: + ValueError: If the input validation fails. + + """ + validate_profile_input(task_id, payload) + + url = f"{ML_BASE_URI}/profile/tasks/{task_id if task_id else ''}" + return self._client.transport.perform_request( + method="GET", url=url, body=payload + ) + + def get_profile( + self, + profile_type: str = "all", + ids: Optional[Union[str, List[str]]] = None, + request_body: Optional[dict] = None, + ) -> dict: + """ + Get profile information based on the profile type. + + Args: + profile_type: The type of profile to retrieve. Valid values are 'all', 'model', or 'task'. Default is 'all'. + 'all': Retrieves all profiles available. + 'model': Retrieves the profile(s) of the specified model(s). The model(s) to retrieve are specified by the 'ids' parameter. + 'task': Retrieves the profile(s) of the specified task(s). The task(s) to retrieve are specified by the 'ids' parameter. + ids: Either a single profile ID as a string, or a list of profile IDs to retrieve. Default is None. + request_body: The request body containing additional information. Default is None. + + Returns: + The profile information. + + Raises: + ValueError: If the profile_type is not 'all', 'model', or 'task'. + + Example: + get_profile() + + get_profile(profile_type='model', ids='model1') + + get_profile(profile_type='model', ids=['model1', 'model2']) + + get_profile(profile_type='task', ids='task1', request_body={"node_ids": ["KzONM8c8T4Od-NoUANQNGg"],"return_all_tasks": true,"return_all_models": true}) + + get_profile(profile_type='task', ids=['task1', 'task2'], request_body={'additional': 'info'}) + """ + + if profile_type == "all": + return self._get_profile(request_body) + elif profile_type == "model": + if ids and isinstance(ids, list): + ids = ",".join(ids) + return self._get_models_profile(ids, request_body) + elif profile_type == "task": + if ids and isinstance(ids, list): + ids = ",".join(ids) + return self._get_tasks_profile(ids, request_body) + else: + raise ValueError( + "Invalid profile type. Profile type must be 'all', 'model' or 'task'." + ) diff --git a/opensearch_py_ml/ml_commons/validators/profile.py b/opensearch_py_ml/ml_commons/validators/profile.py new file mode 100644 index 000000000..602edb0f9 --- /dev/null +++ b/opensearch_py_ml/ml_commons/validators/profile.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. + +"""Module for validating Profile API parameters """ + + +def validate_profile_input(path_parameter, payload): + if path_parameter is not None and not isinstance(path_parameter, str): + raise ValueError("path_parameter needs to be a string or None") + + if payload is not None and not isinstance(payload, dict): + raise ValueError("payload needs to be a dictionary or None") diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 27cd79dc9..bb1adcde1 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -573,3 +573,65 @@ def test_search(): except: # noqa: E722 raised = True assert raised == False, "Raised Exception in searching model" + + +# Model Profile Tests. These tests will need some model train/predict run data. Hence, need +# to be run at the end after the training/prediction tests are done. + + +def test_get_profile(): + res = ml_client.get_profile() + assert isinstance(res, dict) + assert "nodes" in res + test_model_id = None + test_task_id = None + for node_id, val in res["nodes"].items(): + if test_model_id is None and "models" in val: + for model_id, model_val in val["models"].items(): + test_model_id = {"node_id": node_id, "model_id": model_id} + break + if test_task_id is None and "tasks" in val: + for task_id, task_val in val["tasks"].items(): + test_task_id = {"node_id": node_id, "task_id": task_id} + break + + res = ml_client.get_profile(profile_type="model") + assert isinstance(res, dict) + assert "nodes" in res + for node_id, node_val in res["nodes"].items(): + assert "models" in node_val + + res = ml_client.get_profile(profile_type="model", ids=[test_model_id["model_id"]]) + assert isinstance(res, dict) + assert "nodes" in res + assert test_model_id["model_id"] in res["nodes"][test_model_id["node_id"]]["models"] + + res = ml_client.get_profile(profile_type="model", ids=["randomid1", "random_id2"]) + assert isinstance(res, dict) + assert len(res) == 0 + + res = ml_client.get_profile(profile_type="task") + assert isinstance(res, dict) + if len(res) > 0: + assert "nodes" in res + for node_id, node_val in res["nodes"].items(): + assert "tasks" in node_val + + res = ml_client.get_profile(profile_type="task", ids=["random1", "random2"]) + assert isinstance(res, dict) + assert len(res) == 0 + + with pytest.raises(ValueError): + ml_client.get_profile(profile_type="test") + + with pytest.raises(ValueError): + ml_client.get_profile(profile_type="model", ids=1) + + with pytest.raises(ValueError): + ml_client.get_profile(profile_type="model", request_body=10) + + with pytest.raises(ValueError): + ml_client.get_profile(profile_type="task", ids=1) + + with pytest.raises(ValueError): + ml_client.get_profile(profile_type="task", request_body=10)