diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 87de79b1..f304b1bb 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -607,64 +607,29 @@ def delete_task(self, task_id: str) -> object: url=API_URL, ) - def get_profiles( - self, - node_ids: List[str] = [], - model_ids: List[str] = [], - task_ids: List[str] = [], - return_all_tasks: bool = True, - return_all_models: bool = True, - ) -> object: - """ - This method retrieves the profile of one or more machine learning nodes, models, or tasks from OpenSearch cluster (using ml commons api) - - :param node_ids: A list of node ids to retrieve profiles for (default: []) - :type node_ids: List[str] - :param model_ids: A list of model ids to retrieve profiles for (default: []) - :type model_ids: List[str] - :param task_ids: A list of task ids to retrieve profiles for (default: []) - :type task_ids: List[str] - :param return_all_tasks: a flag to indicate whether all tasks associated with the profiles should be returned (default: True) - :type return_all_tasks: bool - :param return_all_models: a flag to indicate whether all models associated with the profiles should be returned (default: True) - :type return_all_models: bool - :return: returns a json object containing the profile information of the specified nodes, models, or tasks from the OpenSearch cluster - :rtype: object - """ - - if model_ids and len(model_ids) == 1: - API_URL = f"{ML_BASE_URI}/profile/models/{model_ids[0]}" - return self._client.transport.perform_request( - method="GET", - url=API_URL, - ) - - if task_ids and len(model_ids) == 1: - API_URL = f"{ML_BASE_URI}/profile/models/{model_ids[0]}" - return self._client.transport.perform_request( - method="GET", - url=API_URL, - ) - - API_URL = f"{ML_BASE_URI}/profile" - - API_BODY = { - "return_all_tasks": return_all_tasks, - "return_all_models": return_all_models, - } - - if len(node_ids) > 0: - API_BODY["node_ids"] = node_ids - if len(model_ids) > 0: - API_BODY["model_ids"] = model_ids - if len(task_ids) > 0: - API_BODY["task_ids"] = task_ids - - try: - API_BODY = json.dumps(API_BODY) - except json.JSONDecodeError: - raise Exception("Invalid request body") + def get_profiles(self, model_id=None, tasks=None, request_body=None) -> object: + """ + """ + # if model id is given and tasks is none, then url = f"{ML_BASE_URI}/profile/models" + # if model id is None and tasks is given, then url = f"{ML_BASE_URI}/profile/tasks" + # if both model_id and tasks are given raise error + if model_id is None and tasks is None: + url = f"{ML_BASE_URI}/profile" + elif model_id is not None and tasks is None: + url = f"{ML_BASE_URI}/profile/models/{model_id}" + elif model_id is None and tasks is not None: + url = f"{ML_BASE_URI}/profile/tasks" + else: + raise ValueError("model_id and tasks cannot be given together") + return self._client.transport.perform_request( + method="POST", + url=url, + body=request_body + ) + + + return self._client.transport.perform_request( method="GET", url=API_URL,