Skip to content

Commit

Permalink
opensearch-project#285: Add support for Model profile (opensearch-pro…
Browse files Browse the repository at this point in the history
…ject#358)

* init

Signed-off-by: kalyanr <[email protected]>

* update changelog

Signed-off-by: kalyanr <[email protected]>

* update

Signed-off-by: kalyanr <[email protected]>

* fix

Signed-off-by: kalyanr <[email protected]>

* fix

Signed-off-by: kalyanr <[email protected]>

* lint fix

Signed-off-by: kalyanr <[email protected]>

* reuse validate input

Signed-off-by: kalyanr <[email protected]>

* update comment

Signed-off-by: kalyanr <[email protected]>

* change

Signed-off-by: kalyanr <[email protected]>

* fix

Signed-off-by: kalyanr <[email protected]>

* update changelog

Signed-off-by: kalyanr <[email protected]>

* fix

Signed-off-by: kalyanr <[email protected]>

* remove separate model profile module

Signed-off-by: kalyanr <[email protected]>

* fix tests

Signed-off-by: kalyanr <[email protected]>

* fix lint

Signed-off-by: kalyanr <[email protected]>

* fix lint

Signed-off-by: kalyanr <[email protected]>

* fix

Signed-off-by: kalyanr <[email protected]>

* Update ml_commons_client.py

Signed-off-by: Kalyan <[email protected]>

---------

Signed-off-by: kalyanr <[email protected]>
Signed-off-by: Kalyan <[email protected]>
  • Loading branch information
