Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: kalyanr <[email protected]>
  • Loading branch information
rawwar committed Dec 16, 2023
1 parent af4de7b commit b0191a8
Showing 1 changed file with 67 additions and 17 deletions.
84 changes: 67 additions & 17 deletions opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,39 +611,89 @@ def delete_task(self, task_id: str) -> object:
)

def _get_profile(self, payload: Optional[dict] = None):
"""
Get the profile using the given payload.
:param payload: The payload to be used for getting the profile. Defaults to None.
:type payload: Optional[dict]
:return: The response from the server after performing the request.
:rtype: Any
"""
validate_profile_input(None, payload)
return self.client.transport.perform_request(
method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}", body=payload
)

def _get_models_profile(
self, path_parameter: Optional[str] = "", payload: Optional[dict] = None
self, model_id: Optional[str] = "", payload: Optional[dict] = None
):
self._validate_input(path_parameter, payload)
"""
Get the profile of a model.
Args:
model_id (str, optional): The ID of the model. Defaults to "".
payload (dict, optional): Additional payload for the request. Defaults to None.
url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{path_parameter if path_parameter else ''}"
Returns:
dict: The response from the API.
"""
self._validate_input(model_id, payload)

url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/models/{model_id if model_id else ''}"
return self.client.transport.perform_request(
method="GET", url=url, body=payload
)

def _get_tasks_profile(
self, path_parameter: Optional[str] = "", payload: Optional[dict] = None
self, task_id: Optional[str] = "", payload: Optional[dict] = None
):
self._validate_input(path_parameter, payload)
"""
Retrieves the profile of a task from the API.
Parameters:
task_id (str, optional): The ID of the task to retrieve the profile for. Defaults to an empty string.
payload (dict, optional): Additional payload for the request. Defaults to None.
Returns:
dict: The profile of the task.
url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{path_parameter if path_parameter else ''}"
Raises:
ValueError: If the input validation fails.
"""
self._validate_input(task_id, payload)

url = f"{ML_BASE_URI}/{self.API_ENDPOINT}/tasks/{task_id if task_id else ''}"
return self.client.transport.perform_request(
method="GET", url=url, body=payload
)

def get_profile(self, profile_type="all", id=None, request_body=None):
if profile_type == "all":
return self._get_profile(request_body)
elif profile_type == "model":
return self._get_models_profile(id, request_body)
elif profile_type == "task":
pass
else:
raise ValueError(
"Invalid profile type. Profile type must be 'all', 'model' or 'task'."
)
def get_profile(self, profile_type: str = "all", ids: Optional[List[str]] = None, request_body: Optional[dict] = None) -> dict:
"""
Get profile information based on the profile type.
Args:
profile_type: The type of profile to retrieve. Valid values are 'all', 'model', or 'task'. Default is 'all'.
ids: A list of profile IDs to retrieve. Default is None.
request_body: The request body containing additional information. Default is None.
Returns:
The profile information.
Raises:
ValueError: If the profile_type is not 'all', 'model', or 'task'.
"""
if profile_type == "all":
return self._get_profile(request_body)
elif profile_type == "model":
if ids:
ids = ",".join(ids)
return self._get_models_profile(ids, request_body)
elif profile_type == "task":
if ids:
ids = ",".join(ids)
return self._get_tasks_profile(ids, request_body)
else:
raise ValueError(
"Invalid profile type. Profile type must be 'all', 'model' or 'task'."
)

0 comments on commit b0191a8

Please sign in to comment.