Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#285: Add support for Model profile #358

Merged
merged 20 commits into from
Jan 3, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Add support for train api functionality by @rawwar in ([#310](https://github.com/opensearch-project/opensearch-py-ml/pull/310))
- Add support for Model Access Control - Register, Update, Search and Delete by @rawwar in ([#332](https://github.com/opensearch-project/opensearch-py-ml/pull/332))
- Add support for model connectors by @rawwar in ([#345](https://github.com/opensearch-project/opensearch-py-ml/pull/345))
- Add support for model profiles by @rawwar in ([#358](https://github.com/opensearch-project/opensearch-py-ml/pull/358))

### Changed
- Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211))
Expand Down
106 changes: 106 additions & 0 deletions opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from opensearch_py_ml.ml_commons.model_connector import Connector
from opensearch_py_ml.ml_commons.model_execute import ModelExecute
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader
from opensearch_py_ml.ml_commons.validators.profile import validate_profile_input


class MLCommonClient:
Expand Down Expand Up @@ -606,3 +607,108 @@ def delete_task(self, task_id: str) -> object:
method="DELETE",
url=API_URL,
)

def _get_profile(self, payload: Optional[dict] = None):
"""
Get the profile using the given payload.

:param payload: The payload to be used for getting the profile. Defaults to None.
:type payload: Optional[dict]
:return: The response from the server after performing the request.
:rtype: Any
"""
validate_profile_input(None, payload)
return self._client.transport.perform_request(
method="GET", url=f"{ML_BASE_URI}/profile", body=payload
)

def _get_models_profile(
self, model_id: Optional[str] = "", payload: Optional[dict] = None
):
"""
Get the profile of a model.

Args:
model_id (str, optional): The ID of the model. Defaults to "".
payload (dict, optional): Additional payload for the request. Defaults to None.

Returns:
dict: The response from the API.
"""
validate_profile_input(model_id, payload)

url = f"{ML_BASE_URI}/profile/models/{model_id if model_id else ''}"
return self._client.transport.perform_request(
method="GET", url=url, body=payload
)

def _get_tasks_profile(
self, task_id: Optional[str] = "", payload: Optional[dict] = None
):
"""
Retrieves the profile of a task from the API.

Parameters:
task_id (str, optional): The ID of the task to retrieve the profile for. Defaults to an empty string.
payload (dict, optional): Additional payload for the request. Defaults to None.

Returns:
dict: The profile of the task.

Raises:
ValueError: If the input validation fails.

"""
validate_profile_input(task_id, payload)

url = f"{ML_BASE_URI}/profile/tasks/{task_id if task_id else ''}"
return self._client.transport.perform_request(
method="GET", url=url, body=payload
)

def get_profile(
self,
profile_type: str = "all",
ids: Optional[Union[str, List[str]]] = None,
request_body: Optional[dict] = None,
) -> dict:
"""
Get profile information based on the profile type.

Args:
profile_type: The type of profile to retrieve. Valid values are 'all', 'model', or 'task'. Default is 'all'.
rawwar marked this conversation as resolved.
Show resolved Hide resolved
ids: Either a single profile ID as a string, or a list of profile IDs to retrieve. Default is None.
request_body: The request body containing additional information. Default is None.

Returns:
The profile information.

Raises:
ValueError: If the profile_type is not 'all', 'model', or 'task'.

Example:
get_profile()

get_profile(profile_type='model', ids='model1')

get_profile(profile_type='model', ids=['model1', 'model2'])

get_profile(profile_type='task', ids='task1', request_body={"node_ids": ["KzONM8c8T4Od-NoUANQNGg"],"return_all_tasks": true,"return_all_models": true})

get_profile(profile_type='task', ids=['task1', 'task2'], request_body={'additional': 'info'})
"""

if profile_type == "all":
return self._get_profile(request_body)
elif profile_type == "model":
if ids and isinstance(ids, list):
ids = ",".join(ids)
return self._get_models_profile(ids, request_body)
elif profile_type == "task":
if ids and isinstance(ids, list):
ids = ",".join(ids)
return self._get_tasks_profile(ids, request_body)
else:
raise ValueError(
"Invalid profile type. Profile type must be 'all', 'model' or 'task'."
)
16 changes: 16 additions & 0 deletions opensearch_py_ml/ml_commons/validators/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

"""Module for validating Profile API parameters """


def validate_profile_input(path_parameter, payload):
if path_parameter is not None and not isinstance(path_parameter, str):
raise ValueError("path_parameter needs to be a string or None")

if payload is not None and not isinstance(payload, dict):
raise ValueError("payload needs to be a dictionary or None")
62 changes: 62 additions & 0 deletions tests/ml_commons/test_ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,65 @@ def test_search():
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching model"


# Model Profile Tests. These tests will need some model train/predict run data. Hence, need
# to be run at the end after the training/prediction tests are done.


def test_get_profile():
res = ml_client.get_profile()
assert isinstance(res, dict)
assert "nodes" in res
test_model_id = None
test_task_id = None
for node_id, val in res["nodes"].items():
if test_model_id is None and "models" in val:
for model_id, model_val in val["models"].items():
test_model_id = {"node_id": node_id, "model_id": model_id}
break
if test_task_id is None and "tasks" in val:
for task_id, task_val in val["tasks"].items():
test_task_id = {"node_id": node_id, "task_id": task_id}
break

res = ml_client.get_profile(profile_type="model")
assert isinstance(res, dict)
assert "nodes" in res
for node_id, node_val in res["nodes"].items():
assert "models" in node_val

res = ml_client.get_profile(profile_type="model", ids=[test_model_id["model_id"]])
assert isinstance(res, dict)
assert "nodes" in res
assert test_model_id["model_id"] in res["nodes"][test_model_id["node_id"]]["models"]

res = ml_client.get_profile(profile_type="model", ids=["randomid1", "random_id2"])
assert isinstance(res, dict)
assert len(res) == 0

res = ml_client.get_profile(profile_type="task")
assert isinstance(res, dict)
if len(res) > 0:
assert "nodes" in res
for node_id, node_val in res["nodes"].items():
assert "tasks" in node_val

res = ml_client.get_profile(profile_type="task", ids=["random1", "random2"])
assert isinstance(res, dict)
assert len(res) == 0

with pytest.raises(ValueError):
ml_client.get_profile(profile_type="test")

with pytest.raises(ValueError):
ml_client.get_profile(profile_type="model", ids=1)

with pytest.raises(ValueError):
ml_client.get_profile(profile_type="model", request_body=10)

with pytest.raises(ValueError):
ml_client.get_profile(profile_type="task", ids=1)

with pytest.raises(ValueError):
ml_client.get_profile(profile_type="task", request_body=10)
Loading