From b0191a83aa718746a45fbdd7a4c351fedb24ab3e Mon Sep 17 00:00:00 2001 From: kalyanr Date: Sat, 16 Dec 2023 12:19:10 +0530 Subject: [PATCH] fix Signed-off-by: kalyanr --- .../ml_commons/ml_commons_client.py | 84 +++++++++++++++---- 1 file changed, 67 insertions(+), 17 deletions(-) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 138bd9c5..2ce4d7bd 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -611,39 +611,89 @@ def delete_task(self, task_id: str) -> object: ) def _get_profile(self, payload: Optional[dict] = None): + """ + Get the profile using the given payload. + + :param payload: The payload to be used for getting the profile. Defaults to None. + :type payload: Optional[dict] + :return: The response from the server after performing the request. + :rtype: Any + """ validate_profile_input(None, payload) return self.client.transport.perform_request( method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload ) def _get_models_profile( - self, path_parameter: Optional[str] = "", payload: Optional[dict] = None + self, model_id: Optional[str] = "", payload: Optional[dict] = None ): - self._validate_input(path_parameter, payload) + """ + Get the profile of a model. + + Args: + model_id (str, optional): The ID of the model. Defaults to "". + payload (dict, optional): Additional payload for the request. Defaults to None. - url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{path_parameter if path_parameter else ''}" + Returns: + dict: The response from the API. + """ + self._validate_input(model_id, payload) + + url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{model_id if model_id else ''}" return self.client.transport.perform_request( method="GET", url=url, body=payload ) def _get_tasks_profile( - self, path_parameter: Optional[str] = "", payload: Optional[dict] = None + self, task_id: Optional[str] = "", payload: Optional[dict] = None ): - self._validate_input(path_parameter, payload) + """ + Retrieves the profile of a task from the API. + + Parameters: + task_id (str, optional): The ID of the task to retrieve the profile for. Defaults to an empty string. + payload (dict, optional): Additional payload for the request. Defaults to None. + + Returns: + dict: The profile of the task. - url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{path_parameter if path_parameter else ''}" + Raises: + ValueError: If the input validation fails. + + """ + self._validate_input(task_id, payload) + + url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{task_id if task_id else ''}" return self.client.transport.perform_request( method="GET", url=url, body=payload ) - def get_profile(self, profile_type="all", id=None, request_body=None): - if profile_type == "all": - return self._get_profile(request_body) - elif profile_type == "model": - return self._get_models_profile(id, request_body) - elif profile_type == "task": - pass - else: - raise ValueError( - "Invalid profile type. Profile type must be 'all', 'model' or 'task'." - ) +def get_profile(self, profile_type: str = "all", ids: Optional[List[str]] = None, request_body: Optional[dict] = None) -> dict: + """ + Get profile information based on the profile type. + + Args: + profile_type: The type of profile to retrieve. Valid values are 'all', 'model', or 'task'. Default is 'all'. + ids: A list of profile IDs to retrieve. Default is None. + request_body: The request body containing additional information. Default is None. + + Returns: + The profile information. + + Raises: + ValueError: If the profile_type is not 'all', 'model', or 'task'. + """ + if profile_type == "all": + return self._get_profile(request_body) + elif profile_type == "model": + if ids: + ids = ",".join(ids) + return self._get_models_profile(ids, request_body) + elif profile_type == "task": + if ids: + ids = ",".join(ids) + return self._get_tasks_profile(ids, request_body) + else: + raise ValueError( + "Invalid profile type. Profile type must be 'all', 'model' or 'task'." + )