Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
Signed-off-by: kalyanr <[email protected]>
  • Loading branch information
rawwar committed Dec 12, 2023
1 parent dbe1e1b commit b511b41
Showing 1 changed file with 22 additions and 57 deletions.
79 changes: 22 additions & 57 deletions opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,64 +607,29 @@ def delete_task(self, task_id: str) -> object:
url=API_URL,
)

def get_profiles(
self,
node_ids: List[str] = [],
model_ids: List[str] = [],
task_ids: List[str] = [],
return_all_tasks: bool = True,
return_all_models: bool = True,
) -> object:
"""
This method retrieves the profile of one or more machine learning nodes, models, or tasks from OpenSearch cluster (using ml commons api)
:param node_ids: A list of node ids to retrieve profiles for (default: [])
:type node_ids: List[str]
:param model_ids: A list of model ids to retrieve profiles for (default: [])
:type model_ids: List[str]
:param task_ids: A list of task ids to retrieve profiles for (default: [])
:type task_ids: List[str]
:param return_all_tasks: a flag to indicate whether all tasks associated with the profiles should be returned (default: True)
:type return_all_tasks: bool
:param return_all_models: a flag to indicate whether all models associated with the profiles should be returned (default: True)
:type return_all_models: bool
:return: returns a json object containing the profile information of the specified nodes, models, or tasks from the OpenSearch cluster
:rtype: object
"""

if model_ids and len(model_ids) == 1:
API_URL = f"{ML_BASE_URI}/profile/models/{model_ids[0]}"
return self._client.transport.perform_request(
method="GET",
url=API_URL,
)

if task_ids and len(model_ids) == 1:
API_URL = f"{ML_BASE_URI}/profile/models/{model_ids[0]}"
return self._client.transport.perform_request(
method="GET",
url=API_URL,
)

API_URL = f"{ML_BASE_URI}/profile"

API_BODY = {
"return_all_tasks": return_all_tasks,
"return_all_models": return_all_models,
}

if len(node_ids) > 0:
API_BODY["node_ids"] = node_ids
if len(model_ids) > 0:
API_BODY["model_ids"] = model_ids
if len(task_ids) > 0:
API_BODY["task_ids"] = task_ids

try:
API_BODY = json.dumps(API_BODY)
except json.JSONDecodeError:
raise Exception("Invalid request body")
def get_profiles(self, model_id=None, tasks=None, request_body=None) -> object:
"""
"""
# if model id is given and tasks is none, then url = f"{ML_BASE_URI}/profile/models"
# if model id is None and tasks is given, then url = f"{ML_BASE_URI}/profile/tasks"
# if both model_id and tasks are given raise error
if model_id is None and tasks is None:
url = f"{ML_BASE_URI}/profile"
elif model_id is not None and tasks is None:
url = f"{ML_BASE_URI}/profile/models/{model_id}"
elif model_id is None and tasks is not None:
url = f"{ML_BASE_URI}/profile/tasks"
else:
raise ValueError("model_id and tasks cannot be given together")

return self._client.transport.perform_request(
method="POST",
url=url,
body=request_body
)



return self._client.transport.perform_request(
method="GET",
url=API_URL,
Expand Down

0 comments on commit b511b41

Please sign in to comment.