rawwar authored Jan 3, 2024
1 parent 31ee5dc commit 9b2f686
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Add support for train api functionality by @rawwar in ([#310](https://github.com/opensearch-project/opensearch-py-ml/pull/310))
- Add support for Model Access Control - Register, Update, Search and Delete by @rawwar in ([#332](https://github.com/opensearch-project/opensearch-py-ml/pull/332))
- Add support for model connectors by @rawwar in ([#345](https://github.com/opensearch-project/opensearch-py-ml/pull/345))
- Add support for model profiles by @rawwar in ([#358](https://github.com/opensearch-project/opensearch-py-ml/pull/358))

### Changed
- Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211))
Expand Down
109 changes: 109 additions & 0 deletions opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from opensearch_py_ml.ml_commons.model_connector import Connector
from opensearch_py_ml.ml_commons.model_execute import ModelExecute
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader
from opensearch_py_ml.ml_commons.validators.profile import validate_profile_input


class MLCommonClient:
Expand Down Expand Up @@ -606,3 +607,111 @@ def delete_task(self, task_id: str) -> object:
method="DELETE",
url=API_URL,
)

def _get_profile(self, payload: Optional[dict] = None):
"""
Get the profile using the given payload.
:param payload: The payload to be used for getting the profile. Defaults to None.
:type payload: Optional[dict]
:return: The response from the server after performing the request.
:rtype: Any
"""
validate_profile_input(None, payload)
return self._client.transport.perform_request(
method="GET", url=f"{ML_BASE_URI}/profile", body=payload
)

def _get_models_profile(
self, model_id: Optional[str] = "", payload: Optional[dict] = None
):
"""
Get the profile of a model.
Args:
model_id (str, optional): The ID of the model. Defaults to "".
payload (dict, optional): Additional payload for the request. Defaults to None.
Returns:
dict: The response from the API.
"""
validate_profile_input(model_id, payload)

url = f"{ML_BASE_URI}/profile/models/{model_id if model_id else ''}"
return self._client.transport.perform_request(
method="GET", url=url, body=payload
)

def _get_tasks_profile(
self, task_id: Optional[str] = "", payload: Optional[dict] = None
):
"""
Retrieves the profile of a task from the API.
Parameters:
task_id (str, optional): The ID of the task to retrieve the profile for. Defaults to an empty string.
payload (dict, optional): Additional payload for the request. Defaults to None.
Returns:
dict: The profile of the task.
Raises:
ValueError: If the input validation fails.
"""
validate_profile_input(task_id, payload)

url = f"{ML_BASE_URI}/profile/tasks/{task_id if task_id else ''}"
return self._client.transport.perform_request(
method="GET", url=url, body=payload
)

def get_profile(
self,
profile_type: str = "all",
ids: Optional[Union[str, List[str]]] = None,
request_body: Optional[dict] = None,
) -> dict:
"""
Get profile information based on the profile type.
Args:
profile_type: The type of profile to retrieve. Valid values are 'all', 'model', or 'task'. Default is 'all'.
'all': Retrieves all profiles available.
'model': Retrieves the profile(s) of the specified model(s). The model(s) to retrieve are specified by the 'ids' parameter.
'task': Retrieves the profile(s) of the specified task(s). The task(s) to retrieve are specified by the 'ids' parameter.
ids: Either a single profile ID as a string, or a list of profile IDs to retrieve. Default is None.
request_body: The request body containing additional information. Default is None.
Returns:
The profile information.
Raises:
ValueError: If the profile_type is not 'all', 'model', or 'task'.
Example:
get_profile()
get_profile(profile_type='model', ids='model1')
get_profile(profile_type='model', ids=['model1', 'model2'])
get_profile(profile_type='task', ids='task1', request_body={"node_ids": ["KzONM8c8T4Od-NoUANQNGg"],"return_all_tasks": true,"return_all_models": true})
get_profile(profile_type='task', ids=['task1', 'task2'], request_body={'additional': 'info'})
"""

if profile_type == "all":
return self._get_profile(request_body)
elif profile_type == "model":
if ids and isinstance(ids, list):
ids = ",".join(ids)
return self._get_models_profile(ids, request_body)
elif profile_type == "task":
if ids and isinstance(ids, list):
ids = ",".join(ids)
return self._get_tasks_profile(ids, request_body)
else:
raise ValueError(
"Invalid profile type. Profile type must be 'all', 'model' or 'task'."
)
16 changes: 16 additions & 0 deletions opensearch_py_ml/ml_commons/validators/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

"""Module for validating Profile API parameters """


def validate_profile_input(path_parameter, payload):
if path_parameter is not None and not isinstance(path_parameter, str):
raise ValueError("path_parameter needs to be a string or None")

if payload is not None and not isinstance(payload, dict):
raise ValueError("payload needs to be a dictionary or None")
62 changes: 62 additions & 0 deletions tests/ml_commons/test_ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,65 @@ def test_search():
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching model"


# Model Profile Tests. These tests will need some model train/predict run data. Hence, need
# to be run at the end after the training/prediction tests are done.


def test_get_profile():
res = ml_client.get_profile()
assert isinstance(res, dict)
assert "nodes" in res
test_model_id = None
test_task_id = None
for node_id, val in res["nodes"].items():
if test_model_id is None and "models" in val:
for model_id, model_val in val["models"].items():
test_model_id = {"node_id": node_id, "model_id": model_id}
break
if test_task_id is None and "tasks" in val:
for task_id, task_val in val["tasks"].items():
test_task_id = {"node_id": node_id, "task_id": task_id}
break

res = ml_client.get_profile(profile_type="model")
assert isinstance(res, dict)
assert "nodes" in res
for node_id, node_val in res["nodes"].items():
assert "models" in node_val

res = ml_client.get_profile(profile_type="model", ids=[test_model_id["model_id"]])
assert isinstance(res, dict)
assert "nodes" in res
assert test_model_id["model_id"] in res["nodes"][test_model_id["node_id"]]["models"]

res = ml_client.get_profile(profile_type="model", ids=["randomid1", "random_id2"])
assert isinstance(res, dict)
assert len(res) == 0

res = ml_client.get_profile(profile_type="task")
assert isinstance(res, dict)
if len(res) > 0:
assert "nodes" in res
for node_id, node_val in res["nodes"].items():
assert "tasks" in node_val

res = ml_client.get_profile(profile_type="task", ids=["random1", "random2"])
assert isinstance(res, dict)
assert len(res) == 0

with pytest.raises(ValueError):
ml_client.get_profile(profile_type="test")

with pytest.raises(ValueError):
ml_client.get_profile(profile_type="model", ids=1)

with pytest.raises(ValueError):
ml_client.get_profile(profile_type="model", request_body=10)

with pytest.raises(ValueError):
ml_client.get_profile(profile_type="task", ids=1)

with pytest.raises(ValueError):
ml_client.get_profile(profile_type="task", request_body=10)

0 comments on commit 9b2f686

Please sign in to comment.